From 70673cee6a570350b142758d52fba215bec2cf46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:29:07 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- src/spikeinterface/sortingcomponents/clustering/circus.py | 8 ++++---- .../sortingcomponents/clustering/random_projections.py | 8 ++++---- src/spikeinterface/sortingcomponents/peak_detection.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index fcba32aeb9..c55a152d10 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -204,7 +204,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before if params["debug"]: - np.save(clustering_folder / 'prototype.npy', prototype) + np.save(clustering_folder / "prototype.npy", prototype) if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) @@ -450,4 +450,4 @@ def get_prototype(recording, n_peaks, ms_before, ms_after, **all_kwargs): waveforms = res[1] with np.errstate(divide="ignore", invalid="ignore"): prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) - return prototype \ No newline at end of file + return prototype diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 4a8ff78c95..cc8ce9e551 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -60,7 +60,7 @@ class CircusClustering: "n_svd": [5, 2], "ms_before": 0.5, "ms_after": 0.5, - "noise_threshold" : 4, + "noise_threshold": 4, "rank": 5, "noise_levels": None, "tmp_folder": None, @@ -241,10 +241,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): job_name=None, **job_kwargs, ) - - peak_snrs = np.abs(templates_array[:, nbefore, :])/templates_array_std[:, nbefore, :] + + peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :] best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) - best_snrs_ratio = (peak_snrs/params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] valid_templates = best_snrs_ratio > params["noise_threshold"] if d["rank"] is not None: diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 666cc9e747..8db7202deb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -53,7 +53,7 @@ class RandomProjectionClustering: "random_seed": 42, "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, - "noise_threshold" : 4, + "noise_threshold": 4, "tmp_folder": None, "verbose": True, } @@ -144,10 +144,10 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): job_name=None, **job_kwargs, ) - - peak_snrs = np.abs(templates_array[:, nbefore, :])/templates_array_std[:, nbefore, :] + + peak_snrs = np.abs(templates_array[:, nbefore, :]) / templates_array_std[:, nbefore, :] best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) - best_snrs_ratio = (peak_snrs/params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] valid_templates = best_snrs_ratio > params["noise_threshold"] templates = Templates( diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index a3b7f35f2b..8a22b35152 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -727,7 +727,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore peak_amplitude = traces[peak_sample_ind, peak_chan_ind] - + local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) local_peaks["sample_index"] = peak_sample_ind local_peaks["channel_index"] = peak_chan_ind