Skip to content

Commit

Permalink
Add tests for epochs object
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexLepauvre committed Jun 7, 2024
1 parent 6dc0633 commit 32fa6ac
Showing 1 changed file with 108 additions and 15 deletions.
123 changes: 108 additions & 15 deletions mne/_fiff/tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,57 +366,65 @@ def test_set_eeg_reference_rest():

@testing.requires_testing_data
@pytest.mark.parametrize(
"ref_channels, expectation",
"ref_channels, inst_type, expectation",
[
(
{2: "EEG 001"},
"raw",
pytest.raises(
AssertionError, match=f"Keys in dict-type.*You provided {int}"
),
),
(
{"EEG 001": (1, 2)},
"raw",
pytest.raises(
ValueError, match=f"Values in dict-type.*You provided {type((1,2))}"
),
),
(
{"EEG 001": [1, 2]},
"raw",
pytest.raises(
AssertionError,
match="Values in dict-type.*You provided a list of <class 'int'>",
),
),
(
{"EEG 999": "EEG 001"},
"raw",
pytest.raises(
AssertionError,
match="Channel EEG 999 in ref_channels is not in the instance",
),
),
(
{"EEG 001": "EEG 999"},
"raw",
pytest.raises(
AssertionError,
match="Channel EEG 999 in ref_channels is not in the instance",
),
),
(
{"EEG 001": "EEG 057"},
"raw",
pytest.warns(
RuntimeWarning,
match="Channel EEG 057 in ref_channels is marked as bad!",
),
),
(
{"EEG 001": "STI 001"},
"raw",
pytest.warns(
RuntimeWarning,
match="Channel EEG 001 is of type EEG, but reference channel STI 001 is of type Stimulus.",
),
),
(
{"EEG 001": "EEG 002", "EEG 002": "EEG 002", "EEG 003": "EEG 005"},
"raw",
nullcontext(),
),
(
Expand All @@ -425,51 +433,136 @@ def test_set_eeg_reference_rest():
"EEG 002": "EEG 002",
"EEG 003": "EEG 005",
},
"raw",
nullcontext(),
),
(
{2: "EEG 001"},
"epochs",
pytest.raises(
AssertionError, match=f"Keys in dict-type.*You provided {int}"
),
),
(
{"EEG 001": (1, 2)},
"epochs",
pytest.raises(
ValueError, match=f"Values in dict-type.*You provided {type((1, 2))}"
),
),
(
{"EEG 001": [1, 2]},
"epochs",
pytest.raises(
AssertionError,
match="Values in dict-type.*You provided a list of <class 'int'>",
),
),
(
{"EEG 999": "EEG 001"},
"epochs",
pytest.raises(
AssertionError,
match="Channel EEG 999 in ref_channels is not in the instance",
),
),
(
{"EEG 001": "EEG 999"},
"epochs",
pytest.raises(
AssertionError,
match="Channel EEG 999 in ref_channels is not in the instance",
),
),
(
{"EEG 001": "EEG 057"},
"epochs",
pytest.warns(
RuntimeWarning,
match="Channel EEG 057 in ref_channels is marked as bad!",
),
),
(
{"EEG 001": "STI 001"},
"epochs",
pytest.warns(
RuntimeWarning,
match="Channel EEG 001 is of type EEG, but reference channel STI 001 is of type Stimulus.",
),
),
(
{"EEG 001": "EEG 002", "EEG 002": "EEG 002", "EEG 003": "EEG 005"},
"epochs",
nullcontext(),
),
(
{
"EEG 001": ["EEG 002", "EEG 003"],
"EEG 002": "EEG 002",
"EEG 003": "EEG 005",
},
"epochs",
nullcontext(),
),
],
)
def test_set_eeg_reference_dict(ref_channels, expectation):
def test_set_eeg_reference_dict(ref_channels, inst_type, expectation):
"""Test setting dict-based reference."""
raw = read_raw_fif(fif_fname).crop(0, 1).pick(picks=["eeg", "stim"])
if inst_type == "raw":
inst = read_raw_fif(fif_fname).crop(0, 1).pick(picks=["eeg", "stim"])
# Test re-referencing Epochs object
elif inst_type == "epochs":
raw = read_raw_fif(fif_fname, preload=False)
events = read_events(eve_fname)
inst = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
preload=False,
)
with pytest.raises(
RuntimeError,
match="By default, MNE does not load data.*Applying a reference requires.*",
):
raw.set_eeg_reference(ref_channels=ref_channels)
raw.load_data()
raw.info["bads"] = ["EEG 057"]
inst.set_eeg_reference(ref_channels=ref_channels)
inst.load_data()
inst.info["bads"] = ["EEG 057"]
with expectation:
reref, _reref = set_eeg_reference(raw.copy(), ref_channels, copy=False)
reref, _reref = set_eeg_reference(inst.copy(), ref_channels, copy=False)

if isinstance(expectation, nullcontext):
# Check that the custom_ref_applied is set correctly:
assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_ON

# Get raw data
_data = raw._data
_data = inst._data

# Get that channels that were and weren't re-referenced:
ch_raw = pick_channels(
raw.ch_names,
[ch for ch in raw.ch_names if ch not in list(ref_channels.keys())],
inst.ch_names,
[ch for ch in inst.ch_names if ch not in list(ref_channels.keys())],
)
ch_reref = pick_channels(raw.ch_names, list(ref_channels.keys()), ordered=True)
ch_reref = pick_channels(inst.ch_names, list(ref_channels.keys()), ordered=True)

# Check that the non re-reference channels are untouched:
assert_allclose(_data[ch_raw, :], _reref[ch_raw, :], 1e-6, atol=1e-15)
assert_allclose(_data[..., ch_raw, :], _reref[..., ch_raw, :], 1e-6, atol=1e-15)

# Compute the reference data:
ref_data = []
for val in ref_channels.values():
if isinstance(val, str):
val = [val] # pick_channels expects a list
ref_data.append(_data[..., pick_channels(raw.ch_names, val, ordered=True), :].mean(
ref_data.append(_data[..., pick_channels(inst.ch_names, val, ordered=True), :].mean(
-2, keepdims=True
))
ref_data = np.squeeze(np.array(ref_data))
if inst_type == "epochs":
ref_data = np.concatenate(ref_data, axis=1)
else:
ref_data = np.squeeze(np.array(ref_data))
assert_allclose(
_data[ch_reref, :], _reref[ch_reref, :] + ref_data, 1e-6, atol=1e-15
_data[..., ch_reref, :], _reref[..., ch_reref, :] + ref_data, 1e-6, atol=1e-15
)


Expand Down

0 comments on commit 32fa6ac

Please sign in to comment.