diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2c297662f4..a16b642dd5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -67,6 +67,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") + noise_levels = np.ones(num_channels, dtype=np.float32) ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() @@ -87,7 +88,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - noise_levels = np.ones(num_channels, dtype=np.float32) selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks( peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params @@ -107,6 +107,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params.update({"noise_levels": noise_levels}) labels, peak_labels = find_cluster_from_peaks( recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index a81458d7a8..72acd49f4f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -20,7 +20,12 @@ from spikeinterface.core import extract_waveforms from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature -from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) class RandomProjectionClustering: @@ -43,7 +48,8 @@ class RandomProjectionClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "smoothing_kwargs": {"window_length_ms": 1}, + "noise_levels": None, + "smoothing_kwargs": {"window_length_ms": 0.25}, "shared_memory": True, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, @@ -72,7 +78,10 @@ def main_function(cls, recording, peaks, params): num_samples = nbefore + nafter num_chans = recording.get_num_channels() - noise_levels = get_noise_levels(recording, return_scaled=False) + if d["noise_levels"] is None: + noise_levels = get_noise_levels(recording, return_scaled=False) + else: + noise_levels = d["noise_levels"] np.random.seed(d["random_seed"]) @@ -82,10 +91,16 @@ def main_function(cls, recording, peaks, params): else: tmp_folder = Path(params["tmp_folder"]).absolute() - ### Then we extract the SVD features + tmp_folder.mkdir(parents=True, exist_ok=True) + node0 = PeakRetriever(recording, peaks) - node1 = ExtractDenseWaveforms( - recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"] + node1 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=params["ms_before"], + ms_after=params["ms_after"], + radius_um=params["radius_um"], ) node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) @@ -123,6 +138,8 @@ def sigmoid(x, L, x0, k, b): return_output=True, projections=projections, radius_um=params["radius_um"], + sigmoid=None, + sparse=True, ) pipeline_nodes = [node0, node1, node2, node3] @@ -136,6 +153,18 @@ def sigmoid(x, L, x0, k, b): clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) peak_labels = clustering[0] + # peak_labels = -1 * np.ones(len(peaks), dtype=int) + # nb_clusters = 0 + # for c in np.unique(peaks['channel_index']): + # mask = peaks['channel_index'] == c + # clustering = hdbscan.hdbscan(hdbscan_data[mask], **d['hdbscan_kwargs']) + # local_labels = clustering[0] + # valid_clusters = local_labels > -1 + # if np.sum(valid_clusters) > 0: + # local_labels[valid_clusters] += nb_clusters + # peak_labels[mask] = local_labels + # nb_clusters += len(np.unique(local_labels[valid_clusters])) + labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -174,15 +203,6 @@ def sigmoid(x, L, x0, k, b): if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) - # create a tmp folder - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]) - - tmp_folder.mkdir(parents=True, exist_ok=True) - sorting_folder = tmp_folder / "sorting" unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index b534c2356d..06d22181cb 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -186,6 +186,7 @@ def __init__( projections=None, sigmoid=None, radius_um=None, + sparse=True, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) @@ -195,7 +196,8 @@ def __init__( self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance < radius_um self.radius_um = radius_um - self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um)) + self.sparse = sparse + self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -213,7 +215,10 @@ def compute(self, traces, peaks, waveforms): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] - wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) + if self.sparse: + wf_ptp = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1) + else: + wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) if self.sigmoid is not None: wf_ptp *= self._sigmoid(wf_ptp)