Skip to content

Commit

Permalink
Merge pull request #2649 from chrishalcrow/fix_plot_traces_ch_ids
Browse files Browse the repository at this point in the history
Fix get_traces for a local common reference
  • Loading branch information
alejoe91 authored Apr 3, 2024
2 parents 8b7b51c + 7455456 commit 8e10eaf
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,12 @@ def get_traces(self, start_frame, end_frame, channel_indices):
shift = traces[:, self.ref_channel_indices]
re_referenced_traces = traces[:, channel_indices] - shift
else: # then it must be local
re_referenced_traces = np.zeros_like(traces[:, channel_indices])
channel_indices_array = np.arange(traces.shape[1])[channel_indices]
for channel_index in channel_indices_array:
re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32")
for i, channel_index in enumerate(channel_indices_array):
channel_neighborhood = self.neighbors[channel_index]
channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1)
re_referenced_traces[:, channel_index] = traces[:, channel_index] - channel_shift
re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift

return re_referenced_traces.astype(self.dtype, copy=False)

Expand Down
55 changes: 48 additions & 7 deletions src/spikeinterface/preprocessing/tests/test_common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _generate_test_recording():
recording = generate_recording(durations=[5.0], num_channels=4)
recording = generate_recording(durations=[1.0], num_channels=4)
recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"]))
return recording

Expand Down Expand Up @@ -46,11 +46,15 @@ def test_common_reference(recording):
def test_common_reference_channel_slicing(recording):
recording_cmr = common_reference(recording, reference="global", operator="median")
recording_car = common_reference(recording, reference="global", operator="average")
recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["a"])
recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["b"])
recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")

channel_ids = ["a", "b"]
indices = recording.ids_to_indices(["a", "b"])
channel_ids = ["b", "d"]
indices = recording.ids_to_indices(channel_ids)

all_channel_ids = recording.channel_ids
all_indices = recording.ids_to_indices(all_channel_ids)

original_traces = recording.get_traces()

cmr_trace = recording_cmr.get_traces(channel_ids=channel_ids)
Expand All @@ -62,13 +66,50 @@ def test_common_reference_channel_slicing(recording):
assert np.allclose(car_trace, expected_trace, atol=0.01)

single_reference_trace = recording_single_reference.get_traces(channel_ids=channel_ids)
single_reference_index = recording.ids_to_indices(["a"])
single_reference_index = recording.ids_to_indices(["b"])
expected_trace = original_traces[:, indices] - original_traces[:, single_reference_index]

assert np.allclose(single_reference_trace, expected_trace, atol=0.01)

# local car
local_trace = recording_local_car.get_traces(channel_ids=channel_ids)
local_trace = recording_local_car.get_traces(channel_ids=all_channel_ids)
local_trace_sub = recording_local_car.get_traces(channel_ids=channel_ids)

assert np.all(local_trace[:, indices] == local_trace_sub)

# test segment slicing

start_frame = 0
end_frame = 10

recording_segment_cmr = recording_cmr._recording_segments[0]
traces_cmr_all = recording_segment_cmr.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices
)
traces_cmr_sub = recording_segment_cmr.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=indices
)

assert np.all(traces_cmr_all[:, indices] == traces_cmr_sub)

recording_segment_car = recording_car._recording_segments[0]
traces_car_all = recording_segment_car.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices
)
traces_car_sub = recording_segment_car.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=indices
)

assert np.all(traces_car_all[:, indices] == traces_car_sub)

recording_segment_local = recording_local_car._recording_segments[0]
traces_local_all = recording_segment_local.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices
)
traces_local_sub = recording_segment_local.get_traces(
start_frame=start_frame, end_frame=end_frame, channel_indices=indices
)

assert np.all(traces_local_all[:, indices] == traces_local_sub)


def test_common_reference_groups(recording):
Expand Down

0 comments on commit 8e10eaf

Please sign in to comment.