Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update curation module to use sorting argument rather than parent_sorting #2922

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/spikeinterface/curation/curationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class CurationSorting:

Parameters
----------
parent_sorting : Recording
The recording object
sorting: BaseSorting
The sorting object
properties_policy : "keep" | "remove", default: "keep"
Policy used to propagate properties after split and merge operation. If "keep" the properties will be
passed to the new units (if the original units have the same value). If "remove" the new units will have
Expand All @@ -32,12 +32,13 @@ class CurationSorting:
Sorting object with the selected units merged
"""

def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"):
def __init__(self, sorting, make_graph=False, properties_policy="keep"):

# to allow undo and redo a list of sortingextractors is keep
self._sorting_stages = [parent_sorting]
self._sorting_stages = [sorting]
self._sorting_stages_i = 0
self._properties_policy = properties_policy
parent_units = parent_sorting.get_unit_ids()
parent_units = sorting.get_unit_ids()
self._make_graph = make_graph
if make_graph:
# to easily allow undo and redo a list of graphs with the history of the curation is keep
Expand All @@ -52,7 +53,7 @@ def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"):
else:
self.max_used_id = max(parent_units) if len(parent_units) > 0 else 0

self._kwargs = dict(parent_sorting=parent_sorting, make_graph=make_graph, properties_policy=properties_policy)
self._kwargs = dict(sorting=sorting, make_graph=make_graph, properties_policy=properties_policy)

def _get_unused_id(self, n=1):
# check units in the graph to the next unused unit id
Expand Down Expand Up @@ -121,7 +122,7 @@ def merge(self, units_to_merge, new_unit_id=None, delta_time_ms=0.4):
elif new_unit_id not in units_to_merge:
assert new_unit_id not in current_sorting.unit_ids, f"new_unit_id already exists!"
new_sorting = MergeUnitsSorting(
parent_sorting=current_sorting,
sorting=current_sorting,
units_to_merge=units_to_merge,
new_unit_ids=[new_unit_id],
delta_time_ms=delta_time_ms,
Expand Down
28 changes: 14 additions & 14 deletions src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting):

Parameters
----------
parent_sorting : Recording
sorting: BaseSorting
The sorting object
units_to_merge : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
Expand All @@ -32,17 +32,17 @@ class MergeUnitsSorting(BaseSorting):
Sorting object with the selected units merged
"""

def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4):
self._parent_sorting = parent_sorting
def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4):
self._parent_sorting = sorting

if not isinstance(units_to_merge[0], (list, tuple)):
# keep backward compatibility : the previous behavior was only one merge
units_to_merge = [units_to_merge]

num_merge = len(units_to_merge)

parents_unit_ids = parent_sorting.unit_ids
sampling_frequency = parent_sorting.get_sampling_frequency()
parents_unit_ids = sorting.unit_ids
sampling_frequency = sorting.get_sampling_frequency()

all_removed_ids = []
for ids in units_to_merge:
Expand Down Expand Up @@ -93,25 +93,25 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
sub_segment = MergeUnitsSortingSegment(parent_segment, units_to_merge, new_unit_ids, rm_dup_delta)
self.add_sorting_segment(sub_segment)

ann_keys = parent_sorting._annotations.keys()
self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys})
ann_keys = sorting._annotations.keys()
self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys})

# copy properties for unchanged units, and check if units propierties are the same
keep_parent_inds = parent_sorting.ids_to_indices(keep_unit_ids)
keep_parent_inds = sorting.ids_to_indices(keep_unit_ids)
# ~ all_removed_inds = parent_sorting.ids_to_indices(all_removed_ids)
keep_inds = self.ids_to_indices(keep_unit_ids)
# ~ merge_inds = self.ids_to_indices(new_unit_ids)
prop_keys = parent_sorting.get_property_keys()
prop_keys = sorting.get_property_keys()
for key in prop_keys:
parent_values = parent_sorting.get_property(key)
parent_values = sorting.get_property(key)

