Skip to content

Commit

Permalink
parent_sorting-> sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed May 29, 2024
1 parent 790715c commit 8f06759
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 34 deletions.
12 changes: 7 additions & 5 deletions src/spikeinterface/curation/curationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import namedtuple
from collections.abc import Iterable
from warnings import warn

import numpy as np

Expand All @@ -18,7 +19,7 @@ class CurationSorting:
Parameters
----------
parent_sorting: Recording
sorting: Recording
The recording object
properties_policy: "keep" | "remove", default: "keep"
Policy used to propagate properties after split and merge operation. If "keep" the properties will be
Expand All @@ -32,12 +33,13 @@ class CurationSorting:
Sorting object with the selected units merged
"""

def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"):
def __init__(self, sorting, make_graph=False, properties_policy="keep"):

# to allow undo and redo a list of sortingextractors is keep
self._sorting_stages = [parent_sorting]
self._sorting_stages = [sorting]
self._sorting_stages_i = 0
self._properties_policy = properties_policy
parent_units = parent_sorting.get_unit_ids()
parent_units = sorting.get_unit_ids()
self._make_graph = make_graph
if make_graph:
# to easily allow undo and redo a list of graphs with the history of the curation is keep
Expand All @@ -52,7 +54,7 @@ def __init__(self, parent_sorting, make_graph=False, properties_policy="keep"):
else:
self.max_used_id = max(parent_units) if len(parent_units) > 0 else 0

self._kwargs = dict(parent_sorting=parent_sorting, make_graph=make_graph, properties_policy=properties_policy)
self._kwargs = dict(parent_sorting=sorting, make_graph=make_graph, properties_policy=properties_policy)

def _get_unused_id(self, n=1):
# check units in the graph to the next unused unit id
Expand Down
28 changes: 14 additions & 14 deletions src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MergeUnitsSorting(BaseSorting):
Parameters
----------
parent_sorting: Recording
sorting: Recording
The sorting object
units_to_merge: list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
Expand All @@ -32,17 +32,17 @@ class MergeUnitsSorting(BaseSorting):
Sorting object with the selected units merged
"""

def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4):
self._parent_sorting = parent_sorting
def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy="keep", delta_time_ms=0.4):
self._parent_sorting = sorting

if not isinstance(units_to_merge[0], (list, tuple)):
# keep backward compatibility : the previous behavior was only one merge
units_to_merge = [units_to_merge]

num_merge = len(units_to_merge)

parents_unit_ids = parent_sorting.unit_ids
sampling_frequency = parent_sorting.get_sampling_frequency()
parents_unit_ids = sorting.unit_ids
sampling_frequency = sorting.get_sampling_frequency()

all_removed_ids = []
for ids in units_to_merge:
Expand Down Expand Up @@ -93,25 +93,25 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
sub_segment = MergeUnitsSortingSegment(parent_segment, units_to_merge, new_unit_ids, rm_dup_delta)
self.add_sorting_segment(sub_segment)

ann_keys = parent_sorting._annotations.keys()
self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys})
ann_keys = sorting._annotations.keys()
self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys})

# copy properties for unchanged units, and check if units propierties are the same
keep_parent_inds = parent_sorting.ids_to_indices(keep_unit_ids)
keep_parent_inds = sorting.ids_to_indices(keep_unit_ids)
# ~ all_removed_inds = parent_sorting.ids_to_indices(all_removed_ids)
keep_inds = self.ids_to_indices(keep_unit_ids)
# ~ merge_inds = self.ids_to_indices(new_unit_ids)
prop_keys = parent_sorting.get_property_keys()
prop_keys = sorting.get_property_keys()
for key in prop_keys:
parent_values = parent_sorting.get_property(key)
parent_values = sorting.get_property(key)

if properties_policy == "keep":
# propagate keep values
shape = (len(unit_ids),) + parent_values.shape[1:]
new_values = np.empty(shape=shape, dtype=parent_values.dtype)
new_values[keep_inds] = parent_values[keep_parent_inds]
for new_id, ids in zip(new_unit_ids, units_to_merge):
removed_inds = parent_sorting.ids_to_indices(ids)
removed_inds = sorting.ids_to_indices(ids)
merge_values = parent_values[removed_inds]

