diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 9987edadc6..718f27f688 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -20,6 +20,10 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) channel_ids = parent_recording.get_channel_ids() if renamed_channel_ids is None: renamed_channel_ids = channel_ids + else: + assert len(renamed_channel_ids) == len( + np.unique(renamed_channel_ids) + ), "renamed_channel_ids must be unique!" self._parent_recording = parent_recording self._channel_ids = np.asarray(channel_ids) diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 38916e5bf1..379575cbfa 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -18,6 +18,7 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None): unit_ids = parent_sorting.get_unit_ids() if renamed_unit_ids is None: renamed_unit_ids = unit_ids + assert len(renamed_unit_ids) == len(np.unique(renamed_unit_ids)), "renamed_unit_ids must be unique!" self._parent_sorting = parent_sorting self._unit_ids = np.asarray(unit_ids) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2d691a456c..b0d470fe40 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -153,7 +153,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): for k in ["ms_before", "ms_after"]: waveforms_params[k] = params["general"][k] - if params["shared_memory"]: + if params["shared_memory"] and not params["debug"]: mode = "memory" waveforms_folder = None else: diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index cfdca6f612..f22c3e3399 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -494,8 +494,8 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): ---------- amplitude: tuple (Minimal, Maximal) amplitudes allowed for every template - omp_min_sps: float - Stopping criteria of the OMP algorithm, as relative error + max_failures: int + Stopping criteria of the OMP algorithm, as number of retry while updating amplitudes sparse_kwargs: dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. @@ -508,8 +508,11 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ _default_params = { - "amplitudes": [0.6, 1.4], - "omp_min_sps": 5e-5, + "amplitudes": [0.6, 2], + "stop_criteria": "max_failures", + "max_failures": 20, + "omp_min_sps": 0.1, + "relative_error": 5e-5, "waveform_extractor": None, "rank": 5, "sparse_kwargs": {"method": "ptp", "threshold": 1}, @@ -522,6 +525,8 @@ def _prepare_templates(cls, d): waveform_extractor = d["waveform_extractor"] num_templates = len(d["waveform_extractor"].sorting.unit_ids) + assert d["stop_criteria"] in ["max_failures", "omp_min_sps", "relative_error"] + if not waveform_extractor.is_sparse(): sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask else: @@ -598,11 +603,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d = cls._default_params.copy() d.update(kwargs) - # assert isinstance(d['waveform_extractor'], WaveformExtractor) - - for v in ["omp_min_sps"]: - assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" - d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() d["num_samples"] = d["waveform_extractor"].nsamples d["nbefore"] = d["waveform_extractor"].nbefore @@ -632,7 +632,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) - d["stop_criteria"] = d["omp_min_sps"] return d @classmethod @@ -666,7 +665,6 @@ def main_function(cls, traces, d): neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"] vicinity = d["vicinity"] rank = d["rank"] @@ -709,13 +707,22 @@ def main_function(cls, traces, d): all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) - if len(ignored_ids) > 0: - new_error = np.linalg.norm(scalar_products[not_ignored]) - else: - new_error = np.linalg.norm(scalar_products) - delta_error = np.inf - while delta_error > stop_criteria: + if d["stop_criteria"] == "omp_min_sps": + stop_criteria = d["omp_min_sps"] * np.maximum(d["norms"], np.sqrt(num_channels * num_samples)) + elif d["stop_criteria"] == "max_failures": + nb_valids = 0 + nb_failures = d["max_failures"] + elif d["stop_criteria"] == "relative_error": + if len(ignored_ids) > 0: + new_error = np.linalg.norm(scalar_products[not_ignored]) + else: + new_error = np.linalg.norm(scalar_products) + delta_error = np.inf + + do_loop = True + + while do_loop: best_amplitude_ind = scalar_products.argmax() best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) @@ -812,12 +819,25 @@ def main_function(cls, traces, d): to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add - previous_error = new_error - if len(ignored_ids) > 0: - new_error = np.linalg.norm(scalar_products[not_ignored]) - else: - new_error = np.linalg.norm(scalar_products) - delta_error = np.abs(new_error / previous_error - 1) + # We stop when updates do not modify the chosen spikes anymore + if d["stop_criteria"] == "omp_min_sps": + is_valid = scalar_products > stop_criteria[:, np.newaxis] + do_loop = np.any(is_valid) + elif d["stop_criteria"] == "max_failures": + is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) + new_nb_valids = np.sum(is_valid) + if (new_nb_valids - nb_valids) == 0: + nb_failures -= 1 + nb_valids = new_nb_valids + do_loop = nb_failures > 0 + elif d["stop_criteria"] == "relative_error": + previous_error = new_error + if len(ignored_ids) > 0: + new_error = np.linalg.norm(scalar_products[not_ignored]) + else: + new_error = np.linalg.norm(scalar_products) + delta_error = np.abs(new_error / previous_error - 1) + do_loop = delta_error > d["relative_error"] is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid)