From fa82f108a1a15e4eeb347a9c86294a65960bbd6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 08:20:50 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 1 - .../clustering/random_projections.py | 20 +++++++++++++------ .../sortingcomponents/features_from_peaks.py | 4 ++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 780e6a14aa..a16b642dd5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -88,7 +88,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks( peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index b1dab9b27c..72acd49f4f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -20,7 +20,12 @@ from spikeinterface.core import extract_waveforms from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature -from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, ExtractSparseWaveforms, PeakRetriever +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) class RandomProjectionClustering: @@ -43,7 +48,7 @@ class RandomProjectionClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "noise_levels" : None, + "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, "shared_memory": True, "tmp_folder": None, @@ -86,13 +91,16 @@ def main_function(cls, recording, peaks, params): else: tmp_folder = Path(params["tmp_folder"]).absolute() - tmp_folder.mkdir(parents=True, exist_ok=True) node0 = PeakRetriever(recording, peaks) node1 = ExtractSparseWaveforms( - recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"], - radius_um=params['radius_um'] + recording, + parents=[node0], + return_output=False, + ms_before=params["ms_before"], + ms_after=params["ms_after"], + radius_um=params["radius_um"], ) node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) @@ -131,7 +139,7 @@ def sigmoid(x, L, x0, k, b): projections=projections, radius_um=params["radius_um"], sigmoid=None, - sparse=True + sparse=True, ) pipeline_nodes = [node0, node1, node2, node3] diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 3ca53b05fb..06d22181cb 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -186,7 +186,7 @@ def __init__( projections=None, sigmoid=None, radius_um=None, - sparse=True + sparse=True, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) @@ -216,7 +216,7 @@ def compute(self, traces, peaks, waveforms): (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] if self.sparse: - wf_ptp = np.ptp(waveforms[idx][:, :, :len(chan_inds)], axis=1) + wf_ptp = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1) else: wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1)