From dedc3e03aacb5efaf14b0c9565578bf881a4d290 Mon Sep 17 00:00:00 2001 From: Roberto <37729096+RobertoDF@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:56:38 +0200 Subject: [PATCH] Update get_channel_locations in sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 817c453a97..915884b7cb 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1100,9 +1100,16 @@ def get_probe(self): def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording # this give all channel locations, so no kwargs like channel_ids and axes + channel_indices = self.channel_ids_to_indices(self.channel_ids) all_probes = self.get_probegroup().probes all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - return all_positions + probes_channel_indices = np.concatenate([probe.device_channel_indices for probe in all_probes]) + + sorted_probe_idx = np.argsort(probes_channel_indices) + sorted_positions_idx = np.searchsorted(probes_channel_indices[sorted_probe_idx], channel_indices) + + positions = all_positions[sorted_probe_idx[sorted_positions_idx]] + return positions def channel_ids_to_indices(self, channel_ids) -> np.ndarray: all_channel_ids = list(self.rec_attributes["channel_ids"])