From 116af531a34cc3ac5841d8d89cc0a1e0d8cb3e14 Mon Sep 17 00:00:00 2001 From: Jiaao Zhang Date: Wed, 15 Nov 2023 11:48:44 -0600 Subject: [PATCH 1/4] Fix: save a copy of group ids in CommonReferenceRecording --- .../preprocessing/common_reference.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 219854f340..5578637346 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -98,7 +98,9 @@ def __init__( # tranforms groups (ids) to groups (indices) if groups is not None: - groups = [self.ids_to_indices(g) for g in groups] + groups_inds = [self.ids_to_indices(g) for g in groups] + else: + groups_inds = None if ref_channel_ids is not None: ref_channel_inds = self.ids_to_indices(ref_channel_ids) else: @@ -106,7 +108,7 @@ def __init__( for parent_segment in recording._recording_segments: rec_segment = CommonReferenceRecordingSegment( - parent_segment, reference, operator, groups, ref_channel_inds, local_radius, neighbors, dtype_ + parent_segment, reference, operator, groups_inds, ref_channel_inds, local_radius, neighbors, dtype_ ) self.add_recording_segment(rec_segment) @@ -123,13 +125,13 @@ def __init__( class CommonReferenceRecordingSegment(BasePreprocessorSegment): def __init__( - self, parent_recording_segment, reference, operator, groups, ref_channel_inds, local_radius, neighbors, dtype + self, parent_recording_segment, reference, operator, groups_inds, ref_channel_inds, local_radius, neighbors, dtype ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.reference = reference self.operator = operator - self.groups = groups + self.groups_inds = groups_inds self.ref_channel_inds = ref_channel_inds self.local_radius = local_radius self.neighbors = neighbors @@ -175,8 +177,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): def _groups(self, channel_indices): selected_groups = [] selected_channels = [] - if self.groups: - for chan_inds in self.groups: + if self.groups_inds: + for chan_inds in self.groups_inds: sel_inds = [ind for ind in channel_indices if ind in chan_inds] # if no channels are in a group, do not return the group if len(sel_inds) > 0: From ac3ae7c31729dbbc4192a8fa55a6cf7ad9d4b6fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 18:13:51 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/common_reference.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 5578637346..316eb9c5c0 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -125,7 +125,15 @@ def __init__( class CommonReferenceRecordingSegment(BasePreprocessorSegment): def __init__( - self, parent_recording_segment, reference, operator, groups_inds, ref_channel_inds, local_radius, neighbors, dtype + self, + parent_recording_segment, + reference, + operator, + groups_inds, + ref_channel_inds, + local_radius, + neighbors, + dtype, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) From 89acc5945b5067848f86f57f4a00d0c85bb3c2dc Mon Sep 17 00:00:00 2001 From: Jiaao Zhang <40973006+oaaij-gnahz@users.noreply.github.com> Date: Tue, 21 Nov 2023 11:37:35 -0600 Subject: [PATCH 3/4] Update src/spikeinterface/preprocessing/common_reference.py Apply suggestion Co-authored-by: Alessio Buccino --- src/spikeinterface/preprocessing/common_reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 316eb9c5c0..a5c9638896 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -98,7 +98,7 @@ def __init__( # tranforms groups (ids) to groups (indices) if groups is not None: - groups_inds = [self.ids_to_indices(g) for g in groups] + group_indices = [self.ids_to_indices(g) for g in groups] else: groups_inds = None if ref_channel_ids is not None: From 1badc49b853ffb8e0bc436f53ba5e0d4fdb5d572 Mon Sep 17 00:00:00 2001 From: Jiaao Zhang <40973006+oaaij-gnahz@users.noreply.github.com> Date: Tue, 21 Nov 2023 12:38:57 -0600 Subject: [PATCH 4/4] Fix variable naming inconsistency Apply suggested changes --- src/spikeinterface/preprocessing/common_reference.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index a5c9638896..c40aa11767 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -100,7 +100,7 @@ def __init__( if groups is not None: group_indices = [self.ids_to_indices(g) for g in groups] else: - groups_inds = None + group_indices = None if ref_channel_ids is not None: ref_channel_inds = self.ids_to_indices(ref_channel_ids) else: @@ -108,7 +108,7 @@ def __init__( for parent_segment in recording._recording_segments: rec_segment = CommonReferenceRecordingSegment( - parent_segment, reference, operator, groups_inds, ref_channel_inds, local_radius, neighbors, dtype_ + parent_segment, reference, operator, group_indices, ref_channel_inds, local_radius, neighbors, dtype_ ) self.add_recording_segment(rec_segment) @@ -129,7 +129,7 @@ def __init__( parent_recording_segment, reference, operator, - groups_inds, + group_indices, ref_channel_inds, local_radius, neighbors, @@ -139,7 +139,7 @@ def __init__( self.reference = reference self.operator = operator - self.groups_inds = groups_inds + self.group_indices = group_indices self.ref_channel_inds = ref_channel_inds self.local_radius = local_radius self.neighbors = neighbors @@ -185,8 +185,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): def _groups(self, channel_indices): selected_groups = [] selected_channels = [] - if self.groups_inds: - for chan_inds in self.groups_inds: + if self.group_indices: + for chan_inds in self.group_indices: sel_inds = [ind for ind in channel_indices if ind in chan_inds] # if no channels are in a group, do not return the group if len(sel_inds) > 0: