diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d468bd90ab..96e01a68c4 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -59,9 +59,6 @@ class ComputeSpikeLocations(AnalyzerExtension): def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) - extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") - self.spikes = self.sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - def _set_params( self, ms_before=0.5, @@ -89,8 +86,9 @@ def _set_params( def _select_extension_data(self, unit_ids): old_unit_ids = self.sorting_analyzer.unit_ids unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) + spikes = self.sorting_analyzer.sorting.to_spike_vector() - spike_mask = np.isin(self.spikes["unit_index"], unit_inds) + spike_mask = np.isin(spikes["unit_index"], unit_inds) new_spike_locations = self.data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations)