From 2ba37a8b3990af3919a3c1b294700909d144a457 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 Nov 2024 15:58:45 +0100 Subject: [PATCH 1/3] Don't let decimate mess with times and skim tests --- src/spikeinterface/preprocessing/decimate.py | 26 +++++++++---------- .../preprocessing/tests/test_decimate.py | 20 +++++++------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 334ebb02d2..c1b1cd9f80 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -63,18 +63,15 @@ def __init__( f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames." ) self._decimation_offset = decimation_offset - resample_rate = self._orig_samp_freq / self._decimation_factor + decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor - BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate) + BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) - # in case there was a time_vector, it will be dropped for sanity. - # This is not necessary but consistent with ResampleRecording for parent_segment in recording._recording_segments: - parent_segment.time_vector = None self.add_recording_segment( DecimateRecordingSegment( parent_segment, - resample_rate, + decimated_sampling_frequency, self._orig_samp_freq, decimation_factor, decimation_offset, @@ -93,22 +90,25 @@ class DecimateRecordingSegment(BaseRecordingSegment): def __init__( self, parent_recording_segment, - resample_rate, + decimated_sampling_frequency, parent_rate, decimation_factor, decimation_offset, dtype, ): - if parent_recording_segment.t_start is None: - new_t_start = None + if parent_recording_segment.time_vector is not None: + time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] + decimated_sampling_frequency = None else: - new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate + time_vector = None + if parent_recording_segment.t_start is None: + t_start = None + else: + t_start = parent_recording_segment.t_start + decimation_offset / parent_rate # Do not use BasePreprocessorSegment bcause we have to reset the sampling rate! BaseRecordingSegment.__init__( - self, - sampling_frequency=resample_rate, - t_start=new_t_start, + self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector ) self._parent_segment = parent_recording_segment self._decimation_factor = decimation_factor diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index 100972f762..adfcbd0d4a 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -8,19 +8,19 @@ import numpy as np -@pytest.mark.parametrize("N_segments", [1, 2]) -@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101]) -@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101]) +@pytest.mark.parametrize("num_segments", [1, 2]) +@pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101]) +@pytest.mark.parametrize("decimation_factor", [1, 7, 50]) @pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) @pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) -def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame): +def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame): rec = generate_recording() - segment_num_samps = [101 + i for i in range(N_segments)] + segment_num_samps = [101 + i for i in range(num_segments)] rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) - parent_traces = [rec.get_traces(i) for i in range(N_segments)] + parent_traces = [rec.get_traces(i) for i in range(num_segments)] if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor: with pytest.raises(ValueError): @@ -28,14 +28,14 @@ def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, return decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) - decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)] + decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] if start_frame is None: - start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) if end_frame is None: - end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) - for i in range(N_segments): + for i in range(num_segments): assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] assert np.all( decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame] From f90011803da2327d7ace74ff2a35b91b30c70d32 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 Nov 2024 16:35:27 +0100 Subject: [PATCH 2/3] More skimming and test decimate with times --- src/spikeinterface/preprocessing/decimate.py | 1 + .../preprocessing/tests/test_decimate.py | 57 +++++++++++++++---- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index c1b1cd9f80..2b47601fc2 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -99,6 +99,7 @@ def __init__( if parent_recording_segment.time_vector is not None: time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] decimated_sampling_frequency = None + t_start = None else: time_vector = None if parent_recording_segment.t_start is None: diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index adfcbd0d4a..dd521cbe9b 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -11,13 +11,8 @@ @pytest.mark.parametrize("num_segments", [1, 2]) @pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101]) @pytest.mark.parametrize("decimation_factor", [1, 7, 50]) -@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) -@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) -def test_decimate(num_segments, decimation_offset, decimation_factor, start_frame, end_frame): - rec = generate_recording() - - segment_num_samps = [101 + i for i in range(num_segments)] - +def test_decimate(num_segments, decimation_offset, decimation_factor): + segment_num_samps = [20000, 40000] rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) parent_traces = [rec.get_traces(i) for i in range(num_segments)] @@ -30,10 +25,19 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] - if start_frame is None: - start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) - if end_frame is None: - end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + for start_frame in [0, 1, 5, None, 1000]: + for end_frame in [0, 1, 5, None, 1000]: + if start_frame is None: + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + if end_frame is None: + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + + for i in range(num_segments): + assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] + assert np.all( + decimated_rec.get_traces(i, start_frame, end_frame) + == decimated_parent_traces[i][start_frame:end_frame] + ) for i in range(num_segments): assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] @@ -42,5 +46,36 @@ def test_decimate(num_segments, decimation_offset, decimation_factor, start_fram ) +def test_decimate_with_times(): + rec = generate_recording(durations=[5, 10]) + + # test with times + times = [rec.get_times(0) + 10, rec.get_times(1) + 20] + for i, t in enumerate(times): + rec.set_times(t, i) + + decimation_factor = 2 + decimation_offset = 1 + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + # test with t_start + rec = generate_recording(durations=[5, 10]) + t_starts = [10, 20] + for t_start, rec_segment in zip(t_starts, rec._recording_segments): + rec_segment.t_start = t_start + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + if __name__ == "__main__": test_decimate() From 2d843f8770a8587c32920d3af4dcc54bb8c05411 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 6 Nov 2024 10:09:56 +0100 Subject: [PATCH 3/3] Zach's comments --- src/spikeinterface/preprocessing/decimate.py | 2 +- src/spikeinterface/preprocessing/tests/test_decimate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 2b47601fc2..d5fc9d2025 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -105,7 +105,7 @@ def __init__( if parent_recording_segment.t_start is None: t_start = None else: - t_start = parent_recording_segment.t_start + decimation_offset / parent_rate + t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate) # Do not use BasePreprocessorSegment bcause we have to reset the sampling rate! BaseRecordingSegment.__init__( diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index dd521cbe9b..aab17560a6 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("num_segments", [1, 2]) -@pytest.mark.parametrize("decimation_offset", [0, 5, 21, 101]) +@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101]) @pytest.mark.parametrize("decimation_factor", [1, 7, 50]) def test_decimate(num_segments, decimation_offset, decimation_factor): segment_num_samps = [20000, 40000]