From 8574de9e58294f8fd8d976a627d2019fd177e19e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Dec 2023 11:44:39 +0100 Subject: [PATCH 1/2] NumpySorting.from_peaks make unit_ids mandatory. --- src/spikeinterface/core/numpyextractors.py | 8 +++----- .../sortingcomponents/waveforms/temporal_pca.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 7abb5596e8..76ee88b84c 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -302,7 +302,7 @@ def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None) return sorting @staticmethod - def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": + def from_peaks(peaks, sampling_frequency, unit_ids) -> "NumpySorting": """ Construct a sorting from peaks returned by 'detect_peaks()' function. The unit ids correspond to the recording channel ids and spike trains are the @@ -314,7 +314,8 @@ def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": Peaks array as returned by the 'detect_peaks()' function sampling_frequency : float the sampling frequency in Hz - + unit_ids: np.array + The unit_ids vector which is generally the channel_ids but can be different. Returns ------- sorting @@ -325,9 +326,6 @@ def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": spikes["unit_index"] = peaks["channel_index"] spikes["segment_index"] = peaks["segment_index"] - if unit_ids is None: - unit_ids = np.unique(peaks["channel_index"]) - sorting = NumpySorting(spikes, sampling_frequency, unit_ids) return sorting diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index a9e6126ccc..95a5166d10 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -135,7 +135,7 @@ def fit( peaks = select_peaks(peaks, **peak_selection_params) # How to select n_peaks # Creates a numpy sorting object where the spike times are the peak times and the unit ids are the peak channel - sorting = NumpySorting.from_peaks(peaks, sampling_frequency=recording.sampling_frequency) + sorting = NumpySorting.from_peaks(peaks, recording.sampling_frequency, recording.channel_ids) # Create a waveform extractor we = extract_waveforms( recording, From 6dac6c7d184aef08b97618ae6aab512c1bef0cd3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 12 Dec 2023 11:49:34 +0100 Subject: [PATCH 2/2] Update src/spikeinterface/core/numpyextractors.py --- src/spikeinterface/core/numpyextractors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 76ee88b84c..9a04eeec37 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -316,6 +316,7 @@ def from_peaks(peaks, sampling_frequency, unit_ids) -> "NumpySorting": the sampling frequency in Hz unit_ids: np.array The unit_ids vector which is generally the channel_ids but can be different. + Returns ------- sorting