From 957861fd43861a124880c41c5cbcc8921db30889 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 14 Oct 2024 17:19:07 +0200 Subject: [PATCH 1/4] Sparsify the weights --- src/spikeinterface/sortingcomponents/peak_detection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index d608c5d105..51b3e4dc77 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -631,7 +631,7 @@ def __init__( weight_method={}, ): PeakDetector.__init__(self, recording, return_output=True) - + import scipy if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') @@ -664,7 +664,7 @@ def __init__( self.num_templates *= 2 self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) - + self.weights = scipy.sparse.csr_matrix(self.weights) random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs) conv_random_data = self.get_convolved_traces(random_data) medians = np.median(conv_random_data, axis=1) @@ -737,7 +737,7 @@ def get_convolved_traces(self, traces): import scipy.signal tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") - scalar_products = np.dot(self.weights, tmp) + scalar_products = self.weights.dot(tmp) return scalar_products From 5568e1a3cd98f6f9c77953c294fdd558c4457e6c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:23:19 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/peak_detection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 51b3e4dc77..2961f11981 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -632,6 +632,7 @@ def __init__( ): PeakDetector.__init__(self, recording, return_output=True) import scipy + if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') From 14278161efe44c5955dc2072a5354de73dcf6bb3 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 14 Oct 2024 23:36:43 +0200 Subject: [PATCH 3/4] Imports --- src/spikeinterface/sortingcomponents/peak_detection.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 2961f11981..d2d1afaafb 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -631,7 +631,7 @@ def __init__( weight_method={}, ): PeakDetector.__init__(self, recording, return_output=True) - import scipy + from scipy.sparse import csr_matrix if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') @@ -665,7 +665,7 @@ def __init__( self.num_templates *= 2 self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) - self.weights = scipy.sparse.csr_matrix(self.weights) + self.weights = csr_matrix(self.weights) random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs) conv_random_data = self.get_convolved_traces(random_data) medians = np.median(conv_random_data, axis=1) @@ -735,9 +735,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) def get_convolved_traces(self, traces): - import scipy.signal - - tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") + from scipy.signal import oaconvolve + tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") scalar_products = self.weights.dot(tmp) return scalar_products From b9f2cc803b295097a6cf4ae95eee5f82d5be222f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Oct 2024 21:37:04 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/peak_detection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index d2d1afaafb..134481289e 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -736,6 +736,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def get_convolved_traces(self, traces): from scipy.signal import oaconvolve + tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") scalar_products = self.weights.dot(tmp) return scalar_products