diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index a7c01afab6..81eb86a8c4 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -434,9 +434,9 @@ def _select_extension_data(self, unit_ids): new_data[key] = arr[keep_unit_indices, :, :] return new_data - + def _merge_extension_data(self, merges, former_unit_ids): - + new_unit_ids = self.sorting_analyzer._get_ids_after_merging(merges) new_data = dict() for key, arr in self.data.items(): @@ -608,7 +608,7 @@ def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None): def _select_extension_data(self, unit_ids): # this do not depend on units return self.data - + def _merge_extension_data(self, merges): # this do not depend on units return self.data diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index e76a018318..d1629a6f57 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -210,6 +210,7 @@ def random_spikes_selection( return random_spikes_indices + def apply_merges_to_sorting(sorting, merges, censor_ms=None): """ Function to apply a resolved representation of the merges to a sorting object. If censor_ms is not None, @@ -245,6 +246,7 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=None): ) from spikeinterface.core import NumpySorting + times_list = [] labels_list = [] for segment_index in range(sorting.get_num_segments()): @@ -259,4 +261,4 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=None): labels_list += [sorting.unit_ids[labels]] sorting = NumpySorting.from_times_labels(times_list, labels_list, sorting.sampling_frequency) - return sorting \ No newline at end of file + return sorting diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 51745dfe47..78f69f8919 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -613,7 +613,9 @@ def _get_ids_after_merging(self, merges): new_unit_ids.discard(id) return list(new_unit_ids) - def _save_or_select_or_merge(self, format="binary_folder", folder=None, unit_ids=None, merges=None) -> "SortingAnalyzer": + def _save_or_select_or_merge( + self, format="binary_folder", folder=None, unit_ids=None, merges=None + ) -> "SortingAnalyzer": """ Internal used by both save_as(), copy() and select_units() which are more or less the same. """ @@ -655,6 +657,7 @@ def _save_or_select_or_merge(self, format="binary_folder", folder=None, unit_ids sorting_provenance = sorting_provenance.select_units(unit_ids) elif merges is not None: from spikeinterface.core.sorting_tools import apply_merges_to_sorting + sorting_provenance = apply_merges_to_sorting(sorting_provenance, merges) if format == "memory": diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 615d434dfa..d6be0d4028 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -113,7 +113,7 @@ def _select_extension_data(self, unit_ids): if self.params["handle_collisions"]: new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask] return new_data - + def _merge_extension_data(self, merges): # keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))