Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 28, 2024
1 parent d83d0c1 commit b2e28f7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def test_whiten():

# test regularization
with pytest.raises(AssertionError):
W, M = compute_whitening_matrix(rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, regularize=True)
W, M = compute_whitening_matrix(
rec, "local", random_chunk_kwargs, apply_mean=False, radius_um=None, regularize=True
)
# W must be sparse
np.sum(W == 0) == 6

Expand Down
22 changes: 14 additions & 8 deletions src/spikeinterface/preprocessing/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,21 @@ def __init__(

if dtype_.kind == "i":
assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale"

if W is not None:
W = np.asarray(W)
if M is not None:
M = np.asarray(M)
else:
W, M = compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=radius_um, eps=eps, regularize=regularize,
regularize_kwargs=regularize_kwargs
recording,
mode,
random_chunk_kwargs,
apply_mean,
radius_um=radius_um,
eps=eps,
regularize=regularize,
regularize_kwargs=regularize_kwargs,
)

BasePreprocessor.__init__(self, recording, dtype=dtype_)
Expand Down Expand Up @@ -142,8 +148,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):


def compute_whitening_matrix(
recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False,
regularize_kwargs=None
recording, mode, random_chunk_kwargs, apply_mean, radius_um=None, eps=None, regularize=False, regularize_kwargs=None
):
"""
Compute whitening matrix
Expand Down Expand Up @@ -197,12 +202,13 @@ def compute_whitening_matrix(
cov = cov / data.shape[0]
else:
import sklearn.covariance

if regularize_kwargs is None:
regularize_kwargs = {}
regularize_kwargs['assume_centered'] = True
regularize_kwargs["assume_centered"] = True
job_kwargs = get_global_job_kwargs()
if 'n_jobs' in job_kwargs and 'n_jobs' not in regularize_kwargs:
regularize_kwargs['n_jobs'] = job_kwargs['n_jobs']
if "n_jobs" in job_kwargs and "n_jobs" not in regularize_kwargs:
regularize_kwargs["n_jobs"] = job_kwargs["n_jobs"]
estimator = sklearn.covariance.GraphicalLassoCV(**regularize_kwargs)
estimator.fit(data)
cov = estimator.covariance_
Expand Down

0 comments on commit b2e28f7

Please sign in to comment.