same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]])
Expand All @@ -133,13 +133,13 @@ def __init__(self, parent_sorting, units_to_merge, new_unit_ids=None, properties
elif properties_policy == "remove":
self.set_property(key, parent_values[keep_parent_inds], keep_unit_ids)

if parent_sorting.has_recording():
self.register_recording(parent_sorting._recording)
if sorting.has_recording():
self.register_recording(sorting._recording)

# make it jsonable
units_to_merge = [list(e) for e in units_to_merge]
self._kwargs = dict(
parent_sorting=parent_sorting,
parent_sorting=sorting,
units_to_merge=units_to_merge,
new_unit_ids=new_unit_ids,
properties_policy=properties_policy,
Expand Down
30 changes: 15 additions & 15 deletions src/spikeinterface/curation/splitunitsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SplitUnitSorting(BaseSorting):
Parameters
----------
parent_sorting: Recording
sorting: Recording
The recording object
parent_unit_id: int
Unit id of the unit to split
Expand All @@ -34,11 +34,11 @@ class SplitUnitSorting(BaseSorting):
Sorting object with the selected units split
"""

def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"):
def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, properties_policy="keep"):
if type(indices_list) is not list:
indices_list = [indices_list]
parents_unit_ids = parent_sorting.unit_ids
assert parent_sorting.get_num_segments() == len(
parents_unit_ids = sorting.unit_ids
assert sorting.get_num_segments() == len(
indices_list
), "The length of indices_list must be the same as parent_sorting.get_num_segments"
split_unit_indices = np.unique([np.unique(v) for v in indices_list])
Expand Down Expand Up @@ -70,10 +70,10 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
np.isin(new_unit_ids, unchanged_units)
), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id"

sampling_frequency = parent_sorting.get_sampling_frequency()
sampling_frequency = sorting.get_sampling_frequency()
units_ids = np.concatenate([unchanged_units, new_unit_ids])

self._parent_sorting = parent_sorting
self._parent_sorting = sorting

BaseSorting.__init__(self, sampling_frequency, units_ids)
assert all(
Expand All @@ -85,18 +85,18 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
self.add_sorting_segment(sub_segment)

# copy properties
ann_keys = parent_sorting._annotations.keys()
self._annotations = deepcopy({k: parent_sorting._annotations[k] for k in ann_keys})
ann_keys = sorting._annotations.keys()
self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys})

# copy properties for unchanged units, and check if units propierties
keep_parent_inds = parent_sorting.ids_to_indices(unchanged_units)
split_unit_id_ind = parent_sorting.id_to_index(split_unit_id)
keep_parent_inds = sorting.ids_to_indices(unchanged_units)
split_unit_id_ind = sorting.id_to_index(split_unit_id)
keep_units_inds = self.ids_to_indices(unchanged_units)
split_unit_ind = self.ids_to_indices(new_unit_ids)
# copy properties from original units to split ones
prop_keys = parent_sorting._properties.keys()
prop_keys = sorting._properties.keys()
for k in prop_keys:
values = parent_sorting._properties[k]
values = sorting._properties[k]
if properties_policy == "keep":
new_values = np.empty_like(values, shape=len(units_ids))
new_values[keep_units_inds] = values[keep_parent_inds]
Expand All @@ -105,11 +105,11 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
continue
self.set_property(k, values[keep_parent_inds], unchanged_units)

if parent_sorting.has_recording():
self.register_recording(parent_sorting._recording)
if sorting.has_recording():
self.register_recording(sorting._recording)

self._kwargs = dict(
parent_sorting=parent_sorting,
parent_sorting=sorting,
split_unit_id=split_unit_id,
indices_list=indices_list,
new_unit_ids=new_unit_ids,
Expand Down

0 comments on commit 8f06759

Please sign in to comment.