Skip to content

Commit

Permalink
Fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jun 26, 2024
1 parent f265bd1 commit 3d73d80
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 27 deletions.
29 changes: 4 additions & 25 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .basesorting import BaseSorting
import numpy as np
from spikeinterface.core import NumpySorting
from spikeinterface.core import NumpySorting



def spike_vector_to_spike_trains(spike_vector: list[np.array], unit_ids: np.array) -> dict[dict[str, np.array]]:
Expand Down Expand Up @@ -264,12 +266,8 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m
"""

spikes = sorting.to_spike_vector().copy()

if censor_ms is None:
to_keep = None
else:
to_keep = np.ones(len(spikes), dtype=bool)

to_keep = np.ones(len(spikes), dtype=bool)

new_unit_ids = get_new_unit_ids_for_merges(sorting, units_to_merge, new_unit_ids)

all_unit_ids = get_ids_after_merging(sorting, units_to_merge, new_unit_ids)
Expand Down Expand Up @@ -298,28 +296,9 @@ def apply_merges_to_sorting(sorting, units_to_merge, new_unit_ids=None, censor_m
(indices,) = s0 + np.nonzero(mask[s0:s1])
to_keep[indices[1:]] = np.diff(spikes[indices]["sample_index"]) > rpv

from spikeinterface.core import NumpySorting

# times_list = []
# labels_list = []
# for segment_index in range(sorting.get_num_segments()):
# s0, s1 = segment_slices[segment_index]
# if censor_ms is not None:
# times_list += [spikes["sample_index"][s0:s1][to_keep[s0:s1]]]
# labels = spikes["unit_index"][s0:s1][to_keep[s0:s1]]
# labels_list += [labels]
# else:
# times_list += [spikes["sample_index"][s0:s1]]
# labels = spikes["unit_index"][s0:s1]
# labels_list += [labels]

# sorting = NumpySorting.from_times_labels(times_list, labels_list, sorting.sampling_frequency)
# sorting = sorting.rename_units(all_unit_ids)

combined_ids = np.array(list(sorting.unit_ids) + list(new_unit_ids))
sorting = NumpySorting(spikes[to_keep], unit_ids=combined_ids, sampling_frequency=sorting.sampling_frequency)
sorting = sorting.select_units(all_unit_ids)

return sorting, to_keep


Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
else:
folder = None
sorting_analyzer4 = sorting_analyzer.merge_units(
units_to_merge=[[0, 1]], new_unit_ids=[50], format=format, folder=folder
units_to_merge=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, mode='hard'
)

# test compute with extension-specific params
Expand Down Expand Up @@ -279,7 +279,6 @@ class DummyAnalyzerExtension(AnalyzerExtension):

def _set_params(self, param0="yep", param1=1.2, param2=[1, 2, 3.0]):
params = dict(param0=param0, param1=param1, param2=param2)
params["more_option"] = "yep"
return params

def _run(self, **kwargs):
Expand Down

0 comments on commit 3d73d80

Please sign in to comment.