From 8c3699e9ea5dcb7f16487011ba8da3cf3e555346 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 18:18:34 +0000 Subject: [PATCH 01/11] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v5.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.6.0...v5.0.0) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c4bd68be4..1e133694ba 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-yaml - id: end-of-file-fixer From 1ffe6ccf8898d237519c36a8ce439b6b9db896e7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 9 Oct 2024 12:49:15 +0200 Subject: [PATCH 02/11] oups --- src/spikeinterface/benchmark/tests/test_benchmark_sorter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index 2564d58d52..db48d32fde 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -64,10 +64,10 @@ def test_SorterStudy(setup_module): print(study) # # this run the sorters - # study.run() + study.run() # # this run comparisons - # study.compute_results() + study.compute_results() print(study) # this is from the base class @@ -84,5 +84,7 @@ def test_SorterStudy(setup_module): if __name__ == "__main__": study_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy" + if study_folder.exists(): + shutil.rmtree(study_folder) create_a_study(study_folder) test_SorterStudy(study_folder) From 275e5017c978a7b5a5ffb29b914c8398658d1954 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 9 Oct 2024 15:34:40 +0200 Subject: [PATCH 03/11] Sc2 fixes * Fixes * Patches * Fixes for SC2 and for split clustering * debugging clustering * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * WIP * Default params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adding gather_func to find_spikes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Gathering mode more explicit for matching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * WIP * Fixes for SC2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * Simplifications * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Naming for Sam * Optimize circus matching engine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Optimizations * Remove the limit to chunk sizes in circus-omp-svd * Naming * Patch imports --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../sorters/internal/spyking_circus2.py | 41 ++-- .../sortingcomponents/clustering/circus.py | 18 +- .../sortingcomponents/clustering/split.py | 22 +- .../sortingcomponents/matching/circus.py | 192 +++++++++--------- .../sortingcomponents/peak_detection.py | 13 +- .../sortingcomponents/peak_localization.py | 2 +- 6 files changed, 153 insertions(+), 135 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index c3b3099535..211adba990 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -24,9 +24,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, + "whitening": {"mode": "local", "regularize": True}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { "method": "uniform", @@ -36,7 +37,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "seed": 42, }, "apply_motion_correction": True, - "motion_correction": {"preset": "nonrigid_fast_and_accurate"}, + "motion_correction": {"preset": "dredge_fast"}, "merging": { "similarity_kwargs": {"method": "cosine", "support": "union", "max_lag_ms": 0.2}, "correlograms_kwargs": {}, @@ -46,7 +47,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): }, }, "clustering": {"legacy": True}, - "matching": {"method": "wobble"}, + "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -62,6 +63,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): and also the radius_um used to be considered during clustering", "sparsity": "A dictionary to be passed to all the calls to sparsify the templates", "filtering": "A dictionary for the high_pass filter to be used during preprocessing", + "whitening": "A dictionary for the whitening option to be used during preprocessing", "detection": "A dictionary for the peak detection node (locally_exclusive)", "selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\ and 5000 peaks per electrode on average.", @@ -109,8 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction from spikeinterface.sortingcomponents.tools import get_prototype_spike - job_kwargs = params["job_kwargs"] - job_kwargs = fix_job_kwargs(job_kwargs) + job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -119,7 +120,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): num_channels = recording.get_num_channels() ms_before = params["general"].get("ms_before", 2) ms_after = params["general"].get("ms_after", 2) - radius_um = params["general"].get("radius_um", 100) + radius_um = params["general"].get("radius_um", 75) exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after) / 2) ## First, we are filtering the data @@ -143,14 +144,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" params["motion_correction"].update({"folder": motion_folder}) - recording_f = correct_motion(recording_f, **params["motion_correction"]) + recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs) else: motion_folder = None ## We need to whiten before the template matching step, to boost the results # TODO add , regularize=True chen ready - recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True) + whitening_kwargs = params["whitening"].copy() + whitening_kwargs["dtype"] = "float32" + whitening_kwargs["radius_um"] = radius_um + if num_channels == 1: + whitening_kwargs["regularize"] = False + recording_w = whiten(recording_f, **whitening_kwargs) noise_levels = get_noise_levels(recording_w, return_scaled=False) if recording_w.check_serializability("json"): @@ -172,20 +178,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) - if params["matched_filtering"]: + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000) prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before - - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in detection_params: - detection_params.pop(value) - - detection_params["chunk_duration"] = "100ms" - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params) + else: + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) if verbose: print("We found %d peaks in total" % len(peaks)) @@ -196,7 +196,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible selection_params = params["selection"] - selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels + selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels) selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) selection_params.update({"noise_levels": noise_levels}) @@ -281,11 +281,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_job_params = job_kwargs.copy() if matching_method is not None: - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in matching_job_params: - matching_job_params[value] = None - matching_job_params["chunk_duration"] = "100ms" - spikes = find_spikes_from_templates( recording_w, matching_method, method_kwargs=matching_params, **matching_job_params ) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index b08ee4d9cb..b7e71d3b45 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -136,6 +136,7 @@ def main_function(cls, recording, peaks, params): pipeline_nodes = [node0, node1, node2] if len(params["recursive_kwargs"]) == 0: + from sklearn.decomposition import PCA all_pc_data = run_node_pipeline( recording, @@ -152,9 +153,9 @@ def main_function(cls, recording, peaks, params): sub_data = sub_data.reshape(len(sub_data), -1) if all_pc_data.shape[1] > params["n_svd"][1]: - tsvd = TruncatedSVD(params["n_svd"][1]) + tsvd = PCA(params["n_svd"][1], whiten=True) else: - tsvd = TruncatedSVD(all_pc_data.shape[1]) + tsvd = PCA(all_pc_data.shape[1], whiten=True) hdbscan_data = tsvd.fit_transform(sub_data) try: @@ -184,7 +185,7 @@ def main_function(cls, recording, peaks, params): ) sparse_mask = node1.neighbours_mask - neighbours_mask = get_channel_distances(recording) < radius_um + neighbours_mask = get_channel_distances(recording) <= radius_um # np.save(features_folder / "sparse_mask.npy", sparse_mask) np.save(features_folder / "peaks.npy", peaks) @@ -192,6 +193,8 @@ def main_function(cls, recording, peaks, params): original_labels = peaks["channel_index"] from spikeinterface.sortingcomponents.clustering.split import split_clusters + min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50) + peak_labels, _ = split_clusters( original_labels, recording, @@ -202,7 +205,7 @@ def main_function(cls, recording, peaks, params): feature_name="sparse_tsvd", neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, - min_size_split=50, + min_size_split=min_size, clusterer_kwargs=d["hdbscan_kwargs"], n_pca_features=params["n_svd"][1], scale_n_pca_by_depth=True, @@ -233,7 +236,7 @@ def main_function(cls, recording, peaks, params): if d["rank"] is not None: from spikeinterface.sortingcomponents.matching.circus import compress_templates - _, _, _, templates_array = compress_templates(templates_array, 5) + _, _, _, templates_array = compress_templates(templates_array, d["rank"]) templates = Templates( templates_array=templates_array, @@ -258,13 +261,8 @@ def main_function(cls, recording, peaks, params): print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) cleaning_matching_params = params["job_kwargs"].copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in cleaning_matching_params: - cleaning_matching_params.pop(value) - cleaning_matching_params["chunk_duration"] = "100ms" cleaning_matching_params["n_jobs"] = 1 cleaning_matching_params["progress_bar"] = False - cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 5934bdfbbb..15917934a8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -24,7 +24,7 @@ def split_clusters( peak_labels, recording, features_dict_or_folder, - method="hdbscan_on_local_pca", + method="local_feature_clustering", method_kwargs={}, recursive=False, recursive_depth=None, @@ -81,7 +81,6 @@ def split_clusters( ) as pool: labels_set = np.setdiff1d(peak_labels, [-1]) current_max_label = np.max(labels_set) + 1 - jobs = [] for label in labels_set: peak_indices = np.flatnonzero(peak_labels == label) @@ -95,15 +94,14 @@ def split_clusters( for res in iterator: is_split, local_labels, peak_indices = res.result() + # print(is_split, local_labels, peak_indices) if not is_split: continue mask = local_labels >= 0 peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label peak_labels[peak_indices[~mask]] = local_labels[~mask] - split_count[peak_indices] += 1 - current_max_label += np.max(local_labels[mask]) + 1 if recursive: @@ -120,6 +118,7 @@ def split_clusters( for label in new_labels_set: peak_indices = np.flatnonzero(peak_labels == label) if peak_indices.size > 0: + # print('Relaunched', label, len(peak_indices), recursion_level) jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level)) if progress_bar: iterator.total += 1 @@ -187,7 +186,7 @@ def split( min_size_split=25, n_pca_features=2, scale_n_pca_by_depth=False, - minimum_common_channels=2, + minimum_overlap_ratio=0.25, ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -199,19 +198,22 @@ def split( # target channel subset is done intersect local channels + neighbours local_chans = np.unique(peaks["channel_index"][peak_indices]) - target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + target_intersection_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + target_union_channels = np.flatnonzero(np.any(neighbours_mask[local_chans, :], axis=0)) + num_intersection = len(target_intersection_channels) + num_union = len(target_union_channels) # TODO fix this a better way, this when cluster have too few overlapping channels - if target_channels.size < minimum_common_channels: + if (num_intersection / num_union) < minimum_overlap_ratio: return False, None aligned_wfs, dont_have_channels = aggregate_sparse_features( - peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_intersection_channels ) local_labels[dont_have_channels] = -2 kept = np.flatnonzero(~dont_have_channels) - + # print(recursion_level, kept.size, min_size_split) if kept.size < min_size_split: return False, None @@ -222,6 +224,8 @@ def split( if flatten_features.shape[1] > n_pca_features: from sklearn.decomposition import PCA + # from sklearn.decomposition import TruncatedSVD + if scale_n_pca_by_depth: # tsvd = TruncatedSVD(n_pca_features * recursion_level) tsvd = PCA(n_pca_features * recursion_level, whiten=True) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index a3624f4296..d1b2139c5b 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -9,6 +9,7 @@ from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel from spikeinterface.core.template import Templates + spike_dtype = [ ("sample_index", "int64"), ("channel_index", "int64"), @@ -140,12 +141,12 @@ def __init__( templates=None, amplitudes=[0.6, np.inf], stop_criteria="max_failures", - max_failures=10, + max_failures=5, omp_min_sps=0.1, relative_error=5e-5, rank=5, ignore_inds=[], - vicinity=3, + vicinity=2, precomputed=None, ): @@ -181,16 +182,16 @@ def __init__( self.unit_overlaps_tables[i] = np.zeros(self.num_templates, dtype=int) self.unit_overlaps_tables[i][self.unit_overlaps_indices[i]] = np.arange(len(self.unit_overlaps_indices[i])) - if self.vicinity > 0: - self.margin = self.vicinity - else: - self.margin = 2 * self.num_samples + self.margin = 2 * self.num_samples def _prepare_templates(self): assert self.stop_criteria in ["max_failures", "omp_min_sps", "relative_error"] - sparsity = self.templates.sparsity.mask + if self.templates.sparsity is None: + sparsity = np.ones((self.num_templates, self.num_channels), dtype=bool) + else: + sparsity = self.templates.sparsity.mask units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) self.units_overlaps = units_overlaps > 0 @@ -265,6 +266,7 @@ def get_trace_margin(self): def compute_matching(self, traces, start_frame, end_frame, segment_index): import scipy.spatial import scipy + from scipy import ndimage (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) @@ -316,8 +318,6 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): full_sps = scalar_products.copy() - neighbors = {} - all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) @@ -336,100 +336,113 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): do_loop = True while do_loop: - best_amplitude_ind = scalar_products.argmax() - best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) - if num_selection > 0: - delta_t = selection[1] - peak_index - idx = np.where((delta_t < num_samples) & (delta_t > -num_samples))[0] - myline = neighbor_window + delta_t[idx] - myindices = selection[0, idx] + best_cluster_inds = np.argmax(scalar_products, axis=0, keepdims=True) + products = np.take_along_axis(scalar_products, best_cluster_inds, axis=0) - local_overlaps = overlaps_array[best_cluster_ind] - overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] - table = self.unit_overlaps_tables[best_cluster_ind] + result = ndimage.maximum_filter(products[0], size=self.vicinity, mode="constant", cval=0) + cond_1 = products[0] / self.norms[best_cluster_inds[0]] > 0.25 + cond_2 = np.abs(products[0] - result) < 1e-9 + peak_indices = np.flatnonzero(cond_1 * cond_2) - if num_selection == M.shape[0]: - Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) - Z[:num_selection, :num_selection] = M - M = Z + if len(peak_indices) == 0: + break - mask = np.isin(myindices, overlapping_templates) - a, b = myindices[mask], myline[mask] - M[num_selection, idx[mask]] = local_overlaps[table[a], b] + for peak_index in peak_indices: - if self.vicinity == 0: - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) - else: - is_in_vicinity = np.where(np.abs(delta_t) < self.vicinity)[0] + best_cluster_ind = best_cluster_inds[0, peak_index] + + if num_selection > 0: + delta_t = selection[1] - peak_index + idx = np.flatnonzero((delta_t < num_samples) & (delta_t > -num_samples)) + myline = neighbor_window + delta_t[idx] + myindices = selection[0, idx] + + local_overlaps = overlaps_array[best_cluster_ind] + overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] + table = self.unit_overlaps_tables[best_cluster_ind] - if len(is_in_vicinity) > 0: - L = M[is_in_vicinity, :][:, is_in_vicinity] + if num_selection == M.shape[0]: + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z[:num_selection, :num_selection] = M + M = Z - M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( - L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + mask = np.isin(myindices, overlapping_templates) + a, b = myindices[mask], myline[mask] + M[num_selection, idx[mask]] = local_overlaps[table[a], b] + + if self.vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, ) - v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + v = nrm2(M[num_selection, :num_selection]) ** 2 Lkk = 1 - v if Lkk <= omp_tol: # selected atoms are dependent break M[num_selection, num_selection] = np.sqrt(Lkk) else: - M[num_selection, num_selection] = 1.0 - else: - M[0, 0] = 1 - - all_selections[:, num_selection] = [best_cluster_ind, peak_index] - num_selection += 1 - - selection = all_selections[:, :num_selection] - res_sps = full_sps[selection[0], selection[1]] - - if self.vicinity == 0: - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - all_amplitudes /= self.norms[selection[0]] - else: - is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) - all_amplitudes = np.append(all_amplitudes, np.float32(1)) - L = M[is_in_vicinity, :][:, is_in_vicinity] - all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) - all_amplitudes[is_in_vicinity] /= self.norms[selection[0][is_in_vicinity]] - - diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] - modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] - final_amplitudes[selection[0], selection[1]] = all_amplitudes - - for i in modified: - tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i] * self.norms[tmp_best] - - local_overlaps = overlaps_array[tmp_best] - overlapping_templates = self.units_overlaps[tmp_best] + is_in_vicinity = np.flatnonzero(np.abs(delta_t) < self.vicinity) + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, + M[num_selection, is_in_vicinity], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 + else: + M[0, 0] = 1 - if not tmp_peak in neighbors.keys(): - idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] - tdx = [neighbor_window + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak - 1] - neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + all_selections[:, num_selection] = [best_cluster_ind, peak_index] + num_selection += 1 - idx = neighbors[tmp_peak]["idx"] - tdx = neighbors[tmp_peak]["tdx"] + selection = all_selections[:, :num_selection] + res_sps = full_sps[selection[0], selection[1]] - to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] - scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add + if self.vicinity == 0: + new_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + sub_selection = selection + new_amplitudes /= self.norms[sub_selection[0]] + else: + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + new_amplitudes, _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + sub_selection = selection[:, is_in_vicinity] + new_amplitudes /= self.norms[sub_selection[0]] + + diff_amplitudes = new_amplitudes - final_amplitudes[sub_selection[0], sub_selection[1]] + modified = np.flatnonzero(np.abs(diff_amplitudes) > omp_tol) + final_amplitudes[sub_selection[0], sub_selection[1]] = new_amplitudes + + for i in modified: + tmp_best, tmp_peak = sub_selection[:, i] + diff_amp = diff_amplitudes[i] * self.norms[tmp_best] + local_overlaps = overlaps_array[tmp_best] + overlapping_templates = self.units_overlaps[tmp_best] + tmp = tmp_peak - neighbor_window + idx = [max(0, tmp), min(num_peaks, tmp_peak + num_samples)] + tdx = [idx[0] - tmp, idx[1] - tmp] + to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] + scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add # We stop when updates do not modify the chosen spikes anymore if self.stop_criteria == "omp_min_sps": @@ -462,12 +475,9 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): spikes["cluster_index"][:num_spikes] = valid_indices[0] spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] - print("yep0", spikes.size, num_spikes, spikes.shape, spikes.dtype) spikes = spikes[:num_spikes] - print("yep1", spikes.size, spikes.shape, spikes.dtype) - if spikes.size > 0: - order = np.argsort(spikes["sample_index"]) - spikes = spikes[order] + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] return spikes diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index ad8897df91..d608c5d105 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -50,7 +50,14 @@ def detect_peaks( - recording, method="locally_exclusive", pipeline_nodes=None, gather_mode="memory", folder=None, names=None, **kwargs + recording, + method="locally_exclusive", + pipeline_nodes=None, + gather_mode="memory", + folder=None, + names=None, + skip_after_n_peaks=None, + **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -73,6 +80,9 @@ def detect_peaks( If gather_mode is "npy", the folder where the files are created. names : list List of strings with file stems associated with returns. + skip_after_n_peaks : None | int + Skip the computation after n_peaks. + This is not an exact because internally this skip is done per worker in average. {method_doc} {job_doc} @@ -124,6 +134,7 @@ def detect_peaks( squeeze_output=squeeze_output, folder=folder, names=names, + skip_after_n_peaks=skip_after_n_peaks, ) return outs diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index ddc8add995..08bcabf5e5 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -135,7 +135,7 @@ def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs["radius_um"] = radius_um def get_dtype(self): From bbf7daf6fb34c831dafc6111e9f51221b028b396 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:40:29 -0400 Subject: [PATCH 04/11] Add neuronexus allego recording Extractor (#3235) * add neuronexus allego * add tests * fix neuronexus name * Heberto feedback * Fix capitalization * oops * add assert messaging * Update src/spikeinterface/extractors/neoextractors/neuronexus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --------- Co-authored-by: Heberto Mayorquin --- .../extractors/neoextractors/__init__.py | 2 + .../extractors/neoextractors/neuronexus.py | 66 +++++++++++++++++++ .../extractors/tests/common_tests.py | 7 +- .../extractors/tests/test_neoextractors.py | 8 +++ 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 src/spikeinterface/extractors/neoextractors/neuronexus.py diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index bf52de7c1d..03d517b46e 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -9,6 +9,7 @@ from .mearec import MEArecRecordingExtractor, MEArecSortingExtractor, read_mearec from .mcsraw import MCSRawRecordingExtractor, read_mcsraw from .neuralynx import NeuralynxRecordingExtractor, NeuralynxSortingExtractor, read_neuralynx, read_neuralynx_sorting +from .neuronexus import NeuroNexusRecordingExtractor, read_neuronexus from .neuroscope import ( NeuroScopeRecordingExtractor, NeuroScopeSortingExtractor, @@ -54,6 +55,7 @@ MCSRawRecordingExtractor, NeuralynxRecordingExtractor, NeuroScopeRecordingExtractor, + NeuroNexusRecordingExtractor, NixRecordingExtractor, OpenEphysBinaryRecordingExtractor, OpenEphysLegacyRecordingExtractor, diff --git a/src/spikeinterface/extractors/neoextractors/neuronexus.py b/src/spikeinterface/extractors/neoextractors/neuronexus.py new file mode 100644 index 0000000000..dca482b28a --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/neuronexus.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from pathlib import Path + +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor + + +class NeuroNexusRecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading data from NeuroNexus Allego. + + Based on :py:class:`neo.rawio.NeuronexusRawIO` + + Parameters + ---------- + file_path : str | Path + The file path to the metadata .xdat.json file of an Allego session + stream_id : str | None, default: None + If there are several streams, specify the stream id you want to load. + stream_name : str | None, default: None + If there are several streams, specify the stream name you want to load. + all_annotations : bool, default: False + Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + In Neuronexus the ids provided by NeoRawIO are the hardware channel ids stored as `ntv_chan_name` within + the metada and the names are the `chan_names` + + + """ + + NeoRawIOClass = "NeuroNexusRawIO" + + def __init__( + self, + file_path: str | Path, + stream_id: str | None = None, + stream_name: str | None = None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): + neo_kwargs = self.map_to_neo_kwargs(file_path) + NeoBaseRecordingExtractor.__init__( + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, + ) + + self._kwargs.update(dict(file_path=str(Path(file_path).resolve()))) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + + neo_kwargs = {"filename": str(file_path)} + + return neo_kwargs + + +read_neuronexus = define_function_from_class(source_class=NeuroNexusRecordingExtractor, name="read_neuronexus") diff --git a/src/spikeinterface/extractors/tests/common_tests.py b/src/spikeinterface/extractors/tests/common_tests.py index 5432efa9f3..61cfc2a153 100644 --- a/src/spikeinterface/extractors/tests/common_tests.py +++ b/src/spikeinterface/extractors/tests/common_tests.py @@ -52,8 +52,11 @@ def test_open(self): num_samples = rec.get_num_samples(segment_index=segment_index) full_traces = rec.get_traces(segment_index=segment_index) - assert full_traces.shape == (num_samples, num_chans) - assert full_traces.dtype == dtype + assert full_traces.shape == ( + num_samples, + num_chans, + ), f"{full_traces.shape} != {(num_samples, num_chans)}" + assert full_traces.dtype == dtype, f"{full_traces.dtype} != {dtype=}" traces_sample_first = rec.get_traces(segment_index=segment_index, start_frame=0, end_frame=1) assert traces_sample_first.shape == (1, num_chans) diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 3f73161218..fcdd766f4f 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -181,6 +181,14 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): ] +class NeuroNexusRecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = NeuroNexusRecordingExtractor + downloads = ["neuronexus"] + entities = [ + ("neuronexus/allego_1/allego_2__uid0701-13-04-49.xdat.json", {"stream_id": "0"}), + ] + + class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonRecordingExtractor downloads = ["plexon"] From 0ae32e729acb0be01d1cee28453e3ad3503877fb Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 11 Oct 2024 11:32:27 +0200 Subject: [PATCH 05/11] Tdc peeler (#3466) Improving the Peeler --- .../benchmark/benchmark_matching.py | 73 +- .../benchmark/benchmark_plot_tools.py | 64 +- .../sortingcomponents/matching/tdc.py | 686 +++++++++++++----- .../sortingcomponents/matching/wobble.py | 2 +- .../tests/test_template_matching.py | 40 +- 5 files changed, 608 insertions(+), 257 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index c53567f460..3799fa19b3 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -33,7 +33,7 @@ def run(self, **job_kwargs): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) - self.result = {"sorting": sorting} + self.result = {"sorting": sorting, "spikes": spikes} self.result["templates"] = self.templates def compute_result(self, with_collision=False, **result_params): @@ -45,6 +45,7 @@ def compute_result(self, with_collision=False, **result_params): _run_key_saved = [ ("sorting", "sorting"), + ("spikes", "npy"), ("templates", "zarr_templates"), ] _result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")] @@ -71,6 +72,11 @@ def plot_performances_vs_snr(self, **kwargs): return plot_performances_vs_snr(self, **kwargs) + def plot_performances_comparison(self, **kwargs): + from .benchmark_plot_tools import plot_performances_comparison + + return plot_performances_comparison(self, **kwargs) + def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -90,70 +96,6 @@ def plot_collisions(self, case_keys=None, figsize=None): return fig - def plot_comparison_matching( - self, - case_keys=None, - performance_names=["accuracy", "recall", "precision"], - colors=["g", "b", "r"], - ylim=(-0.1, 1.1), - figsize=None, - ): - - if case_keys is None: - case_keys = list(self.cases.keys()) - - num_methods = len(case_keys) - import pylab as plt - - fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) - for i, key1 in enumerate(case_keys): - for j, key2 in enumerate(case_keys): - if len(axs.shape) > 1: - ax = axs[i, j] - else: - ax = axs[j] - comp1 = self.get_result(key1)["gt_comparison"] - comp2 = self.get_result(key2)["gt_comparison"] - if i <= j: - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - label1 = self.cases[key1]["label"] - label2 = self.cases[key2]["label"] - if j == i: - ax.set_ylabel(f"{label1}") - else: - ax.set_yticks([]) - if i == j: - ax.set_xlabel(f"{label2}") - else: - ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - import matplotlib.patches as mpatches - - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) - else: - ax.spines["bottom"].set_visible(False) - ax.spines["left"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xticks([]) - ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) - - return fig - def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): import pandas as pd @@ -196,6 +138,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None): plot_study_unit_counts(self, case_keys, figsize=figsize) def plot_unit_losses(self, before, after, metric=["precision"], figsize=None): + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index a6e9b6dacc..e15636ebaf 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -235,9 +235,71 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu ax.scatter(x, y, marker=".", label=label) ax.set_title(k) - ax.set_ylim(0, 1.05) + ax.set_ylim(-0.05, 1.05) if count == 2: ax.legend() return fig + + +def plot_performances_comparison( + study, + case_keys=None, + figsize=None, + metrics=["accuracy", "recall", "precision"], + colors=["g", "b", "r"], + ylim=(-0.1, 1.1), +): + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + num_methods = len(case_keys) + assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!" + + fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False) + for i, key1 in enumerate(case_keys): + for j, key2 in enumerate(case_keys): + + if i < j: + ax = axs[i, j - 1] + + comp1 = study.get_result(key1)["gt_comparison"] + comp2 = study.get_result(key2)["gt_comparison"] + + for performance, color in zip(metrics, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.scatter(perf2, perf1, marker=".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + label1 = study.cases[key1]["label"] + label2 = study.cases[key2]["label"] + + if i == j - 1: + ax.set_xlabel(label2) + ax.set_ylabel(label1) + + else: + if j >= 1 and i < num_methods - 1: + ax = axs[i, j - 1] + ax.spines[["right", "top", "left", "bottom"]].set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + + ax = axs[num_methods - 2, 0] + patches = [] + from matplotlib.patches import Patch + + for color, name in zip(colors, metrics): + patches.append(Patch(color=color, label=name)) + ax.legend(handles=patches) + fig.tight_layout() + return fig diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 56457fe2fa..125baa3bda 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -2,15 +2,11 @@ import numpy as np from spikeinterface.core import ( - get_noise_levels, get_channel_distances, - compute_sparsity, get_template_extremum_channel, ) -from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive -from spikeinterface.core.template import Templates - +from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive, DetectPeakMatchedFiltering from .base import BaseTemplateMatching, _base_matching_dtype @@ -25,7 +21,7 @@ class TridesclousPeeler(BaseTemplateMatching): """ - Template-matching ported from Tridesclous sorter. + Template-matching used by Tridesclous sorter. The idea of this peeler is pretty simple. 1. Find peaks @@ -34,8 +30,10 @@ class TridesclousPeeler(BaseTemplateMatching): 4. remove it from traces. 5. in the residual find peaks again - This method is quite fast but don't give exelent results to resolve - spike collision when templates have high similarity. + Contrary tp circus_peeler or wobble, this template matching is working directly one the waveforms. + There is no SVD decomposition + + """ def __init__( @@ -45,26 +43,29 @@ def __init__( parents=None, templates=None, peak_sign="neg", + exclude_sweep_ms=0.5, peak_shift_ms=0.2, detect_threshold=5, noise_levels=None, - radius_um=100.0, - num_closest=5, - sample_shift=3, - ms_before=0.8, - ms_after=1.2, - num_peeler_loop=2, - num_template_try=1, + use_fine_detector=True, + # TODO optimize theses radius + detection_radius_um=80.0, + cluster_radius_um=150.0, + amplitude_fitting_radius_um=150.0, + sample_shift=2, + ms_before=0.5, + ms_after=0.8, + max_peeler_loop=2, + amplitude_limits=(0.7, 1.4), ): BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - # maybe in base? - self.templates_array = templates.get_dense_templates() - unit_ids = templates.unit_ids channel_ids = recording.channel_ids + num_templates = unit_ids.size + sr = recording.sampling_frequency self.nbefore = templates.nbefore @@ -82,8 +83,9 @@ def __init__( s1 = -(templates.nafter - nafter_short) if s1 == 0: s1 = None + # TODO check with out copy - self.templates_short = self.templates_array[:, slice(s0, s1), :].copy() + self.sparse_templates_array_short = templates.templates_array[:, slice(s0, s1), :].copy() self.peak_shift = int(peak_shift_ms / 1000 * sr) @@ -92,12 +94,12 @@ def __init__( self.abs_thresholds = noise_levels * detect_threshold channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance < radius_um + self.neighbours_mask = channel_distance <= detection_radius_um if templates.sparsity is not None: - self.template_sparsity = templates.sparsity.mask + self.sparsity_mask = templates.sparsity.mask else: - self.template_sparsity = np.ones((unit_ids.size, channel_ids.size), dtype=bool) + self.sparsity_mask = np.ones((unit_ids.size, channel_ids.size), dtype=bool) extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") # as numpy vector @@ -109,72 +111,108 @@ def __init__( # distance between units import scipy - unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - - # seach for closet units and unitary discriminant vector - closest_units = [] - for unit_ind, unit_id in enumerate(unit_ids): - order = np.argsort(unit_distances[unit_ind, :]) - closest_u = np.arange(unit_ids.size)[order].tolist() - closest_u.remove(unit_ind) - closest_u = np.array(closest_u[:num_closest]) - - # compute unitary discriminent vector - (chans,) = np.nonzero(self.template_sparsity[unit_ind, :]) - template_sparse = self.templates_array[unit_ind, :, :][:, chans] - closest_vec = [] - # against N closets - for u in closest_u: - vec = self.templates_array[u, :, :][:, chans] - template_sparse - vec /= np.sum(vec**2) - closest_vec.append((u, vec)) - # against noise - closest_vec.append((None, -template_sparse / np.sum(template_sparse**2))) - - closest_units.append(closest_vec) - - self.closest_units = closest_units - - # distance channel from unit - import scipy - - distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances < radius_um - # nearby cluster for each channel + distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") + near_cluster_mask = distances <= cluster_radius_um self.possible_clusters_by_channel = [] for channel_index in range(distances.shape[0]): (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) self.possible_clusters_by_channel.append(cluster_inds) + # precompute template norms ons sparse channels + self.template_norms = np.zeros(num_templates, dtype="float32") + for i in range(unit_ids.size): + chan_mask = self.sparsity_mask[i, :] + n = np.sum(chan_mask) + template = templates.templates_array[i, :, :n] + self.template_norms[i] = np.sum(template**2) + + # + distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") + self.near_chan_mask = distances <= amplitude_fitting_radius_um + self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") - self.num_peeler_loop = num_peeler_loop - self.num_template_try = num_template_try + self.max_peeler_loop = max_peeler_loop + self.amplitude_limits = amplitude_limits + + self.fast_spike_detector = DetectPeakLocallyExclusive( + recording=recording, + peak_sign=peak_sign, + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + noise_levels=noise_levels, + ) - self.margin = max(self.nbefore, self.nafter) * 2 + ##get prototype from best channel of each template + prototype = np.zeros(self.nbefore + self.nafter, dtype="float32") + for i in range(num_templates): + template = templates.templates_array[i, :, :] + chan_ind = np.argmax(np.abs(template[self.nbefore, :])) + if template[self.nbefore, chan_ind] != 0: + prototype += template[:, chan_ind] / np.abs(template[self.nbefore, chan_ind]) + prototype /= np.abs(prototype[self.nbefore]) + + # import matplotlib.pyplot as plt + # fig,ax = plt.subplots() + # ax.plot(prototype) + # plt.show() + + self.use_fine_detector = use_fine_detector + if self.use_fine_detector: + self.fine_spike_detector = DetectPeakMatchedFiltering( + recording=recording, + prototype=prototype, + ms_before=templates.nbefore / sr * 1000.0, + peak_sign="neg", + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + weight_method=dict( + z_list_um=np.array([50.0]), + sigma_3d=2.5, + mode="exponential_3d", + ), + noise_levels=None, + ) + + self.detector_margin0 = self.fast_spike_detector.get_trace_margin() + self.detector_margin1 = self.fine_spike_detector.get_trace_margin() if use_fine_detector else 0 + self.peeler_margin = max(self.nbefore, self.nafter) * 2 + self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) def get_trace_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): - traces = traces.copy() + + # TODO check if this is usefull + residuals = traces.copy() all_spikes = [] level = 0 + spikes_prev_loop = np.zeros(0, dtype=_base_matching_dtype) + use_fine_detector_level = False while True: - # spikes = _tdc_find_spikes(traces, d, level=level) - spikes = self._find_spikes_one_level(traces, level=level) - keep = spikes["cluster_index"] >= 0 - - if not np.any(keep): - break - all_spikes.append(spikes[keep]) + # print('level', level) + spikes = self._find_spikes_one_level(residuals, spikes_prev_loop, use_fine_detector_level, level) + if spikes.size > 0: + all_spikes.append(spikes) level += 1 - if level == self.num_peeler_loop: - break + # TODO concatenate all spikes for this instead of prev loop + spikes_prev_loop = spikes + + if (spikes.size == 0) or (level == self.max_peeler_loop): + if self.use_fine_detector and not use_fine_detector_level: + # extra loop with fine detector + use_fine_detector_level = True + level = self.max_peeler_loop - 1 + continue + else: + break if len(all_spikes) > 0: all_spikes = np.concatenate(all_spikes) @@ -185,13 +223,34 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): return all_spikes - def _find_spikes_one_level(self, traces, level=0): + def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, level): - peak_traces = traces[self.margin // 2 : -self.margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask - ) - peak_sample_ind += self.margin // 2 + # print(use_fine_detector, level) + + # TODO change the threhold dynaically depending the level + # peak_traces = traces[self.detector_margin : -self.detector_margin, :] + + # peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( + # peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask + # ) + + if use_fine_detector: + peak_detector = self.fine_spike_detector + else: + peak_detector = self.fast_spike_detector + + detector_margin = peak_detector.get_trace_margin() + if self.peeler_margin > detector_margin: + margin_shift = self.peeler_margin - detector_margin + sl = slice(margin_shift, -margin_shift) + else: + sl = slice(None) + margin_shift = 0 + peak_traces = traces[sl, :] + (peaks,) = peak_detector.compute(peak_traces, None, None, 0, self.margin) + peak_sample_ind = peaks["sample_index"] + peak_chan_ind = peaks["channel_index"] + peak_sample_ind += margin_shift peak_amplitude = traces[peak_sample_ind, peak_chan_ind] order = np.argsort(np.abs(peak_amplitude))[::-1] @@ -200,153 +259,438 @@ def _find_spikes_one_level(self, traces, level=0): spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template + spikes["channel_index"] = peak_chan_ind - possible_shifts = self.possible_shifts - distances_shift = np.zeros(possible_shifts.size) + distances_shift = np.zeros(self.possible_shifts.size) - for i in range(peak_sample_ind.size): + delta_sample = max(self.nbefore, self.nafter) # TODO check this maybe add margin + # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + + # neighbors in actual and previous level + neighbors_spikes_inds = get_neighbors_spikes( + np.concatenate([spikes["sample_index"], spikes_prev_loop["sample_index"]]), + np.concatenate([spikes["channel_index"], spikes_prev_loop["channel_index"]]), + delta_sample, + self.near_chan_mask, + ) + + for i in range(spikes.size): sample_index = peak_sample_ind[i] chan_ind = peak_chan_ind[i] possible_clusters = self.possible_clusters_by_channel[chan_ind] if possible_clusters.size > 0: - # ~ s0 = sample_index - d['nbefore'] - # ~ s1 = sample_index + d['nafter'] + cluster_index = get_most_probable_cluster( + traces, + self.sparse_templates_array_short, + possible_clusters, + sample_index, + chan_ind, + self.nbefore_short, + self.nafter_short, + self.sparsity_mask, + ) + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # chans = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # wf = traces[sample_index - self.nbefore : sample_index + self.nafter][:, chans] + # ax.plot(wf.T.flatten(), color='k') + # dense_templates_array = self.templates.get_dense_templates() + # for c_ind in possible_clusters: + # template = dense_templates_array[c_ind, :, :][:, chans] + # ax.plot(template.T.flatten()) + # if c_ind == cluster_index: + # ax.plot(template.T.flatten(), color='m', ls='--') + # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") + # plt.show() + + chan_sparsity_mask = self.sparsity_mask[cluster_index, :] + + # find best shift + numba_best_shift_sparse( + traces, + self.sparse_templates_array_short[cluster_index, :, :], + sample_index, + self.nbefore_short, + self.possible_shifts, + distances_shift, + chan_sparsity_mask, + ) + + ind_shift = np.argmin(distances_shift) + shift = self.possible_shifts[ind_shift] + + # TODO DEBUG shift later + spikes["sample_index"][i] += shift + + spikes["cluster_index"][i] = cluster_index + + # check that the the same cluster is not already detected at same place + # this can happen for small template the substract forvever the traces + outer_neighbors_inds = [ind for ind in neighbors_spikes_inds[i] if ind > i and ind >= spikes.size] + is_valid = True + for b in outer_neighbors_inds: + b = b - spikes.size + if (spikes[i]["sample_index"] == spikes_prev_loop[b]["sample_index"]) and ( + spikes[i]["cluster_index"] == spikes_prev_loop[b]["cluster_index"] + ): + is_valid = False + + if is_valid: + # temporary assign a cluster to neighbors if not done yet + inner_neighbors_inds = [ind for ind in neighbors_spikes_inds[i] if (ind > i and ind < spikes.size)] + for b in inner_neighbors_inds: + spikes["cluster_index"][b] = get_most_probable_cluster( + traces, + self.sparse_templates_array_short, + possible_clusters, + spikes["sample_index"][b], + spikes["channel_index"][b], + self.nbefore_short, + self.nafter_short, + self.sparsity_mask, + ) + + amp = fit_one_amplitude_with_neighbors( + spikes[i], + spikes[inner_neighbors_inds], + traces, + self.sparsity_mask, + self.templates.templates_array, + self.template_norms, + self.nbefore, + self.nafter, + ) - # ~ wf = traces[s0:s1, :] + low_lim, up_lim = self.amplitude_limits + if low_lim <= amp <= up_lim: + spikes["amplitude"][i] = amp + wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) # TODO move this before the loop + construct_prediction_sparse( + spikes[i : i + 1], + traces, + self.templates.templates_array, + self.sparsity_mask, + wanted_channel_mask, + self.nbefore, + additive=False, + ) + elif low_lim > amp: + # print("bad amp", amp) + spikes["cluster_index"][i] = -1 + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # print(chan_sparsity_mask) + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # dense_templates_array = self.templates.get_dense_templates() + # template = dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.plot(template.T.flatten()) + # ax.plot(template.T.flatten() * amp) + # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") + # plt.show() + else: + # amp > up_lim + # TODO should try other cluster for the fit!! + # spikes["cluster_index"][i] = -1 + + # force amplitude to be one and need a fiting at next level + spikes["amplitude"][i] = 1 + + # print(amp) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # dense_templates_array = self.templates.get_dense_templates() + # template = dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.plot(template.T.flatten()) + # ax.plot(template.T.flatten() * amp) + # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") + # plt.show() + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # chans = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # wf = traces[sample_index - self.nbefore : sample_index + self.nafter][:, chans] + # ax.plot(wf.T.flatten(), color='k') + # dense_templates_array = self.templates.get_dense_templates() + # for c_ind in possible_clusters: + # template = dense_templates_array[c_ind, :, :][:, chans] + # ax.plot(template.T.flatten()) + # if c_ind == cluster_index: + # ax.plot(template.T.flatten(), color='m', ls='--') + # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") + # plt.show() - s0 = sample_index - self.nbefore_short - s1 = sample_index + self.nafter_short - wf_short = traces[s0:s1, :] + else: + # not valid because already detected + spikes["cluster_index"][i] = -1 - ## pure numpy with cluster spasity - # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) + else: + # no possible cluster in neighborhood for this channel + spikes["cluster_index"][i] = -1 - ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) - # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) + # delta_sample = self.nbefore + self.nafter + # # TODO benchmark this and make this faster + # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + # for i in range(spikes.size): + # amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, + # self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) + # spikes["amplitude"][i] = amp - ## numba with cluster+channel spasity - union_channels = np.any(self.template_sparsity[possible_clusters, :], axis=0) - # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) - distances = numba_sparse_dist(wf_short, self.templates_short, union_channels, possible_clusters) + keep = spikes["cluster_index"] >= 0 + spikes = spikes[keep] - # DEBUG - # ~ ind = np.argmin(distances) - # ~ cluster_index = possible_clusters[ind] + # keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) + # spikes = spikes[keep] - for ind in np.argsort(distances)[: self.num_template_try]: - cluster_index = possible_clusters[ind] + # sparse_templates_array = self.templates.templates_array + # wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) + # assert np.sum(wanted_channel_mask) == traces.shape[1] # TODO remove this DEBUG later + # construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, wanted_channel_mask, self.nbefore, additive=False) - chan_sparsity = self.template_sparsity[cluster_index, :] - template_sparse = self.templates_array[cluster_index, :, :][:, chan_sparsity] + return spikes - # find best shift - ## pure numpy version - # for s, shift in enumerate(possible_shifts): - # wf_shift = traces[s0 + shift: s1 + shift, chan_sparsity] - # distances_shift[s] = np.sum((template_sparse - wf_shift)**2) - # ind_shift = np.argmin(distances_shift) - # shift = possible_shifts[ind_shift] +def get_most_probable_cluster( + traces, + sparse_templates_array, + possible_clusters, + sample_index, + chan_ind, + nbefore_short, + nafter_short, + template_sparsity_mask, +): + s0 = sample_index - nbefore_short + s1 = sample_index + nafter_short + wf_short = traces[s0:s1, :] - ## numba version - numba_best_shift( - traces, - self.templates_array[cluster_index, :, :], - sample_index, - self.nbefore, - possible_shifts, - distances_shift, - chan_sparsity, - ) - ind_shift = np.argmin(distances_shift) - shift = possible_shifts[ind_shift] - - sample_index = sample_index + shift - s0 = sample_index - self.nbefore - s1 = sample_index + self.nafter - wf_sparse = traces[s0:s1, chan_sparsity] - - # accept or not - - centered = wf_sparse - template_sparse - accepted = True - for other_ind, other_vector in self.closest_units[cluster_index]: - v = np.sum(centered * other_vector) - if np.abs(v) > 0.5: - accepted = False - break - - if accepted: - # ~ if ind != np.argsort(distances)[0]: - # ~ print('not first one', np.argsort(distances), ind) - break - - if accepted: - amplitude = 1.0 - - # remove template - template = self.templates_array[cluster_index, :, :] - s0 = sample_index - self.nbefore - s1 = sample_index + self.nafter - traces[s0:s1, :] -= template * amplitude + ## numba with cluster+channel spasity + union_channels = np.any(template_sparsity_mask[possible_clusters, :], axis=0) + distances = numba_sparse_distance( + wf_short, sparse_templates_array, template_sparsity_mask, union_channels, possible_clusters + ) - else: - cluster_index = -1 - amplitude = 0.0 + ind = np.argmin(distances) + cluster_index = possible_clusters[ind] - else: - cluster_index = -1 - amplitude = 0.0 + return cluster_index - spikes["cluster_index"][i] = cluster_index - spikes["amplitude"][i] = amplitude - return spikes +def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): + + neighbors_spikes_inds = [] + for i in range(sample_inds.size): + + inds = np.flatnonzero(np.abs(sample_inds - sample_inds[i]) < delta_sample) + neighb = [] + for ind in inds: + if near_chan_mask[chan_inds[i], chan_inds[ind]] and i != ind: + neighb.append(ind) + neighbors_spikes_inds.append(neighb) + + return neighbors_spikes_inds + + +def fit_one_amplitude_with_neighbors( + spike, neighbors_spikes, traces, template_sparsity_mask, sparse_templates_array, template_norms, nbefore, nafter +): + """ + Fit amplitude one spike of one spike with/without neighbors + + """ + + import scipy.linalg + + cluster_index = spike["cluster_index"] + sample_index = spike["sample_index"] + chan_sparsity_mask = template_sparsity_mask[cluster_index, :] + num_chans = np.sum(chan_sparsity_mask) + if num_chans == 0: + # protect against empty template because too sparse + return 0.0 + start, stop = sample_index - nbefore, sample_index + nafter + if neighbors_spikes is None or (neighbors_spikes.size == 0): + template = sparse_templates_array[cluster_index, :, :num_chans] + wf = traces[start:stop, :][:, chan_sparsity_mask] + # TODO precompute template norms + amplitude = np.sum(template.flatten() * wf.flatten()) / template_norms[cluster_index] + else: + + lim0 = min(start, np.min(neighbors_spikes["sample_index"]) - nbefore) + lim1 = max(stop, np.max(neighbors_spikes["sample_index"]) + nafter) + + local_traces = traces[lim0:lim1, :][:, chan_sparsity_mask] + mask_not_fitted = (neighbors_spikes["amplitude"] == 0.0) & (neighbors_spikes["cluster_index"] >= 0) + local_spike = spike.copy() + local_spike["sample_index"] -= lim0 + local_spike["amplitude"] = 1.0 + + local_neighbors_spikes = neighbors_spikes.copy() + local_neighbors_spikes["sample_index"] -= lim0 + local_neighbors_spikes["amplitude"][:] = 1.0 + + num_spikes_to_fit = 1 + np.sum(mask_not_fitted) + x = np.zeros((lim1 - lim0, num_chans, num_spikes_to_fit), dtype="float32") + wanted_channel_mask = chan_sparsity_mask + construct_prediction_sparse( + np.array([local_spike]), + x[:, :, 0], + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + True, + ) + + j = 1 + for i in range(neighbors_spikes.size): + if mask_not_fitted[i]: + # add to one regressor + construct_prediction_sparse( + local_neighbors_spikes[i : i + 1], + x[:, :, j], + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + True, + ) + j += 1 + elif local_neighbors_spikes[neighbors_spikes[i]]["sample_index"] >= 0: + # remove from traces + construct_prediction_sparse( + local_neighbors_spikes[i : i + 1], + local_traces, + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + False, + ) + # else: + # pass + + x = x.reshape(-1, num_spikes_to_fit) + y = local_traces.flatten() + + res = scipy.linalg.lstsq(x, y, cond=None, lapack_driver="gelsd") + amplitudes = res[0] + amplitude = amplitudes[0] + + # import matplotlib.pyplot as plt + # x_plot = x.reshape((lim1 - lim0, num_chans, num_spikes_to_fit)).swapaxes(0, 1).reshape(-1, num_spikes_to_fit) + # pred = x @ amplitudes + # pred_plot = pred.reshape(-1, num_chans).T.flatten() + # y_plot = y.reshape(-1, num_chans).T.flatten() + # fig, ax = plt.subplots() + # ax.plot(x_plot, color='b') + # print(x_plot.shape, y_plot.shape) + # ax.plot(y_plot, color='g') + # ax.plot(pred_plot , color='r') + # ax.set_title(f"{amplitudes}") + # # ax.set_title(f"{amplitudes} {amp_dot}") + # plt.show() + + return amplitude if HAVE_NUMBA: @jit(nopython=True) - def numba_sparse_dist(wf, templates, union_channels, possible_clusters): + def construct_prediction_sparse( + spikes, traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, additive + ): + # must have np.sum(wanted_channel_mask) == traces.shape[0] + total_chans = wanted_channel_mask.shape[0] + for spike in spikes: + ind0 = spike["sample_index"] - nbefore + ind1 = ind0 + sparse_templates_array.shape[1] + cluster_index = spike["cluster_index"] + amplitude = spike["amplitude"] + chan_in_template = 0 + chan_in_trace = 0 + for chan in range(total_chans): + if wanted_channel_mask[chan]: + if template_sparsity_mask[cluster_index, chan]: + if additive: + traces[ind0:ind1, chan_in_trace] += ( + sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + ) + else: + traces[ind0:ind1, chan_in_trace] -= ( + sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + ) + chan_in_template += 1 + chan_in_trace += 1 + else: + if template_sparsity_mask[cluster_index, chan]: + chan_in_template += 1 + + @jit(nopython=True) + def numba_sparse_distance( + wf, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, possible_clusters + ): """ numba implementation that compute distance from template with sparsity - handle by two separate vectors + + wf is dense + sparse_templates_array is sparse with the template_sparsity_mask """ - total_cluster, width, num_chan = templates.shape + width, total_chans = wf.shape num_cluster = possible_clusters.shape[0] distances = np.zeros((num_cluster,), dtype=np.float32) for i in prange(num_cluster): cluster_index = possible_clusters[i] sum_dist = 0.0 - for chan_ind in range(num_chan): - if union_channels[chan_ind]: - for s in range(width): - v = wf[s, chan_ind] - t = templates[cluster_index, s, chan_ind] - sum_dist += (v - t) ** 2 + chan_in_template = 0 + for chan in range(total_chans): + if wanted_channel_mask[chan]: + if template_sparsity_mask[cluster_index, chan]: + for s in range(width): + v = wf[s, chan] + t = sparse_templates_array[cluster_index, s, chan_in_template] + sum_dist += (v - t) ** 2 + chan_in_template += 1 + else: + for s in range(width): + v = wf[s, chan] + t = 0 + sum_dist += (v - t) ** 2 + else: + if template_sparsity_mask[cluster_index, chan]: + chan_in_template += 1 distances[i] = sum_dist return distances @jit(nopython=True) - def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity): + def numba_best_shift_sparse( + traces, sparse_template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity + ): """ numba implementation to compute several sample shift before template substraction """ - width, num_chan = template.shape + width = sparse_template.shape[0] + total_chans = traces.shape[1] n_shift = possible_shifts.size for i in range(n_shift): shift = possible_shifts[i] sum_dist = 0.0 - for chan_ind in range(num_chan): - if chan_sparsity[chan_ind]: + chan_in_template = 0 + for chan in range(total_chans): + if chan_sparsity[chan]: for s in range(width): - v = traces[sample_index - nbefore + s + shift, chan_ind] - t = template[s, chan_ind] + v = traces[sample_index - nbefore + s + shift, chan] + t = sparse_template[s, chan_in_template] sum_dist += (v - t) ** 2 + chan_in_template += 1 distances_shift[i] = sum_dist return distances_shift diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 2531a922da..3099448b11 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -348,7 +348,7 @@ def __init__( BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - templates_array = templates.get_dense_templates().astype(np.float32, casting="safe") + templates_array = templates.get_dense_templates().astype(np.float32) # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index cbf1d29932..7cd899a3bb 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -9,8 +9,8 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset -job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) -# job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) +# job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) +job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) def get_sorting_analyzer(): @@ -45,7 +45,7 @@ def test_find_spikes_from_templates(method, sorting_analyzer): "templates": templates, } method_kwargs = {} - if method in ("naive", "tdc-peeler", "circus"): + if method in ("naive", "tdc-peeler", "circus", "tdc-peeler2"): method_kwargs["noise_levels"] = noise_levels # method_kwargs["wobble"] = { @@ -61,26 +61,28 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # print(info) - # DEBUG = True + DEBUG = True - # if DEBUG: - # import matplotlib.pyplot as plt - # import spikeinterface.full as si + if DEBUG: + import matplotlib.pyplot as plt + import spikeinterface.full as si - # sorting_analyzer.compute("waveforms") - # sorting_analyzer.compute("templates") + sorting_analyzer.compute("waveforms") + sorting_analyzer.compute("templates") - # gt_sorting = sorting_analyzer.sorting + gt_sorting = sorting_analyzer.sorting - # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency) + sorting = NumpySorting.from_times_labels( + spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency + ) - # ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) + ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) - # fig, ax = plt.subplots() - # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) - # si.plot_agreement_matrix(comp, ax=ax) - # ax.set_title(method) - # plt.show() + # fig, ax = plt.subplots() + # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) + # si.plot_agreement_matrix(comp, ax=ax) + # ax.set_title(method) + # plt.show() if __name__ == "__main__": @@ -88,6 +90,6 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # method = "naive" # method = "tdc-peeler" # method = "circus" - method = "circus-omp-svd" - # method = "wobble" + # method = "circus-omp-svd" + method = "wobble" test_find_spikes_from_templates(method, sorting_analyzer) From 957861fd43861a124880c41c5cbcc8921db30889 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 14 Oct 2024 17:19:07 +0200 Subject: [PATCH 06/11] Sparsify the weights --- src/spikeinterface/sortingcomponents/peak_detection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index d608c5d105..51b3e4dc77 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -631,7 +631,7 @@ def __init__( weight_method={}, ): PeakDetector.__init__(self, recording, return_output=True) - + import scipy if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') @@ -664,7 +664,7 @@ def __init__( self.num_templates *= 2 self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) - + self.weights = scipy.sparse.csr_matrix(self.weights) random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs) conv_random_data = self.get_convolved_traces(random_data) medians = np.median(conv_random_data, axis=1) @@ -737,7 +737,7 @@ def get_convolved_traces(self, traces): import scipy.signal tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") - scalar_products = np.dot(self.weights, tmp) + scalar_products = self.weights.dot(tmp) return scalar_products From 5568e1a3cd98f6f9c77953c294fdd558c4457e6c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:23:19 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/peak_detection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 51b3e4dc77..2961f11981 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -632,6 +632,7 @@ def __init__( ): PeakDetector.__init__(self, recording, return_output=True) import scipy + if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') From ba847a8720c54558f94d918ee2e6a8797713729d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Oct 2024 17:52:59 +0000 Subject: [PATCH 08/11] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black: 24.8.0 → 24.10.0](https://github.com/psf/black/compare/24.8.0...24.10.0) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e133694ba..4c36d6fb86 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black files: ^src/ From 14278161efe44c5955dc2072a5354de73dcf6bb3 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 14 Oct 2024 23:36:43 +0200 Subject: [PATCH 09/11] Imports --- src/spikeinterface/sortingcomponents/peak_detection.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 2961f11981..d2d1afaafb 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -631,7 +631,7 @@ def __init__( weight_method={}, ): PeakDetector.__init__(self, recording, return_output=True) - import scipy + from scipy.sparse import csr_matrix if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') @@ -665,7 +665,7 @@ def __init__( self.num_templates *= 2 self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) - self.weights = scipy.sparse.csr_matrix(self.weights) + self.weights = csr_matrix(self.weights) random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs) conv_random_data = self.get_convolved_traces(random_data) medians = np.median(conv_random_data, axis=1) @@ -735,9 +735,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) def get_convolved_traces(self, traces): - import scipy.signal - - tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") + from scipy.signal import oaconvolve + tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") scalar_products = self.weights.dot(tmp) return scalar_products From b9f2cc803b295097a6cf4ae95eee5f82d5be222f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Oct 2024 21:37:04 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/peak_detection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index d2d1afaafb..134481289e 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -736,6 +736,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def get_convolved_traces(self, traces): from scipy.signal import oaconvolve + tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") scalar_products = self.weights.dot(tmp) return scalar_products From 3e608c60b96a33f7cccdabb1adcf0a92ab3ada78 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 15 Oct 2024 12:15:37 +0200 Subject: [PATCH 11/11] Torch support for matching engines circus and OMP * Fixes * Patches * Fixes for SC2 and for split clustering * debugging clustering * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Torch for convolutions * Forcing data structures to be float32 * Device and wobble * WIP * Speeding up wobble * WIP * WIP * Troch * WIP torch * WIP * WIP * Addition of a detection node for coherence * Doc * WIP * Default params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * Handling context with torch on the fly * Dealing with torch * Adding support for torch in matching engines * Automatic handling of torch * Default back * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Adding gather_func to find_spikes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Gathering mode more explicit for matching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * WIP * Fixes for SC2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * Simplifications * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Naming for Sam * Optimize circus matching engine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Optimizations * Remove the limit to chunk sizes in circus-omp-svd * WIP * Wobble also * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Wobble also * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Oups * WIP * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes * Backward compatibility* * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Naming * Cleaning * Bringing back context for peak detectors * Update src/spikeinterface/benchmark/benchmark_matching.py * Update src/spikeinterface/sortingcomponents/matching/circus.py * WIP * Patch imports * WIP * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixing tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * KSPeeler * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Moving KS in a new PR * Moving KS in a new PR * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Allow spawn and cuda for circus * Add push_to_torch to allow pickling of objects * Default * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleaning docs * WIP --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Garcia Samuel --- .../benchmark/benchmark_matching.py | 1 + .../sorters/internal/spyking_circus2.py | 8 +- .../sortingcomponents/clustering/circus.py | 6 +- .../sortingcomponents/matching/circus.py | 134 ++++++++++++------ .../sortingcomponents/matching/wobble.py | 111 +++++++++++---- .../sortingcomponents/peak_detection.py | 11 +- .../sortingcomponents/tests/test_wobble.py | 53 +++++-- 7 files changed, 235 insertions(+), 89 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 3799fa19b3..1934b65ef4 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -80,6 +80,7 @@ def plot_performances_comparison(self, **kwargs): def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 211adba990..eed693b343 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -27,7 +27,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, - "whitening": {"mode": "local", "regularize": True}, + "whitening": {"mode": "local", "regularize": False}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { "method": "uniform", @@ -100,6 +100,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): except: HAVE_HDBSCAN = False + try: + import torch + except ImportError: + HAVE_TORCH = False + print("spykingcircus2 could benefit from using torch. Consider installing it") + assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" # this is importanted only on demand because numba import are too heavy diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index b7e71d3b45..99c59f493e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -92,7 +92,7 @@ def main_function(cls, recording, peaks, params): # SVD for time compression few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter)) few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"] + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) wfs = few_wfs[:, :, 0] @@ -141,7 +141,7 @@ def main_function(cls, recording, peaks, params): all_pc_data = run_node_pipeline( recording, pipeline_nodes, - params["job_kwargs"], + job_kwargs, job_name="extracting features", ) @@ -176,7 +176,7 @@ def main_function(cls, recording, peaks, params): _ = run_node_pipeline( recording, pipeline_nodes, - params["job_kwargs"], + job_kwargs, job_name="extracting features", gather_mode="npy", gather_kwargs=dict(exist_ok=True), diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index d1b2139c5b..3b97f2dc6a 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -18,6 +18,15 @@ ("segment_index", "int64"), ] +try: + import torch + import torch.nn.functional as F + + HAVE_TORCH = True + from torch.nn.functional import conv1d +except ImportError: + HAVE_TORCH = False + from .base import BaseTemplateMatching @@ -43,9 +52,9 @@ def compress_templates( temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False) # Keep only the strongest components - temporal = temporal[:, :, :approx_rank] - singular = singular[:, :approx_rank] - spatial = spatial[:, :approx_rank, :] + temporal = temporal[:, :, :approx_rank].astype(np.float32) + singular = singular[:, :approx_rank].astype(np.float32) + spatial = spatial[:, :approx_rank, :].astype(np.float32) if return_new_templates: templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial) @@ -107,18 +116,22 @@ class CircusOMPSVDPeeler(BaseTemplateMatching): Parameters ---------- - amplitude: tuple + amplitude : tuple (Minimal, Maximal) amplitudes allowed for every template - max_failures: int + max_failures : int Stopping criteria of the OMP algorithm, as number of retry while updating amplitudes - sparse_kwargs: dict + sparse_kwargs : dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. - rank: int, default: 5 + rank : int, default: 5 Number of components used internally by the SVD - vicinity: int + vicinity : int Size of the area surrounding a spike to perform modification (expressed in terms of template temporal width) + engine : string in ["numpy", "torch", "auto"]. Default "auto" + The engine to use for the convolutions + torch_device : string in ["cpu", "cuda", None]. Default "cpu" + Controls torch device if the torch engine is selected ----- """ @@ -148,6 +161,8 @@ def __init__( ignore_inds=[], vicinity=2, precomputed=None, + engine="numpy", + torch_device="cpu", ): BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) @@ -158,6 +173,19 @@ def __init__( self.nafter = templates.nafter self.sampling_frequency = recording.get_sampling_frequency() self.vicinity = vicinity * self.num_samples + assert engine in ["numpy", "torch", "auto"], "engine should be numpy, torch or auto" + if engine == "auto": + if HAVE_TORCH: + self.engine = "torch" + else: + self.engine = "numpy" + else: + if engine == "torch": + assert HAVE_TORCH, "please install torch to use the torch engine" + self.engine = engine + + assert torch_device in ["cuda", "cpu", None] + self.torch_device = torch_device self.amplitudes = amplitudes self.stop_criteria = stop_criteria @@ -183,6 +211,7 @@ def __init__( self.unit_overlaps_tables[i][self.unit_overlaps_indices[i]] = np.arange(len(self.unit_overlaps_indices[i])) self.margin = 2 * self.num_samples + self.is_pushed = False def _prepare_templates(self): @@ -254,6 +283,14 @@ def _prepare_templates(self): self.temporal = np.moveaxis(self.temporal, [0, 1, 2], [1, 2, 0]) self.singular = self.singular.T[:, :, np.newaxis] + def _push_to_torch(self): + if self.engine == "torch": + self.spatial = torch.as_tensor(self.spatial, device=self.torch_device) + self.singular = torch.as_tensor(self.singular, device=self.torch_device) + self.temporal = torch.as_tensor(self.temporal.copy(), device=self.torch_device).swapaxes(0, 1) + self.temporal = torch.flip(self.temporal, (2,)) + self.is_pushed = True + def get_extra_outputs(self): output = {} for key in self._more_output_keys: @@ -268,15 +305,15 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): import scipy from scipy import ndimage - (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) + if not self.is_pushed: + self._push_to_torch() + (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) (nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32) - overlaps_array = self.overlaps - omp_tol = np.finfo(np.float32).eps - num_samples = self.nafter + self.nbefore - neighbor_window = num_samples - 1 + neighbor_window = self.num_samples - 1 + if isinstance(self.amplitudes, list): min_amplitude, max_amplitude = self.amplitudes else: @@ -284,27 +321,36 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): min_amplitude = min_amplitude[:, np.newaxis] max_amplitude = max_amplitude[:, np.newaxis] - num_timesteps = len(traces) + if self.engine == "torch": + blank = np.zeros((neighbor_window, self.num_channels), dtype=np.float32) + traces = np.vstack((blank, traces, blank)) + num_timesteps = traces.shape[0] + torch_traces = torch.as_tensor(traces.T[np.newaxis, :, :], device=self.torch_device) + num_templates, num_channels = self.temporal.shape[0], self.temporal.shape[1] + spatially_filtered_data = torch.matmul(self.spatial, torch_traces) + scaled_filtered_data = (spatially_filtered_data * self.singular).swapaxes(0, 1) + scaled_filtered_data_ = scaled_filtered_data.reshape(1, num_templates * num_channels, num_timesteps) + scalar_products = conv1d(scaled_filtered_data_, self.temporal, groups=num_templates, padding="valid") + scalar_products = scalar_products.cpu().numpy()[0, :, self.num_samples - 1 : -neighbor_window] + else: + num_timesteps = traces.shape[0] + num_peaks = num_timesteps - neighbor_window + conv_shape = (self.num_templates, num_peaks) + scalar_products = np.zeros(conv_shape, dtype=np.float32) + # Filter using overlap-and-add convolution + spatially_filtered_data = np.matmul(self.spatial, traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * self.singular + from scipy import signal - num_peaks = num_timesteps - num_samples + 1 - conv_shape = (self.num_templates, num_peaks) - scalar_products = np.zeros(conv_shape, dtype=np.float32) + objective_by_rank = signal.oaconvolve(scaled_filtered_data, self.temporal, axes=2, mode="valid") + scalar_products += np.sum(objective_by_rank, axis=0) + + num_peaks = scalar_products.shape[1] # Filter using overlap-and-add convolution if len(self.ignore_inds) > 0: - not_ignored = ~np.isin(np.arange(self.num_templates), self.ignore_inds) - spatially_filtered_data = np.matmul(self.spatial[:, not_ignored, :], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * self.singular[:, not_ignored, :] - objective_by_rank = scipy.signal.oaconvolve( - scaled_filtered_data, self.temporal[:, not_ignored, :], axes=2, mode="valid" - ) - scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0) scalar_products[self.ignore_inds] = -np.inf - else: - spatially_filtered_data = np.matmul(self.spatial, traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * self.singular - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, self.temporal, axes=2, mode="valid") - scalar_products += np.sum(objective_by_rank, axis=0) + not_ignored = ~np.isin(np.arange(self.num_templates), self.ignore_inds) num_spikes = 0 @@ -322,7 +368,7 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): is_in_vicinity = np.zeros(0, dtype=np.int32) if self.stop_criteria == "omp_min_sps": - stop_criteria = self.omp_min_sps * np.maximum(self.norms, np.sqrt(self.num_channels * num_samples)) + stop_criteria = self.omp_min_sps * np.maximum(self.norms, np.sqrt(self.num_channels * self.num_samples)) elif self.stop_criteria == "max_failures": num_valids = 0 nb_failures = self.max_failures @@ -354,11 +400,11 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): if num_selection > 0: delta_t = selection[1] - peak_index - idx = np.flatnonzero((delta_t < num_samples) & (delta_t > -num_samples)) + idx = np.flatnonzero((delta_t < self.num_samples) & (delta_t > -self.num_samples)) myline = neighbor_window + delta_t[idx] myindices = selection[0, idx] - local_overlaps = overlaps_array[best_cluster_ind] + local_overlaps = self.overlaps[best_cluster_ind] overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] table = self.unit_overlaps_tables[best_cluster_ind] @@ -436,10 +482,10 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): for i in modified: tmp_best, tmp_peak = sub_selection[:, i] diff_amp = diff_amplitudes[i] * self.norms[tmp_best] - local_overlaps = overlaps_array[tmp_best] + local_overlaps = self.overlaps[tmp_best] overlapping_templates = self.units_overlaps[tmp_best] tmp = tmp_peak - neighbor_window - idx = [max(0, tmp), min(num_peaks, tmp_peak + num_samples)] + idx = [max(0, tmp), min(num_peaks, tmp_peak + self.num_samples)] tdx = [idx[0] - tmp, idx[1] - tmp] to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add @@ -500,27 +546,27 @@ class CircusPeeler(BaseTemplateMatching): Parameters ---------- - peak_sign: str + peak_sign : str Sign of the peak (neg, pos, or both) - exclude_sweep_ms: float + exclude_sweep_ms : float The number of samples before/after to classify a peak (should be low) - jitter: int + jitter : int The number of samples considered before/after every peak to search for matches - detect_threshold: int + detect_threshold : int The detection threshold - noise_levels: array + noise_levels : array The noise levels, for every channels - random_chunk_kwargs: dict + random_chunk_kwargs : dict Parameters for computing noise levels, if not provided (sub optimal) - max_amplitude: float + max_amplitude : float Maximal amplitude allowed for every template - min_amplitude: float + min_amplitude : float Minimal amplitude allowed for every template - use_sparse_matrix_threshold: float + use_sparse_matrix_threshold : float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) - sparse_kwargs: dict + sparse_kwargs : dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. ----- diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 3099448b11..59e171fe52 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -8,6 +8,15 @@ from .base import BaseTemplateMatching, _base_matching_dtype from spikeinterface.core.template import Templates +try: + import torch + import torch.nn.functional as F + + HAVE_TORCH = True + from torch.nn.functional import conv1d +except ImportError: + HAVE_TORCH = False + @dataclass class WobbleParameters: @@ -41,6 +50,10 @@ class WobbleParameters: Maximum value for ampltiude scaling of templates. scale_amplitudes : bool If True, scale amplitudes of templates to match spikes. + engine : string in ["numpy", "torch", "auto"]. Default "auto" + The engine to use for the convolutions + torch_device : string in ["cpu", "cuda", None]. Default "cpu" + Controls torch device if the torch engine is selected Notes ----- @@ -62,6 +75,8 @@ class WobbleParameters: scale_min: float = 0 scale_max: float = np.inf scale_amplitudes: bool = False + engine: str = "numpy" + torch_device: str = "cpu" def __post_init__(self): assert self.amplitude_variance >= 0, "amplitude_variance must be a non-negative scalar" @@ -344,6 +359,8 @@ def __init__( parents=None, templates=None, parameters={}, + engine="numpy", + torch_device="cpu", ): BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) @@ -352,6 +369,21 @@ def __init__( # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) + + assert engine in ["numpy", "torch", "auto"], "engine should be numpy, torch or auto" + if engine == "auto": + if HAVE_TORCH: + self.engine = "torch" + else: + self.engine = "numpy" + else: + if engine == "torch": + assert HAVE_TORCH, "please install torch to use the torch engine" + self.engine = engine + + assert torch_device in ["cuda", "cpu", None] + self.torch_device = torch_device + template_meta = TemplateMetadata.from_parameters_and_templates(params, templates_array) if not templates.are_templates_sparse(): sparsity = WobbleSparsity.from_parameters_and_templates(params, templates_array) @@ -366,13 +398,21 @@ def __init__( pairwise_convolution = convolve_templates( compressed_templates, params.jitter_factor, params.approx_rank, template_meta.jittered_indices, sparsity ) + norm_squared = compute_template_norm(sparsity.visible_channels, templates_array) + + spatial = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2]) + temporal = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0]) + singular = singular.T[:, :, np.newaxis] + + compressed_templates = (temporal, singular, spatial, temporal_jittered) template_data = TemplateData( compressed_templates=compressed_templates, pairwise_convolution=pairwise_convolution, norm_squared=norm_squared, ) + self.is_pushed = False self.params = params self.template_meta = template_meta self.sparsity = sparsity @@ -384,11 +424,24 @@ def __init__( # self.margin = int(buffer_ms*1e-3 * recording.sampling_frequency) self.margin = 300 # To ensure equivalence with spike-psvae version of the algorithm + def _push_to_torch(self): + if self.engine == "torch": + temporal, singular, spatial, temporal_jittered = self.template_data.compressed_templates + spatial = torch.as_tensor(spatial, device=self.torch_device) + singular = torch.as_tensor(singular, device=self.torch_device) + temporal = torch.as_tensor(temporal.copy(), device=self.torch_device).swapaxes(0, 1) + temporal = torch.flip(temporal, (2,)) + self.template_data.compressed_templates = (temporal, singular, spatial, temporal_jittered) + self.is_pushed = True + def get_trace_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): + if not self.is_pushed: + self._push_to_torch() + # Unpack method_kwargs # nbefore, nafter = method_kwargs["nbefore"], method_kwargs["nafter"] # template_meta = method_kwargs["template_meta"] @@ -400,7 +453,9 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): assert traces.dtype == np.float32, "traces must be specified as np.float32" # Compute objective - objective = compute_objective(traces, self.template_data, self.params.approx_rank) + objective = compute_objective( + traces, self.template_data, self.params.approx_rank, self.engine, self.torch_device + ) objective_normalized = 2 * objective - self.template_data.norm_squared[:, np.newaxis] # Compute spike train @@ -786,10 +841,11 @@ def compress_templates(templates, approx_rank) -> tuple[np.ndarray, np.ndarray, temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) # Keep only the strongest components - temporal = temporal[:, :, :approx_rank] + temporal = temporal[:, :, :approx_rank].astype(np.float32) temporal = np.flip(temporal, axis=1) - singular = singular[:, :approx_rank] - spatial = spatial[:, :approx_rank, :] + singular = singular[:, :approx_rank].astype(np.float32) + spatial = spatial[:, :approx_rank, :].astype(np.float32) + return temporal, singular, spatial @@ -827,7 +883,6 @@ def upsample_and_jitter(temporal, jitter_factor, num_samples): shape_temporal_jittered = (-1, num_samples, approx_rank) temporal_jittered = np.reshape(temporal_jittered[:, shifted_index, :], shape_temporal_jittered) - temporal_jittered = np.flip(temporal_jittered, axis=1) return temporal_jittered @@ -889,7 +944,7 @@ def convolve_templates(compressed_templates, jitter_factor, approx_rank, jittere return pairwise_convolution -def compute_objective(traces, template_data, approx_rank) -> np.ndarray: +def compute_objective(traces, template_data, approx_rank, engine="numpy", torch_device=None) -> np.ndarray: """Compute objective by convolving templates with voltage traces. Parameters @@ -898,31 +953,39 @@ def compute_objective(traces, template_data, approx_rank) -> np.ndarray: Voltage traces for a chunk of the recording. template_data : TemplateData Dataclass object for aggregating template data together. - approx_rank : int - Rank of the compressed template matrices. Returns ------- objective : ndarray (template_meta.num_templates, traces.shape[0]+template_meta.num_samples-1) Template matching objective for each template. """ - temporal, singular, spatial, temporal_jittered = template_data.compressed_templates - num_templates = temporal.shape[0] - num_samples = temporal.shape[1] - objective_len = get_convolution_len(traces.shape[0], num_samples) - conv_shape = (num_templates, objective_len) - objective = np.zeros(conv_shape, dtype=np.float32) - spatial_filters = np.moveaxis(spatial[:, :approx_rank, :], [0, 1, 2], [1, 0, 2]) - temporal_filters = np.moveaxis(temporal[:, :, :approx_rank], [0, 1, 2], [1, 2, 0]) - singular_filters = singular.T[:, :, np.newaxis] - - # Filter using overlap-and-add convolution - spatially_filtered_data = np.matmul(spatial_filters, traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * singular_filters - from scipy import signal + temporal, singular, spatial, _ = template_data.compressed_templates + if engine == "torch": + nt = temporal.shape[2] - 1 + num_channels = traces.shape[1] + blank = np.zeros((nt, num_channels), dtype=np.float32) + traces = np.vstack((blank, traces, blank)) + torch_traces = torch.as_tensor(traces.T[None, :, :], device=torch_device) + num_templates, num_channels = temporal.shape[0], temporal.shape[1] + num_timesteps = torch_traces.shape[2] + spatially_filtered_data = torch.matmul(spatial, torch_traces) + scaled_filtered_data = (spatially_filtered_data * singular).swapaxes(0, 1) + scaled_filtered_data_ = scaled_filtered_data.reshape(1, num_templates * num_channels, num_timesteps) + objective = conv1d(scaled_filtered_data_, temporal, groups=num_templates, padding="valid") + objective = objective.cpu().numpy()[0, :, :] + elif engine == "numpy": + num_channels, num_templates = temporal.shape[0], temporal.shape[1] + num_timesteps = temporal.shape[2] + objective_len = get_convolution_len(traces.shape[0], num_timesteps) + conv_shape = (num_templates, objective_len) + objective = np.zeros(conv_shape, dtype=np.float32) + # Filter using overlap-and-add convolution + spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * singular + from scipy import signal - objective_by_rank = signal.oaconvolve(scaled_filtered_data, temporal_filters, axes=2, mode="full") - objective += np.sum(objective_by_rank, axis=0) + objective_by_rank = signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="full") + objective += np.sum(objective_by_rank, axis=0) return objective diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 134481289e..5b1d33b334 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -603,13 +603,13 @@ class DetectPeakMatchedFiltering(PeakDetector): params_doc = ( DetectPeakByChannel.params_doc + """ - radius_um: float + radius_um : float The radius to use to select neighbour channels for locally exclusive detection. - prototype: array + prototype : array The canonical waveform of action potentials - rank : int (default 1) - The rank for SVD convolution of spatiotemporal templates with the traces - weight_method: dict + ms_before : float + The time in ms before the maximial value of the absolute prototype + weight_method : dict Parameter that should be provided to the get_convolution_weights() function in order to know how to estimate the positions. One argument is mode that could be either gaussian_2d (KS like) or exponential_3d (default) @@ -625,7 +625,6 @@ def __init__( detect_threshold=5, exclude_sweep_ms=0.1, radius_um=50, - rank=1, noise_levels=None, random_chunk_kwargs={"num_chunks_per_segment": 5}, weight_method={}, diff --git a/src/spikeinterface/sortingcomponents/tests/test_wobble.py b/src/spikeinterface/sortingcomponents/tests/test_wobble.py index d6d1e1e0b9..0d46b790ad 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_wobble.py +++ b/src/spikeinterface/sortingcomponents/tests/test_wobble.py @@ -44,7 +44,7 @@ def test_compress_templates(): elif test_case == "num_channels == num_samples": num_channels = rng.integers(1, 100) num_samples = num_channels - templates = rng.random((num_templates, num_samples, num_channels)) + templates = rng.random((num_templates, num_samples, num_channels), dtype=np.float32) full_rank = np.minimum(num_samples, num_channels) approx_rank = rng.integers(1, full_rank) @@ -66,15 +66,31 @@ def test_compress_templates(): assert np.all(singular_full >= 0) # check that svd matrices are orthonormal if applicable if num_channels > num_samples: - assert np.allclose(np.matmul(temporal_full, temporal_full.transpose(0, 2, 1)), np.eye(num_samples)) + assert np.allclose( + np.matmul(temporal_full, temporal_full.transpose(0, 2, 1)), + np.eye(num_samples, dtype=np.float32), + atol=1e-3, + ) elif num_samples > num_channels: - assert np.allclose(np.matmul(spatial_full, spatial_full.transpose(0, 2, 1)), np.eye(num_channels)) + assert np.allclose( + np.matmul(spatial_full, spatial_full.transpose(0, 2, 1)), + np.eye(num_channels, dtype=np.float32), + atol=1e-3, + ) elif num_channels == num_samples: - assert np.allclose(np.matmul(temporal_full, temporal_full.transpose(0, 2, 1)), np.eye(num_samples)) - assert np.allclose(np.matmul(spatial_full, spatial_full.transpose(0, 2, 1)), np.eye(num_channels)) + assert np.allclose( + np.matmul(temporal_full, temporal_full.transpose(0, 2, 1)), + np.eye(num_samples, dtype=np.float32), + atol=1e-3, + ) + assert np.allclose( + np.matmul(spatial_full, spatial_full.transpose(0, 2, 1)), + np.eye(num_channels, dtype=np.float32), + atol=1e-3, + ) # check that the full rank svd matrices reconstruct the original templates reconstructed_templates = np.matmul(temporal_full * singular_full[:, np.newaxis, :], spatial_full) - assert np.allclose(reconstructed_templates, templates) + assert np.allclose(reconstructed_templates, templates, atol=1e-3) def test_upsample_and_jitter(): @@ -211,18 +227,33 @@ def test_compute_objective(): approx_rank = rng.integers(1, num_samples) num_channels = rng.integers(1, 100) chunk_len = rng.integers(num_samples * 2, num_samples * 10) - traces = rng.random((chunk_len, num_channels)) + traces = rng.random((chunk_len, num_channels), dtype=np.float32) temporal = rng.random((num_templates, num_samples, approx_rank)) singular = rng.random((num_templates, approx_rank)) spatial = rng.random((num_templates, approx_rank, num_channels)) - compressed_templates = (temporal, singular, spatial, temporal) + + spatial_transformed = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2]) + temporal_transformed = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0]) + singular_transformed = singular.T[:, :, np.newaxis] + + compressed_templates_transformed = ( + temporal_transformed, + singular_transformed, + spatial_transformed, + temporal_transformed, + ) norm_squared = np.random.rand(num_templates) + + template_data_transformed = wobble.TemplateData( + compressed_templates=compressed_templates_transformed, pairwise_convolution=[], norm_squared=norm_squared + ) + # Act: run compute_objective + objective = wobble.compute_objective(traces, template_data_transformed, approx_rank, engine="numpy") + + compressed_templates = (temporal, singular, spatial, temporal) template_data = wobble.TemplateData( compressed_templates=compressed_templates, pairwise_convolution=[], norm_squared=norm_squared ) - - # Act: run compute_objective - objective = wobble.compute_objective(traces, template_data, approx_rank) expected_objective = compute_objective_loopy(traces, template_data, approx_rank) # Assert: check shape and equivalence to expected_objective