Skip to content

Commit

Permalink
MAINT: Switch to sparse arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jun 5, 2024
1 parent 1d56eb9 commit 075d8d1
Show file tree
Hide file tree
Showing 27 changed files with 112 additions and 112 deletions.
4 changes: 2 additions & 2 deletions mne/_fiff/proc_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from ..fixes import _csc_matrix_cast
from ..fixes import _csc_array_cast
from ..utils import _check_fname, warn
from .constants import FIFF
from .open import fiff_open, read_tag
Expand Down Expand Up @@ -198,7 +198,7 @@ def _write_proc_history(fid, info):
FIFF.FIFF_DECOUPLER_MATRIX,
)
_sss_ctc_writers = (write_id, write_int, write_string, write_float_sparse)
_sss_ctc_casters = (dict, np.array, str, _csc_matrix_cast)
_sss_ctc_casters = (dict, np.array, str, _csc_array_cast)

_sss_cal_keys = ("cal_chans", "cal_corrs")
_sss_cal_ids = (FIFF.FIFF_SSS_CAL_CHANS, FIFF.FIFF_SSS_CAL_CORRS)
Expand Down
6 changes: 3 additions & 3 deletions mne/_fiff/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any

import numpy as np
from scipy.sparse import csc_matrix, csr_matrix
from scipy.sparse import csc_array, csr_array

from ..utils import _check_option, warn
from ..utils.numerics import _julian_to_cal
Expand Down Expand Up @@ -214,7 +214,7 @@ def _read_matrix(fid, tag, shape, rlims):
)
)
indptr = np.frombuffer(tmp_ptr, dtype="<i4")
data = csc_matrix((data, indices, indptr), shape=shape)
data = csc_array((data, indices, indptr), shape=shape)
else:
assert matrix_coding == "sparse RCS", matrix_coding
tmp_indices = fid.read(4 * nnz)
Expand All @@ -231,7 +231,7 @@ def _read_matrix(fid, tag, shape, rlims):
)
)
indptr = np.frombuffer(tmp_ptr, dtype="<i4")
data = csr_matrix((data, indices, indptr), shape=shape)
data = csr_array((data, indices, indptr), shape=shape)
return data


Expand Down
8 changes: 4 additions & 4 deletions mne/_fiff/tests/test_meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,23 +803,23 @@ def test_csr_csc(tmp_path):
sss_ctc = info["proc_history"][0]["max_info"]["sss_ctc"]
ct = sss_ctc["decoupler"].copy()
# CSC
assert isinstance(ct, sparse.csc_matrix)
assert isinstance(ct, sparse.csc_array)
fname = tmp_path / "test.fif"
write_info(fname, info)
info_read = read_info(fname)
ct_read = info_read["proc_history"][0]["max_info"]["sss_ctc"]["decoupler"]
assert isinstance(ct_read, sparse.csc_matrix)
assert isinstance(ct_read, sparse.csc_array)
assert_array_equal(ct_read.toarray(), ct.toarray())
# Now CSR
csr = ct.tocsr()
assert isinstance(csr, sparse.csr_matrix)
assert isinstance(csr, sparse.csr_array)
assert_array_equal(csr.toarray(), ct.toarray())
info["proc_history"][0]["max_info"]["sss_ctc"]["decoupler"] = csr
fname = tmp_path / "test1.fif"
write_info(fname, info)
info_read = read_info(fname)
ct_read = info_read["proc_history"][0]["max_info"]["sss_ctc"]["decoupler"]
assert isinstance(ct_read, sparse.csc_matrix) # this gets cast to CSC
assert isinstance(ct_read, sparse.csc_array) # this gets cast to CSC
assert_array_equal(ct_read.toarray(), ct.toarray())


Expand Down
6 changes: 3 additions & 3 deletions mne/_fiff/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from gzip import GzipFile

import numpy as np
from scipy.sparse import csc_matrix, csr_matrix
from scipy.sparse import csc_array, csr_array

from ..utils import _file_like, _validate_type, logger
from ..utils.numerics import _cal_to_julian
Expand Down Expand Up @@ -417,8 +417,8 @@ def write_float_sparse_rcs(fid, kind, mat):
def write_float_sparse(fid, kind, mat, fmt="auto"):
"""Write a single-precision floating-point sparse matrix tag."""
if fmt == "auto":
fmt = "csr" if isinstance(mat, csr_matrix) else "csc"
need = csr_matrix if fmt == "csr" else csc_matrix
fmt = "csr" if isinstance(mat, csr_array) else "csc"
need = csr_array if fmt == "csr" else csc_array
matrix_type = getattr(FIFF, f"FIFFT_SPARSE_{fmt[-1].upper()}CS_MATRIX")
_validate_type(mat, need, "sparse")
matrix_type = matrix_type | FIFF.FIFFT_MATRIX | FIFF.FIFFT_FLOAT
Expand Down
14 changes: 7 additions & 7 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numpy as np
from scipy.io import loadmat
from scipy.sparse import csr_matrix, lil_matrix
from scipy.sparse import csr_array, lil_array
from scipy.spatial import Delaunay
from scipy.stats import zscore

