diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 817c453a97..915884b7cb 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1100,9 +1100,16 @@ def get_probe(self): def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording # this give all channel locations, so no kwargs like channel_ids and axes + channel_indices = self.channel_ids_to_indices(self.channel_ids) all_probes = self.get_probegroup().probes all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - return all_positions + probes_channel_indices = np.concatenate([probe.device_channel_indices for probe in all_probes]) + + sorted_probe_idx = np.argsort(probes_channel_indices) + sorted_positions_idx = np.searchsorted(probes_channel_indices[sorted_probe_idx], channel_indices) + + positions = all_positions[sorted_probe_idx[sorted_positions_idx]] + return positions def channel_ids_to_indices(self, channel_ids) -> np.ndarray: all_channel_ids = list(self.rec_attributes["channel_ids"])