diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index d608c5d105..134481289e 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -631,6 +631,7 @@ def __init__( weight_method={}, ): PeakDetector.__init__(self, recording, return_output=True) + from scipy.sparse import csr_matrix if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') @@ -664,7 +665,7 @@ def __init__( self.num_templates *= 2 self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) - + self.weights = csr_matrix(self.weights) random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs) conv_random_data = self.get_convolved_traces(random_data) medians = np.median(conv_random_data, axis=1) @@ -734,10 +735,10 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) def get_convolved_traces(self, traces): - import scipy.signal + from scipy.signal import oaconvolve - tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") - scalar_products = np.dot(self.weights, tmp) + tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") + scalar_products = self.weights.dot(tmp) return scalar_products