From a0501e3c319799af33fe1c55415db4a51fb8e090 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 13 Nov 2023 10:49:38 +0100 Subject: [PATCH 01/30] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a16b642dd5..a120d4e97a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -169,4 +169,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorting_folder) - return sorting + return sorting \ No newline at end of file From fcdca11d488c6fd0d92420a6462da704de81c0b7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 13 Nov 2023 11:27:43 +0100 Subject: [PATCH 02/30] Baseline implementation of circus1 --- .../sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/circus.py | 373 +++++++++--------- 2 files changed, 198 insertions(+), 177 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a120d4e97a..53a72e2696 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -110,7 +110,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update({"noise_levels": noise_levels}) labels, peak_labels = find_cluster_from_peaks( - recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params + recording_f, selected_peaks, method="circus", method_kwargs=clustering_params ) ## We get the labels for our peaks diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 39f46475dc..83a92f1970 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -19,7 +19,53 @@ from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks +from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer +from spikeinterface.sortingcomponents.peak_selection import select_peaks +from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection +from sklearn.decomposition import TruncatedSVD +import pickle, json +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) + + +def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): + """ + Helper function to extractor waveforms at max channel from a peak list + + """ + n = rec.get_num_channels() + unit_ids = np.arange(n, dtype="int64") + sparsity_mask = np.eye(n, dtype="bool") + + spikes = np.zeros( + peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + ) + spikes["sample_index"] = peaks["sample_index"] + spikes["unit_index"] = peaks["channel_index"] + spikes["segment_index"] = peaks["segment_index"] + + nbefore = int(ms_before * rec.sampling_frequency / 1000.0) + nafter = int(ms_after * rec.sampling_frequency / 1000.0) + + all_wfs = extract_waveforms_to_single_buffer( + rec, + spikes, + unit_ids, + nbefore, + nafter, + mode="shared_memory", + return_scaled=False, + sparsity_mask=sparsity_mask, + copy=True, + **job_kwargs, + ) + + return all_wfs class CircusClustering: """ @@ -27,8 +73,6 @@ class CircusClustering: """ _default_params = { - "peak_locations": None, - "peak_localization_kwargs": {"method": "center_of_mass"}, "hdbscan_kwargs": { "min_cluster_size": 50, "allow_single_cluster": True, @@ -36,15 +80,17 @@ class CircusClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, - "tmp_folder": None, + "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, - "n_pca": 10, - "max_spikes_per_unit": 200, - "ms_before": 1.5, - "ms_after": 2.5, - "cleaning_method": "dip", - "waveform_mode": "memmap", - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, + "selection_method": "closest_to_centroid", + "n_svd": 10, + "ms_before": 1, + "ms_after": 1, + "random_seed": 42, + "noise_levels": None, + "shared_memory": True, + "tmp_folder": None, + "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } @classmethod @@ -70,131 +116,109 @@ def _check_params(cls, recording, peaks, params): @classmethod def main_function(cls, recording, peaks, params): - assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" - - params = cls._check_params(recording, peaks, params) - d = params + assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" - if d["peak_locations"] is None: - from spikeinterface.sortingcomponents.peak_localization import localize_peaks + if "n_jobs" in params["job_kwargs"]: + if params["job_kwargs"]["n_jobs"] == -1: + params["job_kwargs"]["n_jobs"] = os.cpu_count() - peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **d["job_kwargs"]) - else: - peak_locations = d["peak_locations"] - - tmp_folder = d["tmp_folder"] - if tmp_folder is not None: - tmp_folder.mkdir(exist_ok=True) + if "core_dist_n_jobs" in params["hdbscan_kwargs"]: + if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: + params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() - location_keys = ["x", "y"] - locations = np.stack([peak_locations[k] for k in location_keys], axis=1) - - chan_locs = recording.get_channel_locations() + d = params + verbose = d["job_kwargs"]["verbose"] peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - spikes = np.zeros(peaks.size, dtype=peak_dtype) - spikes["sample_index"] = peaks["sample_index"] - spikes["segment_index"] = peaks["segment_index"] - spikes["unit_index"] = peaks["channel_index"] - - num_chans = recording.get_num_channels() - sparsity_mask = np.zeros((peaks.size, num_chans), dtype="bool") - - unit_inds = range(num_chans) - chan_distances = get_channel_distances(recording) - - for main_chan in unit_inds: - (closest_chans,) = np.nonzero(chan_distances[main_chan, :] <= params["radius_um"]) - sparsity_mask[main_chan, closest_chans] = True - - if params["waveform_mode"] == "shared_memory": - wf_folder = None - else: - assert params["tmp_folder"] is not None, "tmp_folder must be supplied" - wf_folder = params["tmp_folder"] / "sparse_snippets" - wf_folder.mkdir() fs = recording.get_sampling_frequency() - nbefore = int(params["ms_before"] * fs / 1000.0) - nafter = int(params["ms_after"] * fs / 1000.0) + ms_before = params["ms_before"] + ms_after = params["ms_after"] + nbefore = int(ms_before * fs / 1000.0) + nafter = int(ms_after * fs / 1000.0) num_samples = nbefore + nafter + num_chans = recording.get_num_channels() - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_inds, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=sparsity_mask, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], - ) - - n_loc = len(location_keys) - import sklearn.decomposition, hdbscan + if d["noise_levels"] is None: + noise_levels = get_noise_levels(recording, return_scaled=False) + else: + noise_levels = d["noise_levels"] - noise_levels = get_noise_levels(recording, return_scaled=False) + np.random.seed(d["random_seed"]) - nb_clusters = 0 - peak_labels = np.zeros(len(spikes), dtype=np.int32) + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]).absolute() - noise = get_random_data_chunks( - recording, - return_scaled=False, - num_chunks_per_segment=params["max_spikes_per_unit"], - chunk_size=nbefore + nafter, - concatenated=False, - seed=None, - ) - noise = np.stack(noise, axis=0) + tmp_folder.mkdir(parents=True, exist_ok=True) - for main_chan, waveforms in wfs_arrays.items(): - idx = np.where(spikes["unit_index"] == main_chan)[0] - (channels,) = np.nonzero(sparsity_mask[main_chan]) - sub_noise = noise[:, :, channels] + # SVD for time compression + few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) + few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"]) - if len(waveforms) > 0: - sub_waveforms = waveforms + wfs = few_wfs[:, :, 0] + tsvd = TruncatedSVD(params["n_svd"]) + tsvd.fit(wfs) - wfs = np.swapaxes(sub_waveforms, 1, 2).reshape(len(sub_waveforms), -1) - noise_wfs = np.swapaxes(sub_noise, 1, 2).reshape(len(sub_noise), -1) + model_folder = tmp_folder / "tsvd_model" - n_pca = min(d["n_pca"], len(wfs)) - pca = sklearn.decomposition.PCA(n_pca) + model_folder.mkdir(exist_ok=True) + with open(model_folder / "pca_model.pkl", "wb") as f: + pickle.dump(tsvd, f) - hdbscan_data = np.vstack((wfs, noise_wfs)) + model_params = { + "ms_before": ms_before, + "ms_after": ms_after, + "sampling_frequency": float(fs), + } - pca.fit(wfs) - hdbscan_data_pca = pca.transform(hdbscan_data) - clustering = hdbscan.hdbscan(hdbscan_data_pca, **d["hdbscan_kwargs"]) + with open(model_folder / "params.json", "w") as f: + json.dump(model_params, f) - noise_labels = clustering[0][len(wfs) :] - valid_labels = clustering[0][: len(wfs)] + # features + features_folder = model_folder / "features" + node0 = PeakRetriever(recording, peaks) - shared_indices = np.intersect1d(np.unique(noise_labels), np.unique(valid_labels)) - for l in shared_indices: - idx_noise = noise_labels == l - idx_valid = valid_labels == l - if np.sum(idx_noise) > np.sum(idx_valid): - valid_labels[idx_valid] = -1 + radius_um = params["radius_um"] + node3 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=ms_before, + ms_after=ms_after, + radius_um=radius_um, + ) - if np.unique(valid_labels).min() == -1: - valid_labels += 1 + node4 = TemporalPCAProjection( + recording, parents=[node0, node3], return_output=True, model_folder_path=model_folder + ) - for l in np.unique(valid_labels): - idx_valid = valid_labels == l - if np.sum(idx_valid) < d["hdbscan_kwargs"]["min_cluster_size"]: - valid_labels[idx_valid] = -1 + # pipeline_nodes = [node0, node1, node2, node3, node4] + pipeline_nodes = [node0, node3, node4] - peak_labels[idx] = valid_labels + nb_clusters + all_pc_data = run_node_pipeline( + recording, + pipeline_nodes, + params["job_kwargs"], + job_name="extracting PCs", + ) - labels = np.unique(valid_labels) - labels = labels[labels >= 0] - nb_clusters += len(labels) + peak_labels = -1 * np.ones(len(peaks), dtype=int) + nb_clusters = 0 + for c in np.unique(peaks['channel_index']): + mask = peaks['channel_index'] == c + tsvd = TruncatedSVD(params["n_svd"]) + sub_data = all_pc_data[mask] + hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) + clustering = hdbscan.hdbscan(hdbscan_data, **d['hdbscan_kwargs']) + local_labels = clustering[0] + valid_clusters = local_labels > -1 + if np.sum(valid_clusters) > 0: + local_labels[valid_clusters] += nb_clusters + peak_labels[mask] = local_labels + nb_clusters += len(np.unique(local_labels[valid_clusters])) labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -202,11 +226,22 @@ def main_function(cls, recording, peaks, params): best_spikes = {} nb_spikes = 0 + import sklearn + all_indices = np.arange(0, peak_labels.size) + max_spikes = params["waveforms"]["max_spikes_per_unit"] + selection_method = params["selection_method"] + for unit_ind in labels: mask = peak_labels == unit_ind - best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[: params["max_spikes_per_unit"]] + if selection_method == "closest_to_centroid": + data = all_pc_data[mask].reshape(np.sum(mask), -1) + centroid = np.median(data, axis=0) + distances = sklearn.metrics.pairwise_distances(centroid[np.newaxis, :], data)[0] + best_spikes[unit_ind] = all_indices[mask][np.argsort(distances)[:max_spikes]] + elif selection_method == "random": + best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[:max_spikes] nb_spikes += best_spikes[unit_ind].size spikes = np.zeros(nb_spikes, dtype=peak_dtype) @@ -222,72 +257,58 @@ def main_function(cls, recording, peaks, params): spikes["segment_index"] = peaks[mask]["segment_index"] spikes["unit_index"] = peak_labels[mask] - if params["waveform_mode"] == "shared_memory": - wf_folder = None + if verbose: + print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + + sorting_folder = tmp_folder / "sorting" + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + + if params["shared_memory"]: + waveform_folder = None + mode = "memory" else: - assert params["tmp_folder"] is not None, "tmp_folder must be supplied" - wf_folder = params["tmp_folder"] / "dense_snippets" - wf_folder.mkdir() - - cleaning_method = params["cleaning_method"] - - print(f"We found {len(labels)} raw clusters, starting to clean with {cleaning_method}...") - - if cleaning_method == "cosine": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates( - wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] - ) - - elif cleaning_method == "dip": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates_via_dip(wfs_arrays, peak_labels) - - elif cleaning_method == "matching": - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) - - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) - we = extract_waveforms( - recording, - sorting, - tmp_folder, - overwrite=True, - ms_before=params["ms_before"], - ms_after=params["ms_after"], - **params["job_kwargs"], - ) - labels, peak_labels = remove_duplicates_via_matching(we, peak_labels, job_kwargs=params["job_kwargs"]) + waveform_folder = tmp_folder / "waveforms" + mode = "folder" + sorting = sorting.save(folder=sorting_folder) + + we = extract_waveforms( + recording, + sorting, + waveform_folder, + **params["job_kwargs"], + **params["waveforms"], + return_scaled=False, + mode=mode, + ) + + cleaning_matching_params = params["job_kwargs"].copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in cleaning_matching_params: + cleaning_matching_params.pop(value) + cleaning_matching_params["chunk_duration"] = "100ms" + cleaning_matching_params["n_jobs"] = 1 + cleaning_matching_params["verbose"] = False + cleaning_matching_params["progress_bar"] = False + + cleaning_params = params["cleaning_kwargs"].copy() + cleaning_params["tmp_folder"] = tmp_folder + + labels, peak_labels = remove_duplicates_via_matching( + we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + ) + + del we, sorting + + if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) + else: + if not params["shared_memory"]: + shutil.rmtree(tmp_folder / "waveforms") + shutil.rmtree(tmp_folder / "sorting") - print(f"We kept {len(labels)} non-duplicated clusters...") + if verbose: + print("We kept %d non-duplicated clusters..." % len(labels)) return labels, peak_labels + From d5dd8447cdd535d4664cc2243c5ebd608602db65 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 13 Nov 2023 11:34:22 +0100 Subject: [PATCH 03/30] WIP --- src/spikeinterface/sortingcomponents/clustering/circus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 83a92f1970..4a2f36e1e1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -74,7 +74,7 @@ class CircusClustering: _default_params = { "hdbscan_kwargs": { - "min_cluster_size": 50, + "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1, "cluster_selection_method": "leaf", @@ -83,7 +83,7 @@ class CircusClustering: "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, "selection_method": "closest_to_centroid", - "n_svd": 10, + "n_svd": 5, "ms_before": 1, "ms_after": 1, "random_seed": 42, From a00c4374abc257cf71b8145db18df06c2883e8d0 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 14:43:12 +0100 Subject: [PATCH 04/30] Circus 1 like --- .../comparison/groundtruthstudy.py | 2 +- .../sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/circus.py | 29 +++---------------- .../clustering/clustering_tools.py | 2 +- 4 files changed, 7 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 0d08922543..adc2898071 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -21,7 +21,7 @@ # This is to separate names when the key are tuples when saving folders -_key_separator = " ## " +_key_separator = "--" class GroundTruthStudy: diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 53a72e2696..76cc9684fa 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,7 +21,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 0.5}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 4a2f36e1e1..c193b1f93d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -77,13 +77,13 @@ class CircusClustering: "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1, - "cluster_selection_method": "leaf", + "cluster_selection_method": "eom", }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, "selection_method": "closest_to_centroid", - "n_svd": 5, + "n_svd": [6, 6], "ms_before": 1, "ms_after": 1, "random_seed": 42, @@ -93,27 +93,6 @@ class CircusClustering: "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } - @classmethod - def _check_params(cls, recording, peaks, params): - d = params - params2 = params.copy() - - tmp_folder = params["tmp_folder"] - if params["waveform_mode"] == "memmap": - if tmp_folder is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) - else: - tmp_folder = Path(tmp_folder) - tmp_folder.mkdir() - params2["tmp_folder"] = tmp_folder - elif params["waveform_mode"] == "shared_memory": - assert tmp_folder is None, "tmp_folder must be None for shared_memory" - else: - raise ValueError("'waveform_mode' must be 'memmap' or 'shared_memory'") - - return params2 - @classmethod def main_function(cls, recording, peaks, params): assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" @@ -159,7 +138,7 @@ def main_function(cls, recording, peaks, params): few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"]) wfs = few_wfs[:, :, 0] - tsvd = TruncatedSVD(params["n_svd"]) + tsvd = TruncatedSVD(params["n_svd"][0]) tsvd.fit(wfs) model_folder = tmp_folder / "tsvd_model" @@ -209,7 +188,7 @@ def main_function(cls, recording, peaks, params): nb_clusters = 0 for c in np.unique(peaks['channel_index']): mask = peaks['channel_index'] == c - tsvd = TruncatedSVD(params["n_svd"]) + tsvd = TruncatedSVD(params["n_svd"][1]) sub_data = all_pc_data[mask] hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) clustering = hdbscan.hdbscan(hdbscan_data, **d['hdbscan_kwargs']) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index b4938717f8..66fe660918 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -599,7 +599,7 @@ def remove_duplicates_via_matching( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, - "amplitudes": [0.95, 1.05], + "amplitudes": [0.975, 1.025], "omp_min_sps": 0.05, } ) From a7e297018d76317cf5b4c1a55ba9a64c35178d0a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 20:50:32 +0100 Subject: [PATCH 05/30] WIP for circus 1 --- .../sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/circus.py | 11 +++++------ .../sortingcomponents/clustering/clustering_tools.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 76cc9684fa..e746883259 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -23,7 +23,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 0.5}, "filtering": {"freq_min": 150, "dtype": "float32"}, - "detection": {"peak_sign": "neg", "detect_threshold": 5}, + "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, "clustering": {}, diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index c193b1f93d..ef733224bd 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -161,7 +161,7 @@ def main_function(cls, recording, peaks, params): node0 = PeakRetriever(recording, peaks) radius_um = params["radius_um"] - node3 = ExtractSparseWaveforms( + node1 = ExtractSparseWaveforms( recording, parents=[node0], return_output=False, @@ -170,18 +170,17 @@ def main_function(cls, recording, peaks, params): radius_um=radius_um, ) - node4 = TemporalPCAProjection( - recording, parents=[node0, node3], return_output=True, model_folder_path=model_folder + node2 = TemporalPCAProjection( + recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder ) - # pipeline_nodes = [node0, node1, node2, node3, node4] - pipeline_nodes = [node0, node3, node4] + pipeline_nodes = [node0, node1, node2] all_pc_data = run_node_pipeline( recording, pipeline_nodes, params["job_kwargs"], - job_name="extracting PCs", + job_name="extracting features", ) peak_labels = -1 * np.ones(len(peaks), dtype=int) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 66fe660918..6b6aba892e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -600,7 +600,7 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.975, 1.025], - "omp_min_sps": 0.05, + "omp_min_sps": 0.1, } ) From 292c96d53d920f007283450dd6f781a619b054e5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 20:53:39 +0100 Subject: [PATCH 06/30] Legacy mode --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index e746883259..fad744d143 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,7 +21,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 0.5}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6b6aba892e..aaddc15b46 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -600,7 +600,7 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.975, 1.025], - "omp_min_sps": 0.1, + "omp_min_sps": 0.025, } ) From 91edc001961480b51e32b2c8b8d9f68590eedc82 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 20:53:39 +0100 Subject: [PATCH 07/30] Legacy mode --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index e746883259..fad744d143 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,7 +21,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 0.5}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6b6aba892e..66fe660918 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -600,7 +600,7 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.975, 1.025], - "omp_min_sps": 0.1, + "omp_min_sps": 0.05, } ) From 5ffa8911ccfe72d1d532462fde2aa2f7a571c855 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 21:48:24 +0100 Subject: [PATCH 08/30] Adding a legacy mode for the clustering, similar as circus 1 --- .../sorters/internal/spyking_circus2.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index fad744d143..955f228ad5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,12 +21,11 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "energy", "threshold": 0.25}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "localization": {}, - "clustering": {}, + "clustering": {"legacy" : False}, "matching": {}, "apply_preprocessing": True, "shared_memory": True, @@ -109,8 +108,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["tmp_folder"] = sorter_output_folder / "clustering" clustering_params.update({"noise_levels": noise_levels}) + if "legacy" in clustering_params: + legacy = clustering_params["legacy"] + else: + legacy = False + + if legacy: + clustering_method = "circus" + else: + clustering_method = "random_projections" + labels, peak_labels = find_cluster_from_peaks( - recording_f, selected_peaks, method="circus", method_kwargs=clustering_params + recording_f, selected_peaks, method=clustering_method, method_kwargs=clustering_params ) ## We get the labels for our peaks From bd3e3b04f73828323e9ceb9fd6fefca7f6a103f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Nov 2023 21:09:38 +0000 Subject: [PATCH 09/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 12 +++++++++--- .../sortingcomponents/clustering/circus.py | 12 +++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 955f228ad5..29de8a4b0d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,11 +21,17 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "energy", "threshold": 0.25}, + "waveforms": { + "max_spikes_per_unit": 200, + "overwrite": True, + "sparse": True, + "method": "energy", + "threshold": 0.25, + }, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "clustering": {"legacy" : False}, + "clustering": {"legacy": False}, "matching": {}, "apply_preprocessing": True, "shared_memory": True, @@ -178,4 +184,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorting_folder) - return sorting \ No newline at end of file + return sorting diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index ef733224bd..d7d94f73cc 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -67,6 +67,7 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs + class CircusClustering: """ hdbscan clustering on peak_locations previously done by localize_peaks() @@ -135,7 +136,9 @@ def main_function(cls, recording, peaks, params): # SVD for time compression few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) - few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"]) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"] + ) wfs = few_wfs[:, :, 0] tsvd = TruncatedSVD(params["n_svd"][0]) @@ -185,12 +188,12 @@ def main_function(cls, recording, peaks, params): peak_labels = -1 * np.ones(len(peaks), dtype=int) nb_clusters = 0 - for c in np.unique(peaks['channel_index']): - mask = peaks['channel_index'] == c + for c in np.unique(peaks["channel_index"]): + mask = peaks["channel_index"] == c tsvd = TruncatedSVD(params["n_svd"][1]) sub_data = all_pc_data[mask] hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) - clustering = hdbscan.hdbscan(hdbscan_data, **d['hdbscan_kwargs']) + clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) local_labels = clustering[0] valid_clusters = local_labels > -1 if np.sum(valid_clusters) > 0: @@ -289,4 +292,3 @@ def main_function(cls, recording, peaks, params): print("We kept %d non-duplicated clusters..." % len(labels)) return labels, peak_labels - From d970f8a27bfb00ab1180d2ca06d885496f686f52 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 22:14:21 +0100 Subject: [PATCH 10/30] Minor edits --- .../sorters/internal/spyking_circus2.py | 11 +++++++++-- .../sortingcomponents/clustering/circus.py | 5 +++-- .../clustering/random_projections.py | 5 +++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 955f228ad5..c94c49a4bf 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -109,7 +109,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update({"noise_levels": noise_levels}) if "legacy" in clustering_params: - legacy = clustering_params["legacy"] + legacy = clustering_params.pop("legacy") else: legacy = False @@ -147,7 +147,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): waveforms_folder = sorter_output_folder / "waveforms" we = extract_waveforms( - recording_f, sorting, waveforms_folder, mode=mode, **waveforms_params, return_scaled=False + recording_f, + sorting, + waveforms_folder, + return_scaled=False, + precompute_template=["median"], + mode=mode, + **waveforms_params + ) ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index ef733224bd..8a511b5fdc 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -254,10 +254,11 @@ def main_function(cls, recording, peaks, params): recording, sorting, waveform_folder, - **params["job_kwargs"], - **params["waveforms"], return_scaled=False, + precompute_template=["median"], mode=mode, + **params["job_kwargs"], + **params["waveforms"] ) cleaning_matching_params = params["job_kwargs"].copy() diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 72acd49f4f..3053bfbdd0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -219,10 +219,11 @@ def sigmoid(x, L, x0, k, b): recording, sorting, waveform_folder, - **params["job_kwargs"], - **params["waveforms"], return_scaled=False, mode=mode, + precompute_template=["median"], + **params["job_kwargs"], + **params["waveforms"], ) cleaning_matching_params = params["job_kwargs"].copy() From 954cb450d7a84567e34de65a5d3eba0d8f02b5b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Nov 2023 21:14:52 +0000 Subject: [PATCH 11/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 11 +++++------ .../sortingcomponents/clustering/circus.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f12d5c4fb9..28b9652a3a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -153,14 +153,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): waveforms_folder = sorter_output_folder / "waveforms" we = extract_waveforms( - recording_f, - sorting, - waveforms_folder, + recording_f, + sorting, + waveforms_folder, return_scaled=False, precompute_template=["median"], - mode=mode, - **waveforms_params - + mode=mode, + **waveforms_params, ) ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index d20e33d244..24f4b29718 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -261,7 +261,7 @@ def main_function(cls, recording, peaks, params): precompute_template=["median"], mode=mode, **params["job_kwargs"], - **params["waveforms"] + **params["waveforms"], ) cleaning_matching_params = params["job_kwargs"].copy() From 0343ac2ccc3c26b0501d001505792067c765e2e7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 22:44:55 +0100 Subject: [PATCH 12/30] Patch for hdbscan --- src/spikeinterface/sortingcomponents/clustering/circus.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index d20e33d244..44cddc4f70 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -193,8 +193,11 @@ def main_function(cls, recording, peaks, params): tsvd = TruncatedSVD(params["n_svd"][1]) sub_data = all_pc_data[mask] hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) - clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) - local_labels = clustering[0] + try: + clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) + local_labels = clustering[0] + except Exception: + local_labels = -1 * np.ones(len(hdbscan_data)) valid_clusters = local_labels > -1 if np.sum(valid_clusters) > 0: local_labels[valid_clusters] += nb_clusters From d31d402f880c6dc2224332acd622319f6b98a620 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 15 Nov 2023 14:17:12 +0100 Subject: [PATCH 13/30] Still a gap --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 + src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- src/spikeinterface/sortingcomponents/matching/circus.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 28b9652a3a..b344606c52 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -74,6 +74,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = zscore(recording_f, dtype="float32") noise_levels = np.ones(num_channels, dtype=np.float32) + ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index dd36135b8d..59983cbe03 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -84,7 +84,7 @@ class CircusClustering: "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, "selection_method": "closest_to_centroid", - "n_svd": [6, 6], + "n_svd": [5, 10], "ms_before": 1, "ms_after": 1, "random_seed": 42, diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ea36b75847..6278067987 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -508,8 +508,8 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ _default_params = { - "amplitudes": [0.6, 2], - "omp_min_sps": 0.1, + "amplitudes": [0.75, 1.25], + "omp_min_sps": 0.05, "waveform_extractor": None, "random_chunk_kwargs": {}, "noise_levels": None, From b4f760e0f89c5a1db80efff8edec36ec9c47c111 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:17:34 +0000 Subject: [PATCH 14/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b344606c52..28b9652a3a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -74,7 +74,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = zscore(recording_f, dtype="float32") noise_levels = np.ones(num_channels, dtype=np.float32) - ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) From baf0280da90e05a51852d132fbe815db7a769b63 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 15 Nov 2023 15:14:02 +0100 Subject: [PATCH 15/30] Cleaning --- .../clustering/random_projections.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 3053bfbdd0..7d5b58551b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -113,32 +113,12 @@ def main_function(cls, recording, peaks, params): nafter = int(params["ms_after"] * fs / 1000) nsamples = nbefore + nafter - import scipy - - x = np.random.randn(100, nsamples, num_chans).astype(np.float32) - x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1) - - ptps = np.ptp(x, axis=1) - a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000)) - ydata = np.cumsum(a) / a.sum() - xdata = b[1:] - - from scipy.optimize import curve_fit - - def sigmoid(x, L, x0, k, b): - y = L / (1 + np.exp(-k * (x - x0))) + b - return y - - p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess - popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) - node3 = RandomProjectionsFeature( recording, parents=[node0, node2], return_output=True, projections=projections, radius_um=params["radius_um"], - sigmoid=None, sparse=True, ) From 50853f0a2d90dfa116d5ab547272edaaa7099198 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Nov 2023 10:50:24 +0100 Subject: [PATCH 16/30] WIP on the peeler --- .../sorters/internal/spyking_circus2.py | 1 - .../sortingcomponents/clustering/circus.py | 9 +---- .../clustering/clustering_tools.py | 6 +--- .../clustering/random_projections.py | 8 +---- .../sortingcomponents/matching/circus.py | 34 ++++++++----------- 5 files changed, 18 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 37478b1aa4..c690b4228d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -167,7 +167,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_params = params["matching"].copy() matching_params["waveform_extractor"] = we - matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 59983cbe03..0905e61169 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -88,7 +88,6 @@ class CircusClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "noise_levels": None, "shared_memory": True, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, @@ -118,12 +117,6 @@ def main_function(cls, recording, peaks, params): nafter = int(ms_after * fs / 1000.0) num_samples = nbefore + nafter num_chans = recording.get_num_channels() - - if d["noise_levels"] is None: - noise_levels = get_noise_levels(recording, return_scaled=False) - else: - noise_levels = d["noise_levels"] - np.random.seed(d["random_seed"]) if params["tmp_folder"] is None: @@ -280,7 +273,7 @@ def main_function(cls, recording, peaks, params): cleaning_params["tmp_folder"] = tmp_folder labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + we, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) del we, sorting diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 66fe660918..72da52d7a0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -534,7 +534,6 @@ def remove_duplicates( def remove_duplicates_via_matching( waveform_extractor, - noise_levels, peak_labels, method_kwargs={}, job_kwargs={}, @@ -542,7 +541,6 @@ def remove_duplicates_via_matching( method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates - from spikeinterface import get_noise_levels from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms @@ -598,9 +596,7 @@ def remove_duplicates_via_matching( local_params.update( { "waveform_extractor": waveform_extractor, - "noise_levels": noise_levels, - "amplitudes": [0.975, 1.025], - "omp_min_sps": 0.05, + "amplitudes": [0.975, 1.025] } ) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 7d5b58551b..fee35709d7 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -77,12 +77,6 @@ def main_function(cls, recording, peaks, params): nafter = int(params["ms_after"] * fs / 1000.0) num_samples = nbefore + nafter num_chans = recording.get_num_channels() - - if d["noise_levels"] is None: - noise_levels = get_noise_levels(recording, return_scaled=False) - else: - noise_levels = d["noise_levels"] - np.random.seed(d["random_seed"]) if params["tmp_folder"] is None: @@ -219,7 +213,7 @@ def main_function(cls, recording, peaks, params): cleaning_params["tmp_folder"] = tmp_folder labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + we, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) del we, sorting diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 6278067987..8bc4b34806 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -496,9 +496,6 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float Stopping criteria of the OMP algorithm, in percentage of the norm - noise_levels: array - The noise levels, for every channels. If None, they will be automatically - computed random_chunk_kwargs: dict Parameters for computing noise levels, if not provided (sub optimal) sparse_kwargs: dict @@ -509,10 +506,9 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): _default_params = { "amplitudes": [0.75, 1.25], - "omp_min_sps": 0.05, + "omp_min_sps": 1e-4, "waveform_extractor": None, "random_chunk_kwargs": {}, - "noise_levels": None, "rank": 5, "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], @@ -612,10 +608,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() d["vicinity"] *= d["num_samples"] - if d["noise_levels"] is None: - print("CircusOMPPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - if "templates" not in d: d = cls._prepare_templates(d) else: @@ -638,10 +630,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) - omp_min_sps = d["omp_min_sps"] - # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) - d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) - + d["stop_criteria"] = d["omp_min_sps"] return d @classmethod @@ -675,7 +664,7 @@ def main_function(cls, traces, d): neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"][:, np.newaxis] + stop_criteria = d["stop_criteria"] vicinity = d["vicinity"] rank = d["rank"] @@ -717,13 +706,15 @@ def main_function(cls, traces, d): neighbors = {} cached_overlaps = {} - is_valid = scalar_products > stop_criteria all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) + new_error = np.linalg.norm(scalar_products) + delta_error = np.inf - while np.any(is_valid): - best_amplitude_ind = scalar_products[is_valid].argmax() - best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) + while delta_error > stop_criteria: + + best_amplitude_ind = scalar_products.argmax() + best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) if num_selection > 0: delta_t = selection[1] - peak_index @@ -818,7 +809,12 @@ def main_function(cls, traces, d): to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add - is_valid = scalar_products > stop_criteria + previous_error = new_error + new_error = np.linalg.norm(scalar_products) + if previous_error != 0: + delta_error = np.abs(new_error / previous_error - 1) + else: + delta_error = 0 is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) From df226b651db4685162115fc5f5c2a2cbb3f8e57f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 09:51:53 +0000 Subject: [PATCH 17/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/clustering_tools.py | 7 +------ src/spikeinterface/sortingcomponents/matching/circus.py | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 72da52d7a0..052a596c63 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,12 +593,7 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update( - { - "waveform_extractor": waveform_extractor, - "amplitudes": [0.975, 1.025] - } - ) + local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025]}) spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8bc4b34806..21de446162 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -712,7 +712,6 @@ def main_function(cls, traces, d): delta_error = np.inf while delta_error > stop_criteria: - best_amplitude_ind = scalar_products.argmax() best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) From 919a40494ab3635efbc3f8df31ffd3d7308f1b98 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Nov 2023 11:54:04 +0100 Subject: [PATCH 18/30] Fix for cleaning via matching --- .../sortingcomponents/matching/circus.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8bc4b34806..091b5d32a4 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -676,13 +676,13 @@ def main_function(cls, traces, d): # Filter using overlap-and-add convolution if len(ignored_ids) > 0: - mask = ~np.isin(np.arange(num_templates), ignored_ids) - spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"][:, mask, :] + not_ignored = ~np.isin(np.arange(num_templates), ignored_ids) + spatially_filtered_data = np.matmul(d["spatial"][:, not_ignored, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"][:, not_ignored, :] objective_by_rank = scipy.signal.oaconvolve( - scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid" + scaled_filtered_data, d["temporal"][:, not_ignored, :], axes=2, mode="valid" ) - scalar_products[mask] += np.sum(objective_by_rank, axis=0) + scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0) scalar_products[ignored_ids] = -np.inf else: spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) @@ -693,7 +693,6 @@ def main_function(cls, traces, d): num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) M = np.zeros((num_templates, num_templates), dtype=np.float32) @@ -708,7 +707,10 @@ def main_function(cls, traces, d): all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) - new_error = np.linalg.norm(scalar_products) + if len(ignored_ids) > 0: + new_error = np.linalg.norm(scalar_products[not_ignored]) + else: + new_error = np.linalg.norm(scalar_products) delta_error = np.inf while delta_error > stop_criteria: @@ -810,11 +812,11 @@ def main_function(cls, traces, d): scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add previous_error = new_error - new_error = np.linalg.norm(scalar_products) - if previous_error != 0: - delta_error = np.abs(new_error / previous_error - 1) + if len(ignored_ids) > 0: + new_error = np.linalg.norm(scalar_products[not_ignored]) else: - delta_error = 0 + new_error = np.linalg.norm(scalar_products) + delta_error = np.abs(new_error / previous_error - 1) is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) From ee239e7a0ba940e74a6086c1cc0376d9f740ed0b Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 16 Nov 2023 14:34:02 +0100 Subject: [PATCH 19/30] Closing the gap --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 + src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 3 +-- src/spikeinterface/sortingcomponents/matching/circus.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index c690b4228d..1d4f04a382 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -71,6 +71,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = common_reference(recording_f) else: recording_f = recording + recording_f.annotate(is_filtered=True) # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 0905e61169..6d29fe3b37 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -12,7 +12,7 @@ HAVE_HDBSCAN = False import random, string, os -from spikeinterface.core import get_global_tmp_folder, get_noise_levels, get_channel_distances +from spikeinterface.core import get_global_tmp_folder, get_channel_distances from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index fee35709d7..dcb84cb6ff 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -12,7 +12,7 @@ HAVE_HDBSCAN = False import random, string, os -from spikeinterface.core import get_global_tmp_folder, get_noise_levels, get_channel_distances, get_random_data_chunks +from spikeinterface.core import get_global_tmp_folder, get_channel_distances, get_random_data_chunks from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip @@ -48,7 +48,6 @@ class RandomProjectionClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, "shared_memory": True, "tmp_folder": None, diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 39decc2380..839fe1dbd2 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -505,8 +505,8 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ _default_params = { - "amplitudes": [0.75, 1.25], - "omp_min_sps": 1e-4, + "amplitudes": [0.6, 1.4], + "omp_min_sps": 1e-5, "waveform_extractor": None, "random_chunk_kwargs": {}, "rank": 5, From 9054f7b79ea7cd742bd383f9921ecafe58fd02f5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Nov 2023 14:46:01 +0100 Subject: [PATCH 20/30] Speeding up merging --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 052a596c63..1167541ebf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,7 +593,7 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025]}) + local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "omp_min_sps" : 1e-3}) spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) From b3852eabe99770b5ca12f8b3ba2276c544041ede Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 13:46:23 +0000 Subject: [PATCH 21/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 1167541ebf..629b0b13ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,7 +593,7 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "omp_min_sps" : 1e-3}) + local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "omp_min_sps": 1e-3}) spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) From ddb7eb964ca7d975d06ecb8f2e9a02b226ab924c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Nov 2023 15:31:23 +0100 Subject: [PATCH 22/30] WIP --- src/spikeinterface/sortingcomponents/matching/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 839fe1dbd2..77bbf3a73b 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -506,7 +506,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): _default_params = { "amplitudes": [0.6, 1.4], - "omp_min_sps": 1e-5, + "omp_min_sps": 5e-5, "waveform_extractor": None, "random_chunk_kwargs": {}, "rank": 5, From 3b514fe3d32d511556b65c50b965e90dbacf8f3a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 09:59:22 +0100 Subject: [PATCH 23/30] Documentation --- .../sortingcomponents/matching/circus.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 77bbf3a73b..b0311e10bd 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -495,12 +495,15 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): amplitude: tuple (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float - Stopping criteria of the OMP algorithm, in percentage of the norm - random_chunk_kwargs: dict - Parameters for computing noise levels, if not provided (sub optimal) + Stopping criteria of the OMP algorithm, as relative error sparse_kwargs: dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. + rank: int + Number of components used internally by the SVD (default 5) + vicinity: int + Size of the area surrounding a spike to perform modification (expressed in terms + of template temporal width) ----- """ @@ -508,7 +511,6 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): "amplitudes": [0.6, 1.4], "omp_min_sps": 5e-5, "waveform_extractor": None, - "random_chunk_kwargs": {}, "rank": 5, "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], From 47643f7555a9c950d63c62085ae1c5f830e7898d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 13:00:09 +0100 Subject: [PATCH 24/30] Update src/spikeinterface/sortingcomponents/clustering/circus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 6d29fe3b37..401ed58871 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -95,7 +95,7 @@ class CircusClustering: @classmethod def main_function(cls, recording, peaks, params): - assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" + assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" if "n_jobs" in params["job_kwargs"]: if params["job_kwargs"]["n_jobs"] == -1: From b78433068d9a464485e4397fb7f0272d0d4b9093 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 13:00:19 +0100 Subject: [PATCH 25/30] Update src/spikeinterface/sortingcomponents/matching/circus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/matching/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index b0311e10bd..d23095d838 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -500,7 +500,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. rank: int - Number of components used internally by the SVD (default 5) + Number of components used internally by the SVD vicinity: int Size of the area surrounding a spike to perform modification (expressed in terms of template temporal width) From 0c65b2ca8895895b1a89ddc9a8eb9688b4b42c9a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 13:00:25 +0100 Subject: [PATCH 26/30] Update src/spikeinterface/sortingcomponents/matching/circus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/matching/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index d23095d838..cfdca6f612 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -499,7 +499,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): sparse_kwargs: dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. - rank: int + rank: int, default: 5 Number of components used internally by the SVD vicinity: int Size of the area surrounding a spike to perform modification (expressed in terms From 3def285b8af6600337abde9374ec67dcea685048 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 13:00:34 +0100 Subject: [PATCH 27/30] Update src/spikeinterface/sortingcomponents/clustering/circus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 401ed58871..47c5a1e58f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -34,7 +34,7 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): """ - Helper function to extractor waveforms at max channel from a peak list + Helper function to extract waveforms at the max channel from a peak list """ From eb00e8bdba72087d540d0c1fd8ff01d5eb0a1e97 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 20 Nov 2023 12:16:36 +0100 Subject: [PATCH 28/30] Update src/spikeinterface/comparison/groundtruthstudy.py Co-authored-by: Garcia Samuel --- src/spikeinterface/comparison/groundtruthstudy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index adc2898071..23d13c0afe 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -21,7 +21,7 @@ # This is to separate names when the key are tuples when saving folders -_key_separator = "--" +_key_separator = "_##_" class GroundTruthStudy: From 231f352bda0212f412c44325a8b001bb7c1789d2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 Nov 2023 12:14:49 +0100 Subject: [PATCH 29/30] Move extract_waveforms_to_single_buffer to tools.py --- .../sorters/internal/tridesclous2.py | 46 +++----------- .../sortingcomponents/clustering/circus.py | 38 +---------- src/spikeinterface/sortingcomponents/tools.py | 63 ++++++++++++++----- 3 files changed, 55 insertions(+), 92 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index eb2ddc922d..9e67bbf4f4 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -9,12 +9,14 @@ NumpySorting, get_channel_distances, ) -from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer + from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel + import numpy as np import pickle @@ -115,9 +117,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("We kept %d peaks for clustering" % len(peaks)) + ms_before = params["waveforms"]["ms_before"] + ms_after = params["waveforms"]["ms_after"] + # SVD for time compression few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) - few_wfs = extract_waveform_at_max_channel(recording, few_peaks, **job_kwargs) + few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs) wfs = few_wfs[:, :, 0] tsvd = TruncatedSVD(params["svd"]["n_components"]) @@ -129,8 +134,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): with open(model_folder / "pca_model.pkl", "wb") as f: pickle.dump(tsvd, f) - ms_before = params["waveforms"]["ms_before"] - ms_after = params["waveforms"]["ms_after"] model_params = { "ms_before": ms_before, "ms_after": ms_after, @@ -321,37 +324,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): return sorting -def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): - """ - Helper function to extractor waveforms at max channel from a peak list - - - """ - n = rec.get_num_channels() - unit_ids = np.arange(n, dtype="int64") - sparsity_mask = np.eye(n, dtype="bool") - - spikes = np.zeros( - peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - ) - spikes["sample_index"] = peaks["sample_index"] - spikes["unit_index"] = peaks["channel_index"] - spikes["segment_index"] = peaks["segment_index"] - - nbefore = int(ms_before * rec.sampling_frequency / 1000.0) - nafter = int(ms_after * rec.sampling_frequency / 1000.0) - - all_wfs = extract_waveforms_to_single_buffer( - rec, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - sparsity_mask=sparsity_mask, - copy=True, - **job_kwargs, - ) - - return all_wfs + diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 47c5a1e58f..4dbd88c411 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -18,8 +18,6 @@ from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms -from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks -from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from sklearn.decomposition import TruncatedSVD @@ -30,43 +28,9 @@ ExtractSparseWaveforms, PeakRetriever, ) +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel -def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): - """ - Helper function to extract waveforms at the max channel from a peak list - - - """ - n = rec.get_num_channels() - unit_ids = np.arange(n, dtype="int64") - sparsity_mask = np.eye(n, dtype="bool") - - spikes = np.zeros( - peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - ) - spikes["sample_index"] = peaks["sample_index"] - spikes["unit_index"] = peaks["channel_index"] - spikes["segment_index"] = peaks["segment_index"] - - nbefore = int(ms_before * rec.sampling_frequency / 1000.0) - nafter = int(ms_after * rec.sampling_frequency / 1000.0) - - all_wfs = extract_waveforms_to_single_buffer( - rec, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - sparsity_mask=sparsity_mask, - copy=True, - **job_kwargs, - ) - - return all_wfs - class CircusClustering: """ diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index cd9226d5e8..1e8a933990 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -1,7 +1,7 @@ import numpy as np from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever - +from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer def make_multi_method_doc(methods, ident=" "): doc = "" @@ -18,23 +18,52 @@ def make_multi_method_doc(methods, ident=" "): return doc -def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): - # TODO for Pierre: this function is really inefficient because it runs a full pipeline only for a few - # spikes, which means that all traces need to be accesses! Please find a better way - nb_peaks = min(len(peaks), nb_peaks) - idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) - peak_retriever = PeakRetriever(recording, peaks[idx]) - - sparse_waveforms = ExtractSparseWaveforms( - recording, - parents=[peak_retriever], - ms_before=ms_before, - ms_after=ms_after, - return_output=True, - radius_um=5, +def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): + """ + Helper function to extract waveforms at the max channel from a peak list + + + """ + n = rec.get_num_channels() + unit_ids = np.arange(n, dtype="int64") + sparsity_mask = np.eye(n, dtype="bool") + + spikes = np.zeros( + peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + ) + spikes["sample_index"] = peaks["sample_index"] + spikes["unit_index"] = peaks["channel_index"] + spikes["segment_index"] = peaks["segment_index"] + + nbefore = int(ms_before * rec.sampling_frequency / 1000.0) + nafter = int(ms_after * rec.sampling_frequency / 1000.0) + + all_wfs = extract_waveforms_to_single_buffer( + rec, + spikes, + unit_ids, + nbefore, + nafter, + mode="shared_memory", + return_scaled=False, + sparsity_mask=sparsity_mask, + copy=True, + **job_kwargs, ) - nbefore = sparse_waveforms.nbefore - waveforms = run_node_pipeline(recording, [peak_retriever, sparse_waveforms], job_kwargs=job_kwargs) + return all_wfs + + +def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): + + if peaks.size > nb_peaks: + idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) + some_peaks = peaks[idx] + else: + some_peaks = peaks + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + + waveforms = extract_waveform_at_max_channel(recording, some_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs) prototype = np.median(waveforms[:, :, 0] / (waveforms[:, nbefore, 0][:, np.newaxis]), axis=0) return prototype From 20cc70ef9b6e4fb6a18e6942fb2449000105f82e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Nov 2023 11:16:39 +0000 Subject: [PATCH 30/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/tridesclous2.py | 7 +++---- src/spikeinterface/sortingcomponents/clustering/circus.py | 1 - src/spikeinterface/sortingcomponents/tools.py | 6 ++++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 9e67bbf4f4..6d53414c9f 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -122,7 +122,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # SVD for time compression few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) - few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) wfs = few_wfs[:, :, 0] tsvd = TruncatedSVD(params["svd"]["n_components"]) @@ -322,6 +324,3 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorter_output_folder / "sorting") return sorting - - - diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 4dbd88c411..238b16260c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -31,7 +31,6 @@ from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel - class CircusClustering: """ hdbscan clustering on peak_locations previously done by localize_peaks() diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1e8a933990..328e3b715d 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -3,6 +3,7 @@ from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer + def make_multi_method_doc(methods, ident=" "): doc = "" @@ -55,7 +56,6 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): - if peaks.size > nb_peaks: idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) some_peaks = peaks[idx] @@ -64,6 +64,8 @@ def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0 nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - waveforms = extract_waveform_at_max_channel(recording, some_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs) + waveforms = extract_waveform_at_max_channel( + recording, some_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) prototype = np.median(waveforms[:, :, 0] / (waveforms[:, nbefore, 0][:, np.newaxis]), axis=0) return prototype