diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 7abb5596e8..9a04eeec37 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -302,7 +302,7 @@ def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None) return sorting @staticmethod - def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": + def from_peaks(peaks, sampling_frequency, unit_ids) -> "NumpySorting": """ Construct a sorting from peaks returned by 'detect_peaks()' function. The unit ids correspond to the recording channel ids and spike trains are the @@ -314,6 +314,8 @@ def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": Peaks array as returned by the 'detect_peaks()' function sampling_frequency : float the sampling frequency in Hz + unit_ids: np.array + The unit_ids vector which is generally the channel_ids but can be different. Returns ------- @@ -325,9 +327,6 @@ def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": spikes["unit_index"] = peaks["channel_index"] spikes["segment_index"] = peaks["segment_index"] - if unit_ids is None: - unit_ids = np.unique(peaks["channel_index"]) - sorting = NumpySorting(spikes, sampling_frequency, unit_ids) return sorting diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index a9e6126ccc..95a5166d10 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -135,7 +135,7 @@ def fit( peaks = select_peaks(peaks, **peak_selection_params) # How to select n_peaks # Creates a numpy sorting object where the spike times are the peak times and the unit ids are the peak channel - sorting = NumpySorting.from_peaks(peaks, sampling_frequency=recording.sampling_frequency) + sorting = NumpySorting.from_peaks(peaks, recording.sampling_frequency, recording.channel_ids) # Create a waveform extractor we = extract_waveforms( recording,