diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 7d3d890f7f..f1c2cb5a5b 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -156,7 +156,7 @@ def _get_waveforms(self, sorting_analyzer=None, unit_ids=None, verbose=False, ** recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - + some_spikes = sorting_analyzer.get_extension("random_spikes").get_random_spikes() if unit_ids is None: @@ -246,27 +246,29 @@ def _merge_extension_data(self, units_to_merge, new_unit_ids, new_sorting_analyz if new_sorting_analyzer.sparsity is not None: sparsity_mask = new_sorting_analyzer.sparsity.mask num_chans = int(max(np.sum(sparsity_mask, axis=1))) - old_num_chans = self.data['waveforms'].shape[2] + old_num_chans = self.data["waveforms"].shape[2] if num_chans == old_num_chans: - new_data['waveforms'] = self.data['waveforms'] + new_data["waveforms"] = self.data["waveforms"] else: - num_waveforms = len(self.data['waveforms']) - num_samples = self.data['waveforms'].shape[1] + num_waveforms = len(self.data["waveforms"]) + num_samples = self.data["waveforms"].shape[1] some_spikes = new_sorting_analyzer.get_extension("random_spikes").get_random_spikes() - new_data['waveforms'] = np.zeros((num_waveforms, num_samples, num_chans), dtype=self.data['waveforms'].dtype) + new_data["waveforms"] = np.zeros( + (num_waveforms, num_samples, num_chans), dtype=self.data["waveforms"].dtype + ) keep_unit_indices = np.flatnonzero(~np.isin(new_sorting_analyzer.unit_ids, new_unit_ids)) keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices) - new_data['waveforms'][keep_spike_mask, :, :old_num_chans] = self.data['waveforms'][keep_spike_mask] + new_data["waveforms"][keep_spike_mask, :, :old_num_chans] = self.data["waveforms"][keep_spike_mask] - # We only recompute waveforms for new units that might have a new sparsity mask. Could be + # We only recompute waveforms for new units that might have a new sparsity mask. Could be # slightly optimized by checking exactly which merged units have a different mask updated_unit_indices = np.flatnonzero(np.isin(new_sorting_analyzer.unit_ids, new_unit_ids)) updated_spike_mask = np.isin(some_spikes["unit_index"], updated_unit_indices) new_waveforms = self._get_waveforms(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) - new_data['waveforms'][updated_spike_mask] = new_waveforms + new_data["waveforms"][updated_spike_mask] = new_waveforms else: - new_data['waveforms'] = self.data['waveforms'] + new_data["waveforms"] = self.data["waveforms"] return new_data @@ -495,7 +497,6 @@ def _merge_extension_data(self, units_to_merge, new_unit_ids, new_sorting_analyz return new_data - def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator