diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 0ea9426674..e65afabaca 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -233,8 +233,14 @@ def get_duration(self, segment_index=None) -> float: The duration in seconds """ segment_index = self._check_segment_index(segment_index) - segment_num_samples = self.get_num_samples(segment_index=segment_index) - segment_duration = segment_num_samples / self.get_sampling_frequency() + + if self.has_time_vector(segment_index): + times = self.get_times(segment_index) + segment_duration = times[-1] - times[0] + (1 / self.get_sampling_frequency()) + else: + segment_num_samples = self.get_num_samples(segment_index=segment_index) + segment_duration = segment_num_samples / self.get_sampling_frequency() + return segment_duration def get_total_duration(self) -> float: @@ -246,7 +252,7 @@ def get_total_duration(self) -> float: float The duration in seconds """ - duration = self.get_total_samples() / self.get_sampling_frequency() + duration = sum([self.get_duration(idx) for idx in range(self.get_num_segments())]) return duration def get_memory_size(self, segment_index=None) -> int: diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index af3b502f0f..89e9e2cf0f 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -821,7 +821,10 @@ def get_total_samples(self) -> int: return s def get_total_duration(self) -> float: - duration = self.get_total_samples() / self.sampling_frequency + if self.has_recording() or self.has_temporary_recording(): + duration = self.recording.get_total_duration() + else: + duration = self.get_total_samples() / self.sampling_frequency return duration def get_num_channels(self) -> int: diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 049d5ab6e5..1b570091be 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -243,6 +243,73 @@ def test_slice_recording(self, time_type, bounds): assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8) + def test_get_durations(self, time_vector_recording, t_start_recording): + """ + Test the `get_durations` functions that return the total duration + for a segment. Test that it is correct after adding both `t_start` + or `time_vector` to the recording. + """ + raw_recording, tvector_recording, all_time_vectors = time_vector_recording + _, tstart_recording, all_t_starts = t_start_recording + + ts = 1 / raw_recording.get_sampling_frequency() + + all_raw_durations = [] + all_vector_durations = [] + for segment_index in range(raw_recording.get_num_segments()): + + # Test before `t_start` and `t_start` (`t_start` is just an offset, + # should not affect duration). + raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts + + assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8) + assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8) + + # Test the duration from the time vector. + vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts + + assert tvector_recording.get_duration(segment_index) == vector_duration + + all_raw_durations.append(raw_duration) + all_vector_durations.append(vector_duration) + + # Finally test the total recording duration + assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8) + assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8) + + def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recording): + """ + Test that when a recording is set on `sorting_analyzer`, the + total duration is propagated from the recording to the + `sorting_analyzer.get_total_duration()` function. + """ + _, times_recording, _ = time_vector_recording + + sorting = si.generate_sorting( + durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())] + ) + sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording) + + assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration()) + + def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording): + """ + Test when the `sorting_analzyer` does not have a recording set, + the total duration is calculated on the fly from num samples and + sampling frequency (thus matching `raw_recording` with no times set + that uses the same method to calculate the total duration). + """ + raw_recording, _, _ = time_vector_recording + + sorting = si.generate_sorting( + durations=[raw_recording.get_duration(s) for s in range(raw_recording.get_num_segments())] + ) + sorting_analyzer = si.create_sorting_analyzer(sorting, recording=raw_recording) + + sorting_analyzer._recording = None + + assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) + # Helpers #### def _check_times_match(self, recording, all_times): """ diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 8f2ef0ec21..2806754c9d 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -400,6 +400,7 @@ def generate_hybrid_recording( num_segments = recording.get_num_segments() dtype = recording.dtype durations = np.array([recording.get_duration(segment_index) for segment_index in range(num_segments)]) + num_samples = np.array([recording.get_num_samples(segment_index) for segment_index in range(num_segments)]) channel_locations = probe.contact_positions assert ( @@ -548,7 +549,7 @@ def generate_hybrid_recording( displacement_vectors=displacement_vectors, displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=displacement_unit_factor, - num_samples=(np.array(durations) * sampling_frequency).astype("int64"), + num_samples=num_samples.astype("int64"), amplitude_factor=amplitude_factor, )