From 8cf91e0bbaf630cf97c614f3fc2c42eb15d9f28b Mon Sep 17 00:00:00 2001 From: SimpingOjou Date: Thu, 4 Jun 2026 15:08:19 +0200 Subject: [PATCH] Optimize expand_index_in_trial in Dataset class for improved performance --- cebra/data/base.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/cebra/data/base.py b/cebra/data/base.py index f5491e51..c3ca46ae 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -154,8 +154,6 @@ def expand_index_in_trial(self, index, trial_ids, trial_borders): trial_ids is in size of a length of self.index and indicate the trial id of the index belong to. trial_borders is in size of a length of self.idnex and indicate the border of each trial. - Todo: - - rewrite """ # TODO(stes) potential room for speed improvements by pre-allocating these tensors/ @@ -163,16 +161,15 @@ def expand_index_in_trial(self, index, trial_ids, trial_borders): offset = torch.arange(-self.offset.left, self.offset.right, device=index.device) - index = torch.tensor( - [ - torch.clamp( - i, - trial_borders[trial_ids[i]] + self.offset.left, - trial_borders[trial_ids[i] + 1] - self.offset.right, - ) for i in index - ], - device=self.device, - ) + + # Vectorized lookup and boundary calculation + batch_trial_ids = trial_ids[index] + min_borders = trial_borders[batch_trial_ids] + self.offset.left + max_borders = trial_borders[batch_trial_ids + 1] - self.offset.right + + # Fast C-level clamp + index = torch.clamp(index, min=min_borders, max=max_borders) + return index[:, None] + offset[None, :] @abc.abstractmethod