diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 5ccae819c4..fd00d18cd3 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -11,7 +11,11 @@ from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion -from spikeinterface.sortingcomponents.tools import cache_preprocessing, get_prototype_and_waveforms, get_shuffled_recording_slices +from spikeinterface.sortingcomponents.tools import ( + cache_preprocessing, + get_prototype_and_waveforms, + get_shuffled_recording_slices, +) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.core.sortinganalyzer import create_sorting_analyzer @@ -208,8 +212,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ms_after=ms_after, seed=params["seed"], return_waveforms=True, - **detection_params, - **job_kwargs + **detection_params, + **job_kwargs, ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index e0bbf66af5..b6ef240f80 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -314,7 +314,9 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ms_before = 1.0 ms_after = 1.0 - prototype = get_prototype_and_waveforms(recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs) + prototype = get_prototype_and_waveforms( + recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) peaks_local_mf_filtering = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 439e71f0fa..3c968f5da4 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -68,7 +68,10 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs -def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs): + +def get_prototype_and_waveforms( + recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs +): """ Function to extract a prototype waveform from a peak list or from a peak detection. Note that in case of a peak detection, the detection stops as soon as n_peaks are detected. @@ -99,7 +102,7 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0 waveforms : numpy.array, optional The extracted waveforms, returned if return_waveforms is True. """ - + seed = seed if seed else None rng = np.random.default_rng(seed=seed) @@ -110,6 +113,7 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0 if peaks is None: from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ExtractSparseWaveforms + node = ExtractSparseWaveforms( recording, parents=None, @@ -123,16 +127,20 @@ def get_prototype_and_waveforms(recording, n_peaks=5000, peaks=None, ms_before=0 recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) res = detect_peaks( - recording, pipeline_nodes=pipeline_nodes, - skip_after_n_peaks=n_peaks, - recording_slices=recording_slices, - **detection_kwargs, - **job_kwargs, + recording, + pipeline_nodes=pipeline_nodes, + skip_after_n_peaks=n_peaks, + recording_slices=recording_slices, + **detection_kwargs, + **job_kwargs, ) waveforms = res[1] else: from spikeinterface.sortingcomponents.peak_selection import select_peaks - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed) + + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed + ) waveforms = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) @@ -229,4 +237,4 @@ def get_shuffled_recording_slices(recording, seed=None, **job_kwargs): rng = np.random.RandomState(seed) recording_slices = rng.permutation(recording_slices) - return recording_slices \ No newline at end of file + return recording_slices