Skip to content

Commit

Permalink
NumpySorting.from_peaks make unit_ids mandatory.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Dec 8, 2023
1 parent 8e32955 commit 8574de9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
8 changes: 3 additions & 5 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None)
return sorting

@staticmethod
def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting":
def from_peaks(peaks, sampling_frequency, unit_ids) -> "NumpySorting":
"""
Construct a sorting from peaks returned by 'detect_peaks()' function.
The unit ids correspond to the recording channel ids and spike trains are the
Expand All @@ -314,7 +314,8 @@ def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting":
Peaks array as returned by the 'detect_peaks()' function
sampling_frequency : float
the sampling frequency in Hz
unit_ids: np.array
The unit_ids vector which is generally the channel_ids but can be different.
Returns
-------
sorting
Expand All @@ -325,9 +326,6 @@ def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting":
spikes["unit_index"] = peaks["channel_index"]
spikes["segment_index"] = peaks["segment_index"]

if unit_ids is None:
unit_ids = np.unique(peaks["channel_index"])

sorting = NumpySorting(spikes, sampling_frequency, unit_ids)

return sorting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def fit(
peaks = select_peaks(peaks, **peak_selection_params) # How to select n_peaks

# Creates a numpy sorting object where the spike times are the peak times and the unit ids are the peak channel
sorting = NumpySorting.from_peaks(peaks, sampling_frequency=recording.sampling_frequency)
sorting = NumpySorting.from_peaks(peaks, recording.sampling_frequency, recording.channel_ids)
# Create a waveform extractor
we = extract_waveforms(
recording,
Expand Down

0 comments on commit 8574de9

Please sign in to comment.