Skip to content

Commit

Permalink
Fix zero_channel_pad case when selected frames are outside of origi…
Browse files Browse the repository at this point in the history
…nal data region (#1979)

* Add case in which start and end frame after outside of original data region.

* Fix the new check for retriving chunks in the end-padding zone.

* Add test case to reveal error.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add `phase_shift` test case.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
JoeZiminski and pre-commit-ci[bot] authored Oct 30, 2023
1 parent 1a23d1d commit 712e576
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
49 changes: 47 additions & 2 deletions src/spikeinterface/preprocessing/tests/test_zero_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from spikeinterface.core import generate_recording
from spikeinterface.core.numpyextractors import NumpyRecording

from spikeinterface.preprocessing import zero_channel_pad
from spikeinterface.preprocessing import zero_channel_pad, bandpass_filter, phase_shift
from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording

if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_zero_padding_channel():
@pytest.fixture
def recording():
num_channels = 4
num_samples = 10
num_samples = 10000
rng = np.random.default_rng(seed=0)
traces = rng.random(size=(num_samples, num_channels))
traces_list = [traces]
Expand Down Expand Up @@ -258,5 +258,50 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording,
assert np.allclose(padded_traces_end, expected_zeros)


@pytest.mark.parametrize("preprocessing", ["bandpass_filter", "phase_shift"])
@pytest.mark.parametrize("padding_start, padding_end", [(5000, 5000), (5000, 0), (0, 5000)])
def test_trace_padded_recording_retrieve_full_recording_with_preprocessing(
recording, padding_start, padding_end, preprocessing
):
num_samples = recording.get_num_samples()
num_channels = recording.get_num_channels()

if preprocessing == "bandpass_filter":
recording = bandpass_filter(recording, freq_min=300, freq_max=6000)
else:
sample_shift_size = 0.4
inter_sample_shift = np.arange(recording.get_num_channels()) * sample_shift_size
recording.set_property("inter_sample_shift", inter_sample_shift)
recording = phase_shift(recording)

padded_recording = TracePaddedRecording(
parent_recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)

# Cycle through the whole recording, using get_traces() to pull chunks of
# size `step`. This emulates the processing of writing to a binary file.
# Data that lie within the padding region should be fill value only, while
# data from original trace should match exactly. Note that the step
# size must be chosen to retreieve data that is purely padding or original data
step = 1000
start_frames = np.arange(padded_recording.get_num_samples(), step=step)
end_frames = start_frames + step

for start_frame, end_frame in zip(start_frames, end_frames):
padded_trace = padded_recording.get_traces(start_frame=start_frame, end_frame=end_frame)

end_padding_region_first_idx = padding_start + num_samples

if padding_start <= start_frame < end_padding_region_first_idx:
original_trace = recording.get_traces(
start_frame=start_frame - padding_start, end_frame=end_frame - padding_start
)
assert np.allclose(padded_trace, original_trace, rtol=0, atol=1e-10)
else:
assert np.all(padded_trace == padded_recording.fill_value)


if __name__ == "__main__":
test_zero_padding_channel()
10 changes: 9 additions & 1 deletion src/spikeinterface/preprocessing/zero_channel_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# Else, we start with the full padded traces and allocate the original traces in the middle
output_traces = np.full(shape=(trace_size, num_channels), fill_value=self.fill_value, dtype=self.dtype)

# If start and end frame are outside of the original data region (e.g. for Kilosort), return only paddding
if (
start_frame > self.num_samples_in_original_segment + self.padding_start
and end_frame > self.num_samples_in_original_segment + self.padding_start
):
return output_traces

# After the padding, the original traces are placed in the middle until the end of the original traces
if end_frame >= self.padding_start:
original_traces = self.get_original_traces_shifted(
Expand All @@ -119,11 +126,12 @@ def get_original_traces_shifted(self, start_frame, end_frame, channel_indices):
"""
original_start_frame = max(start_frame - self.padding_start, 0)
original_end_frame = min(end_frame - self.padding_start, self.num_samples_in_original_segment)

original_traces = self.parent_recording_segment.get_traces(
start_frame=original_start_frame,
end_frame=original_end_frame,
channel_indices=channel_indices,
)
) # BREAKPOINT HERE!!

return original_traces

Expand Down

0 comments on commit 712e576

Please sign in to comment.