diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f23b524271..a2135777db 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -395,7 +395,9 @@ def get_time_info(self, segment_index=None) -> dict: return time_kwargs - def get_times(self, segment_index=None) -> np.ndarray: + def get_times( + self, segment_index: int | None = None, start_frame: int | None = None, end_frame: int | None = None + ) -> np.ndarray: """Get time vector for a recording segment. If the segment has a time_vector, then it is returned. Otherwise @@ -407,6 +409,10 @@ def get_times(self, segment_index=None) -> np.ndarray: ---------- segment_index : int or None, default: None The segment index (required for multi-segment) + start_frame : int or None, default: None + The start frame index. If None, it starts from the beginning of the segment. + end_frame : int or None, default: None + The end frame index. If None, it goes until the end of the segment. Returns ------- @@ -415,7 +421,7 @@ def get_times(self, segment_index=None) -> np.ndarray: """ segment_index = self._check_segment_index(segment_index) rs = self.segments[segment_index] - times = rs.get_times() + times = rs.get_times(start_frame=start_frame, end_frame=end_frame) return times def get_start_time(self, segment_index=None) -> float: @@ -913,12 +919,22 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): BaseSegment.__init__(self) - def get_times(self) -> np.ndarray: + def get_times(self, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray: if self.time_vector is not None: - self.time_vector = np.asarray(self.time_vector) - return self.time_vector + # Cache full times as numpy if start_frame and end_frame are None. If the user passes start_frame and + # end_frame, we slice the time vector and return the sliced version as numpy array. + # This is useful for very long recordings, where the full time vector might be too large to fit in memory. + if start_frame is None and end_frame is None: + self.time_vector = np.asarray(self.time_vector) + return self.time_vector + else: + start_frame = int(start_frame) if start_frame is not None else 0 + end_frame = int(end_frame) if end_frame is not None else self.get_num_samples() + return np.asarray(self.time_vector[start_frame:end_frame]) else: - time_vector = np.arange(self.get_num_samples(), dtype="float64") + start_frame = int(start_frame) if start_frame is not None else 0 + end_frame = int(end_frame) if end_frame is not None else self.get_num_samples() + time_vector = np.arange(start_frame, end_frame, dtype="float64") time_vector /= self.sampling_frequency if self.t_start is not None: time_vector += self.t_start diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index cb68f3d455..19aced7ffc 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -347,7 +347,12 @@ def has_time_vector(self, segment_index: int | None = None) -> bool: else: return False - def get_times(self, segment_index=None): + def get_times( + self, + segment_index: int | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + ): """ Get time vector for a registered recording segment. @@ -359,7 +364,7 @@ def get_times(self, segment_index=None): """ segment_index = self._check_segment_index(segment_index) if self.has_recording(): - return self._recording.get_times(segment_index=segment_index) + return self._recording.get_times(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) else: return None diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 1ebeb677c6..bb6db4cb66 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -83,6 +83,8 @@ def test_BaseRecording(create_cache_folder): assert values.dtype.kind == "i" times0 = rec.get_times(segment_index=0) + times0_slice = rec.get_times(segment_index=0, start_frame=10, end_frame=20) + assert np.allclose(times0_slice, times0[10:20]) # dump/load dict d = rec.to_dict(include_annotations=True, include_properties=True) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 2f99399392..f0db2d5898 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -143,10 +143,6 @@ def __init__( raise ValueError('You must provide "segment_index" for multisegment recordings.') segment_index = 0 - if not rec0.has_time_vector(segment_index=segment_index): - times = None - else: - times = rec0.get_times(segment_index=segment_index) t_start = rec0.get_start_time(segment_index=segment_index) t_end = rec0.get_end_time(segment_index=segment_index) @@ -172,7 +168,7 @@ def __init__( cmap = cmap times_in_range, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, return_in_uV=return_in_uV, times=times + recordings, channel_ids, segment_index, time_range=time_range, return_in_uV=return_in_uV ) list_traces = [traces * scale for traces in list_traces] @@ -405,25 +401,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.figure.canvas.header_visible = False plt.show() - if not self.rec0.has_time_vector(segment_index=data_plot["segment_index"]): - times = None - t_starts = [ - rec0.get_start_time(segment_index=segment_index) for segment_index in range(rec0.get_num_segments()) - ] - else: - times = [ - np.array(self.rec0.get_times(segment_index=segment_index)) - for segment_index in range(self.rec0.get_num_segments()) - ] - t_starts = None - # some widgets self.time_slider = TimeSlider( durations=[rec0.get_duration(s) for s in range(rec0.get_num_segments())], sampling_frequency=rec0.sampling_frequency, - time_range=data_plot["time_range"], - times=times, - t_starts=t_starts, + frame_range=data_plot["frame_range"], + rec0=rec0, ) # handle times if data_plot["events"] is not None: @@ -559,24 +542,17 @@ def _retrieve_traces(self, change=None): start_frame, end_frame, segment_index = self.time_slider.value - if not self.rec0.has_time_vector(segment_index=segment_index): - times = None - time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency + self.rec0.get_start_time( - segment_index=segment_index - ) - else: - times = self.rec0.get_times(segment_index=segment_index) - time_range = np.array([times[start_frame], times[end_frame]]) + frame_range = np.array([start_frame, end_frame]) self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times_in_range, list_traces, frame_range, channel_ids = _get_trace_list( self._selected_recordings, channel_ids, - time_range, segment_index, return_in_uV=self.return_in_uV, - times=times, + frame_range=frame_range, ) + time_range = np.array([times_in_range[0], times_in_range[-1]]) self._channel_ids = channel_ids self._list_traces = list_traces @@ -640,12 +616,11 @@ def plot_figpack(self, data_plot, **backend_kwargs): handle_display_and_url, import_figpack_or_sortingview, ) + import importlib.util use_sortingview = backend_kwargs.get("use_sortingview", False) vv_base, vv_views = import_figpack_or_sortingview(use_sortingview) - import importlib.util - spec = importlib.util.find_spec("pyvips") if spec is None: raise ImportError("To use `plot_traces()` in sortingview you need the pyvips package.") @@ -705,25 +680,28 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs): app.exec() -def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_in_uV=False, times=None): +def _get_trace_list(recordings, channel_ids, segment_index, time_range=None, return_in_uV=False, frame_range=None): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] - fs = rec0.get_sampling_frequency() - if return_in_uV: assert all( rec.has_scaleable_traces() for rec in recordings.values() ), "Some recording layers do not have scaled traces. Use `return_in_uV=False`" - if times is not None: - frame_range = np.searchsorted(times, time_range) - times = times[frame_range[0] : frame_range[1]] - else: - frame_range = rec0.time_to_sample_index(time_range, segment_index=segment_index) + + assert time_range is not None or frame_range is not None, "You must provide either time_range or frame_range" + + if frame_range is None: + # use the sampling-frequency approximation to avoid loading the full time vector + t_start = rec0.get_start_time(segment_index=segment_index) + fs = rec0.get_sampling_frequency() + frame_range = np.round((np.asarray(time_range) - t_start) * fs).astype(np.int64) a_max = rec0.get_num_frames(segment_index=segment_index) frame_range = np.clip(frame_range, 0, a_max) - times = np.arange(frame_range[0], frame_range[1]) / fs + rec0.get_start_time(segment_index=segment_index) + + # lazily load only the needed time slice + times_in_range = rec0.get_times(segment_index=segment_index, start_frame=frame_range[0], end_frame=frame_range[1]) list_traces = [] for rec_name, rec in recordings.items(): @@ -737,4 +715,4 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_i list_traces.append(traces) - return times, list_traces, frame_range, channel_ids + return times_in_range, list_traces, frame_range, channel_ids diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 00ea954a15..25c67d36bd 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -16,29 +16,25 @@ def check_ipywidget_backend(): class TimeSlider(W.HBox): value = traitlets.Tuple(traitlets.Int(), traitlets.Int(), traitlets.Int()) - def __init__(self, durations, sampling_frequency, time_range, times=None, t_starts=None, **kwargs): + def __init__(self, durations, sampling_frequency, frame_range, rec0=None, t_starts=None, **kwargs): self.num_segments = len(durations) self.frame_limits = [int(sampling_frequency * d) for d in durations] self.sampling_frequency = sampling_frequency self.segment_index = 0 - if times is not None: - assert len(times) == len(durations), "times should be a list of arrays with one array per segment" - times_segment = times[self.segment_index] - start_frame, end_frame = np.searchsorted(times_segment, time_range) - self.times = times + start_frame, end_frame = int(frame_range[0]), int(frame_range[1]) + + if rec0 is not None: + self.rec0 = rec0 self.t_starts = None else: assert t_starts is not None - t_start_segment = t_starts[self.segment_index] - start_frame = int((time_range[0] - t_start_segment) * sampling_frequency) - end_frame = int((time_range[1] - t_start_segment) * sampling_frequency) - self.times = None + self.rec0 = None self.t_starts = t_starts self.frame_range = (start_frame, end_frame) - self.value = (int(start_frame), int(end_frame), self.segment_index) + self.value = (start_frame, end_frame, self.segment_index) layout = W.Layout(align_items="center", width="2.5cm", height="1.cm") but_left = W.Button(description="", disabled=False, button_style="", icon="arrow-left", layout=layout) @@ -63,8 +59,16 @@ def __init__(self, durations, sampling_frequency, time_range, times=None, t_star ) # DatetimePicker is only for ipywidget v8 (which is not working in vscode 2023-03) + if self.rec0 is not None: + initial_time = float( + self.rec0.get_times( + segment_index=self.segment_index, start_frame=start_frame, end_frame=start_frame + 1 + )[0] + ) + else: + initial_time = start_frame / sampling_frequency + self.t_starts[self.segment_index] self.time_label = W.Text( - value=f"{time_range[0]}", description="", disabled=False, layout=W.Layout(width="2.5cm") + value=f"{initial_time}", description="", disabled=False, layout=W.Layout(width="2.5cm") ) self.time_label.observe(self.time_label_changed, names="value", type="change") @@ -137,8 +141,10 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update if new_frame is None and new_time is None: start_frame = self.slider.value elif new_frame is None: - if self.times is not None: - start_frame = int(np.searchsorted(self.times[self.segment_index], [new_time])[0]) + if self.rec0 is not None: + # approximate via sampling frequency to avoid loading the full time vector + t_start = float(self.rec0.get_start_time(segment_index=self.segment_index)) + start_frame = int((new_time - t_start) * self.sampling_frequency) else: start_frame = int((new_time - self.t_starts[self.segment_index]) * self.sampling_frequency) else: @@ -153,8 +159,12 @@ def update_time(self, new_frame=None, new_time=None, update_slider=False, update end_frame = min(self.frame_limits[self.segment_index], end_frame) - if self.times is not None: - start_time = self.times[self.segment_index][start_frame] + if self.rec0 is not None: + start_time = float( + self.rec0.get_times( + segment_index=self.segment_index, start_frame=start_frame, end_frame=start_frame + 1 + )[0] + ) else: start_time = start_frame / self.sampling_frequency + self.t_starts[self.segment_index]