From 303211251210dc4093919eef2222f8e110e71950 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 1 Oct 2024 15:20:29 +0200 Subject: [PATCH 1/2] fix random_spikes_selection() --- src/spikeinterface/core/sorting_tools.py | 14 +++++++++----- .../core/tests/test_sorting_tools.py | 6 +++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 5f33350820..575c7f67e9 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -197,17 +197,21 @@ def random_spikes_selection( cum_sizes = np.cumsum([0] + [s.size for s in spikes]) # this fast when numba - spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids) + spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False) random_spikes_indices = [] for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] for segment_index in range(sorting.get_num_segments()): - inds_in_seg = spike_indices[segment_index][unit_id] + cum_sizes[segment_index] + # this is local index + inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: - inds_in_seg = inds_in_seg[inds_in_seg >= margin_size] - inds_in_seg = inds_in_seg[inds_in_seg < (num_samples[segment_index] - margin_size)] - all_unit_indices.append(inds_in_seg) + local_spikes = spikes[segment_index][inds_in_seg] + mask = (local_spikes["sample_index"] >= margin_size) & (local_spikes["sample_index"] < (num_samples[segment_index] - margin_size)) + inds_in_seg = inds_in_seg[mask] + # go back to absolut index + inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index] + all_unit_indices.append(inds_in_seg_abs) all_unit_indices = np.concatenate(all_unit_indices) selected_unit_indices = rng.choice( all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 34bb3a221d..7d26773ac3 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -162,8 +162,8 @@ def test_generate_unit_ids_for_merge_group(): if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() - # test_random_spikes_selection() + test_random_spikes_selection() - test_apply_merges_to_sorting() - test_get_ids_after_merging() + # test_apply_merges_to_sorting() + # test_get_ids_after_merging() # test_generate_unit_ids_for_merge_group() From 036691bb04ed079d5736a53808d4a7e8edb375da Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 13:24:14 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 575c7f67e9..213968a80b 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -207,7 +207,9 @@ def random_spikes_selection( inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: local_spikes = spikes[segment_index][inds_in_seg] - mask = (local_spikes["sample_index"] >= margin_size) & (local_spikes["sample_index"] < (num_samples[segment_index] - margin_size)) + mask = (local_spikes["sample_index"] >= margin_size) & ( + local_spikes["sample_index"] < (num_samples[segment_index] - margin_size) + ) inds_in_seg = inds_in_seg[mask] # go back to absolut index inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index]