Skip to content

Commit

Permalink
Merge branch 'main' into fix-read-only
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Nov 24, 2023
2 parents ef9ecfd + fcbf422 commit 7879aa5
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 24 deletions.
4 changes: 4 additions & 0 deletions src/spikeinterface/core/channelslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None)
channel_ids = parent_recording.get_channel_ids()
if renamed_channel_ids is None:
renamed_channel_ids = channel_ids
else:
assert len(renamed_channel_ids) == len(
np.unique(renamed_channel_ids)
), "renamed_channel_ids must be unique!"

self._parent_recording = parent_recording
self._channel_ids = np.asarray(channel_ids)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None):
unit_ids = parent_sorting.get_unit_ids()
if renamed_unit_ids is None:
renamed_unit_ids = unit_ids
assert len(renamed_unit_ids) == len(np.unique(renamed_unit_ids)), "renamed_unit_ids must be unique!"

self._parent_sorting = parent_sorting
self._unit_ids = np.asarray(unit_ids)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
for k in ["ms_before", "ms_after"]:
waveforms_params[k] = params["general"][k]

if params["shared_memory"]:
if params["shared_memory"] and not params["debug"]:
mode = "memory"
waveforms_folder = None
else:
Expand Down
66 changes: 43 additions & 23 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,8 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine):
----------
amplitude: tuple
(Minimal, Maximal) amplitudes allowed for every template
omp_min_sps: float
Stopping criteria of the OMP algorithm, as relative error
max_failures: int
Stopping criteria of the OMP algorithm, as number of retry while updating amplitudes
sparse_kwargs: dict
Parameters to extract a sparsity mask from the waveform_extractor, if not
already sparse.
Expand All @@ -508,8 +508,11 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine):
"""

_default_params = {
"amplitudes": [0.6, 1.4],
"omp_min_sps": 5e-5,
"amplitudes": [0.6, 2],
"stop_criteria": "max_failures",
"max_failures": 20,
"omp_min_sps": 0.1,
"relative_error": 5e-5,
"waveform_extractor": None,
"rank": 5,
"sparse_kwargs": {"method": "ptp", "threshold": 1},
Expand All @@ -522,6 +525,8 @@ def _prepare_templates(cls, d):
waveform_extractor = d["waveform_extractor"]
num_templates = len(d["waveform_extractor"].sorting.unit_ids)

assert d["stop_criteria"] in ["max_failures", "omp_min_sps", "relative_error"]

if not waveform_extractor.is_sparse():
sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask
else:
Expand Down Expand Up @@ -598,11 +603,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs):
d = cls._default_params.copy()
d.update(kwargs)

# assert isinstance(d['waveform_extractor'], WaveformExtractor)

for v in ["omp_min_sps"]:
assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]"

d["num_channels"] = d["waveform_extractor"].recording.get_num_channels()
d["num_samples"] = d["waveform_extractor"].nsamples
d["nbefore"] = d["waveform_extractor"].nbefore
Expand Down Expand Up @@ -632,7 +632,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs):
d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int)
d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i]))

d["stop_criteria"] = d["omp_min_sps"]
return d

@classmethod
Expand Down Expand Up @@ -666,7 +665,6 @@ def main_function(cls, traces, d):
neighbor_window = num_samples - 1
min_amplitude, max_amplitude = d["amplitudes"]
ignored_ids = d["ignored_ids"]
stop_criteria = d["stop_criteria"]
vicinity = d["vicinity"]
rank = d["rank"]

Expand Down Expand Up @@ -709,13 +707,22 @@ def main_function(cls, traces, d):

all_amplitudes = np.zeros(0, dtype=np.float32)
is_in_vicinity = np.zeros(0, dtype=np.int32)
if len(ignored_ids) > 0:
new_error = np.linalg.norm(scalar_products[not_ignored])
else:
new_error = np.linalg.norm(scalar_products)
delta_error = np.inf

while delta_error > stop_criteria:
if d["stop_criteria"] == "omp_min_sps":
stop_criteria = d["omp_min_sps"] * np.maximum(d["norms"], np.sqrt(num_channels * num_samples))
elif d["stop_criteria"] == "max_failures":
nb_valids = 0
nb_failures = d["max_failures"]
elif d["stop_criteria"] == "relative_error":
if len(ignored_ids) > 0:
new_error = np.linalg.norm(scalar_products[not_ignored])
else:
new_error = np.linalg.norm(scalar_products)
delta_error = np.inf

do_loop = True

while do_loop:
best_amplitude_ind = scalar_products.argmax()
best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape)

Expand Down Expand Up @@ -812,12 +819,25 @@ def main_function(cls, traces, d):
to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]]
scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add

previous_error = new_error
if len(ignored_ids) > 0:
new_error = np.linalg.norm(scalar_products[not_ignored])
else:
new_error = np.linalg.norm(scalar_products)
delta_error = np.abs(new_error / previous_error - 1)
# We stop when updates do not modify the chosen spikes anymore
if d["stop_criteria"] == "omp_min_sps":
is_valid = scalar_products > stop_criteria[:, np.newaxis]
do_loop = np.any(is_valid)
elif d["stop_criteria"] == "max_failures":
is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude)
new_nb_valids = np.sum(is_valid)
if (new_nb_valids - nb_valids) == 0:
nb_failures -= 1
nb_valids = new_nb_valids
do_loop = nb_failures > 0
elif d["stop_criteria"] == "relative_error":
previous_error = new_error
if len(ignored_ids) > 0:
new_error = np.linalg.norm(scalar_products[not_ignored])
else:
new_error = np.linalg.norm(scalar_products)
delta_error = np.abs(new_error / previous_error - 1)
do_loop = delta_error > d["relative_error"]

is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude)
valid_indices = np.where(is_valid)
Expand Down

0 comments on commit 7879aa5

Please sign in to comment.