From 8f067599974e00543d894662466ca2df27b88857 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 29 May 2024 13:27:54 +0100 Subject: [PATCH 1/3] parent_sorting-> sorting --- .../curation/curationsorting.py | 12 ++++---- .../curation/mergeunitssorting.py | 28 ++++++++--------- .../curation/splitunitsorting.py | 30 +++++++++---------- 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index 1635a915fe..6011c9ccee 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -2,6 +2,7 @@ from collections import namedtuple from collections.abc import Iterable +from warnings import warn import numpy as np @@ -18,7 +19,7 @@ class CurationSorting: Parameters ---------- - parent_sorting: Recording + sorting: Recording The recording object properties_policy: "keep" | "remove", default: "keep" Policy used to propagate properties after split and merge operation. If "keep" the properties will be @@ -32,12 +33,13 @@ class CurationSorting: Sorting object with the selected units merged """ - def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"): + def __init__(self, sorting, make_graph=False, properties_policy="keep"): + # to allow undo and redo a list of sortingextractors is keep - self._sorting_stages = [parent_sorting] + self._sorting_stages = [sorting] self._sorting_stages_i = 0 self._properties_policy = properties_policy - parent_units = parent_sorting.get_unit_ids() + parent_units = sorting.get_unit_ids() self._make_graph = make_graph if make_graph: # to easily allow undo and redo a list of graphs with the history of the curation is keep @@ -52,7 +54,7 @@ def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"): else: self.max_used_id = max(parent_units) if len(parent_units) > 0 else 0 - self._kwargs = dict(parent_sorting=parent_sorting, make_graph=make_graph, properties_policy=properties_policy) + self._kwargs = dict(parent_sorting=sorting, make_graph=make_graph, properties_policy=properties_policy) def _get_unused_id(self, n=1): # check units in the graph to the next unused unit id diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index d32f3ef9b3..4921afa793 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting): Parameters ---------- - parent_sorting: Recording + sorting: Recording The sorting object units_to_merge: list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), @@ -32,8 +32,8 @@ class MergeUnitsSorting(BaseSorting): Sorting object with the selected units merged """ - def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): - self._parent_sorting = parent_sorting + def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4): + self._parent_sorting = sorting if not isinstance(units_to_merge[0], (list, tuple)): # keep backward compatibility : the previous behavior was only one merge @@ -41,8 +41,8 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties num_merge = len(units_to_merge) - parents_unit_ids = parent_sorting.unit_ids - sampling_frequency = parent_sorting.get_sampling_frequency() + parents_unit_ids = sorting.unit_ids + sampling_frequency = sorting.get_sampling_frequency() all_removed_ids = [] for ids in units_to_merge: @@ -93,17 +93,17 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties sub_segment = MergeUnitsSortingSegment(parent_segment, units_to_merge, new_unit_ids, rm_dup_delta) self.add_sorting_segment(sub_segment) - ann_keys = parent_sorting._annotations.keys() - self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys}) + ann_keys = sorting._annotations.keys() + self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys}) # copy properties for unchanged units, and check if units propierties are the same - keep_parent_inds = parent_sorting.ids_to_indices(keep_unit_ids) + keep_parent_inds = sorting.ids_to_indices(keep_unit_ids) # ~ all_removed_inds = parent_sorting.ids_to_indices(all_removed_ids) keep_inds = self.ids_to_indices(keep_unit_ids) # ~ merge_inds = self.ids_to_indices(new_unit_ids) - prop_keys = parent_sorting.get_property_keys() + prop_keys = sorting.get_property_keys() for key in prop_keys: - parent_values = parent_sorting.get_property(key) + parent_values = sorting.get_property(key) if properties_policy == "keep": # propagate keep values @@ -111,7 +111,7 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties new_values = np.empty(shape=shape, dtype=parent_values.dtype) new_values[keep_inds] = parent_values[keep_parent_inds] for new_id, ids in zip(new_unit_ids, units_to_merge): - removed_inds = parent_sorting.ids_to_indices(ids) + removed_inds = sorting.ids_to_indices(ids) merge_values = parent_values[removed_inds] same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]]) @@ -133,13 +133,13 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties elif properties_policy == "remove": self.set_property(key, parent_values[keep_parent_inds], keep_unit_ids) - if parent_sorting.has_recording(): - self.register_recording(parent_sorting._recording) + if sorting.has_recording(): + self.register_recording(sorting._recording) # make it jsonable units_to_merge = [list(e) for e in units_to_merge] self._kwargs = dict( - parent_sorting=parent_sorting, + parent_sorting=sorting, units_to_merge=units_to_merge, new_unit_ids=new_unit_ids, properties_policy=properties_policy, diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index 5854d1b64a..eaf9e736cb 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -13,7 +13,7 @@ class SplitUnitSorting(BaseSorting): Parameters ---------- - parent_sorting: Recording + sorting: Recording The recording object parent_unit_id: int Unit id of the unit to split @@ -34,11 +34,11 @@ class SplitUnitSorting(BaseSorting): Sorting object with the selected units split """ - def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"): + def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"): if type(indices_list) is not list: indices_list = [indices_list] - parents_unit_ids = parent_sorting.unit_ids - assert parent_sorting.get_num_segments() == len( + parents_unit_ids = sorting.unit_ids + assert sorting.get_num_segments() == len( indices_list ), "The length of indices_list must be the same as parent_sorting.get_num_segments" split_unit_indices = np.unique([np.unique(v) for v in indices_list]) @@ -70,10 +70,10 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non np.isin(new_unit_ids, unchanged_units) ), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id" - sampling_frequency = parent_sorting.get_sampling_frequency() + sampling_frequency = sorting.get_sampling_frequency() units_ids = np.concatenate([unchanged_units, new_unit_ids]) - self._parent_sorting = parent_sorting + self._parent_sorting = sorting BaseSorting.__init__(self, sampling_frequency, units_ids) assert all( @@ -85,18 +85,18 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non self.add_sorting_segment(sub_segment) # copy properties - ann_keys = parent_sorting._annotations.keys() - self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys}) + ann_keys = sorting._annotations.keys() + self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys}) # copy properties for unchanged units, and check if units propierties - keep_parent_inds = parent_sorting.ids_to_indices(unchanged_units) - split_unit_id_ind = parent_sorting.id_to_index(split_unit_id) + keep_parent_inds = sorting.ids_to_indices(unchanged_units) + split_unit_id_ind = sorting.id_to_index(split_unit_id) keep_units_inds = self.ids_to_indices(unchanged_units) split_unit_ind = self.ids_to_indices(new_unit_ids) # copy properties from original units to split ones - prop_keys = parent_sorting._properties.keys() + prop_keys = sorting._properties.keys() for k in prop_keys: - values = parent_sorting._properties[k] + values = sorting._properties[k] if properties_policy == "keep": new_values = np.empty_like(values, shape=len(units_ids)) new_values[keep_units_inds] = values[keep_parent_inds] @@ -105,11 +105,11 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non continue self.set_property(k, values[keep_parent_inds], unchanged_units) - if parent_sorting.has_recording(): - self.register_recording(parent_sorting._recording) + if sorting.has_recording(): + self.register_recording(sorting._recording) self._kwargs = dict( - parent_sorting=parent_sorting, + parent_sorting=sorting, split_unit_id=split_unit_id, indices_list=indices_list, new_unit_ids=new_unit_ids, From c64e39cef8433a284cec8478c3a0f564f507bb06 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 29 May 2024 13:58:52 +0100 Subject: [PATCH 2/3] another instance of parent_sorting --- src/spikeinterface/curation/curationsorting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index 6011c9ccee..4d83998bde 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -2,7 +2,6 @@ from collections import namedtuple from collections.abc import Iterable -from warnings import warn import numpy as np @@ -123,7 +122,7 @@ def merge(self, units_to_merge, new_unit_id=None, delta_time_ms=0.4): elif new_unit_id not in units_to_merge: assert new_unit_id not in current_sorting.unit_ids, f"new_unit_id already exists!" new_sorting = MergeUnitsSorting( - parent_sorting=current_sorting, + sorting=current_sorting, units_to_merge=units_to_merge, new_unit_ids=[new_unit_id], delta_time_ms=delta_time_ms, From 6594baf0d26a7e63f0514456e60415f163851b91 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 31 May 2024 10:02:39 +0100 Subject: [PATCH 3/3] final parent_sorting -> sorting --- src/spikeinterface/curation/curationsorting.py | 2 +- src/spikeinterface/curation/mergeunitssorting.py | 2 +- src/spikeinterface/curation/splitunitsorting.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index 4d83998bde..d7043c13cb 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -53,7 +53,7 @@ def __init__(self, sorting, make_graph=False, properties_policy="keep"): else: self.max_used_id = max(parent_units) if len(parent_units) > 0 else 0 - self._kwargs = dict(parent_sorting=sorting, make_graph=make_graph, properties_policy=properties_policy) + self._kwargs = dict(sorting=sorting, make_graph=make_graph, properties_policy=properties_policy) def _get_unused_id(self, n=1): # check units in the graph to the next unused unit id diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index 4921afa793..cc4fabc1b4 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -139,7 +139,7 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy # make it jsonable units_to_merge = [list(e) for e in units_to_merge] self._kwargs = dict( - parent_sorting=sorting, + sorting=sorting, units_to_merge=units_to_merge, new_unit_ids=new_unit_ids, properties_policy=properties_policy, diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index eaf9e736cb..7ecc6dece3 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -109,7 +109,7 @@ def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, prop self.register_recording(sorting._recording) self._kwargs = dict( - parent_sorting=sorting, + sorting=sorting, split_unit_id=split_unit_id, indices_list=indices_list, new_unit_ids=new_unit_ids,