Skip to content

Commit

Permalink
Merge pull request #2922 from zm711/fix-parent-sorting
Browse files Browse the repository at this point in the history
Update curation module to use `sorting` argument rather than `parent_sorting`
  • Loading branch information
alejoe91 authored Jun 5, 2024
2 parents e09fe63 + 51e86d0 commit 3c58fca
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 37 deletions.
15 changes: 8 additions & 7 deletions src/spikeinterface/curation/curationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class CurationSorting:
Parameters
----------
parent_sorting : Recording
The recording object
sorting: BaseSorting
The sorting object
properties_policy : "keep" | "remove", default: "keep"
Policy used to propagate properties after split and merge operation. If "keep" the properties will be
passed to the new units (if the original units have the same value). If "remove" the new units will have
Expand All @@ -32,12 +32,13 @@ class CurationSorting:
Sorting object with the selected units merged
"""

def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"):
def __init__(self, sorting, make_graph=False, properties_policy="keep"):

# to allow undo and redo a list of sortingextractors is keep
self._sorting_stages = [parent_sorting]
self._sorting_stages = [sorting]
self._sorting_stages_i = 0
self._properties_policy = properties_policy
parent_units = parent_sorting.get_unit_ids()
parent_units = sorting.get_unit_ids()
self._make_graph = make_graph
if make_graph:
# to easily allow undo and redo a list of graphs with the history of the curation is keep
Expand All @@ -52,7 +53,7 @@ def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"):
else:
self.max_used_id = max(parent_units) if len(parent_units) > 0 else 0

self._kwargs = dict(parent_sorting=parent_sorting, make_graph=make_graph, properties_policy=properties_policy)
self._kwargs = dict(sorting=sorting, make_graph=make_graph, properties_policy=properties_policy)

def _get_unused_id(self, n=1):
# check units in the graph to the next unused unit id
Expand Down Expand Up @@ -121,7 +122,7 @@ def merge(self, units_to_merge, new_unit_id=None, delta_time_ms=0.4):
elif new_unit_id not in units_to_merge:
assert new_unit_id not in current_sorting.unit_ids, f"new_unit_id already exists!"
new_sorting = MergeUnitsSorting(
parent_sorting=current_sorting,
sorting=current_sorting,
units_to_merge=units_to_merge,
new_unit_ids=[new_unit_id],
delta_time_ms=delta_time_ms,
Expand Down
28 changes: 14 additions & 14 deletions src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting):
Parameters
----------
parent_sorting : Recording
sorting: BaseSorting
The sorting object
units_to_merge : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
Expand All @@ -32,17 +32,17 @@ class MergeUnitsSorting(BaseSorting):
Sorting object with the selected units merged
"""

def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4):
self._parent_sorting = parent_sorting
def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4):
self._parent_sorting = sorting

if not isinstance(units_to_merge[0], (list, tuple)):
# keep backward compatibility : the previous behavior was only one merge
units_to_merge = [units_to_merge]

num_merge = len(units_to_merge)

parents_unit_ids = parent_sorting.unit_ids
sampling_frequency = parent_sorting.get_sampling_frequency()
parents_unit_ids = sorting.unit_ids
sampling_frequency = sorting.get_sampling_frequency()

all_removed_ids = []
for ids in units_to_merge:
Expand Down Expand Up @@ -93,25 +93,25 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
sub_segment = MergeUnitsSortingSegment(parent_segment, units_to_merge, new_unit_ids, rm_dup_delta)
self.add_sorting_segment(sub_segment)

ann_keys = parent_sorting._annotations.keys()
self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys})
ann_keys = sorting._annotations.keys()
self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys})

# copy properties for unchanged units, and check if units propierties are the same
keep_parent_inds = parent_sorting.ids_to_indices(keep_unit_ids)
keep_parent_inds = sorting.ids_to_indices(keep_unit_ids)
# ~ all_removed_inds = parent_sorting.ids_to_indices(all_removed_ids)
keep_inds = self.ids_to_indices(keep_unit_ids)
# ~ merge_inds = self.ids_to_indices(new_unit_ids)
prop_keys = parent_sorting.get_property_keys()
prop_keys = sorting.get_property_keys()
for key in prop_keys:
parent_values = parent_sorting.get_property(key)
parent_values = sorting.get_property(key)

