diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index f2ef38e691..fdedf37266 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -34,7 +34,9 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None): if start_frame is None: start_frame = 0 else: - assert 0 <= start_frame < parent_size + assert ( + 0 <= start_frame < parent_size + ), f"`start_frame` must be fewer than number of samples in parent: {parent_size}" if end_frame is None: end_frame = parent_size diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 86f2350a85..da649fd76a 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -125,8 +125,10 @@ def __init__( if not rec0.has_time_vector(segment_index=segment_index): times = None - t_start = 0 - t_end = rec0.get_duration(segment_index=segment_index) + t_start = rec0.sample_index_to_time(0, segment_index=segment_index) + t_end = rec0.sample_index_to_time( + rec0.get_num_samples(segment_index=segment_index), segment_index=segment_index + ) else: times = rec0.get_times(segment_index=segment_index) t_start = times[0] @@ -149,6 +151,11 @@ def __init__( ) time_range[1] = t_end + if time_range[0] < t_start or time_range[1] < t_start: + raise ValueError(f"All `time_range` values must be greater than {t_start}") + if time_range[1] <= time_range[0]: + raise ValueError("`time_range[1]` must be greater than `time_range[0]`") + assert mode in ("auto", "line", "map"), 'Mode must be one of "auto","line", "map"' if mode == "auto": if len(channel_ids) <= 64: @@ -158,7 +165,7 @@ def __init__( mode = mode cmap = cmap - times_in_range, list_traces, frame_range, channel_ids = _get_trace_list( + times_in_range, list_traces, sample_range, channel_ids = _get_trace_list( recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled, times=times ) @@ -250,7 +257,7 @@ def __init__( channel_ids=channel_ids, channel_locations=channel_locations, time_range=time_range, - frame_range=frame_range, + sample_range=sample_range, times_in_range=times_in_range, layer_keys=layer_keys, list_traces=list_traces, @@ -536,7 +543,7 @@ def _retrieve_traces(self, change=None): time_range = np.array([times[start_frame], times[end_frame]]) self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} - times_in_range, list_traces, frame_range, channel_ids = _get_trace_list( + times_in_range, list_traces, sample_range, channel_ids = _get_trace_list( self._selected_recordings, channel_ids, time_range, @@ -549,7 +556,7 @@ def _retrieve_traces(self, change=None): self._list_traces = list_traces self._times_in_range = times_in_range self._time_range = time_range - self._frame_range = (start_frame, end_frame) + self._sample_range = (start_frame, end_frame) self._segment_index = segment_index self._update_plot() @@ -562,7 +569,7 @@ def _update_plot(self, change=None): layer_keys = self._get_layers() data_plot["mode"] = mode - data_plot["frame_range"] = self._frame_range + data_plot["sample_range"] = self._sample_range data_plot["time_range"] = self._time_range if self.colorbar.value: data_plot["with_colorbar"] = True @@ -673,26 +680,29 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_s assert all( rec.has_scaleable_traces() for rec in recordings.values() ), "Some recording layers do not have scaled traces. Use `return_scaled=False`" + sample_range = np.array( + [ + rec0.time_to_sample_index(time_range[0], segment_index=segment_index), + rec0.time_to_sample_index(time_range[1], segment_index=segment_index), + ] + ) if times is not None: - frame_range = np.searchsorted(times, time_range) - times = times[frame_range[0] : frame_range[1]] + times = times[sample_range[0] : sample_range[1]] else: - frame_range = (time_range * fs).astype("int64", copy=False) - a_max = rec0.get_num_frames(segment_index=segment_index) - frame_range = np.clip(frame_range, 0, a_max) - time_range = frame_range / fs - times = np.arange(frame_range[0], frame_range[1]) / fs + num_samples = rec0.get_num_samples(segment_index=segment_index) + sample_range = np.clip(sample_range, 0, num_samples) + times = np.arange(sample_range[0], sample_range[1]) / fs list_traces = [] - for rec_name, rec in recordings.items(): + for _, rec in recordings.items(): traces = rec.get_traces( segment_index=segment_index, channel_ids=channel_ids, - start_frame=frame_range[0], - end_frame=frame_range[1], + start_frame=sample_range[0], + end_frame=sample_range[1], return_scaled=return_scaled, ) list_traces.append(traces) - return times, list_traces, frame_range, channel_ids + return times, list_traces, sample_range, channel_ids