Skip to content

Commit

Permalink
Merge pull request #2207 from h-mayorquin/add_rename_units
Browse files Browse the repository at this point in the history
Add `rename_units` method in sorting
  • Loading branch information
alejoe91 authored Nov 15, 2023
2 parents 0aeed04 + 97f46c3 commit 0ed608f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
52 changes: 38 additions & 14 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def add_sorting_segment(self, sorting_segment):
self._sorting_segments.append(sorting_segment)
sorting_segment.set_parent_extractor(self)

def get_sampling_frequency(self):
def get_sampling_frequency(self) -> float:
return self._sampling_frequency

def get_num_segments(self):
def get_num_segments(self) -> int:
return len(self._sorting_segments)

def get_num_samples(self, segment_index=None):
def get_num_samples(self, segment_index=None) -> int:
"""Returns the number of samples of the associated recording for a segment.
Parameters
Expand All @@ -82,7 +82,7 @@ def get_num_samples(self, segment_index=None):
), "This methods requires an associated recording. Call self.register_recording() first."
return self._recording.get_num_samples(segment_index=segment_index)

def get_total_samples(self):
def get_total_samples(self) -> int:
"""Returns the total number of samples of the associated recording.
Returns
Expand Down Expand Up @@ -299,9 +299,11 @@ def count_num_spikes_per_unit(self) -> dict:

return num_spikes

def count_total_num_spikes(self):
def count_total_num_spikes(self) -> int:
"""
Get total number of spikes summed across segment and units.
Get total number of spikes in the sorting.
This is the sum of all spikes in all segments across all units.
Returns
-------
Expand All @@ -310,9 +312,10 @@ def count_total_num_spikes(self):
"""
return self.to_spike_vector().size

def select_units(self, unit_ids, renamed_unit_ids=None):
def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting:
"""
Selects a subset of units
Returns a new sorting object which contains only a selected subset of units.
Parameters
----------
Expand All @@ -331,9 +334,30 @@ def select_units(self, unit_ids, renamed_unit_ids=None):
sub_sorting = UnitsSelectionSorting(self, unit_ids, renamed_unit_ids=renamed_unit_ids)
return sub_sorting

def remove_units(self, remove_unit_ids):
def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting:
"""
Removes a subset of units
Returns a new sorting object with renamed units.
Parameters
----------
new_unit_ids : numpy.array or list
List of new names for unit ids.
They should map positionally to the existing unit ids.
Returns
-------
BaseSorting
Sorting object with renamed units
"""
from spikeinterface import UnitsSelectionSorting

sub_sorting = UnitsSelectionSorting(self, renamed_unit_ids=new_unit_ids)
return sub_sorting

def remove_units(self, remove_unit_ids) -> BaseSorting:
"""
Returns a new sorting object with contains only a selected subset of units.
Parameters
----------
Expand All @@ -343,7 +367,7 @@ def remove_units(self, remove_unit_ids):
Returns
-------
BaseSorting
Sorting object without removed units
Sorting without the removed units
"""
from spikeinterface import UnitsSelectionSorting

Expand All @@ -353,7 +377,8 @@ def remove_units(self, remove_unit_ids):

def remove_empty_units(self):
"""
Removes units with empty spike trains
Returns a new sorting object which contains only units with at least one spike.
Returns
-------
Expand Down Expand Up @@ -389,7 +414,7 @@ def get_all_spike_trains(self, outputs="unit_id"):
"""
Return all spike trains concatenated.
This is deprecated use sorting.to_spike_vector() instead
This is deprecated and will be removed in spikeinterface 0.102 use sorting.to_spike_vector() instead
"""

warnings.warn(
Expand Down Expand Up @@ -429,7 +454,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac
Construct a unique structured numpy vector concatenating all spikes
with several fields: sample_index, unit_index, segment_index.
See also `get_all_spike_trains()`
Parameters
----------
Expand Down
13 changes: 13 additions & 0 deletions src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal
from spikeinterface.core.generate import generate_sorting

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
Expand Down Expand Up @@ -169,6 +170,18 @@ def test_npy_sorting():
assert_raises(Exception, sorting.register_recording, rec)


def test_rename_units_method():
num_units = 2
durations = [1.0, 1.0]

sorting = generate_sorting(num_units=num_units, durations=durations)

new_unit_ids = ["a", "b"]
new_sorting = sorting.rename_units(new_unit_ids=new_unit_ids)

assert np.array_equal(new_sorting.get_unit_ids(), new_unit_ids)


def test_empty_sorting():
sorting = NumpySorting.from_unit_dict({}, 30000)

Expand Down

0 comments on commit 0ed608f

Please sign in to comment.