Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regularised the API from parent_recording to recording in zero_channel_pad #2923

Merged
merged 3 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/spikeinterface/preprocessing/tests/test_zero_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_trace_padded_recording_full_trace(recording, padding_start, padding_end
num_samples = recording.get_num_samples()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_trace_padded_recording_full_trace_with_channel_indices(recording, paddi
num_samples = recording.get_num_samples()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_trace_padded_recording_retrieve_original_trace(recording, padding_start
num_samples = recording.get_num_samples()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand All @@ -129,7 +129,7 @@ def test_trace_padded_recording_retrieve_partial_original_trace(recording, paddi
num_samples = recording.get_num_samples()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand All @@ -156,7 +156,7 @@ def test_trace_padded_recording_retrieve_start_padding_and_partial_original_trac
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_trace_padded_recording_retrieve_end_padding_and_partial_original_trace(
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -222,7 +222,7 @@ def test_trace_padded_recording_retrieve_traces_with_partial_padding(recording,
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_trace_padded_recording_retrieve_only_start_padding(recording, padding_s
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand All @@ -281,7 +281,7 @@ def test_trace_padded_recording_retrieve_only_end_padding(recording, padding_sta
num_channels = recording.get_num_channels()

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_trace_padded_recording_retrieve_only_end_padding_with_preprocessing(
recording = phase_shift(recording)

padded_recording = TracePaddedRecording(
parent_recording=recording,
recording=recording,
padding_start=padding_start,
padding_end=padding_end,
)
Expand Down
63 changes: 27 additions & 36 deletions src/spikeinterface/preprocessing/zero_channel_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TracePaddedRecording(BasePreprocessor):

Parameters
----------
parent_recording_segment : BaseRecording
recording_segment : BaseRecording
The parent recording segment from which the traces are to be retrieved.
padding_start : int, default: 0
The amount of padding to add to the left of the traces. It has to be non-negative.
Expand All @@ -29,19 +29,17 @@ class TracePaddedRecording(BasePreprocessor):
The value to pad with
"""

def __init__(
self, parent_recording: BaseRecording, padding_start: int = 0, padding_end: int = 0, fill_value: float = 0.0
):
def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end: int = 0, fill_value: float = 0.0):
assert padding_end >= 0 and padding_start >= 0, "Paddings must be >= 0"
super().__init__(recording=parent_recording)
super().__init__(recording=recording)

self.padding_start = padding_start
self.padding_end = padding_end
self.fill_value = fill_value
for segment in parent_recording._recording_segments:
for segment in recording._recording_segments:
recording_segment = TracePaddedRecordingSegment(
segment,
parent_recording.get_num_channels(),
recording.get_num_channels(),
self.dtype,
self.padding_start,
self.padding_end,
Expand All @@ -50,7 +48,7 @@ def __init__(
self.add_recording_segment(recording_segment)

self._kwargs = dict(
parent_recording=parent_recording,
parent_recording=recording,
padding_start=padding_start,
padding_end=padding_end,
fill_value=fill_value,
Expand All @@ -60,21 +58,21 @@ def __init__(
class TracePaddedRecordingSegment(BasePreprocessorSegment):
def __init__(
self,
parent_recording_segment: BaseRecordingSegment,
recording_segment: BaseRecordingSegment,
num_channels,
dtype,
paddign_left,
padding_left,
padding_end,
fill_value,
):
self.padding_start = paddign_left
self.padding_start = padding_left
self.padding_end = padding_end
self.fill_value = fill_value
self.num_channels = num_channels
self.num_samples_in_original_segment = parent_recording_segment.get_num_samples()
self.num_samples_in_original_segment = recording_segment.get_num_samples()
self.dtype = dtype

super().__init__(parent_recording_segment=parent_recording_segment)
super().__init__(parent_recording_segment=recording_segment)

def get_traces(self, start_frame, end_frame, channel_indices):
if start_frame is None:
Expand Down Expand Up @@ -146,12 +144,12 @@ class ZeroChannelPaddedRecording(BaseRecording):
name = "zero_channel_pad"
installed = True

def __init__(self, parent_recording: BaseRecording, num_channels: int, channel_mapping: Union[list, None] = None):
def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: Union[list, None] = None):
"""Pads a recording with channels that contain only zero.

Parameters
----------
parent_recording : BaseRecording
recording : BaseRecording
recording to zero-pad
num_channels : int
Total number of channels in the zero-channel-padded recording
Expand All @@ -160,51 +158,44 @@ def __init__(self, parent_recording: BaseRecording, num_channels: int, channel_m
If None, sorts the channel indices in ascending y channel location and puts them at the
beginning of the zero-channel-padded recording.
"""
BaseRecording.__init__(
self, parent_recording.get_sampling_frequency(), np.arange(num_channels), parent_recording.get_dtype()
)
BaseRecording.__init__(self, recording.get_sampling_frequency(), np.arange(num_channels), recording.get_dtype())

if channel_mapping is not None:
assert (
len(channel_mapping) == parent_recording.get_num_channels()
len(channel_mapping) == recording.get_num_channels()
), "The new mapping must be specified for all channels."
assert max(channel_mapping) < num_channels, (
"The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording."
)
else:
if (
"locations" in parent_recording.get_property_keys()
or "contact_vector" in parent_recording.get_property_keys()
):
self.channel_mapping = np.argsort(parent_recording.get_channel_locations()[:, 1])
if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys():
self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1])
else:
self.channel_mapping = np.arange(parent_recording.get_num_channels())
self.channel_mapping = np.arange(recording.get_num_channels())

self.parent_recording = parent_recording
self.parent_recording = recording
self.num_channels = num_channels
for segment in parent_recording._recording_segments:
for segment in recording._recording_segments:
recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping)
self.add_recording_segment(recording_segment)

# only copy relevant metadata and properties
parent_recording.copy_metadata(self, only_main=True)
self._parent = parent_recording
prop_keys = parent_recording.get_property_keys()
recording.copy_metadata(self, only_main=True)
self._parent = recording
prop_keys = recording.get_property_keys()

for k in prop_keys:
values = self.get_property(k)
if values is not None:
self.set_property(k, values, ids=self.channel_ids[self.channel_mapping])

self._kwargs = dict(
parent_recording=parent_recording, num_channels=num_channels, channel_mapping=channel_mapping
)
self._kwargs = dict(parent_recording=recording, num_channels=num_channels, channel_mapping=channel_mapping)


class ZeroChannelPaddedRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment: BaseRecordingSegment, num_channels: int, channel_mapping: list):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.parent_recording_segment = parent_recording_segment
def __init__(self, recording_segment: BaseRecordingSegment, num_channels: int, channel_mapping: list):
BasePreprocessorSegment.__init__(self, recording_segment)
self.parent_recording_segment = recording_segment
self.num_channels = num_channels
self.channel_mapping = channel_mapping

Expand Down
Loading