Skip to content

Commit

Permalink
Merge pull request #2215 from oaaij-gnahz/dev_mine
Browse files Browse the repository at this point in the history
Fix: save a copy of group ids in CommonReferenceRecording
  • Loading branch information
alejoe91 authored Nov 22, 2023
2 parents 029c24a + 1badc49 commit 4d1796c
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,17 @@ def __init__(

# tranforms groups (ids) to groups (indices)
if groups is not None:
groups = [self.ids_to_indices(g) for g in groups]
group_indices = [self.ids_to_indices(g) for g in groups]
else:
group_indices = None
if ref_channel_ids is not None:
ref_channel_inds = self.ids_to_indices(ref_channel_ids)
else:
ref_channel_inds = None

for parent_segment in recording._recording_segments:
rec_segment = CommonReferenceRecordingSegment(
parent_segment, reference, operator, groups, ref_channel_inds, local_radius, neighbors, dtype_
parent_segment, reference, operator, group_indices, ref_channel_inds, local_radius, neighbors, dtype_
)
self.add_recording_segment(rec_segment)

Expand All @@ -123,13 +125,21 @@ def __init__(

class CommonReferenceRecordingSegment(BasePreprocessorSegment):
def __init__(
self, parent_recording_segment, reference, operator, groups, ref_channel_inds, local_radius, neighbors, dtype
self,
parent_recording_segment,
reference,
operator,
group_indices,
ref_channel_inds,
local_radius,
neighbors,
dtype,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)

self.reference = reference
self.operator = operator
self.groups = groups
self.group_indices = group_indices
self.ref_channel_inds = ref_channel_inds
self.local_radius = local_radius
self.neighbors = neighbors
Expand Down Expand Up @@ -175,8 +185,8 @@ def get_traces(self, start_frame, end_frame, channel_indices):
def _groups(self, channel_indices):
selected_groups = []
selected_channels = []
if self.groups:
for chan_inds in self.groups:
if self.group_indices:
for chan_inds in self.group_indices:
sel_inds = [ind for ind in channel_indices if ind in chan_inds]
# if no channels are in a group, do not return the group
if len(sel_inds) > 0:
Expand Down

0 comments on commit 4d1796c

Please sign in to comment.