diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index f8093dd25f..f3e9a8221f 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -51,7 +51,9 @@ def test_whiten(): # test regularization with pytest.raises(AssertionError): - W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, regularize=True) + W, M = compute_whitening_matrix( + rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, regularize=True + ) # W must be sparse np.sum(W == 0) == 6 diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 5c5d167ba8..f3f0a1368b 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -77,15 +77,21 @@ def __init__( if dtype_.kind == "i": assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale" - + if W is not None: W = np.asarray(W) if M is not None: M = np.asarray(M) else: W, M = compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps, regularize=regularize, - regularize_kwargs=regularize_kwargs + recording, + mode, + random_chunk_kwargs, + apply_mean, + radius_um=radius_um, + eps=eps, + regularize=regularize, + regularize_kwargs=regularize_kwargs, ) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -142,8 +148,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): def compute_whitening_matrix( - recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, - regularize_kwargs=None + recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None ): """ Compute whitening matrix @@ -197,12 +202,13 @@ def compute_whitening_matrix( cov = cov / data.shape[0] else: import sklearn.covariance + if regularize_kwargs is None: regularize_kwargs = {} - regularize_kwargs['assume_centered'] = True + regularize_kwargs["assume_centered"] = True job_kwargs = get_global_job_kwargs() - if 'n_jobs' in job_kwargs and 'n_jobs' not in regularize_kwargs: - regularize_kwargs['n_jobs'] = job_kwargs['n_jobs'] + if "n_jobs" in job_kwargs and "n_jobs" not in regularize_kwargs: + regularize_kwargs["n_jobs"] = job_kwargs["n_jobs"] estimator = sklearn.covariance.GraphicalLassoCV(**regularize_kwargs) estimator.fit(data) cov = estimator.covariance_