Skip to content

Commit

Permalink
Merge pull request SpikeInterface#3519 from alejoe91/fix-decimate-times
Browse files Browse the repository at this point in the history
Don't let decimate mess with times and skim tests
  • Loading branch information
alejoe91 authored Nov 7, 2024
2 parents e525d85 + 2d843f8 commit 8a7895e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 30 deletions.
27 changes: 14 additions & 13 deletions src/spikeinterface/preprocessing/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,15 @@ def __init__(
f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames."
)
self._decimation_offset = decimation_offset
resample_rate = self._orig_samp_freq / self._decimation_factor
decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor

BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate)
BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency)

# in case there was a time_vector, it will be dropped for sanity.
# This is not necessary but consistent with ResampleRecording
for parent_segment in recording._recording_segments:
parent_segment.time_vector = None
self.add_recording_segment(
DecimateRecordingSegment(
parent_segment,
resample_rate,
decimated_sampling_frequency,
self._orig_samp_freq,
decimation_factor,
decimation_offset,
Expand All @@ -93,22 +90,26 @@ class DecimateRecordingSegment(BaseRecordingSegment):
def __init__(
self,
parent_recording_segment,
resample_rate,
decimated_sampling_frequency,
parent_rate,
decimation_factor,
decimation_offset,
dtype,
):
if parent_recording_segment.t_start is None:
new_t_start = None
if parent_recording_segment.time_vector is not None:
time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor]
decimated_sampling_frequency = None
t_start = None
else:
new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate
time_vector = None
if parent_recording_segment.t_start is None:
t_start = None
else:
t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate)

# Do not use BasePreprocessorSegment bcause we have to reset the sampling rate!
BaseRecordingSegment.__init__(
self,
sampling_frequency=resample_rate,
t_start=new_t_start,
self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector
)
self._parent_segment = parent_recording_segment
self._decimation_factor = decimation_factor
Expand Down
69 changes: 52 additions & 17 deletions src/spikeinterface/preprocessing/tests/test_decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,74 @@
import numpy as np


@pytest.mark.parametrize("N_segments", [1, 2])
@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101])
@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101])
@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000])
@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000])
def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame):
rec = generate_recording()

segment_num_samps = [101 + i for i in range(N_segments)]

@pytest.mark.parametrize("num_segments", [1, 2])
@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101])
@pytest.mark.parametrize("decimation_factor", [1, 7, 50])
def test_decimate(num_segments, decimation_offset, decimation_factor):
segment_num_samps = [20000, 40000]
rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1)

parent_traces = [rec.get_traces(i) for i in range(N_segments)]
parent_traces = [rec.get_traces(i) for i in range(num_segments)]

if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor:
with pytest.raises(ValueError):
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
return

decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)]
decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)]

if start_frame is None:
start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments))
if end_frame is None:
end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments))
for start_frame in [0, 1, 5, None, 1000]:
for end_frame in [0, 1, 5, None, 1000]:
if start_frame is None:
start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))
if end_frame is None:
end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments))

for i in range(N_segments):
for i in range(num_segments):
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
assert np.all(
decimated_rec.get_traces(i, start_frame, end_frame)
== decimated_parent_traces[i][start_frame:end_frame]
)

for i in range(num_segments):
assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0]
assert np.all(
decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame]
)


def test_decimate_with_times():
rec = generate_recording(durations=[5, 10])

# test with times
times = [rec.get_times(0) + 10, rec.get_times(1) + 20]
for i, t in enumerate(times):
rec.set_times(t, i)

decimation_factor = 2
decimation_offset = 1
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)

for segment_index in range(rec.get_num_segments()):
assert np.allclose(
decimated_rec.get_times(segment_index),
rec.get_times(segment_index)[decimation_offset::decimation_factor],
)

# test with t_start
rec = generate_recording(durations=[5, 10])
t_starts = [10, 20]
for t_start, rec_segment in zip(t_starts, rec._recording_segments):
rec_segment.t_start = t_start
decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset)
for segment_index in range(rec.get_num_segments()):
assert np.allclose(
decimated_rec.get_times(segment_index),
rec.get_times(segment_index)[decimation_offset::decimation_factor],
)


if __name__ == "__main__":
test_decimate()

0 comments on commit 8a7895e

Please sign in to comment.