From 712e576ddbb02acde5a29f23e405154390ff339c Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 30 Oct 2023 11:43:52 +0000 Subject: [PATCH] Fix `zero_channel_pad` case when selected frames are outside of original data region (#1979) * Add case in which start and end frame after outside of original data region. * Fix the new check for retriving chunks in the end-padding zone. * Add test case to reveal error. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add `phase_shift` test case. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../preprocessing/tests/test_zero_padding.py | 49 ++++++++++++++++++- .../preprocessing/zero_channel_pad.py | 10 +++- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_zero_padding.py b/src/spikeinterface/preprocessing/tests/test_zero_padding.py index 75d64b0088..5c2c00f9ca 100644 --- a/src/spikeinterface/preprocessing/tests/test_zero_padding.py +++ b/src/spikeinterface/preprocessing/tests/test_zero_padding.py @@ -6,7 +6,7 @@ from spikeinterface.core import generate_recording from spikeinterface.core.numpyextractors import NumpyRecording -from spikeinterface.preprocessing import zero_channel_pad +from spikeinterface.preprocessing import zero_channel_pad, bandpass_filter, phase_shift from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording if hasattr(pytest, "global_test_folder"): @@ -39,7 +39,7 @@ def test_zero_padding_channel(): @pytest.fixture def recording(): num_channels = 4 - num_samples = 10 + num_samples = 10000 rng = np.random.default_rng(seed=0) traces = rng.random(size=(num_samples, num_channels)) traces_list = [traces] @@ -258,5 +258,50 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording, assert np.allclose(padded_traces_end, expected_zeros) +@pytest.mark.parametrize("preprocessing", ["bandpass_filter", "phase_shift"]) +@pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)]) +def test_trace_padded_recording_retrieve_full_recording_with_preprocessing( + recording, padding_start, padding_end, preprocessing +): + num_samples = recording.get_num_samples() + num_channels = recording.get_num_channels() + + if preprocessing == "bandpass_filter": + recording = bandpass_filter(recording, freq_min=300, freq_max=6000) + else: + sample_shift_size = 0.4 + inter_sample_shift = np.arange(recording.get_num_channels()) * sample_shift_size + recording.set_property("inter_sample_shift", inter_sample_shift) + recording = phase_shift(recording) + + padded_recording = TracePaddedRecording( + parent_recording=recording, + padding_start=padding_start, + padding_end=padding_end, + ) + + # Cycle through the whole recording, using get_traces() to pull chunks of + # size `step`. This emulates the processing of writing to a binary file. + # Data that lie within the padding region should be fill value only, while + # data from original trace should match exactly. Note that the step + # size must be chosen to retreieve data that is purely padding or original data + step = 1000 + start_frames = np.arange(padded_recording.get_num_samples(), step=step) + end_frames = start_frames + step + + for start_frame, end_frame in zip(start_frames, end_frames): + padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame) + + end_padding_region_first_idx = padding_start + num_samples + + if padding_start <= start_frame < end_padding_region_first_idx: + original_trace = recording.get_traces( + start_frame=start_frame - padding_start, end_frame=end_frame - padding_start + ) + assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10) + else: + assert np.all(padded_trace == padded_recording.fill_value) + + if __name__ == "__main__": test_zero_padding_channel() diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 124b2b080e..eaac91bb18 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -94,6 +94,13 @@ def get_traces(self, start_frame, end_frame, channel_indices): # Else, we start with the full padded traces and allocate the original traces in the middle output_traces = np.full(shape=(trace_size, num_channels), fill_value=self.fill_value, dtype=self.dtype) + # If start and end frame are outside of the original data region (e.g. for Kilosort), return only paddding + if ( + start_frame > self.num_samples_in_original_segment + self.padding_start + and end_frame > self.num_samples_in_original_segment + self.padding_start + ): + return output_traces + # After the padding, the original traces are placed in the middle until the end of the original traces if end_frame >= self.padding_start: original_traces = self.get_original_traces_shifted( @@ -119,11 +126,12 @@ def get_original_traces_shifted(self, start_frame, end_frame, channel_indices): """ original_start_frame = max(start_frame - self.padding_start, 0) original_end_frame = min(end_frame - self.padding_start, self.num_samples_in_original_segment) + original_traces = self.parent_recording_segment.get_traces( start_frame=original_start_frame, end_frame=original_end_frame, channel_indices=channel_indices, - ) + ) # BREAKPOINT HERE!! return original_traces