Expand Down Expand Up @@ -1336,7 +1336,7 @@ def read_ch_adjacency(fname, picks=None):
Returns
-------
ch_adjacency : scipy.sparse.csr_matrix, shape (n_channels, n_channels)
ch_adjacency : scipy.sparse.csr_array, shape (n_channels, n_channels)
The adjacency matrix.
ch_names : list
The list of channel names present in adjacency matrix.
Expand Down Expand Up @@ -1437,7 +1437,7 @@ def _ch_neighbor_adjacency(ch_names, neighbors):
ch_adjacency = np.eye(len(ch_names), dtype=bool)
for ii, neigbs in enumerate(neighbors):
ch_adjacency[ii, [ch_names.index(i) for i in neigbs]] = True
ch_adjacency = csr_matrix(ch_adjacency)
ch_adjacency = csr_array(ch_adjacency)
return ch_adjacency


Expand All @@ -1459,7 +1459,7 @@ def find_ch_adjacency(info, ch_type):
Returns
-------
ch_adjacency : scipy.sparse.csr_matrix, shape (n_channels, n_channels)
ch_adjacency : scipy.sparse.csr_array, shape (n_channels, n_channels)
The adjacency matrix.
ch_names : list
The list of channel names present in adjacency matrix.
Expand Down Expand Up @@ -1572,7 +1572,7 @@ def _compute_ch_adjacency(info, ch_type):
Returns
-------
ch_adjacency : scipy.sparse.csr_matrix, shape (n_channels, n_channels)
ch_adjacency : scipy.sparse.csr_array, shape (n_channels, n_channels)
The adjacency matrix.
ch_names : list
The list of channel names present in adjacency matrix.
Expand Down Expand Up @@ -1611,9 +1611,9 @@ def _compute_ch_adjacency(info, ch_type):
for jj in range(2):
ch_adjacency[idx * 2 + ii, neigbs * 2 + jj] = True
ch_adjacency[idx * 2 + ii, idx * 2 + jj] = True # pair
ch_adjacency = csr_matrix(ch_adjacency)
ch_adjacency = csr_array(ch_adjacency)
else:
ch_adjacency = lil_matrix(neighbors)
ch_adjacency = lil_array(neighbors)
ch_adjacency.setdiag(np.repeat(1, ch_adjacency.shape[0]))
ch_adjacency = ch_adjacency.tocsr()

Expand Down
6 changes: 3 additions & 3 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def _safe_svd(A, **kwargs):
return linalg.svd(A, lapack_driver="gesvd", **kwargs)


def _csc_matrix_cast(x):
from scipy.sparse import csc_matrix
def _csc_array_cast(x):
from scipy.sparse import csc_array

return csc_matrix(x)
return csc_array(x)


