diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index c3b3099535..211adba990 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -24,9 +24,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, + "whitening": {"mode": "local", "regularize": True}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { "method": "uniform", @@ -36,7 +37,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "seed": 42, }, "apply_motion_correction": True, - "motion_correction": {"preset": "nonrigid_fast_and_accurate"}, + "motion_correction": {"preset": "dredge_fast"}, "merging": { "similarity_kwargs": {"method": "cosine", "support": "union", "max_lag_ms": 0.2}, "correlograms_kwargs": {}, @@ -46,7 +47,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): }, }, "clustering": {"legacy": True}, - "matching": {"method": "wobble"}, + "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -62,6 +63,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): and also the radius_um used to be considered during clustering", "sparsity": "A dictionary to be passed to all the calls to sparsify the templates", "filtering": "A dictionary for the high_pass filter to be used during preprocessing", + "whitening": "A dictionary for the whitening option to be used during preprocessing", "detection": "A dictionary for the peak detection node (locally_exclusive)", "selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\ and 5000 peaks per electrode on average.", @@ -109,8 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction from spikeinterface.sortingcomponents.tools import get_prototype_spike - job_kwargs = params["job_kwargs"] - job_kwargs = fix_job_kwargs(job_kwargs) + job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -119,7 +120,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): num_channels = recording.get_num_channels() ms_before = params["general"].get("ms_before", 2) ms_after = params["general"].get("ms_after", 2) - radius_um = params["general"].get("radius_um", 100) + radius_um = params["general"].get("radius_um", 75) exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after) / 2) ## First, we are filtering the data @@ -143,14 +144,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" params["motion_correction"].update({"folder": motion_folder}) - recording_f = correct_motion(recording_f, **params["motion_correction"]) + recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs) else: motion_folder = None ## We need to whiten before the template matching step, to boost the results # TODO add , regularize=True chen ready - recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True) + whitening_kwargs = params["whitening"].copy() + whitening_kwargs["dtype"] = "float32" + whitening_kwargs["radius_um"] = radius_um + if num_channels == 1: + whitening_kwargs["regularize"] = False + recording_w = whiten(recording_f, **whitening_kwargs) noise_levels = get_noise_levels(recording_w, return_scaled=False) if recording_w.check_serializability("json"): @@ -172,20 +178,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) - if params["matched_filtering"]: + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000) prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before - - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in detection_params: - detection_params.pop(value) - - detection_params["chunk_duration"] = "100ms" - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params) + else: + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) if verbose: print("We found %d peaks in total" % len(peaks)) @@ -196,7 +196,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible selection_params = params["selection"] - selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels + selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels) selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) selection_params.update({"noise_levels": noise_levels}) @@ -281,11 +281,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_job_params = job_kwargs.copy() if matching_method is not None: - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in matching_job_params: - matching_job_params[value] = None - matching_job_params["chunk_duration"] = "100ms" - spikes = find_spikes_from_templates( recording_w, matching_method, method_kwargs=matching_params, **matching_job_params ) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index b08ee4d9cb..b7e71d3b45 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -136,6 +136,7 @@ def main_function(cls, recording, peaks, params): pipeline_nodes = [node0, node1, node2] if len(params["recursive_kwargs"]) == 0: + from sklearn.decomposition import PCA all_pc_data = run_node_pipeline( recording, @@ -152,9 +153,9 @@ def main_function(cls, recording, peaks, params): sub_data = sub_data.reshape(len(sub_data), -1) if all_pc_data.shape[1] > params["n_svd"][1]: - tsvd = TruncatedSVD(params["n_svd"][1]) + tsvd = PCA(params["n_svd"][1], whiten=True) else: - tsvd = TruncatedSVD(all_pc_data.shape[1]) + tsvd = PCA(all_pc_data.shape[1], whiten=True) hdbscan_data = tsvd.fit_transform(sub_data) try: @@ -184,7 +185,7 @@ def main_function(cls, recording, peaks, params): ) sparse_mask = node1.neighbours_mask - neighbours_mask = get_channel_distances(recording) < radius_um + neighbours_mask = get_channel_distances(recording) <= radius_um # np.save(features_folder / "sparse_mask.npy", sparse_mask) np.save(features_folder / "peaks.npy", peaks) @@ -192,6 +193,8 @@ def main_function(cls, recording, peaks, params): original_labels = peaks["channel_index"] from spikeinterface.sortingcomponents.clustering.split import split_clusters + min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50) + peak_labels, _ = split_clusters( original_labels, recording, @@ -202,7 +205,7 @@ def main_function(cls, recording, peaks, params): feature_name="sparse_tsvd", neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, - min_size_split=50, + min_size_split=min_size, clusterer_kwargs=d["hdbscan_kwargs"], n_pca_features=params["n_svd"][1], scale_n_pca_by_depth=True, @@ -233,7 +236,7 @@ def main_function(cls, recording, peaks, params): if d["rank"] is not None: from spikeinterface.sortingcomponents.matching.circus import compress_templates - _, _, _, templates_array = compress_templates(templates_array, 5) + _, _, _, templates_array = compress_templates(templates_array, d["rank"]) templates = Templates( templates_array=templates_array, @@ -258,13 +261,8 @@ def main_function(cls, recording, peaks, params): print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) cleaning_matching_params = params["job_kwargs"].copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in cleaning_matching_params: - cleaning_matching_params.pop(value) - cleaning_matching_params["chunk_duration"] = "100ms" cleaning_matching_params["n_jobs"] = 1 cleaning_matching_params["progress_bar"] = False - cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 5934bdfbbb..15917934a8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -24,7 +24,7 @@ def split_clusters( peak_labels, recording, features_dict_or_folder, - method="hdbscan_on_local_pca", + method="local_feature_clustering", method_kwargs={}, recursive=False, recursive_depth=None, @@ -81,7 +81,6 @@ def split_clusters( ) as pool: labels_set = np.setdiff1d(peak_labels, [-1]) current_max_label = np.max(labels_set) + 1 - jobs = [] for label in labels_set: peak_indices = np.flatnonzero(peak_labels == label) @@ -95,15 +94,14 @@ def split_clusters( for res in iterator: is_split, local_labels, peak_indices = res.result() + # print(is_split, local_labels, peak_indices) if not is_split: continue mask = local_labels >= 0 peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label peak_labels[peak_indices[~mask]] = local_labels[~mask] - split_count[peak_indices] += 1 - current_max_label += np.max(local_labels[mask]) + 1 if recursive: @@ -120,6 +118,7 @@ def split_clusters( for label in new_labels_set: peak_indices = np.flatnonzero(peak_labels == label) if peak_indices.size > 0: + # print('Relaunched', label, len(peak_indices), recursion_level) jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level)) if progress_bar: iterator.total += 1 @@ -187,7 +186,7 @@ def split( min_size_split=25, n_pca_features=2, scale_n_pca_by_depth=False, - minimum_common_channels=2, + minimum_overlap_ratio=0.25, ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -199,19 +198,22 @@ def split( # target channel subset is done intersect local channels + neighbours local_chans = np.unique(peaks["channel_index"][peak_indices]) - target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + target_intersection_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + target_union_channels = np.flatnonzero(np.any(neighbours_mask[local_chans, :], axis=0)) + num_intersection = len(target_intersection_channels) + num_union = len(target_union_channels) # TODO fix this a better way, this when cluster have too few overlapping channels - if target_channels.size < minimum_common_channels: + if (num_intersection / num_union) < minimum_overlap_ratio: return False, None aligned_wfs, dont_have_channels = aggregate_sparse_features( - peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_intersection_channels ) local_labels[dont_have_channels] = -2 kept = np.flatnonzero(~dont_have_channels) - + # print(recursion_level, kept.size, min_size_split) if kept.size < min_size_split: return False, None @@ -222,6 +224,8 @@ def split( if flatten_features.shape[1] > n_pca_features: from sklearn.decomposition import PCA + # from sklearn.decomposition import TruncatedSVD + if scale_n_pca_by_depth: # tsvd = TruncatedSVD(n_pca_features * recursion_level) tsvd = PCA(n_pca_features * recursion_level, whiten=True) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index a3624f4296..d1b2139c5b 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -9,6 +9,7 @@ from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel from spikeinterface.core.template import Templates + spike_dtype = [ ("sample_index", "int64"), ("channel_index", "int64"), @@ -140,12 +141,12 @@ def __init__( templates=None, amplitudes=[0.6, np.inf], stop_criteria="max_failures", - max_failures=10, + max_failures=5, omp_min_sps=0.1, relative_error=5e-5, rank=5, ignore_inds=[], - vicinity=3, + vicinity=2, precomputed=None, ): @@ -181,16 +182,16 @@ def __init__( self.unit_overlaps_tables[i] = np.zeros(self.num_templates, dtype=int) self.unit_overlaps_tables[i][self.unit_overlaps_indices[i]] = np.arange(len(self.unit_overlaps_indices[i])) - if self.vicinity > 0: - self.margin = self.vicinity - else: - self.margin = 2 * self.num_samples + self.margin = 2 * self.num_samples def _prepare_templates(self): assert self.stop_criteria in ["max_failures", "omp_min_sps", "relative_error"] - sparsity = self.templates.sparsity.mask + if self.templates.sparsity is None: + sparsity = np.ones((self.num_templates, self.num_channels), dtype=bool) + else: + sparsity = self.templates.sparsity.mask units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) self.units_overlaps = units_overlaps > 0 @@ -265,6 +266,7 @@ def get_trace_margin(self): def compute_matching(self, traces, start_frame, end_frame, segment_index): import scipy.spatial import scipy + from scipy import ndimage (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) @@ -316,8 +318,6 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): full_sps = scalar_products.copy() - neighbors = {} - all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) @@ -336,100 +336,113 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): do_loop = True while do_loop: - best_amplitude_ind = scalar_products.argmax() - best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) - if num_selection > 0: - delta_t = selection[1] - peak_index - idx = np.where((delta_t < num_samples) & (delta_t > -num_samples))[0] - myline = neighbor_window + delta_t[idx] - myindices = selection[0, idx] + best_cluster_inds = np.argmax(scalar_products, axis=0, keepdims=True) + products = np.take_along_axis(scalar_products, best_cluster_inds, axis=0) - local_overlaps = overlaps_array[best_cluster_ind] - overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] - table = self.unit_overlaps_tables[best_cluster_ind] + result = ndimage.maximum_filter(products[0], size=self.vicinity, mode="constant", cval=0) + cond_1 = products[0] / self.norms[best_cluster_inds[0]] > 0.25 + cond_2 = np.abs(products[0] - result) < 1e-9 + peak_indices = np.flatnonzero(cond_1 * cond_2) - if num_selection == M.shape[0]: - Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) - Z[:num_selection, :num_selection] = M - M = Z + if len(peak_indices) == 0: + break - mask = np.isin(myindices, overlapping_templates) - a, b = myindices[mask], myline[mask] - M[num_selection, idx[mask]] = local_overlaps[table[a], b] + for peak_index in peak_indices: - if self.vicinity == 0: - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) - else: - is_in_vicinity = np.where(np.abs(delta_t) < self.vicinity)[0] + best_cluster_ind = best_cluster_inds[0, peak_index] + + if num_selection > 0: + delta_t = selection[1] - peak_index + idx = np.flatnonzero((delta_t < num_samples) & (delta_t > -num_samples)) + myline = neighbor_window + delta_t[idx] + myindices = selection[0, idx] + + local_overlaps = overlaps_array[best_cluster_ind] + overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] + table = self.unit_overlaps_tables[best_cluster_ind] - if len(is_in_vicinity) > 0: - L = M[is_in_vicinity, :][:, is_in_vicinity] + if num_selection == M.shape[0]: + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z[:num_selection, :num_selection] = M + M = Z - M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( - L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + mask = np.isin(myindices, overlapping_templates) + a, b = myindices[mask], myline[mask] + M[num_selection, idx[mask]] = local_overlaps[table[a], b] + + if self.vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, ) - v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + v = nrm2(M[num_selection, :num_selection]) ** 2 Lkk = 1 - v if Lkk <= omp_tol: # selected atoms are dependent break M[num_selection, num_selection] = np.sqrt(Lkk) else: - M[num_selection, num_selection] = 1.0 - else: - M[0, 0] = 1 - - all_selections[:, num_selection] = [best_cluster_ind, peak_index] - num_selection += 1 - - selection = all_selections[:, :num_selection] - res_sps = full_sps[selection[0], selection[1]] - - if self.vicinity == 0: - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - all_amplitudes /= self.norms[selection[0]] - else: - is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) - all_amplitudes = np.append(all_amplitudes, np.float32(1)) - L = M[is_in_vicinity, :][:, is_in_vicinity] - all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) - all_amplitudes[is_in_vicinity] /= self.norms[selection[0][is_in_vicinity]] - - diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] - modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] - final_amplitudes[selection[0], selection[1]] = all_amplitudes - - for i in modified: - tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i] * self.norms[tmp_best] - - local_overlaps = overlaps_array[tmp_best] - overlapping_templates = self.units_overlaps[tmp_best] + is_in_vicinity = np.flatnonzero(np.abs(delta_t) < self.vicinity) + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, + M[num_selection, is_in_vicinity], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 + else: + M[0, 0] = 1 - if not tmp_peak in neighbors.keys(): - idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] - tdx = [neighbor_window + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak - 1] - neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + all_selections[:, num_selection] = [best_cluster_ind, peak_index] + num_selection += 1 - idx = neighbors[tmp_peak]["idx"] - tdx = neighbors[tmp_peak]["tdx"] + selection = all_selections[:, :num_selection] + res_sps = full_sps[selection[0], selection[1]] - to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] - scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add + if self.vicinity == 0: + new_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + sub_selection = selection + new_amplitudes /= self.norms[sub_selection[0]] + else: + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + new_amplitudes, _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + sub_selection = selection[:, is_in_vicinity] + new_amplitudes /= self.norms[sub_selection[0]] + + diff_amplitudes = new_amplitudes - final_amplitudes[sub_selection[0], sub_selection[1]] + modified = np.flatnonzero(np.abs(diff_amplitudes) > omp_tol) + final_amplitudes[sub_selection[0], sub_selection[1]] = new_amplitudes + + for i in modified: + tmp_best, tmp_peak = sub_selection[:, i] + diff_amp = diff_amplitudes[i] * self.norms[tmp_best] + local_overlaps = overlaps_array[tmp_best] + overlapping_templates = self.units_overlaps[tmp_best] + tmp = tmp_peak - neighbor_window + idx = [max(0, tmp), min(num_peaks, tmp_peak + num_samples)] + tdx = [idx[0] - tmp, idx[1] - tmp] + to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] + scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add # We stop when updates do not modify the chosen spikes anymore if self.stop_criteria == "omp_min_sps": @@ -462,12 +475,9 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): spikes["cluster_index"][:num_spikes] = valid_indices[0] spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] - print("yep0", spikes.size, num_spikes, spikes.shape, spikes.dtype) spikes = spikes[:num_spikes] - print("yep1", spikes.size, spikes.shape, spikes.dtype) - if spikes.size > 0: - order = np.argsort(spikes["sample_index"]) - spikes = spikes[order] + order = np.argsort(spikes["sample_index"]) + spikes = spikes[order] return spikes diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index ad8897df91..d608c5d105 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -50,7 +50,14 @@ def detect_peaks( - recording, method="locally_exclusive", pipeline_nodes=None, gather_mode="memory", folder=None, names=None, **kwargs + recording, + method="locally_exclusive", + pipeline_nodes=None, + gather_mode="memory", + folder=None, + names=None, + skip_after_n_peaks=None, + **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -73,6 +80,9 @@ def detect_peaks( If gather_mode is "npy", the folder where the files are created. names : list List of strings with file stems associated with returns. + skip_after_n_peaks : None | int + Skip the computation after n_peaks. + This is not an exact because internally this skip is done per worker in average. {method_doc} {job_doc} @@ -124,6 +134,7 @@ def detect_peaks( squeeze_output=squeeze_output, folder=folder, names=names, + skip_after_n_peaks=skip_after_n_peaks, ) return outs diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index ddc8add995..08bcabf5e5 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -135,7 +135,7 @@ def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs["radius_um"] = radius_um def get_dtype(self):