Skip to content

Commit

Permalink
update tests as poc, they pass
Browse files Browse the repository at this point in the history
  • Loading branch information
scottrbrtsn committed Dec 3, 2024
1 parent 9dfc89f commit 4ad40d4
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions mne/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from scipy.signal import resample as sp_resample

from mne import Epochs, create_info
from mne.cuda import get_shared_mem
from mne._fiff.pick import _DATA_CH_TYPES_SPLIT
from mne.filter import (
_length_factors,
Expand Down Expand Up @@ -408,6 +409,10 @@ def test_resample_scipy():
err_msg = f"{N}: {window}"
x_2_sp = sp_resample(x, 2 * N, window=window)
for n_jobs in n_jobs_test:
if n_jobs == "cuda":
tmp = x
x = get_shared_mem(x.shape)
x[:] = tmp
x_2 = resample(x, 2, 1, npad=0, window=window, n_jobs=n_jobs)
assert_allclose(x_2, x_2_sp, atol=1e-12, err_msg=err_msg)
new_len = int(round(len(x) * (1.0 / 2.0)))
Expand All @@ -421,6 +426,12 @@ def test_resample_scipy():
def test_n_jobs(n_jobs):
"""Test resampling against SciPy."""
x = np.random.RandomState(0).randn(4, 100)

if n_jobs == "cuda":
tmp = x
x = get_shared_mem(x.shape)
x[:] = tmp

y1 = resample(x, 2, 1, n_jobs=None)
y2 = resample(x, 2, 1, n_jobs=n_jobs)
assert_allclose(y1, y2)
Expand Down Expand Up @@ -846,6 +857,8 @@ def test_cuda_resampling():
a = rng.randn(2, N)
for fro, to in ((1, 2), (2, 1), (1, 3), (3, 1)):
a1 = resample(a, fro, to, n_jobs=None, npad="auto", window=window)
x = get_shared_mem(a.shape)
x[:] = a
a2 = resample(a, fro, to, n_jobs="cuda", npad="auto", window=window)
assert_allclose(a1, a2, rtol=1e-7, atol=1e-14)
assert_array_almost_equal(a1, a2, 14)
Expand Down

0 comments on commit 4ad40d4

Please sign in to comment.