diff --git a/mne/tests/test_filter.py b/mne/tests/test_filter.py index e259ececbce..52e555a775e 100644 --- a/mne/tests/test_filter.py +++ b/mne/tests/test_filter.py @@ -16,6 +16,7 @@ from scipy.signal import resample as sp_resample from mne import Epochs, create_info +from mne.cuda import get_shared_mem from mne._fiff.pick import _DATA_CH_TYPES_SPLIT from mne.filter import ( _length_factors, @@ -408,6 +409,10 @@ def test_resample_scipy(): err_msg = f"{N}: {window}" x_2_sp = sp_resample(x, 2 * N, window=window) for n_jobs in n_jobs_test: + if n_jobs == "cuda": + tmp = x + x = get_shared_mem(x.shape) + x[:] = tmp x_2 = resample(x, 2, 1, npad=0, window=window, n_jobs=n_jobs) assert_allclose(x_2, x_2_sp, atol=1e-12, err_msg=err_msg) new_len = int(round(len(x) * (1.0 / 2.0))) @@ -421,6 +426,12 @@ def test_resample_scipy(): def test_n_jobs(n_jobs): """Test resampling against SciPy.""" x = np.random.RandomState(0).randn(4, 100) + + if n_jobs == "cuda": + tmp = x + x = get_shared_mem(x.shape) + x[:] = tmp + y1 = resample(x, 2, 1, n_jobs=None) y2 = resample(x, 2, 1, n_jobs=n_jobs) assert_allclose(y1, y2) @@ -846,6 +857,8 @@ def test_cuda_resampling(): a = rng.randn(2, N) for fro, to in ((1, 2), (2, 1), (1, 3), (3, 1)): a1 = resample(a, fro, to, n_jobs=None, npad="auto", window=window) + x = get_shared_mem(a.shape) + x[:] = a a2 = resample(a, fro, to, n_jobs="cuda", npad="auto", window=window) assert_allclose(a1, a2, rtol=1e-7, atol=1e-14) assert_array_almost_equal(a1, a2, 14)