Skip to content

Commit

Permalink
update mne/cuda to call cupy.asarray when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
scottrbrtsn committed Dec 3, 2024
1 parent aa9dcb6 commit 9dfc89f
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions mne/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from scipy.fft import irfft, rfft


from .utils import (
_check_option,
_explain_exception,
Expand All @@ -18,6 +19,45 @@

_cuda_capable = False

def get_shared_mem(
shape,
dtype=np.float64,
strides=None,
order="C",
stream=0,
portable=False,
wc=True,
):
"""Get shared memory space to avoid copying from cpu to gpu when possible.
Allocate a mapped ndarray with a buffer that is pinned and mapped on
to the device. Similar to np.empty()
Parameters
----------
portable: bool
a boolean flag to allow the allocated device memory to be
usable in multiple devices.
wc: bool
a boolean flag to enable writecombined allocation which is faster
to write by the host and to read by the device, but slower to
write by the host and slower to write by the device.
Returns
-------
a mapped array: np.ndarray
An array to be passed into cupy.asarray, which does not copy if shared memory is already allocated.
"""
from numba import cuda
return cuda.mapped_array(
shape,
dtype=dtype,
strides=strides,
order=order,
stream=stream,
portable=portable,
wc=wc,
)

def get_cuda_memory(kind="available"):
"""Get the amount of free memory for CUDA operations.
Expand Down Expand Up @@ -176,7 +216,7 @@ def _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft, kind="FFT FIR filtering"

try:
# do the IFFT normalization now so we don't have to later
h_fft = cupy.array(cuda_dict["h_fft"])
h_fft = cupy.asarray(cuda_dict["h_fft"])
logger.info(f"Using CUDA for {kind}")
except Exception as exp:
logger.info(
Expand Down Expand Up @@ -276,7 +316,7 @@ def _setup_cuda_fft_resample(n_jobs, W, new_len):
import cupy

# do the IFFT normalization now so we don't have to later
W = cupy.array(W)
W = cupy.asarray(W)
logger.info("Using CUDA for FFT resampling")
except Exception:
logger.info(
Expand All @@ -301,7 +341,7 @@ def _cuda_upload_rfft(x, n, axis=-1):
"""Upload and compute rfft."""
import cupy

return cupy.fft.rfft(cupy.array(x), n=n, axis=axis)
return cupy.fft.rfft(cupy.asarray(x), n=n, axis=axis)


def _cuda_irfft_get(x, n, axis=-1):
Expand Down

0 comments on commit 9dfc89f

Please sign in to comment.