Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add time vector case to get_durations. #3118

12 changes: 9 additions & 3 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,14 @@ def get_duration(self, segment_index=None) -> float:
The duration in seconds
"""
segment_index = self._check_segment_index(segment_index)
segment_num_samples = self.get_num_samples(segment_index=segment_index)
segment_duration = segment_num_samples / self.get_sampling_frequency()

if self.has_time_vector(segment_index):
times = self.get_times(segment_index)
segment_duration = times[-1] - times[0] + (1 / self.get_sampling_frequency())
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
else:
segment_num_samples = self.get_num_samples(segment_index=segment_index)
segment_duration = segment_num_samples / self.get_sampling_frequency()

return segment_duration

def get_total_duration(self) -> float:
Expand All @@ -246,7 +252,7 @@ def get_total_duration(self) -> float:
float
The duration in seconds
"""
duration = self.get_total_samples() / self.get_sampling_frequency()
duration = sum([self.get_duration(idx) for idx in range(self.get_num_segments())])
return duration

def get_memory_size(self, segment_index=None) -> int:
Expand Down
5 changes: 4 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,10 @@ def get_total_samples(self) -> int:
return s

def get_total_duration(self) -> float:
duration = self.get_total_samples() / self.sampling_frequency
if self.has_recording() or self.has_temporary_recording():
duration = self.recording.get_total_duration()
else:
duration = self.get_total_samples() / self.sampling_frequency
return duration

def get_num_channels(self) -> int:
Expand Down
67 changes: 67 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,73 @@ def test_slice_recording(self, time_type, bounds):

assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)

def test_get_durations(self, time_vector_recording, t_start_recording):
"""
Test the `get_durations` functions that return the total duration
for a segment. Test that it is correct after adding both `t_start`
or `time_vector` to the recording.
"""
raw_recording, tvector_recording, all_time_vectors = time_vector_recording
_, tstart_recording, all_t_starts = t_start_recording

ts = 1 / raw_recording.get_sampling_frequency()

all_raw_durations = []
all_vector_durations = []
for segment_index in range(raw_recording.get_num_segments()):

# Test before `t_start` and `t_start` (`t_start` is just an offset,
# should not affect duration).
raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts

assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)

# Test the duration from the time vector.
vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts

assert tvector_recording.get_duration(segment_index) == vector_duration

all_raw_durations.append(raw_duration)
all_vector_durations.append(vector_duration)

# Finally test the total recording duration
assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8)
assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8)

def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recording):
"""
Test that when a recording is set on `sorting_analyzer`, the
total duration is propagated from the recording to the
`sorting_analyzer.get_total_duration()` function.
"""
_, times_recording, _ = time_vector_recording

sorting = si.generate_sorting(
durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())]
)
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording)

assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration())

def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording):
"""
Test when the `sorting_analzyer` does not have a recording set,
the total duration is calculated on the fly from num samples and
sampling frequency (thus matching `raw_recording` with no times set
that uses the same method to calculate the total duration).
"""
raw_recording, _, _ = time_vector_recording

sorting = si.generate_sorting(
durations=[raw_recording.get_duration(s) for s in range(raw_recording.get_num_segments())]
)
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=raw_recording)

sorting_analyzer._recording = None

assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration())

# Helpers ####
def _check_times_match(self, recording, all_times):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/generation/hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def generate_hybrid_recording(
num_segments = recording.get_num_segments()
dtype = recording.dtype
durations = np.array([recording.get_duration(segment_index) for segment_index in range(num_segments)])
num_samples = np.array([recording.get_num_samples(segment_index) for segment_index in range(num_segments)])
channel_locations = probe.contact_positions

assert (
Expand Down Expand Up @@ -548,7 +549,7 @@ def generate_hybrid_recording(
displacement_vectors=displacement_vectors,
displacement_sampling_frequency=displacement_sampling_frequency,
displacement_unit_factor=displacement_unit_factor,
num_samples=(np.array(durations) * sampling_frequency).astype("int64"),
num_samples=num_samples.astype("int64"),
amplitude_factor=amplitude_factor,
)

Expand Down
Loading