Skip to content

Commit

Permalink
Merge branch 'merging_units' of github.com:yger/spikeinterface into m…
Browse files Browse the repository at this point in the history
…erging_units
  • Loading branch information
yger committed Jun 26, 2024
2 parents 7b88ff6 + d520b28 commit 4338fe3
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 8 deletions.
4 changes: 0 additions & 4 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def _rearange_waveforms(self, new_sorting_analyzer, units_to_merge, unit_ids, ke
for unit_id1, to_be_merged in zip(unit_ids, units_to_merge):

chan_inds_new = new_sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id1]
print(chan_inds_new.size)

for unit_id2 in to_be_merged:
unit_ind2 = self.sorting_analyzer.sorting.id_to_index(unit_id2)
Expand Down Expand Up @@ -328,7 +327,6 @@ def _merge_extension_data(

keep_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(keep_unit_ids)
keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices)
print(new_data['waveforms'].shape, np.where(keep_spike_mask)[-1], len(keep_spike_mask), np.sum(keep_spike_mask))
new_data["waveforms"][keep_spike_mask, :, :old_num_chans] = waveforms[keep_spike_mask]

# We only recompute waveforms for new units that might have a new sparsity mask. Could be
Expand Down Expand Up @@ -356,7 +354,6 @@ def _merge_extension_data(
new_waveforms = self._get_waveforms(
new_sorting_analyzer, new_unit_ids_large, verbose, **job_kwargs
)
print(new_waveforms.shape, np.where(updated_spike_mask)[-1], len(updated_spike_mask), np.sum(updated_spike_mask))
new_data["waveforms"][updated_spike_mask] = new_waveforms

## For the units with smaller masks, we need to rearrange the waveforms
Expand All @@ -367,7 +364,6 @@ def _merge_extension_data(
new_waveforms = self._rearange_waveforms(
new_sorting_analyzer, old_units_to_merge_small, new_unit_ids_small, kept_indices, verbose, **job_kwargs
)
print(new_waveforms.shape, np.where(updated_spike_mask)[-1], len(updated_spike_mask), np.sum(updated_spike_mask))
new_data["waveforms"][updated_spike_mask] = new_waveforms

else:
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def random_spikes_selection(
def get_ids_after_merging(sorting, units_to_merge, new_unit_ids):

assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge"

merged_unit_ids = set(sorting.unit_ids)
for count in range(len(units_to_merge)):
assert len(units_to_merge[count]) > 1, "A merge should have at least two units"
Expand Down Expand Up @@ -269,6 +269,7 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m
to_keep = np.ones(len(spikes), dtype=bool)

from spikeinterface.curation.curation_tools import get_new_unit_ids_for_merges

new_unit_ids = get_new_unit_ids_for_merges(sorting, units_to_merge, new_unit_ids)

all_unit_ids = get_ids_after_merging(sorting, units_to_merge, new_unit_ids)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,7 @@ def merge_units(
assert merging_mode in ["soft", "hard"], "Merging mode should be either soft or hard"

from spikeinterface.curation.curation_tools import get_new_unit_ids_for_merges

new_unit_ids = get_new_unit_ids_for_merges(self.sorting, units_to_merge, new_unit_ids)

if not isinstance(units_to_merge[0], (list, tuple)):
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/curation/curation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def find_duplicated_spikes(
else:
raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes. Use one of {_methods}")


def get_new_unit_ids_for_merges(sorting, units_to_merge, new_unit_ids):

all_removed_ids = []
Expand Down Expand Up @@ -166,5 +167,5 @@ def get_new_unit_ids_for_merges(sorting, units_to_merge, new_unit_ids):
else:
# dtype int
new_unit_ids = list(max(sorting.unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
return new_unit_ids

return new_unit_ids
2 changes: 1 addition & 1 deletion src/spikeinterface/curation/mergeunitssorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy
keep_unit_ids = [u for u in parents_unit_ids if u not in all_removed_ids]

from .curation_tools import get_new_unit_ids_for_merges
new_unit_ids = get_new_unit_ids_for_merges(sorting, units_to_merge, new_unit_ids)

new_unit_ids = get_new_unit_ids_for_merges(sorting, units_to_merge, new_unit_ids)

assert len(new_unit_ids) == num_merge, "new_unit_ids must have the same size as units_to_merge"

Expand Down

0 comments on commit 4338fe3

Please sign in to comment.