###############################################################################
Expand Down
14 changes: 7 additions & 7 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _block_diag(A, n):
Returns
-------
bd : scipy.sparse.spmatrix
bd : scipy.sparse.csc_array
The block diagonal matrix
"""
if sparse.issparse(A): # then make block sparse
Expand All @@ -311,7 +311,7 @@ def _block_diag(A, n):
jj = jj * np.ones(ma, dtype=np.int64)[:, None]
jj = jj.T.ravel() # column indices foreach sparse bd

bd = sparse.coo_matrix((A.T.ravel(), np.c_[ii, jj].T)).tocsc()
bd = sparse.coo_array((A.T.ravel(), np.c_[ii, jj].T)).tocsc()

return bd

Expand Down Expand Up @@ -797,11 +797,11 @@ def convert_forward_solution(
fix_rot = _block_diag(fwd["source_nn"].T, 1)
# newer versions of numpy require explicit casting here, so *= no
# longer works
fwd["sol"]["data"] = (fwd["_orig_sol"] * fix_rot).astype("float32")
fwd["sol"]["data"] = (fwd["_orig_sol"] @ fix_rot).astype("float32")
fwd["sol"]["ncol"] = fwd["nsource"]
if fwd["sol_grad"] is not None:
x = sparse.block_diag([fix_rot] * 3)
fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] * x # dot prod
fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] @ x
fwd["sol_grad"]["ncol"] = 3 * fwd["nsource"]
fwd["source_ori"] = FIFF.FIFFV_MNE_FIXED_ORI
fwd["surf_ori"] = True
Expand All @@ -828,7 +828,7 @@ def convert_forward_solution(
fix_rot = _block_diag(fwd["source_nn"].T, 1)
# newer versions of numpy require explicit casting here, so *= no
# longer works
fwd["sol"]["data"] = (fwd["_orig_sol"] * fix_rot).astype("float32")
fwd["sol"]["data"] = (fwd["_orig_sol"] @ fix_rot).astype("float32")
fwd["sol"]["ncol"] = fwd["nsource"]
if fwd["sol_grad"] is not None:
x = sparse.block_diag([fix_rot] * 3)
Expand All @@ -838,11 +838,11 @@ def convert_forward_solution(
fwd["surf_ori"] = True
else:
surf_rot = _block_diag(fwd["source_nn"].T, 3)
fwd["sol"]["data"] = fwd["_orig_sol"] * surf_rot
fwd["sol"]["data"] = fwd["_orig_sol"] @ surf_rot
fwd["sol"]["ncol"] = 3 * fwd["nsource"]
if fwd["sol_grad"] is not None:
x = sparse.block_diag([surf_rot] * 3)
fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] * x # dot prod
fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] @ x # dot prod
fwd["sol_grad"]["ncol"] = 9 * fwd["nsource"]
fwd["source_ori"] = FIFF.FIFFV_MNE_FREE_ORI
fwd["surf_ori"] = True
Expand Down
8 changes: 4 additions & 4 deletions mne/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,7 @@ def _verts_within_dist(graph, sources, max_dist):
Parameters
----------
graph : scipy.sparse.csr_matrix
graph : scipy.sparse.csr_array
Sparse matrix with distances between adjacent vertices.
sources : list of int
Source vertices.
Expand Down Expand Up @@ -2572,9 +2572,9 @@ def _labels_to_stc_surf(labels, values, tmin, tstep, subject):
data[hemi] = np.concatenate(data[hemi], axis=0).astype(float)
cols = np.arange(len(vertices[hemi]))
vertices[hemi], rows = np.unique(vertices[hemi], return_inverse=True)
mat = sparse.coo_matrix((np.ones(len(rows)), (rows, cols))).tocsr()
mat = mat * sparse.diags(1.0 / np.asarray(mat.sum(axis=-1))[:, 0])
data[hemi] = mat.dot(data[hemi])
mat = sparse.coo_array((np.ones(len(rows)), (rows, cols))).tocsr()
mat = mat @ sparse.diags_array(1.0 / np.asarray(mat.sum(axis=-1))[:, 0])
data[hemi] = mat @ data[hemi]
vertices = [vertices[hemi] for hemi in hemis]
data = np.concatenate([data[hemi] for hemi in hemis], axis=0)
return data, vertices, subject
Expand Down
2 changes: 1 addition & 1 deletion mne/minimum_norm/tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _compare(a, b):
assert len(a) == len(b)
for i, j in zip(a, b):
_compare(i, j)
elif isinstance(a, sparse.csr_matrix):
elif isinstance(a, sparse.csr_array):
assert_array_almost_equal(a.data, b.data)
assert_equal(a.indices, b.indices)
assert_equal(a.indptr, b.indptr)
Expand Down
30 changes: 15 additions & 15 deletions mne/morph.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _compute_sparse_morph(vertices_from, subject_from, subject_to, subjects_dir=
cols = np.concatenate(cols)
rows = np.arange(len(cols))
data = np.ones(len(cols))
morph_mat = sparse.coo_matrix(
morph_mat = sparse.coo_array(
(data, (rows, cols)), shape=(len(cols), len(cols))
).tocsr()
return vertices, morph_mat
Expand Down Expand Up @@ -404,7 +404,7 @@ class SourceMorph:
See :func:`mne.compute_source_morph`.
xhemi : bool
Morph across hemisphere.
morph_mat : scipy.sparse.csr_matrix
morph_mat : scipy.sparse.csr_array
The sparse surface morphing matrix for spherical surface
based morphing :footcite:`GreveEtAl2013`.
vertices_to : list of ndarray
Expand All @@ -420,7 +420,7 @@ class SourceMorph:
(SDR) morph.
src_data : dict
Additional source data necessary to perform morphing.
vol_morph_mat : scipy.sparse.csr_matrix | None
vol_morph_mat : scipy.sparse.csr_array | None
The volumetric morph matrix, if :meth:`compute_vol_morph_mat`
was used.
%(verbose)s
Expand Down Expand Up @@ -587,7 +587,7 @@ def compute_vol_morph_mat(self, *, verbose=None):
-----
For a volumetric morph, this will compute the morph for an identity
source volume, i.e., with one source vertex active at a time, and store
the result as a :class:`sparse <scipy.sparse.csr_matrix>`
the result as a :class:`sparse <scipy.sparse.csr_array>`
morphing matrix. This takes a long time (minutes) to compute initially,
but drastically speeds up :meth:`apply` for STCs, so it can be
beneficial when many time points or many morphs (i.e., greater than
Expand Down Expand Up @@ -728,7 +728,7 @@ def _morph_vols(self, vols, mesg, subselect=True):
img_to[:, ii] = img_to[:, ii] + 1j * img_real

if vols is None:
img_to = sparse.csc_matrix(img_to, shape=(len(vol_verts), n_vols)).tocsr()
img_to = sparse.csc_array(img_to, shape=(len(vol_verts), n_vols)).tocsr()

return img_to

Expand Down Expand Up @@ -1216,24 +1216,24 @@ def _compute_morph_matrix(
data = np.concatenate(data)
# this is equivalent to morpher = sparse_block_diag(morpher).tocsr(),
# but works for xhemi mode
morpher = sparse.csr_matrix((data, indices, indptr), shape=shape)
morpher = sparse.csr_array((data, indices, indptr), shape=shape)
logger.info("[done]")
return morpher


def _hemi_morph(tris, vertices_to, vertices_from, smooth, maps, warn):
_validate_type(smooth, (str, None, "int-like"), "smoothing steps")
if len(vertices_from) == 0:
return sparse.csr_matrix((len(vertices_to), 0))
return sparse.csr_array((len(vertices_to), 0))
e = mesh_edges(tris)
e.data[e.data == 2] = 1
n_vertices = e.shape[0]
e += sparse.eye(n_vertices, format="csr")
e += sparse.eye_array(n_vertices, format="csr")
if isinstance(smooth, str):
_check_option("smooth", smooth, ("nearest",), extra=" when used as a string.")
mm = _surf_nearest(vertices_from, e).tocsr()
elif smooth == 0:
mm = sparse.csc_matrix(
mm = sparse.csc_array(
(
np.ones(len(vertices_from)), # data, indices, indptr
vertices_from,
Expand All @@ -1251,7 +1251,7 @@ def _hemi_morph(tris, vertices_to, vertices_from, smooth, maps, warn):
logger.info(f" {n_iter} smooth iterations done.")
assert mm.shape == (n_vertices, len(vertices_from))
if maps is not None:
mm = maps[vertices_to] * mm
mm = maps[vertices_to] @ mm
else: # to == from
mm = mm[vertices_to]
assert mm.shape == (len(vertices_to), len(vertices_from))
Expand Down Expand Up @@ -1346,7 +1346,7 @@ def _surf_nearest(vertices, adj_mat):
col = order[col]
row = np.arange(len(col))
data = np.ones(len(col))
mat = sparse.coo_matrix((data, (row, col)))
mat = sparse.coo_array((data, (row, col)))
assert mat.shape == (adj_mat.shape[0], len(vertices)), mat.shape
return mat

Expand All @@ -1364,12 +1364,12 @@ def _csr_row_norm(data, row_norm):
def _surf_upsampling_mat(idx_from, e, smooth):
"""Upsample data on a subject's surface given mesh edges."""
# we're in CSR format and it's to==from
assert isinstance(e, sparse.csr_matrix)
assert isinstance(e, sparse.csr_array)
n_tot = e.shape[0]
assert e.shape == (n_tot, n_tot)
# our output matrix starts out as a smaller matrix, and will gradually
# increase in size
data = sparse.eye(len(idx_from), format="csr")
data = sparse.eye_array(len(idx_from), format="csr")
_validate_type(smooth, ("int-like", str, None), "smoothing steps")
if smooth is not None: # number of steps
smooth = _ensure_int(smooth, "smoothing steps")
Expand All @@ -1389,7 +1389,7 @@ def _surf_upsampling_mat(idx_from, e, smooth):
data = data[idx]
# smoothing multiplication
use_e = e[:, idx] if len(idx) < n_tot else e
data = use_e * data
data = use_e @ data
del use_e
# compute row sums + output indices
if recompute_idx_sum:
Expand All @@ -1399,7 +1399,7 @@ def _surf_upsampling_mat(idx_from, e, smooth):
recompute_idx_sum = False
else:
mult[idx] = 1
row_sum = e * mult
row_sum = e @ mult
idx = np.where(row_sum)[0]
# do row normalization
_csr_row_norm(data, row_sum)
Expand Down
Loading

0 comments on commit 075d8d1

Please sign in to comment.