if properties_policy == "keep":
# propagate keep values
shape = (len(unit_ids),) + parent_values.shape[1:]
new_values = np.empty(shape=shape, dtype=parent_values.dtype)
new_values[keep_inds] = parent_values[keep_parent_inds]
for new_id, ids in zip(new_unit_ids, units_to_merge):
removed_inds = parent_sorting.ids_to_indices(ids)
removed_inds = sorting.ids_to_indices(ids)
merge_values = parent_values[removed_inds]

same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]])
Expand All @@ -133,13 +133,13 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
elif properties_policy == "remove":
self.set_property(key, parent_values[keep_parent_inds], keep_unit_ids)

if parent_sorting.has_recording():
self.register_recording(parent_sorting._recording)
if sorting.has_recording():
self.register_recording(sorting._recording)

# make it jsonable
units_to_merge = [list(e) for e in units_to_merge]
self._kwargs = dict(
parent_sorting=parent_sorting,
sorting=sorting,
units_to_merge=units_to_merge,
new_unit_ids=new_unit_ids,
properties_policy=properties_policy,
Expand Down
32 changes: 16 additions & 16 deletions src/spikeinterface/curation/splitunitsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class SplitUnitSorting(BaseSorting):
Parameters
----------
parent_sorting : Recording
The recording object
sorting: BaseSorting
The sorting object
parent_unit_id : int
Unit id of the unit to split
indices_list : list or np.array
Expand All @@ -34,11 +34,11 @@ class SplitUnitSorting(BaseSorting):
Sorting object with the selected units split
"""

def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"):
def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"):
if type(indices_list) is not list:
indices_list = [indices_list]
parents_unit_ids = parent_sorting.unit_ids
assert parent_sorting.get_num_segments() == len(
parents_unit_ids = sorting.unit_ids
assert sorting.get_num_segments() == len(
indices_list
), "The length of indices_list must be the same as parent_sorting.get_num_segments"
split_unit_indices = np.unique([np.unique(v) for v in indices_list])
Expand Down Expand Up @@ -70,10 +70,10 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
np.isin(new_unit_ids, unchanged_units)
), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id"

sampling_frequency = parent_sorting.get_sampling_frequency()
sampling_frequency = sorting.get_sampling_frequency()
units_ids = np.concatenate([unchanged_units, new_unit_ids])

self._parent_sorting = parent_sorting
self._parent_sorting = sorting

BaseSorting.__init__(self, sampling_frequency, units_ids)
assert all(
Expand All @@ -85,18 +85,18 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
self.add_sorting_segment(sub_segment)

# copy properties
ann_keys = parent_sorting._annotations.keys()
self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys})
ann_keys = sorting._annotations.keys()
self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys})

# copy properties for unchanged units, and check if units propierties
keep_parent_inds = parent_sorting.ids_to_indices(unchanged_units)
split_unit_id_ind = parent_sorting.id_to_index(split_unit_id)
keep_parent_inds = sorting.ids_to_indices(unchanged_units)
split_unit_id_ind = sorting.id_to_index(split_unit_id)
keep_units_inds = self.ids_to_indices(unchanged_units)
split_unit_ind = self.ids_to_indices(new_unit_ids)
# copy properties from original units to split ones
prop_keys = parent_sorting._properties.keys()
prop_keys = sorting._properties.keys()
for k in prop_keys:
values = parent_sorting._properties[k]
values = sorting._properties[k]
if properties_policy == "keep":
new_values = np.empty_like(values, shape=len(units_ids))
new_values[keep_units_inds] = values[keep_parent_inds]
Expand All @@ -105,11 +105,11 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
continue
self.set_property(k, values[keep_parent_inds], unchanged_units)

if parent_sorting.has_recording():
self.register_recording(parent_sorting._recording)
if sorting.has_recording():
self.register_recording(sorting._recording)

self._kwargs = dict(
parent_sorting=parent_sorting,
sorting=sorting,
split_unit_id=split_unit_id,
indices_list=indices_list,
new_unit_ids=new_unit_ids,
Expand Down

0 comments on commit 3c58fca

Please sign in to comment.