if properties_policy == "keep":
# propagate keep values
shape = (len(unit_ids),) + parent_values.shape[1:]
new_values = np.empty(shape=shape, dtype=parent_values.dtype)
new_values[keep_inds] = parent_values[keep_parent_inds]
for new_id, ids in zip(new_unit_ids, units_to_merge):
removed_inds = parent_sorting.ids_to_indices(ids)
removed_inds = sorting.ids_to_indices(ids)
merge_values = parent_values[removed_inds]

same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]])
Expand All @@ -133,13 +133,13 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
elif properties_policy == "remove":
self.set_property(key, parent_values[keep_parent_inds], keep_unit_ids)

if parent_sorting.has_recording():
self.register_recording(parent_sorting._recording)
if sorting.has_recording():
self.register_recording(sorting._recording)

# make it jsonable
units_to_merge = [list(e) for e in units_to_merge]
self._kwargs = dict(
parent_sorting=parent_sorting,
sorting=sorting,
units_to_merge=units_to_merge,
new_unit_ids=new_unit_ids,
properties_policy=properties_policy,
Expand Down
32 changes: 16 additions & 16 deletions src/spikeinterface/curation/splitunitsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class SplitUnitSorting(BaseSorting):

Parameters
----------
parent_sorting : Recording
The recording object
sorting: BaseSorting
The sorting object
parent_unit_id : int
Unit id of the unit to split
indices_list : list or np.array
Expand All @@ -34,11 +34,11 @@ class SplitUnitSorting(BaseSorting):
Sorting object with the selected units split
"""

def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"):
def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"):
if type(indices_list) is not list:
indices_list = [indices_list]
parents_unit_ids = parent_sorting.unit_ids
assert parent_sorting.get_num_segments() == len(
parents_unit_ids = sorting.unit_ids
assert sorting.get_num_segments() == len(
indices_list
), "The length of indices_list must be the same as parent_sorting.get_num_segments"
split_unit_indices = np.unique([np.unique(v) for v in indices_list])
Expand Down Expand Up @@ -70,10 +70,10 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
np.isin(new_unit_ids, unchanged_units)
), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id"

sampling_frequency = parent_sorting.get_sampling_frequency()
sampling_frequency = sorting.get_sampling_frequency()
units_ids = np.concatenate([unchanged_units, new_unit_ids])

self._parent_sorting = parent_sorting
self._parent_sorting = sorting

BaseSorting.__init__(self, sampling_frequency, units_ids)
assert all(
Expand All @@ -85,18 +85,18 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
self.add_sorting_segment(sub_segment)

# copy properties
ann_keys = parent_sorting._annotations.keys()
self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys})
ann_keys = sorting._annotations.keys()
self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys})

# copy properties for unchanged units, and check if units propierties
keep_parent_inds = parent_sorting.ids_to_indices(unchanged_units)
split_unit_id_ind = parent_sorting.id_to_index(split_unit_id)
keep_parent_inds = sorting.ids_to_indices(unchanged_units)
split_unit_id_ind = sorting.id_to_index(split_unit_id)
keep_units_inds = self.ids_to_indices(unchanged_units)
split_unit_ind = self.ids_to_indices(new_unit_ids)
# copy properties from original units to split ones
prop_keys = parent_sorting._properties.keys()
prop_keys = sorting._properties.keys()
for k in prop_keys:
values = parent_sorting._properties[k]
values = sorting._properties[k]
if properties_policy == "keep":
new_values = np.empty_like(values, shape=len(units_ids))
new_values[keep_units_inds] = values[keep_parent_inds]
Expand All @@ -105,11 +105,11 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
continue
self.set_property(k, values[keep_parent_inds], unchanged_units)

if parent_sorting.has_recording():
self.register_recording(parent_sorting._recording)
if sorting.has_recording():
self.register_recording(sorting._recording)

self._kwargs = dict(
parent_sorting=parent_sorting,
sorting=sorting,
split_unit_id=split_unit_id,
indices_list=indices_list,
new_unit_ids=new_unit_ids,
Expand Down