From 68fe2ba08f9c41b3feaf7866fee934291d78f7ea Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 26 Sep 2023 09:26:40 +0200 Subject: [PATCH 01/23] OMP with SVD decomposition --- .../sortingcomponents/matching/circus.py | 307 ++++++++++++++++++ .../sortingcomponents/matching/method_list.py | 3 +- 2 files changed, 309 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index a19e7b71b5..e86c913976 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -478,6 +478,313 @@ def main_function(cls, traces, d): return spikes +class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): + """ + Orthogonal Matching Pursuit inspired from Spyking Circus sorter + + https://elifesciences.org/articles/34518 + + This is an Orthogonal Template Matching algorithm. For speed and + memory optimization, templates are automatically sparsified. Signal + is convolved with the templates, and as long as some scalar products + are higher than a given threshold, we use a Cholesky decomposition + to compute the optimal amplitudes needed to reconstruct the signal. + + IMPORTANT NOTE: small chunks are more efficient for such Peeler, + consider using 100ms chunk + + Parameters + ---------- + amplitude: tuple + (Minimal, Maximal) amplitudes allowed for every template + omp_min_sps: float + Stopping criteria of the OMP algorithm, in percentage of the norm + noise_levels: array + The noise levels, for every channels. If None, they will be automatically + computed + random_chunk_kwargs: dict + Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. + ----- + """ + + _default_params = { + "amplitudes": [0.6, 2], + "omp_min_sps": 0.1, + "waveform_extractor": None, + "templates": None, + "overlaps": None, + "norms": None, + "random_chunk_kwargs": {}, + "noise_levels": None, + "rank" : 3, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, + "ignored_ids": [], + "vicinity": 0, + } + + @classmethod + def _prepare_templates(cls, d): + waveform_extractor = d["waveform_extractor"] + num_templates = len(d["waveform_extractor"].sorting.unit_ids) + + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask + + templates = waveform_extractor.get_all_templates(mode="median").copy() + + temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) + + # Keep only the strongest components + rank = d['rank'] + d['templates'] = {} + d["norms"] = np.zeros(num_templates, dtype=np.float32) + d['sparsities'] = {} + d["norms"] = np.linalg.norm(templates, axis=(1, 2)) + for i in range(num_templates): + d['sparsities'][i] = np.arange(templates.shape[2]) + d['templates'][i] = templates[i] / d["norms"][i] + + temporal = temporal[:, :, :rank] + d["temporal"] = np.flip(temporal, axis=1) + d["singular"] = singular[:, :rank] + d["spatial"] = spatial[:, :rank, :] + + d['temporal'] /= d['norms'][:, np.newaxis, np.newaxis] + + d["spatial"] = np.moveaxis(d['spatial'][:, :rank, :], [0, 1, 2], [1, 0, 2]) + d['temporal'] = np.moveaxis(d['temporal'][:, :, :rank], [0, 1, 2], [1, 2, 0]) + d['singular'] = d['singular'].T[:, :, np.newaxis] + return d + + @classmethod + def initialize_and_check_kwargs(cls, recording, kwargs): + d = cls._default_params.copy() + d.update(kwargs) + + # assert isinstance(d['waveform_extractor'], WaveformExtractor) + + for v in ["omp_min_sps"]: + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + + d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() + d["num_samples"] = d["waveform_extractor"].nsamples + d["nbefore"] = d["waveform_extractor"].nbefore + d["nafter"] = d["waveform_extractor"].nafter + d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] + + if d["noise_levels"] is None: + print("CircusOMPPeeler : noise should be computed outside") + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + + if d["templates"] is None: + d = cls._prepare_templates(d) + else: + for key in ["norms", "sparsities"]: + assert d[key] is not None, "If templates are provided, %d should also be there" % key + + d["num_templates"] = len(d["templates"]) + + if d["overlaps"] is None: + d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) + + d["ignored_ids"] = np.array(d["ignored_ids"]) + + omp_min_sps = d["omp_min_sps"] + # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) + d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + + return d + + @classmethod + def serialize_method_kwargs(cls, kwargs): + kwargs = dict(kwargs) + # remove waveform_extractor + kwargs.pop("waveform_extractor") + return kwargs + + @classmethod + def unserialize_in_worker(cls, kwargs): + return kwargs + + @classmethod + def get_margin(cls, recording, kwargs): + margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) + return margin + + @classmethod + def main_function(cls, traces, d): + templates = d["templates"] + num_templates = d["num_templates"] + num_channels = d["num_channels"] + num_samples = d["num_samples"] + overlaps = d["overlaps"] + norms = d["norms"] + nbefore = d["nbefore"] + nafter = d["nafter"] + omp_tol = np.finfo(np.float32).eps + num_samples = d["nafter"] + d["nbefore"] + neighbor_window = num_samples - 1 + min_amplitude, max_amplitude = d["amplitudes"] + sparsities = d["sparsities"] + ignored_ids = d["ignored_ids"] + stop_criteria = d["stop_criteria"] + vicinity = d["vicinity"] + rank = d['rank'] + + num_timesteps = len(traces) + + num_peaks = num_timesteps - num_samples + 1 + conv_shape = (num_templates, num_peaks) + scalar_products = np.zeros(conv_shape, dtype=np.float32) + + # Filter using overlap-and-add convolution + spatially_filtered_data = np.matmul(d['spatial'], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d['singular'] + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d['temporal'], axes=2, mode="valid") + scalar_products += np.sum(objective_by_rank, axis=0) + + if len(ignored_ids) > 0: + scalar_products[ignored_ids] = -np.inf + + num_spikes = 0 + + spikes = np.empty(scalar_products.size, dtype=spike_dtype) + idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) + + M = np.zeros((100, 100), dtype=np.float32) + + all_selections = np.empty((2, scalar_products.size), dtype=np.int32) + final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) + num_selection = 0 + + full_sps = scalar_products.copy() + + neighbors = {} + cached_overlaps = {} + + is_valid = scalar_products > stop_criteria + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) + + while np.any(is_valid): + best_amplitude_ind = scalar_products[is_valid].argmax() + best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) + + if num_selection > 0: + delta_t = selection[1] - peak_index + idx = np.where((delta_t < neighbor_window) & (delta_t > -num_samples))[0] + myline = num_samples + delta_t[idx] + + if not best_cluster_ind in cached_overlaps: + cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() + + if num_selection == M.shape[0]: + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z[:num_selection, :num_selection] = M + M = Z + + M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] + + if vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 + else: + M[0, 0] = 1 + + all_selections[:, num_selection] = [best_cluster_ind, peak_index] + num_selection += 1 + + selection = all_selections[:, :num_selection] + res_sps = full_sps[selection[0], selection[1]] + + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + # This is not working, need to figure out why + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + + diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] + modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] + final_amplitudes[selection[0], selection[1]] = all_amplitudes + + for i in modified: + tmp_best, tmp_peak = selection[:, i] + diff_amp = diff_amplitudes[i] * norms[tmp_best] + + if not tmp_best in cached_overlaps: + cached_overlaps[tmp_best] = overlaps[tmp_best].toarray() + + if not tmp_peak in neighbors.keys(): + idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] + tdx = [num_samples + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak] + neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + + idx = neighbors[tmp_peak]["idx"] + tdx = neighbors[tmp_peak]["tdx"] + + to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0] : tdx[1]] + scalar_products[:, idx[0] : idx[1]] -= to_add + + is_valid = scalar_products > stop_criteria + + is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) + valid_indices = np.where(is_valid) + + num_spikes = len(valid_indices[0]) + spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["channel_index"][:num_spikes] = 0 + spikes["cluster_index"][:num_spikes] = valid_indices[0] + spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + + spikes = spikes[:num_spikes] + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] + + return spikes + + + + class CircusPeeler(BaseTemplateMatchingEngine): """ diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index bedc04a9d5..46c4a53872 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -1,6 +1,6 @@ from .naive import NaiveMatching from .tdc import TridesclousPeeler -from .circus import CircusPeeler, CircusOMPPeeler +from .circus import CircusPeeler, CircusOMPPeeler, CircusOMPSVDPeeler from .wobble import WobbleMatch matching_methods = { @@ -8,5 +8,6 @@ "tridesclous": TridesclousPeeler, "circus": CircusPeeler, "circus-omp": CircusOMPPeeler, + 'circus-omp-svd' : CircusOMPSVDPeeler, "wobble": WobbleMatch, } From cc4720460127960d5d8cf16248690b3323c6c4a9 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 26 Sep 2023 10:49:57 +0200 Subject: [PATCH 02/23] Increase default rank --- .../sortingcomponents/matching/circus.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index e86c913976..bc378fb9a2 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -519,7 +519,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): "norms": None, "random_chunk_kwargs": {}, "noise_levels": None, - "rank" : 3, + "rank" : 10, "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], "vicinity": 0, @@ -537,17 +537,20 @@ def _prepare_templates(cls, d): templates = waveform_extractor.get_all_templates(mode="median").copy() - temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) - # Keep only the strongest components rank = d['rank'] d['templates'] = {} d["norms"] = np.zeros(num_templates, dtype=np.float32) d['sparsities'] = {} - d["norms"] = np.linalg.norm(templates, axis=(1, 2)) - for i in range(num_templates): - d['sparsities'][i] = np.arange(templates.shape[2]) - d['templates'][i] = templates[i] / d["norms"][i] + + for count in range(num_templates): + template = templates[count][:, sparsity[count]] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) + d["norms"][count] = np.linalg.norm(template) + templates[count][:, ~sparsity[count]] = 0 + d["templates"][count] = template / d["norms"][count] + + temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) temporal = temporal[:, :, :rank] d["temporal"] = np.flip(temporal, axis=1) @@ -631,7 +634,6 @@ def main_function(cls, traces, d): num_samples = d["nafter"] + d["nbefore"] neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] - sparsities = d["sparsities"] ignored_ids = d["ignored_ids"] stop_criteria = d["stop_criteria"] vicinity = d["vicinity"] From 10c33c1c8645aa7e144bdb8efbc06b993c79c4b0 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 26 Sep 2023 12:01:10 +0200 Subject: [PATCH 03/23] To be tried --- src/spikeinterface/sortingcomponents/matching/circus.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index bc378fb9a2..8c002a5cc7 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -601,6 +601,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): omp_min_sps = d["omp_min_sps"] # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + #d['stop_criteria'] = omp_min_sps * np.maximum(d['norms'], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) return d @@ -635,7 +636,7 @@ def main_function(cls, traces, d): neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"] + stop_criteria = d["stop_criteria"]#[:, np.newaxis] vicinity = d["vicinity"] rank = d['rank'] From b2a9b70abeb1fccbfa73e51f604253c0f02c81c0 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 26 Sep 2023 16:33:17 +0200 Subject: [PATCH 04/23] WIP --- src/spikeinterface/sortingcomponents/matching/circus.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8c002a5cc7..482d36956f 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -519,7 +519,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): "norms": None, "random_chunk_kwargs": {}, "noise_levels": None, - "rank" : 10, + "rank" : 5, "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], "vicinity": 0, @@ -599,9 +599,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["ignored_ids"] = np.array(d["ignored_ids"]) omp_min_sps = d["omp_min_sps"] - # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) - d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) - #d['stop_criteria'] = omp_min_sps * np.maximum(d['norms'], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) + #d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + d['stop_criteria'] = omp_min_sps * np.maximum(d['norms'], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) return d @@ -636,7 +635,7 @@ def main_function(cls, traces, d): neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"]#[:, np.newaxis] + stop_criteria = d["stop_criteria"][:, np.newaxis] vicinity = d["vicinity"] rank = d['rank'] From 3c94594fdd5ee6a58c2635a2f9a8dba9c8ce500d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 26 Sep 2023 17:01:51 +0200 Subject: [PATCH 05/23] Working with circus2 --- .../sorters/internal/spyking_circus2.py | 2 +- .../clustering/clustering_tools.py | 7 ++-- .../sortingcomponents/matching/circus.py | 37 ++++++++++--------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index db3d88f116..7097b9e56b 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -152,7 +152,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( - recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params + recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params ) if verbose: diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index b87bbc7cee..99836fa293 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -602,8 +602,6 @@ def remove_duplicates_via_matching( "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], "omp_min_sps": 0.1, - "templates": None, - "overlaps": None, } ) @@ -618,7 +616,7 @@ def remove_duplicates_via_matching( method_kwargs.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + sub_recording, method="circus-omp-svd", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs ) method_kwargs.update( { @@ -626,6 +624,9 @@ def remove_duplicates_via_matching( "templates": computed["templates"], "norms": computed["norms"], "sparsities": computed["sparsities"], + "temporal" : computed["temporal"], + "spatial" : computed["spatial"], + "singular" : computed["singular"], } ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 482d36956f..e955687ed7 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -514,9 +514,6 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): "amplitudes": [0.6, 2], "omp_min_sps": 0.1, "waveform_extractor": None, - "templates": None, - "overlaps": None, - "norms": None, "random_chunk_kwargs": {}, "noise_levels": None, "rank" : 5, @@ -537,28 +534,34 @@ def _prepare_templates(cls, d): templates = waveform_extractor.get_all_templates(mode="median").copy() - # Keep only the strongest components - rank = d['rank'] - d['templates'] = {} - d["norms"] = np.zeros(num_templates, dtype=np.float32) + #First, we set masked channels to 0 d['sparsities'] = {} - for count in range(num_templates): template = templates[count][:, sparsity[count]] (d["sparsities"][count],) = np.nonzero(sparsity[count]) - d["norms"][count] = np.linalg.norm(template) templates[count][:, ~sparsity[count]] = 0 - d["templates"][count] = template / d["norms"][count] + # Then we keep only the strongest components + rank = d['rank'] temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) - - temporal = temporal[:, :, :rank] - d["temporal"] = np.flip(temporal, axis=1) + d["temporal"] = temporal[:, :, :rank] d["singular"] = singular[:, :rank] d["spatial"] = spatial[:, :rank, :] - d['temporal'] /= d['norms'][:, np.newaxis, np.newaxis] + # We reconstruct the approximated templates + templates = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) + + d["temporal"] = np.flip(temporal, axis=1) + d['templates'] = {} + d["norms"] = np.zeros(num_templates, dtype=np.float32) + + # And get the norms, saving compressed templates for CC matrix + for count in range(num_templates): + template = templates[count][:, sparsity[count]] + d["norms"][count] = np.linalg.norm(template) + d["templates"][count] = template / d["norms"][count] + d['temporal'] /= d['norms'][:, np.newaxis, np.newaxis] d["spatial"] = np.moveaxis(d['spatial'][:, :rank, :], [0, 1, 2], [1, 0, 2]) d['temporal'] = np.moveaxis(d['temporal'][:, :, :rank], [0, 1, 2], [1, 2, 0]) d['singular'] = d['singular'].T[:, :, np.newaxis] @@ -585,15 +588,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): print("CircusOMPPeeler : noise should be computed outside") d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - if d["templates"] is None: + if "templates" not in d: d = cls._prepare_templates(d) else: - for key in ["norms", "sparsities"]: + for key in ["norms", "sparsities", 'temporal', 'spatial', 'singular']: assert d[key] is not None, "If templates are provided, %d should also be there" % key d["num_templates"] = len(d["templates"]) - if d["overlaps"] is None: + if "overlaps" not in d: d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) d["ignored_ids"] = np.array(d["ignored_ids"]) From 46149ef0730a8965f2ae612e9672419a18dc674c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 26 Sep 2023 22:29:35 +0200 Subject: [PATCH 06/23] Put OMP with SVD as default --- .../sorters/internal/spyking_circus2.py | 2 +- .../clustering/clustering_tools.py | 2 +- .../sortingcomponents/matching/circus.py | 315 ------------------ .../sortingcomponents/matching/method_list.py | 1 - 4 files changed, 2 insertions(+), 318 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 7097b9e56b..db3d88f116 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -152,7 +152,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( - recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params + recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params ) if verbose: diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 99836fa293..7a2af09942 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -616,7 +616,7 @@ def remove_duplicates_via_matching( method_kwargs.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method="circus-omp-svd", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs ) method_kwargs.update( { diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index e955687ed7..aeac69fc86 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -195,321 +195,6 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): ----- """ - _default_params = { - "amplitudes": [0.6, 2], - "omp_min_sps": 0.1, - "waveform_extractor": None, - "templates": None, - "overlaps": None, - "norms": None, - "random_chunk_kwargs": {}, - "noise_levels": None, - "sparse_kwargs": {"method": "ptp", "threshold": 1}, - "ignored_ids": [], - "vicinity": 0, - } - - @classmethod - def _prepare_templates(cls, d): - waveform_extractor = d["waveform_extractor"] - num_templates = len(d["waveform_extractor"].sorting.unit_ids) - - if not waveform_extractor.is_sparse(): - sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask - else: - sparsity = waveform_extractor.sparsity.mask - - templates = waveform_extractor.get_all_templates(mode="median").copy() - - d["sparsities"] = {} - d["templates"] = {} - d["norms"] = np.zeros(num_templates, dtype=np.float32) - - for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - template = templates[count][:, sparsity[count]] - (d["sparsities"][count],) = np.nonzero(sparsity[count]) - d["norms"][count] = np.linalg.norm(template) - d["templates"][count] = template / d["norms"][count] - - return d - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls._default_params.copy() - d.update(kwargs) - - # assert isinstance(d['waveform_extractor'], WaveformExtractor) - - for v in ["omp_min_sps"]: - assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" - - d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() - d["num_samples"] = d["waveform_extractor"].nsamples - d["nbefore"] = d["waveform_extractor"].nbefore - d["nafter"] = d["waveform_extractor"].nafter - d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() - d["vicinity"] *= d["num_samples"] - - if d["noise_levels"] is None: - print("CircusOMPPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - - if d["templates"] is None: - d = cls._prepare_templates(d) - else: - for key in ["norms", "sparsities"]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key - - d["num_templates"] = len(d["templates"]) - - if d["overlaps"] is None: - d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) - - d["ignored_ids"] = np.array(d["ignored_ids"]) - - omp_min_sps = d["omp_min_sps"] - # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) - d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - # remove waveform_extractor - kwargs.pop("waveform_extractor") - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) - return margin - - @classmethod - def main_function(cls, traces, d): - templates = d["templates"] - num_templates = d["num_templates"] - num_channels = d["num_channels"] - num_samples = d["num_samples"] - overlaps = d["overlaps"] - norms = d["norms"] - nbefore = d["nbefore"] - nafter = d["nafter"] - omp_tol = np.finfo(np.float32).eps - num_samples = d["nafter"] + d["nbefore"] - neighbor_window = num_samples - 1 - min_amplitude, max_amplitude = d["amplitudes"] - sparsities = d["sparsities"] - ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"] - vicinity = d["vicinity"] - - if "cached_fft_kernels" not in d: - d["cached_fft_kernels"] = {"fshape": 0} - - cached_fft_kernels = d["cached_fft_kernels"] - - num_timesteps = len(traces) - - num_peaks = num_timesteps - num_samples + 1 - - traces = traces.T - - dummy_filter = np.empty((num_channels, num_samples), dtype=np.float32) - dummy_traces = np.empty((num_channels, num_timesteps), dtype=np.float32) - - fshape, axes = get_scipy_shape(dummy_filter, traces, axes=1) - fft_cache = {"full": sp_fft.rfftn(traces, fshape, axes=axes)} - - scalar_products = np.empty((num_templates, num_peaks), dtype=np.float32) - - flagged_chunk = cached_fft_kernels["fshape"] != fshape[0] - - for i in range(num_templates): - if i not in ignored_ids: - if i not in cached_fft_kernels or flagged_chunk: - kernel_filter = np.ascontiguousarray(templates[i][::-1].T) - cached_fft_kernels.update({i: sp_fft.rfftn(kernel_filter, fshape, axes=axes)}) - cached_fft_kernels["fshape"] = fshape[0] - - fft_cache.update({"mask": sparsities[i], "template": cached_fft_kernels[i]}) - - convolution = fftconvolve_with_cache(dummy_filter, dummy_traces, fft_cache, axes=1, mode="valid") - if len(convolution) > 0: - scalar_products[i] = convolution.sum(0) - else: - scalar_products[i] = 0 - - if len(ignored_ids) > 0: - scalar_products[ignored_ids] = -np.inf - - num_spikes = 0 - - spikes = np.empty(scalar_products.size, dtype=spike_dtype) - idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - - M = np.zeros((100, 100), dtype=np.float32) - - all_selections = np.empty((2, scalar_products.size), dtype=np.int32) - final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) - num_selection = 0 - - full_sps = scalar_products.copy() - - neighbors = {} - cached_overlaps = {} - - is_valid = scalar_products > stop_criteria - all_amplitudes = np.zeros(0, dtype=np.float32) - is_in_vicinity = np.zeros(0, dtype=np.int32) - - while np.any(is_valid): - best_amplitude_ind = scalar_products[is_valid].argmax() - best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) - - if num_selection > 0: - delta_t = selection[1] - peak_index - idx = np.where((delta_t < neighbor_window) & (delta_t > -num_samples))[0] - myline = num_samples + delta_t[idx] - - if not best_cluster_ind in cached_overlaps: - cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() - - if num_selection == M.shape[0]: - Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) - Z[:num_selection, :num_selection] = M - M = Z - - M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] - - if vicinity == 0: - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) - else: - is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] - - if len(is_in_vicinity) > 0: - L = M[is_in_vicinity, :][:, is_in_vicinity] - - M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( - L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False - ) - - v = nrm2(M[num_selection, is_in_vicinity]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) - else: - M[num_selection, num_selection] = 1.0 - else: - M[0, 0] = 1 - - all_selections[:, num_selection] = [best_cluster_ind, peak_index] - num_selection += 1 - - selection = all_selections[:, :num_selection] - res_sps = full_sps[selection[0], selection[1]] - - if True: # vicinity == 0: - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - all_amplitudes /= norms[selection[0]] - else: - # This is not working, need to figure out why - is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) - all_amplitudes = np.append(all_amplitudes, np.float32(1)) - L = M[is_in_vicinity, :][:, is_in_vicinity] - all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) - all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] - - diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] - modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] - final_amplitudes[selection[0], selection[1]] = all_amplitudes - - for i in modified: - tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i] * norms[tmp_best] - - if not tmp_best in cached_overlaps: - cached_overlaps[tmp_best] = overlaps[tmp_best].toarray() - - if not tmp_peak in neighbors.keys(): - idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] - tdx = [num_samples + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak] - neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} - - idx = neighbors[tmp_peak]["idx"] - tdx = neighbors[tmp_peak]["tdx"] - - to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0] : tdx[1]] - scalar_products[:, idx[0] : idx[1]] -= to_add - - is_valid = scalar_products > stop_criteria - - is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) - valid_indices = np.where(is_valid) - - num_spikes = len(valid_indices[0]) - spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] - spikes["channel_index"][:num_spikes] = 0 - spikes["cluster_index"][:num_spikes] = valid_indices[0] - spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] - - spikes = spikes[:num_spikes] - order = np.argsort(spikes["sample_index"]) - spikes = spikes[order] - - return spikes - - -class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): - """ - Orthogonal Matching Pursuit inspired from Spyking Circus sorter - - https://elifesciences.org/articles/34518 - - This is an Orthogonal Template Matching algorithm. For speed and - memory optimization, templates are automatically sparsified. Signal - is convolved with the templates, and as long as some scalar products - are higher than a given threshold, we use a Cholesky decomposition - to compute the optimal amplitudes needed to reconstruct the signal. - - IMPORTANT NOTE: small chunks are more efficient for such Peeler, - consider using 100ms chunk - - Parameters - ---------- - amplitude: tuple - (Minimal, Maximal) amplitudes allowed for every template - omp_min_sps: float - Stopping criteria of the OMP algorithm, in percentage of the norm - noise_levels: array - The noise levels, for every channels. If None, they will be automatically - computed - random_chunk_kwargs: dict - Parameters for computing noise levels, if not provided (sub optimal) - sparse_kwargs: dict - Parameters to extract a sparsity mask from the waveform_extractor, if not - already sparse. - ----- - """ - _default_params = { "amplitudes": [0.6, 2], "omp_min_sps": 0.1, diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index 46c4a53872..c00c0a1fd3 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -8,6 +8,5 @@ "tridesclous": TridesclousPeeler, "circus": CircusPeeler, "circus-omp": CircusOMPPeeler, - 'circus-omp-svd' : CircusOMPSVDPeeler, "wobble": WobbleMatch, } From f21d80bf3cb34e5f39d59a7692a0c594025ea7b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Sep 2023 20:32:10 +0000 Subject: [PATCH 07/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../clustering/clustering_tools.py | 6 +-- .../sortingcomponents/matching/circus.py | 44 +++++++++---------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 7a2af09942..c1b635fdaf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -624,9 +624,9 @@ def remove_duplicates_via_matching( "templates": computed["templates"], "norms": computed["norms"], "sparsities": computed["sparsities"], - "temporal" : computed["temporal"], - "spatial" : computed["spatial"], - "singular" : computed["singular"], + "temporal": computed["temporal"], + "spatial": computed["spatial"], + "singular": computed["singular"], } ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index aeac69fc86..d2b02ea15d 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -201,7 +201,7 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): "waveform_extractor": None, "random_chunk_kwargs": {}, "noise_levels": None, - "rank" : 5, + "rank": 5, "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], "vicinity": 0, @@ -219,37 +219,37 @@ def _prepare_templates(cls, d): templates = waveform_extractor.get_all_templates(mode="median").copy() - #First, we set masked channels to 0 - d['sparsities'] = {} + # First, we set masked channels to 0 + d["sparsities"] = {} for count in range(num_templates): template = templates[count][:, sparsity[count]] (d["sparsities"][count],) = np.nonzero(sparsity[count]) templates[count][:, ~sparsity[count]] = 0 # Then we keep only the strongest components - rank = d['rank'] + rank = d["rank"] temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) d["temporal"] = temporal[:, :, :rank] d["singular"] = singular[:, :rank] d["spatial"] = spatial[:, :rank, :] - + # We reconstruct the approximated templates templates = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) d["temporal"] = np.flip(temporal, axis=1) - d['templates'] = {} + d["templates"] = {} d["norms"] = np.zeros(num_templates, dtype=np.float32) - + # And get the norms, saving compressed templates for CC matrix for count in range(num_templates): template = templates[count][:, sparsity[count]] d["norms"][count] = np.linalg.norm(template) - d["templates"][count] = template / d["norms"][count] - - d['temporal'] /= d['norms'][:, np.newaxis, np.newaxis] - d["spatial"] = np.moveaxis(d['spatial'][:, :rank, :], [0, 1, 2], [1, 0, 2]) - d['temporal'] = np.moveaxis(d['temporal'][:, :, :rank], [0, 1, 2], [1, 2, 0]) - d['singular'] = d['singular'].T[:, :, np.newaxis] + d["templates"][count] = template / d["norms"][count] + + d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] + d["spatial"] = np.moveaxis(d["spatial"][:, :rank, :], [0, 1, 2], [1, 0, 2]) + d["temporal"] = np.moveaxis(d["temporal"][:, :, :rank], [0, 1, 2], [1, 2, 0]) + d["singular"] = d["singular"].T[:, :, np.newaxis] return d @classmethod @@ -276,7 +276,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): if "templates" not in d: d = cls._prepare_templates(d) else: - for key in ["norms", "sparsities", 'temporal', 'spatial', 'singular']: + for key in ["norms", "sparsities", "temporal", "spatial", "singular"]: assert d[key] is not None, "If templates are provided, %d should also be there" % key d["num_templates"] = len(d["templates"]) @@ -287,8 +287,8 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["ignored_ids"] = np.array(d["ignored_ids"]) omp_min_sps = d["omp_min_sps"] - #d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) - d['stop_criteria'] = omp_min_sps * np.maximum(d['norms'], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) + # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) return d @@ -325,18 +325,18 @@ def main_function(cls, traces, d): ignored_ids = d["ignored_ids"] stop_criteria = d["stop_criteria"][:, np.newaxis] vicinity = d["vicinity"] - rank = d['rank'] + rank = d["rank"] num_timesteps = len(traces) num_peaks = num_timesteps - num_samples + 1 conv_shape = (num_templates, num_peaks) scalar_products = np.zeros(conv_shape, dtype=np.float32) - + # Filter using overlap-and-add convolution - spatially_filtered_data = np.matmul(d['spatial'], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d['singular'] - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d['temporal'], axes=2, mode="valid") + spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"] + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") scalar_products += np.sum(objective_by_rank, axis=0) if len(ignored_ids) > 0: @@ -473,8 +473,6 @@ def main_function(cls, traces, d): return spikes - - class CircusPeeler(BaseTemplateMatchingEngine): """ From a275bcaaf14819e64aa24a78a504b134f1d9288e Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 26 Sep 2023 22:32:57 +0200 Subject: [PATCH 08/23] Patch --- src/spikeinterface/sortingcomponents/matching/method_list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index c00c0a1fd3..bedc04a9d5 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -1,6 +1,6 @@ from .naive import NaiveMatching from .tdc import TridesclousPeeler -from .circus import CircusPeeler, CircusOMPPeeler, CircusOMPSVDPeeler +from .circus import CircusPeeler, CircusOMPPeeler from .wobble import WobbleMatch matching_methods = { From 85eb432c16a0719520a8dcbb24d2c8bb2c804d60 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 26 Sep 2023 22:44:15 +0200 Subject: [PATCH 09/23] Cleaning useless functions --- .../clustering/clustering_tools.py | 6 -- .../sortingcomponents/matching/circus.py | 95 ------------------- 2 files changed, 101 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index c1b635fdaf..5ff74db3e7 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -546,7 +546,6 @@ def remove_duplicates_via_matching( from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms from spikeinterface.core import get_global_tmp_folder - from spikeinterface.sortingcomponents.matching.circus import get_scipy_shape import string, random, shutil, os from pathlib import Path @@ -591,11 +590,6 @@ def remove_duplicates_via_matching( chunk_size = duration + 3 * margin - dummy_filter = np.empty((num_chans, duration), dtype=np.float32) - dummy_traces = np.empty((num_chans, chunk_size), dtype=np.float32) - - fshape, axes = get_scipy_shape(dummy_filter, dummy_traces, axes=1) - method_kwargs.update( { "waveform_extractor": waveform_extractor, diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index d2b02ea15d..ec6ef3a292 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -35,101 +35,6 @@ ################# # Circus peeler # -################# - -from scipy.fft._helper import _init_nd_shape_and_axes - -try: - from scipy.signal.signaltools import _init_freq_conv_axes, _apply_conv_mode -except Exception: - from scipy.signal._signaltools import _init_freq_conv_axes, _apply_conv_mode -from scipy import linalg, fft as sp_fft - - -def get_scipy_shape(in1, in2, mode="full", axes=None, calc_fast_len=True): - in1 = np.asarray(in1) - in2 = np.asarray(in2) - - if in1.ndim == in2.ndim == 0: # scalar inputs - return in1 * in2 - elif in1.ndim != in2.ndim: - raise ValueError("in1 and in2 should have the same dimensionality") - elif in1.size == 0 or in2.size == 0: # empty arrays - return np.array([]) - - in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False) - - s1 = in1.shape - s2 = in2.shape - - shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 for i in range(in1.ndim)] - - if not len(axes): - return in1 * in2 - - complex_result = in1.dtype.kind == "c" or in2.dtype.kind == "c" - - if calc_fast_len: - # Speed up FFT by padding to optimal size. - fshape = [sp_fft.next_fast_len(shape[a], not complex_result) for a in axes] - else: - fshape = shape - - return fshape, axes - - -def fftconvolve_with_cache(in1, in2, cache, mode="full", axes=None): - in1 = np.asarray(in1) - in2 = np.asarray(in2) - - if in1.ndim == in2.ndim == 0: # scalar inputs - return in1 * in2 - elif in1.ndim != in2.ndim: - raise ValueError("in1 and in2 should have the same dimensionality") - elif in1.size == 0 or in2.size == 0: # empty arrays - return np.array([]) - - in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False) - - s1 = in1.shape - s2 = in2.shape - - shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 for i in range(in1.ndim)] - - ret = _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True) - - return _apply_conv_mode(ret, s1, s2, mode, axes) - - -def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): - if not len(axes): - return in1 * in2 - - complex_result = in1.dtype.kind == "c" or in2.dtype.kind == "c" - - if calc_fast_len: - # Speed up FFT by padding to optimal size. - fshape = [sp_fft.next_fast_len(shape[a], not complex_result) for a in axes] - else: - fshape = shape - - if not complex_result: - fft, ifft = sp_fft.rfftn, sp_fft.irfftn - else: - fft, ifft = sp_fft.fftn, sp_fft.ifftn - - sp1 = cache["full"][cache["mask"]] - sp2 = cache["template"] - - # sp2 = fft(in2[cache['mask']], fshape, axes=axes) - ret = ifft(sp1 * sp2, fshape, axes=axes) - - if calc_fast_len: - fslice = tuple([slice(sz) for sz in shape]) - ret = ret[fslice] - - return ret - def compute_overlaps(templates, num_samples, num_channels, sparsities): num_templates = len(templates) From 15ae43215bf5a3b49a52081e18ad8ba3810bce15 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Sep 2023 20:44:37 +0000 Subject: [PATCH 10/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/circus.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ec6ef3a292..7bef8358de 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -36,6 +36,7 @@ ################# # Circus peeler # + def compute_overlaps(templates, num_samples, num_channels, sparsities): num_templates = len(templates) From 41155a1835f348d9181501d823cd78fca5cf6191 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 27 Sep 2023 13:36:15 +0200 Subject: [PATCH 11/23] Changing the internal representation of overlaps --- .../clustering/clustering_tools.py | 4 +- .../sortingcomponents/matching/circus.py | 78 +++++++++++++------ 2 files changed, 59 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 5ff74db3e7..032694a47e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -617,10 +617,12 @@ def remove_duplicates_via_matching( "overlaps": computed["overlaps"], "templates": computed["templates"], "norms": computed["norms"], - "sparsities": computed["sparsities"], "temporal": computed["temporal"], "spatial": computed["spatial"], "singular": computed["singular"], + "units_overlaps": computed["units_overlaps"], + "unit_overlaps_indices": computed["unit_overlaps_indices"], + "sparsity_mask": computed["sparsity_mask"], } ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ec6ef3a292..ffc2a225e8 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -122,14 +122,20 @@ def _prepare_templates(cls, d): else: sparsity = waveform_extractor.sparsity.mask + d['sparsity_mask'] = sparsity + units_overlaps = np.sum( + np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2 + ) + d['units_overlaps'] = units_overlaps > 0 + d['unit_overlaps_indices'] = {} + for i in range(num_templates): + d['unit_overlaps_indices'][i], = np.nonzero(d['units_overlaps'][i]) + templates = waveform_extractor.get_all_templates(mode="median").copy() # First, we set masked channels to 0 - d["sparsities"] = {} for count in range(num_templates): - template = templates[count][:, sparsity[count]] - (d["sparsities"][count],) = np.nonzero(sparsity[count]) - templates[count][:, ~sparsity[count]] = 0 + templates[count][:, ~d['sparsity_mask'][count]] = 0 # Then we keep only the strongest components rank = d["rank"] @@ -141,19 +147,45 @@ def _prepare_templates(cls, d): # We reconstruct the approximated templates templates = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) - d["temporal"] = np.flip(temporal, axis=1) d["templates"] = {} d["norms"] = np.zeros(num_templates, dtype=np.float32) # And get the norms, saving compressed templates for CC matrix for count in range(num_templates): - template = templates[count][:, sparsity[count]] + template = templates[count][:, d['sparsity_mask'][count]] d["norms"][count] = np.linalg.norm(template) d["templates"][count] = template / d["norms"][count] d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] - d["spatial"] = np.moveaxis(d["spatial"][:, :rank, :], [0, 1, 2], [1, 0, 2]) - d["temporal"] = np.moveaxis(d["temporal"][:, :, :rank], [0, 1, 2], [1, 2, 0]) + d["temporal"] = np.flip(d["temporal"], axis=1) + + d['overlaps'] = [] + for i in range(num_templates): + num_overlaps = np.sum(d['units_overlaps'][i]) + overlapping_units = np.where(d['units_overlaps'][i])[0] + + # Reconstruct unit template from SVD Matrices + data = d['temporal'][i] * d['singular'][i][np.newaxis, :] + template_i = np.matmul(data, d['spatial'][i, :, :]) + template_i = np.flipud(template_i) + + unit_overlaps = np.zeros([num_overlaps, 2*d['num_samples'] - 1], dtype=np.float32) + + for count, j in enumerate(overlapping_units): + overlapped_channels = d['sparsity_mask'][j] + visible_i = template_i[:, overlapped_channels] + + spatial_filters = d['spatial'][j, :, overlapped_channels] + spatially_filtered_template = np.matmul(visible_i, spatial_filters) + visible_i = spatially_filtered_template * d['singular'][j] + + for rank in range(visible_i.shape[1]): + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d['temporal'][j][:, rank], mode='full') + + d['overlaps'].append(unit_overlaps) + + d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) + d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) d["singular"] = d["singular"].T[:, :, np.newaxis] return d @@ -181,14 +213,10 @@ def initialize_and_check_kwargs(cls, recording, kwargs): if "templates" not in d: d = cls._prepare_templates(d) else: - for key in ["norms", "sparsities", "temporal", "spatial", "singular"]: + for key in ["norms", "temporal", "spatial", "singular", "units_overlaps", "sparsity_mask", "unit_overlaps_indices"]: assert d[key] is not None, "If templates are provided, %d should also be there" % key d["num_templates"] = len(d["templates"]) - - if "overlaps" not in d: - d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) - d["ignored_ids"] = np.array(d["ignored_ids"]) omp_min_sps = d["omp_min_sps"] @@ -252,7 +280,7 @@ def main_function(cls, traces, d): spikes = np.empty(scalar_products.size, dtype=spike_dtype) idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - M = np.zeros((100, 100), dtype=np.float32) + M = np.zeros((num_templates, num_templates), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -273,18 +301,24 @@ def main_function(cls, traces, d): if num_selection > 0: delta_t = selection[1] - peak_index - idx = np.where((delta_t < neighbor_window) & (delta_t > -num_samples))[0] + idx = np.where((delta_t < neighbor_window) & (delta_t >= -num_samples))[0] myline = num_samples + delta_t[idx] + myindices = selection[0, idx] - if not best_cluster_ind in cached_overlaps: - cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() + local_overlaps = overlaps[best_cluster_ind] + overlapping_templates = d['unit_overlaps_indices'][best_cluster_ind] if num_selection == M.shape[0]: Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) Z[:num_selection, :num_selection] = M M = Z - M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] + mask = np.isin(myindices, overlapping_templates) + a, b = myindices[mask], myline[mask] + + table = np.zeros(num_templates, dtype=int) + table[overlapping_templates] = np.arange(len(overlapping_templates)) + M[num_selection, myindices[mask]] = local_overlaps[table[a], b] if vicinity == 0: scipy.linalg.solve_triangular( @@ -346,8 +380,8 @@ def main_function(cls, traces, d): tmp_best, tmp_peak = selection[:, i] diff_amp = diff_amplitudes[i] * norms[tmp_best] - if not tmp_best in cached_overlaps: - cached_overlaps[tmp_best] = overlaps[tmp_best].toarray() + local_overlaps = overlaps[tmp_best] + overlapping_templates = d['units_overlaps'][tmp_best] if not tmp_peak in neighbors.keys(): idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] @@ -357,8 +391,8 @@ def main_function(cls, traces, d): idx = neighbors[tmp_peak]["idx"] tdx = neighbors[tmp_peak]["tdx"] - to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0] : tdx[1]] - scalar_products[:, idx[0] : idx[1]] -= to_add + to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] + scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add is_valid = scalar_products > stop_criteria From 97aff7f6754e7c4d333b95629552fe37151bf24f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 11:36:51 +0000 Subject: [PATCH 12/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/matching/circus.py | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index e047cbdd31..5924d3bc18 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -123,20 +123,18 @@ def _prepare_templates(cls, d): else: sparsity = waveform_extractor.sparsity.mask - d['sparsity_mask'] = sparsity - units_overlaps = np.sum( - np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2 - ) - d['units_overlaps'] = units_overlaps > 0 - d['unit_overlaps_indices'] = {} + d["sparsity_mask"] = sparsity + units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) + d["units_overlaps"] = units_overlaps > 0 + d["unit_overlaps_indices"] = {} for i in range(num_templates): - d['unit_overlaps_indices'][i], = np.nonzero(d['units_overlaps'][i]) + (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) templates = waveform_extractor.get_all_templates(mode="median").copy() # First, we set masked channels to 0 for count in range(num_templates): - templates[count][:, ~d['sparsity_mask'][count]] = 0 + templates[count][:, ~d["sparsity_mask"][count]] = 0 # Then we keep only the strongest components rank = d["rank"] @@ -153,37 +151,37 @@ def _prepare_templates(cls, d): # And get the norms, saving compressed templates for CC matrix for count in range(num_templates): - template = templates[count][:, d['sparsity_mask'][count]] + template = templates[count][:, d["sparsity_mask"][count]] d["norms"][count] = np.linalg.norm(template) d["templates"][count] = template / d["norms"][count] d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] d["temporal"] = np.flip(d["temporal"], axis=1) - d['overlaps'] = [] + d["overlaps"] = [] for i in range(num_templates): - num_overlaps = np.sum(d['units_overlaps'][i]) - overlapping_units = np.where(d['units_overlaps'][i])[0] + num_overlaps = np.sum(d["units_overlaps"][i]) + overlapping_units = np.where(d["units_overlaps"][i])[0] # Reconstruct unit template from SVD Matrices - data = d['temporal'][i] * d['singular'][i][np.newaxis, :] - template_i = np.matmul(data, d['spatial'][i, :, :]) + data = d["temporal"][i] * d["singular"][i][np.newaxis, :] + template_i = np.matmul(data, d["spatial"][i, :, :]) template_i = np.flipud(template_i) - unit_overlaps = np.zeros([num_overlaps, 2*d['num_samples'] - 1], dtype=np.float32) + unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) for count, j in enumerate(overlapping_units): - overlapped_channels = d['sparsity_mask'][j] + overlapped_channels = d["sparsity_mask"][j] visible_i = template_i[:, overlapped_channels] - spatial_filters = d['spatial'][j, :, overlapped_channels] + spatial_filters = d["spatial"][j, :, overlapped_channels] spatially_filtered_template = np.matmul(visible_i, spatial_filters) - visible_i = spatially_filtered_template * d['singular'][j] - + visible_i = spatially_filtered_template * d["singular"][j] + for rank in range(visible_i.shape[1]): - unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d['temporal'][j][:, rank], mode='full') + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full") - d['overlaps'].append(unit_overlaps) + d["overlaps"].append(unit_overlaps) d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) @@ -214,7 +212,15 @@ def initialize_and_check_kwargs(cls, recording, kwargs): if "templates" not in d: d = cls._prepare_templates(d) else: - for key in ["norms", "temporal", "spatial", "singular", "units_overlaps", "sparsity_mask", "unit_overlaps_indices"]: + for key in [ + "norms", + "temporal", + "spatial", + "singular", + "units_overlaps", + "sparsity_mask", + "unit_overlaps_indices", + ]: assert d[key] is not None, "If templates are provided, %d should also be there" % key d["num_templates"] = len(d["templates"]) @@ -307,7 +313,7 @@ def main_function(cls, traces, d): myindices = selection[0, idx] local_overlaps = overlaps[best_cluster_ind] - overlapping_templates = d['unit_overlaps_indices'][best_cluster_ind] + overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] if num_selection == M.shape[0]: Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) @@ -382,7 +388,7 @@ def main_function(cls, traces, d): diff_amp = diff_amplitudes[i] * norms[tmp_best] local_overlaps = overlaps[tmp_best] - overlapping_templates = d['units_overlaps'][tmp_best] + overlapping_templates = d["units_overlaps"][tmp_best] if not tmp_peak in neighbors.keys(): idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] From 8da6b79daa95bc4148123e76742607fb82b23fb3 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 27 Sep 2023 13:59:41 +0200 Subject: [PATCH 13/23] Keeping the two matching engines for more tests before merging and final decision --- .../clustering/clustering_tools.py | 39 +- .../sortingcomponents/matching/circus.py | 410 +++++++++++++++++- .../sortingcomponents/matching/method_list.py | 5 +- 3 files changed, 436 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 032694a47e..455af3ddfd 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -539,6 +539,7 @@ def remove_duplicates_via_matching( method_kwargs={}, job_kwargs={}, tmp_folder=None, + method='circus-omp' ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface import get_noise_levels @@ -610,21 +611,31 @@ def remove_duplicates_via_matching( method_kwargs.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method="circus-omp", method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs - ) - method_kwargs.update( - { - "overlaps": computed["overlaps"], - "templates": computed["templates"], - "norms": computed["norms"], - "temporal": computed["temporal"], - "spatial": computed["spatial"], - "singular": computed["singular"], - "units_overlaps": computed["units_overlaps"], - "unit_overlaps_indices": computed["unit_overlaps_indices"], - "sparsity_mask": computed["sparsity_mask"], - } + sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs ) + if method == 'circus-omp-vsd': + method_kwargs.update( + { + "overlaps": computed["overlaps"], + "templates": computed["templates"], + "norms": computed["norms"], + "temporal": computed["temporal"], + "spatial": computed["spatial"], + "singular": computed["singular"], + "units_overlaps": computed["units_overlaps"], + "unit_overlaps_indices": computed["unit_overlaps_indices"], + "sparsity_mask": computed["sparsity_mask"], + } + ) + elif method == 'circus-omp': + method_kwargs.update( + { + "overlaps": computed["overlaps"], + "templates": computed["templates"], + "norms": computed["norms"], + "sparsities": computed["sparsities"] + } + ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) if np.sum(valid) > 0: if np.sum(valid) == 1: diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index e047cbdd31..08be0985f1 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -33,8 +33,100 @@ from .main import BaseTemplateMatchingEngine -################# -# Circus peeler # + +from scipy.fft._helper import _init_nd_shape_and_axes + +try: + from scipy.signal.signaltools import _init_freq_conv_axes, _apply_conv_mode +except Exception: + from scipy.signal._signaltools import _init_freq_conv_axes, _apply_conv_mode +from scipy import linalg, fft as sp_fft + + +def get_scipy_shape(in1, in2, mode="full", axes=None, calc_fast_len=True): + in1 = np.asarray(in1) + in2 = np.asarray(in2) + + if in1.ndim == in2.ndim == 0: # scalar inputs + return in1 * in2 + elif in1.ndim != in2.ndim: + raise ValueError("in1 and in2 should have the same dimensionality") + elif in1.size == 0 or in2.size == 0: # empty arrays + return np.array([]) + + in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False) + + s1 = in1.shape + s2 = in2.shape + + shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 for i in range(in1.ndim)] + + if not len(axes): + return in1 * in2 + + complex_result = in1.dtype.kind == "c" or in2.dtype.kind == "c" + + if calc_fast_len: + # Speed up FFT by padding to optimal size. + fshape = [sp_fft.next_fast_len(shape[a], not complex_result) for a in axes] + else: + fshape = shape + + return fshape, axes + + +def fftconvolve_with_cache(in1, in2, cache, mode="full", axes=None): + in1 = np.asarray(in1) + in2 = np.asarray(in2) + + if in1.ndim == in2.ndim == 0: # scalar inputs + return in1 * in2 + elif in1.ndim != in2.ndim: + raise ValueError("in1 and in2 should have the same dimensionality") + elif in1.size == 0 or in2.size == 0: # empty arrays + return np.array([]) + + in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False) + + s1 = in1.shape + s2 = in2.shape + + shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 for i in range(in1.ndim)] + + ret = _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True) + + return _apply_conv_mode(ret, s1, s2, mode, axes) + + +def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): + if not len(axes): + return in1 * in2 + + complex_result = in1.dtype.kind == "c" or in2.dtype.kind == "c" + + if calc_fast_len: + # Speed up FFT by padding to optimal size. + fshape = [sp_fft.next_fast_len(shape[a], not complex_result) for a in axes] + else: + fshape = shape + + if not complex_result: + fft, ifft = sp_fft.rfftn, sp_fft.irfftn + else: + fft, ifft = sp_fft.fftn, sp_fft.ifftn + + sp1 = cache["full"][cache["mask"]] + sp2 = cache["template"] + + # sp2 = fft(in2[cache['mask']], fshape, axes=axes) + ret = ifft(sp1 * sp2, fshape, axes=axes) + + if calc_fast_len: + fslice = tuple([slice(sz) for sz in shape]) + ret = ret[fslice] + + return ret + def compute_overlaps(templates, num_samples, num_channels, sparsities): @@ -101,6 +193,320 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): ----- """ + _default_params = { + "amplitudes": [0.6, 2], + "omp_min_sps": 0.1, + "waveform_extractor": None, + "templates": None, + "overlaps": None, + "norms": None, + "random_chunk_kwargs": {}, + "noise_levels": None, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, + "ignored_ids": [], + "vicinity": 0, + } + + @classmethod + def _prepare_templates(cls, d): + waveform_extractor = d["waveform_extractor"] + num_templates = len(d["waveform_extractor"].sorting.unit_ids) + + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask + + templates = waveform_extractor.get_all_templates(mode="median").copy() + + d["sparsities"] = {} + d["templates"] = {} + d["norms"] = np.zeros(num_templates, dtype=np.float32) + + for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): + template = templates[count][:, sparsity[count]] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) + d["norms"][count] = np.linalg.norm(template) + d["templates"][count] = template / d["norms"][count] + + return d + + @classmethod + def initialize_and_check_kwargs(cls, recording, kwargs): + d = cls._default_params.copy() + d.update(kwargs) + + # assert isinstance(d['waveform_extractor'], WaveformExtractor) + + for v in ["omp_min_sps"]: + assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + + d["num_channels"] = d["waveform_extractor"].recording.get_num_channels() + d["num_samples"] = d["waveform_extractor"].nsamples + d["nbefore"] = d["waveform_extractor"].nbefore + d["nafter"] = d["waveform_extractor"].nafter + d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] + + if d["noise_levels"] is None: + print("CircusOMPPeeler : noise should be computed outside") + d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + + if d["templates"] is None: + d = cls._prepare_templates(d) + else: + for key in ["norms", "sparsities"]: + assert d[key] is not None, "If templates are provided, %d should also be there" % key + + d["num_templates"] = len(d["templates"]) + + if d["overlaps"] is None: + d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) + + d["ignored_ids"] = np.array(d["ignored_ids"]) + + omp_min_sps = d["omp_min_sps"] + # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) + d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) + + return d + + @classmethod + def serialize_method_kwargs(cls, kwargs): + kwargs = dict(kwargs) + # remove waveform_extractor + kwargs.pop("waveform_extractor") + return kwargs + + @classmethod + def unserialize_in_worker(cls, kwargs): + return kwargs + + @classmethod + def get_margin(cls, recording, kwargs): + margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) + return margin + + @classmethod + def main_function(cls, traces, d): + templates = d["templates"] + num_templates = d["num_templates"] + num_channels = d["num_channels"] + num_samples = d["num_samples"] + overlaps = d["overlaps"] + norms = d["norms"] + nbefore = d["nbefore"] + nafter = d["nafter"] + omp_tol = np.finfo(np.float32).eps + num_samples = d["nafter"] + d["nbefore"] + neighbor_window = num_samples - 1 + min_amplitude, max_amplitude = d["amplitudes"] + sparsities = d["sparsities"] + ignored_ids = d["ignored_ids"] + stop_criteria = d["stop_criteria"] + vicinity = d["vicinity"] + + if "cached_fft_kernels" not in d: + d["cached_fft_kernels"] = {"fshape": 0} + + cached_fft_kernels = d["cached_fft_kernels"] + + num_timesteps = len(traces) + + num_peaks = num_timesteps - num_samples + 1 + + traces = traces.T + + dummy_filter = np.empty((num_channels, num_samples), dtype=np.float32) + dummy_traces = np.empty((num_channels, num_timesteps), dtype=np.float32) + + fshape, axes = get_scipy_shape(dummy_filter, traces, axes=1) + fft_cache = {"full": sp_fft.rfftn(traces, fshape, axes=axes)} + + scalar_products = np.empty((num_templates, num_peaks), dtype=np.float32) + + flagged_chunk = cached_fft_kernels["fshape"] != fshape[0] + + for i in range(num_templates): + if i not in ignored_ids: + if i not in cached_fft_kernels or flagged_chunk: + kernel_filter = np.ascontiguousarray(templates[i][::-1].T) + cached_fft_kernels.update({i: sp_fft.rfftn(kernel_filter, fshape, axes=axes)}) + cached_fft_kernels["fshape"] = fshape[0] + + fft_cache.update({"mask": sparsities[i], "template": cached_fft_kernels[i]}) + + convolution = fftconvolve_with_cache(dummy_filter, dummy_traces, fft_cache, axes=1, mode="valid") + if len(convolution) > 0: + scalar_products[i] = convolution.sum(0) + else: + scalar_products[i] = 0 + + if len(ignored_ids) > 0: + scalar_products[ignored_ids] = -np.inf + + num_spikes = 0 + + spikes = np.empty(scalar_products.size, dtype=spike_dtype) + idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) + + M = np.zeros((100, 100), dtype=np.float32) + + all_selections = np.empty((2, scalar_products.size), dtype=np.int32) + final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) + num_selection = 0 + + full_sps = scalar_products.copy() + + neighbors = {} + cached_overlaps = {} + + is_valid = scalar_products > stop_criteria + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) + + while np.any(is_valid): + best_amplitude_ind = scalar_products[is_valid].argmax() + best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) + + if num_selection > 0: + delta_t = selection[1] - peak_index + idx = np.where((delta_t < neighbor_window) & (delta_t > -num_samples))[0] + myline = num_samples + delta_t[idx] + + if not best_cluster_ind in cached_overlaps: + cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() + + if num_selection == M.shape[0]: + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z[:num_selection, :num_selection] = M + M = Z + + M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] + + if vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 + else: + M[0, 0] = 1 + + all_selections[:, num_selection] = [best_cluster_ind, peak_index] + num_selection += 1 + + selection = all_selections[:, :num_selection] + res_sps = full_sps[selection[0], selection[1]] + + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + # This is not working, need to figure out why + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + + diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] + modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] + final_amplitudes[selection[0], selection[1]] = all_amplitudes + + for i in modified: + tmp_best, tmp_peak = selection[:, i] + diff_amp = diff_amplitudes[i] * norms[tmp_best] + + if not tmp_best in cached_overlaps: + cached_overlaps[tmp_best] = overlaps[tmp_best].toarray() + + if not tmp_peak in neighbors.keys(): + idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] + tdx = [num_samples + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak] + neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + + idx = neighbors[tmp_peak]["idx"] + tdx = neighbors[tmp_peak]["tdx"] + + to_add = diff_amp * cached_overlaps[tmp_best][:, tdx[0] : tdx[1]] + scalar_products[:, idx[0] : idx[1]] -= to_add + + is_valid = scalar_products > stop_criteria + + is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) + valid_indices = np.where(is_valid) + + num_spikes = len(valid_indices[0]) + spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["channel_index"][:num_spikes] = 0 + spikes["cluster_index"][:num_spikes] = valid_indices[0] + spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] + + spikes = spikes[:num_spikes] + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] + + return spikes + +class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): + """ + Orthogonal Matching Pursuit inspired from Spyking Circus sorter + + https://elifesciences.org/articles/34518 + + This is an Orthogonal Template Matching algorithm. For speed and + memory optimization, templates are automatically sparsified. Signal + is convolved with the templates, and as long as some scalar products + are higher than a given threshold, we use a Cholesky decomposition + to compute the optimal amplitudes needed to reconstruct the signal. + + IMPORTANT NOTE: small chunks are more efficient for such Peeler, + consider using 100ms chunk + + Parameters + ---------- + amplitude: tuple + (Minimal, Maximal) amplitudes allowed for every template + omp_min_sps: float + Stopping criteria of the OMP algorithm, in percentage of the norm + noise_levels: array + The noise levels, for every channels. If None, they will be automatically + computed + random_chunk_kwargs: dict + Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. + ----- + """ + _default_params = { "amplitudes": [0.6, 2], "omp_min_sps": 0.1, diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index bedc04a9d5..99c2817338 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -1,6 +1,6 @@ from .naive import NaiveMatching from .tdc import TridesclousPeeler -from .circus import CircusPeeler, CircusOMPPeeler +from .circus import CircusPeeler, CircusOMPPeeler, CircusOMPSVDPeeler from .wobble import WobbleMatch matching_methods = { @@ -8,5 +8,6 @@ "tridesclous": TridesclousPeeler, "circus": CircusPeeler, "circus-omp": CircusOMPPeeler, + 'circus-omp-svd' : CircusOMPSVDPeeler, "wobble": WobbleMatch, -} +} \ No newline at end of file From a6b4774000159f8db5439072acc8bdec4757d26b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:00:19 +0000 Subject: [PATCH 14/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../clustering/clustering_tools.py | 14 ++++---------- .../sortingcomponents/matching/circus.py | 2 +- .../sortingcomponents/matching/method_list.py | 4 ++-- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 455af3ddfd..17c38e2f8a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -533,13 +533,7 @@ def remove_duplicates( def remove_duplicates_via_matching( - waveform_extractor, - noise_levels, - peak_labels, - method_kwargs={}, - job_kwargs={}, - tmp_folder=None, - method='circus-omp' + waveform_extractor, noise_levels, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, method="circus-omp" ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface import get_noise_levels @@ -613,7 +607,7 @@ def remove_duplicates_via_matching( spikes, computed = find_spikes_from_templates( sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs ) - if method == 'circus-omp-vsd': + if method == "circus-omp-vsd": method_kwargs.update( { "overlaps": computed["overlaps"], @@ -627,13 +621,13 @@ def remove_duplicates_via_matching( "sparsity_mask": computed["sparsity_mask"], } ) - elif method == 'circus-omp': + elif method == "circus-omp": method_kwargs.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], "norms": computed["norms"], - "sparsities": computed["sparsities"] + "sparsities": computed["sparsities"], } ) valid = (spikes["sample_index"] >= half_marging) * (spikes["sample_index"] < duration + half_marging) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index e7bdcd161c..502c887ac4 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -128,7 +128,6 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret - def compute_overlaps(templates, num_samples, num_channels, sparsities): num_templates = len(templates) @@ -475,6 +474,7 @@ def main_function(cls, traces, d): return spikes + class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter diff --git a/src/spikeinterface/sortingcomponents/matching/method_list.py b/src/spikeinterface/sortingcomponents/matching/method_list.py index 99c2817338..d982943126 100644 --- a/src/spikeinterface/sortingcomponents/matching/method_list.py +++ b/src/spikeinterface/sortingcomponents/matching/method_list.py @@ -8,6 +8,6 @@ "tridesclous": TridesclousPeeler, "circus": CircusPeeler, "circus-omp": CircusOMPPeeler, - 'circus-omp-svd' : CircusOMPSVDPeeler, + "circus-omp-svd": CircusOMPSVDPeeler, "wobble": WobbleMatch, -} \ No newline at end of file +} From 257c74c856254f8ed31365f0629b53baf844fb74 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 27 Sep 2023 15:52:38 +0200 Subject: [PATCH 15/23] Slight misalignement --- .../sortingcomponents/matching/circus.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index e7bdcd161c..04d780bb6b 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -714,8 +714,8 @@ def main_function(cls, traces, d): if num_selection > 0: delta_t = selection[1] - peak_index - idx = np.where((delta_t < neighbor_window) & (delta_t >= -num_samples))[0] - myline = num_samples + delta_t[idx] + idx = np.where((delta_t < num_samples) & (delta_t > -num_samples))[0] + myline = neighbor_window + delta_t[idx] myindices = selection[0, idx] local_overlaps = overlaps[best_cluster_ind] @@ -731,7 +731,7 @@ def main_function(cls, traces, d): table = np.zeros(num_templates, dtype=int) table[overlapping_templates] = np.arange(len(overlapping_templates)) - M[num_selection, myindices[mask]] = local_overlaps[table[a], b] + M[num_selection, idx[mask]] = local_overlaps[table[a], b] if vicinity == 0: scipy.linalg.solve_triangular( @@ -797,8 +797,8 @@ def main_function(cls, traces, d): overlapping_templates = d["units_overlaps"][tmp_best] if not tmp_peak in neighbors.keys(): - idx = [max(0, tmp_peak - num_samples), min(num_peaks, tmp_peak + neighbor_window)] - tdx = [num_samples + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak] + idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] + tdx = [neighbor_window + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak - 1] neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} idx = neighbors[tmp_peak]["idx"] From 0a2c0f618b11374558f536147845a1cbc6710661 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 27 Sep 2023 16:01:21 +0200 Subject: [PATCH 16/23] Default SVD Peeler is now good to go --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index db3d88f116..7097b9e56b 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -152,7 +152,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( - recording_f, method="circus-omp", method_kwargs=matching_params, **matching_job_params + recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params ) if verbose: diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 17c38e2f8a..273b1402fe 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -533,7 +533,7 @@ def remove_duplicates( def remove_duplicates_via_matching( - waveform_extractor, noise_levels, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, method="circus-omp" + waveform_extractor, noise_levels, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, method="circus-omp-svd" ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface import get_noise_levels From 5fbc88d416f863784ee7ed890c45f04726d4dc5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 14:01:43 +0000 Subject: [PATCH 17/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/clustering_tools.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 273b1402fe..af3a9cb86a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -533,7 +533,13 @@ def remove_duplicates( def remove_duplicates_via_matching( - waveform_extractor, noise_levels, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None, method="circus-omp-svd" + waveform_extractor, + noise_levels, + peak_labels, + method_kwargs={}, + job_kwargs={}, + tmp_folder=None, + method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface import get_noise_levels From 9f45f2e5757e9f3dcb890a65d69bdecbca8c7eb6 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 27 Sep 2023 17:35:31 +0200 Subject: [PATCH 18/23] Enhance the clustering --- .../sorters/internal/spyking_circus2.py | 2 +- .../clustering/random_projections.py | 106 +++++++++--------- .../sortingcomponents/features_from_peaks.py | 27 +++-- 3 files changed, 71 insertions(+), 64 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 7097b9e56b..55a36d26d5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -20,7 +20,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index be8ecd6702..8c0cab07c6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -18,7 +18,9 @@ from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms -from spikeinterface.sortingcomponents.features_from_peaks import compute_features_from_peaks, EnergyFeature +from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser +from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature +from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever class RandomProjectionClustering: @@ -34,17 +36,17 @@ class RandomProjectionClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, + "waveforms" : {"ms_before" : 2, "ms_after" : 2, "max_spikes_per_unit": 100}, "radius_um": 100, - "max_spikes_per_unit": 200, "selection_method": "closest_to_centroid", - "nb_projections": {"ptp": 8, "energy": 2}, - "ms_before": 1.5, - "ms_after": 1.5, + "nb_projections": 10, + "ms_before": 1, + "ms_after": 1, "random_seed": 42, - "shared_memory": False, - "min_values": {"ptp": 0, "energy": 0}, + "smoothing_kwargs" : {"window_length_ms" : 1}, + "shared_memory": True, "tmp_folder": None, - "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "10M", "verbose": True, "progress_bar": True}, + "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } @classmethod @@ -74,50 +76,52 @@ def main_function(cls, recording, peaks, params): np.random.seed(d["random_seed"]) - features_params = {} - features_list = [] - - noise_snippets = None - - for proj_type in ["ptp", "energy"]: - if d["nb_projections"][proj_type] > 0: - features_list += [f"random_projections_{proj_type}"] - - if d["min_values"][proj_type] == "auto": - if noise_snippets is None: - num_segments = recording.get_num_segments() - num_chunks = 3 * d["max_spikes_per_unit"] // num_segments - noise_snippets = get_random_data_chunks( - recording, num_chunks_per_segment=num_chunks, chunk_size=num_samples, seed=42 - ) - noise_snippets = noise_snippets.reshape(num_chunks, num_samples, num_chans) - - if proj_type == "energy": - data = np.linalg.norm(noise_snippets, axis=1) - min_values = np.median(data, axis=0) - elif proj_type == "ptp": - data = np.ptp(noise_snippets, axis=1) - min_values = np.median(data, axis=0) - elif d["min_values"][proj_type] > 0: - min_values = d["min_values"][proj_type] - else: - min_values = None - - projections = np.random.randn(num_chans, d["nb_projections"][proj_type]) - features_params[f"random_projections_{proj_type}"] = { - "radius_um": params["radius_um"], - "projections": projections, - "min_values": min_values, - } - - features_data = compute_features_from_peaks( - recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]).absolute() + + ### Then we extract the SVD features + node0 = PeakRetriever(recording, peaks) + node1 = ExtractDenseWaveforms(recording, parents=[node0], return_output=False, + ms_before=params['ms_before'], + ms_after=params['ms_after'] ) - if len(features_data) > 1: - hdbscan_data = np.hstack((features_data[0], features_data[1])) - else: - hdbscan_data = features_data[0] + node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params['smoothing_kwargs']) + + projections = np.random.randn(num_chans, d["nb_projections"]) + projections -= projections.mean(0) + projections /= projections.std(0) + + nbefore = int(params['ms_before'] * fs / 1000) + nafter = int(params['ms_after'] * fs / 1000) + nsamples = nbefore + nafter + + import scipy + x = np.random.randn(100, nsamples, num_chans).astype(np.float32) + x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1) + + ptps = np.ptp(x, axis=1) + a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000)) + ydata = np.cumsum(a)/a.sum() + xdata = b[1:] + + from scipy.optimize import curve_fit + def sigmoid(x, L ,x0, k, b): + y = L / (1 + np.exp(-k*(x-x0))) + b + return (y) + + p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess + popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) + + node3 = RandomProjectionsFeature(recording, parents=[node0, node2], return_output=True, + projections=projections, radius_um=params['radius_um']) + + pipeline_nodes = [node0, node1, node2, node3] + + hdbscan_data = run_node_pipeline(recording, pipeline_nodes, params["job_kwargs"]) import sklearn @@ -132,7 +136,7 @@ def main_function(cls, recording, peaks, params): all_indices = np.arange(0, peak_labels.size) - max_spikes = params["max_spikes_per_unit"] + max_spikes = params['waveforms']["max_spikes_per_unit"] selection_method = params["selection_method"] for unit_ind in labels: diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index bd82ffa0a6..2f1acb6a19 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -184,41 +184,44 @@ def __init__( return_output=True, parents=None, projections=None, - radius_um=150.0, - min_values=None, + sigmoid=None, + radius_um=None ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.projections = projections - self.radius_um = radius_um - self.min_values = min_values - + self.sigmoid = sigmoid self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance < radius_um - - self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values)) - + self.radius_um = radius_um + self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um)) self._dtype = recording.get_dtype() def get_dtype(self): return self._dtype + def _sigmoid(self, x): + L, x0, k, b = self.sigmoid + y = L / (1 + np.exp(-k*(x-x0))) + b + return y + def compute(self, traces, peaks, waveforms): all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) + for main_chan in np.unique(peaks["channel_index"]): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] - wf_ptp = (waveforms[idx][:, :, chan_inds]).ptp(axis=1) + wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) - if self.min_values is not None: - wf_ptp = (wf_ptp / self.min_values[chan_inds]) ** 4 + if self.sigmoid is not None: + wf_ptp *= self._sigmoid(wf_ptp) denom = np.sum(wf_ptp, axis=1) mask = denom != 0 - all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis]) + return all_projections From 3cbf8f8fc8267ff0bffd8c340514db983e059a0c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:36:51 +0000 Subject: [PATCH 19/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../clustering/random_projections.py | 38 +++++++++++-------- .../sortingcomponents/features_from_peaks.py | 8 ++-- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 8c0cab07c6..f8cad2cf3f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -36,14 +36,14 @@ class RandomProjectionClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, - "waveforms" : {"ms_before" : 2, "ms_after" : 2, "max_spikes_per_unit": 100}, + "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, "selection_method": "closest_to_centroid", "nb_projections": 10, "ms_before": 1, "ms_after": 1, "random_seed": 42, - "smoothing_kwargs" : {"window_length_ms" : 1}, + "smoothing_kwargs": {"window_length_ms": 1}, "shared_memory": True, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, @@ -84,40 +84,46 @@ def main_function(cls, recording, peaks, params): ### Then we extract the SVD features node0 = PeakRetriever(recording, peaks) - node1 = ExtractDenseWaveforms(recording, parents=[node0], return_output=False, - ms_before=params['ms_before'], - ms_after=params['ms_after'] + node1 = ExtractDenseWaveforms( + recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"] ) - node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params['smoothing_kwargs']) + node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) projections = np.random.randn(num_chans, d["nb_projections"]) projections -= projections.mean(0) projections /= projections.std(0) - nbefore = int(params['ms_before'] * fs / 1000) - nafter = int(params['ms_after'] * fs / 1000) + nbefore = int(params["ms_before"] * fs / 1000) + nafter = int(params["ms_after"] * fs / 1000) nsamples = nbefore + nafter import scipy + x = np.random.randn(100, nsamples, num_chans).astype(np.float32) x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1) ptps = np.ptp(x, axis=1) a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000)) - ydata = np.cumsum(a)/a.sum() + ydata = np.cumsum(a) / a.sum() xdata = b[1:] from scipy.optimize import curve_fit - def sigmoid(x, L ,x0, k, b): - y = L / (1 + np.exp(-k*(x-x0))) + b - return (y) - p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess + def sigmoid(x, L, x0, k, b): + y = L / (1 + np.exp(-k * (x - x0))) + b + return y + + p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) - node3 = RandomProjectionsFeature(recording, parents=[node0, node2], return_output=True, - projections=projections, radius_um=params['radius_um']) + node3 = RandomProjectionsFeature( + recording, + parents=[node0, node2], + return_output=True, + projections=projections, + radius_um=params["radius_um"], + ) pipeline_nodes = [node0, node1, node2, node3] @@ -136,7 +142,7 @@ def sigmoid(x, L ,x0, k, b): all_indices = np.arange(0, peak_labels.size) - max_spikes = params['waveforms']["max_spikes_per_unit"] + max_spikes = params["waveforms"]["max_spikes_per_unit"] selection_method = params["selection_method"] for unit_ind in labels: diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 2f1acb6a19..b534c2356d 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -185,7 +185,7 @@ def __init__( parents=None, projections=None, sigmoid=None, - radius_um=None + radius_um=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) @@ -203,12 +203,12 @@ def get_dtype(self): def _sigmoid(self, x): L, x0, k, b = self.sigmoid - y = L / (1 + np.exp(-k*(x-x0))) + b + y = L / (1 + np.exp(-k * (x - x0))) + b return y def compute(self, traces, peaks, waveforms): all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) - + for main_chan in np.unique(peaks["channel_index"]): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) @@ -221,7 +221,7 @@ def compute(self, traces, peaks, waveforms): denom = np.sum(wf_ptp, axis=1) mask = denom != 0 all_projections[idx[mask]] = np.dot(wf_ptp[mask], local_projections) / (denom[mask][:, np.newaxis]) - + return all_projections From daddd8cef722a35233dbed530e14775de87b8caa Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 28 Sep 2023 09:16:51 +0200 Subject: [PATCH 20/23] Adding a lookup table --- .../sortingcomponents/matching/circus.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 5775589321..1d13eca1df 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -128,6 +128,7 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret + def compute_overlaps(templates, num_samples, num_channels, sparsities): num_templates = len(templates) @@ -474,7 +475,6 @@ def main_function(cls, traces, d): return spikes - class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -632,6 +632,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["num_templates"] = len(d["templates"]) d["ignored_ids"] = np.array(d["ignored_ids"]) + d["unit_overlaps_tables"] = {} + for i in range(d["num_templates"]): + d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) + d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) + + omp_min_sps = d["omp_min_sps"] # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) @@ -720,6 +726,7 @@ def main_function(cls, traces, d): local_overlaps = overlaps[best_cluster_ind] overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] + table = d["unit_overlaps_tables"][best_cluster_ind] if num_selection == M.shape[0]: Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) @@ -728,9 +735,6 @@ def main_function(cls, traces, d): mask = np.isin(myindices, overlapping_templates) a, b = myindices[mask], myline[mask] - - table = np.zeros(num_templates, dtype=int) - table[overlapping_templates] = np.arange(len(overlapping_templates)) M[num_selection, idx[mask]] = local_overlaps[table[a], b] if vicinity == 0: From d7dcbe05f082f5ecd93d9233b9f5ca30ae51a8f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 07:17:14 +0000 Subject: [PATCH 21/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/circus.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 1d13eca1df..44c394aec9 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -128,7 +128,6 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret - def compute_overlaps(templates, num_samples, num_channels, sparsities): num_templates = len(templates) @@ -475,6 +474,7 @@ def main_function(cls, traces, d): return spikes + class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -637,7 +637,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) - omp_min_sps = d["omp_min_sps"] # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) From d623da38f38924b9c5857abdeccf16891c729bc7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 28 Sep 2023 10:21:08 +0200 Subject: [PATCH 22/23] typos for cleaning via matching --- .../clustering/clustering_tools.py | 2 +- .../clustering/random_projections.py | 2 +- .../sortingcomponents/matching/circus.py | 15 ++++++++++----- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index af3a9cb86a..28a1a63065 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -613,7 +613,7 @@ def remove_duplicates_via_matching( spikes, computed = find_spikes_from_templates( sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs ) - if method == "circus-omp-vsd": + if method == "circus-omp-svd": method_kwargs.update( { "overlaps": computed["overlaps"], diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index f8cad2cf3f..df9290a1f5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -127,7 +127,7 @@ def sigmoid(x, L, x0, k, b): pipeline_nodes = [node0, node1, node2, node3] - hdbscan_data = run_node_pipeline(recording, pipeline_nodes, params["job_kwargs"]) + hdbscan_data = run_node_pipeline(recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features") import sklearn diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 1d13eca1df..9e02aa4ff6 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -686,13 +686,18 @@ def main_function(cls, traces, d): scalar_products = np.zeros(conv_shape, dtype=np.float32) # Filter using overlap-and-add convolution - spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"] - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") - scalar_products += np.sum(objective_by_rank, axis=0) - if len(ignored_ids) > 0: + mask = ~np.isin(np.arange(num_templates), ignored_ids) + spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"][:, mask, :] + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid") + scalar_products[mask] += np.sum(objective_by_rank, axis=0) scalar_products[ignored_ids] = -np.inf + else: + spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"] + objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") + scalar_products += np.sum(objective_by_rank, axis=0) num_spikes = 0 From fdb84668137ba71b1ca36787032551da52764842 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 08:21:36 +0000 Subject: [PATCH 23/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/random_projections.py | 4 +++- src/spikeinterface/sortingcomponents/matching/circus.py | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index df9290a1f5..864548e7d4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -127,7 +127,9 @@ def sigmoid(x, L, x0, k, b): pipeline_nodes = [node0, node1, node2, node3] - hdbscan_data = run_node_pipeline(recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features") + hdbscan_data = run_node_pipeline( + recording, pipeline_nodes, params["job_kwargs"], job_name="extracting features" + ) import sklearn diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index b963447ba2..358691cd25 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -687,16 +687,18 @@ def main_function(cls, traces, d): # Filter using overlap-and-add convolution if len(ignored_ids) > 0: mask = ~np.isin(np.arange(num_templates), ignored_ids) - spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) + spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) scaled_filtered_data = spatially_filtered_data * d["singular"][:, mask, :] - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid") + objective_by_rank = scipy.signal.oaconvolve( + scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid" + ) scalar_products[mask] += np.sum(objective_by_rank, axis=0) scalar_products[ignored_ids] = -np.inf else: spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) scaled_filtered_data = spatially_filtered_data * d["singular"] objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") - scalar_products += np.sum(objective_by_rank, axis=0) + scalar_products += np.sum(objective_by_rank, axis=0) num_spikes = 0