diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 219854f340..c40aa11767 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -98,7 +98,9 @@ def __init__( # tranforms groups (ids) to groups (indices) if groups is not None: - groups = [self.ids_to_indices(g) for g in groups] + group_indices = [self.ids_to_indices(g) for g in groups] + else: + group_indices = None if ref_channel_ids is not None: ref_channel_inds = self.ids_to_indices(ref_channel_ids) else: @@ -106,7 +108,7 @@ def __init__( for parent_segment in recording._recording_segments: rec_segment = CommonReferenceRecordingSegment( - parent_segment, reference, operator, groups, ref_channel_inds, local_radius, neighbors, dtype_ + parent_segment, reference, operator, group_indices, ref_channel_inds, local_radius, neighbors, dtype_ ) self.add_recording_segment(rec_segment) @@ -123,13 +125,21 @@ def __init__( class CommonReferenceRecordingSegment(BasePreprocessorSegment): def __init__( - self, parent_recording_segment, reference, operator, groups, ref_channel_inds, local_radius, neighbors, dtype + self, + parent_recording_segment, + reference, + operator, + group_indices, + ref_channel_inds, + local_radius, + neighbors, + dtype, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.reference = reference self.operator = operator - self.groups = groups + self.group_indices = group_indices self.ref_channel_inds = ref_channel_inds self.local_radius = local_radius self.neighbors = neighbors @@ -175,8 +185,8 @@ def get_traces(self, start_frame, end_frame, channel_indices): def _groups(self, channel_indices): selected_groups = [] selected_channels = [] - if self.groups: - for chan_inds in self.groups: + if self.group_indices: + for chan_inds in self.group_indices: sel_inds = [ind for ind in channel_indices if ind in chan_inds] # if no channels are in a group, do not return the group if len(sel_inds) > 0: