diff --git a/mne/_fiff/tests/test_reference.py b/mne/_fiff/tests/test_reference.py index 73ec780aac9..1297377c8b3 100644 --- a/mne/_fiff/tests/test_reference.py +++ b/mne/_fiff/tests/test_reference.py @@ -437,72 +437,72 @@ def test_set_eeg_reference_rest(): nullcontext(), ), ( - {2: "EEG 001"}, - "epochs", - pytest.raises( - AssertionError, match=f"Keys in dict-type.*You provided {int}" - ), + {2: "EEG 001"}, + "epochs", + pytest.raises( + AssertionError, match=f"Keys in dict-type.*You provided {int}" + ), ), ( - {"EEG 001": (1, 2)}, - "epochs", - pytest.raises( - ValueError, match=f"Values in dict-type.*You provided {type((1, 2))}" - ), + {"EEG 001": (1, 2)}, + "epochs", + pytest.raises( + ValueError, match=f"Values in dict-type.*You provided {type((1, 2))}" + ), ), ( - {"EEG 001": [1, 2]}, - "epochs", - pytest.raises( - AssertionError, - match="Values in dict-type.*You provided a list of ", - ), + {"EEG 001": [1, 2]}, + "epochs", + pytest.raises( + AssertionError, + match="Values in dict-type.*You provided a list of ", + ), ), ( - {"EEG 999": "EEG 001"}, - "epochs", - pytest.raises( - AssertionError, - match="Channel EEG 999 in ref_channels is not in the instance", - ), + {"EEG 999": "EEG 001"}, + "epochs", + pytest.raises( + AssertionError, + match="Channel EEG 999 in ref_channels is not in the instance", + ), ), ( - {"EEG 001": "EEG 999"}, - "epochs", - pytest.raises( - AssertionError, - match="Channel EEG 999 in ref_channels is not in the instance", - ), + {"EEG 001": "EEG 999"}, + "epochs", + pytest.raises( + AssertionError, + match="Channel EEG 999 in ref_channels is not in the instance", + ), ), ( - {"EEG 001": "EEG 057"}, - "epochs", - pytest.warns( - RuntimeWarning, - match="Channel EEG 057 in ref_channels is marked as bad!", - ), + {"EEG 001": "EEG 057"}, + "epochs", + pytest.warns( + RuntimeWarning, + match="Channel EEG 057 in ref_channels is marked as bad!", + ), ), ( - {"EEG 001": "STI 001"}, - "epochs", - pytest.warns( - RuntimeWarning, - match="Channel EEG 001 is of type EEG, but reference channel STI 001 is of type Stimulus.", - ), + {"EEG 001": "STI 001"}, + "epochs", + pytest.warns( + RuntimeWarning, + match="Channel EEG 001 is of type EEG, but reference channel STI 001 is of type Stimulus.", + ), ), ( - {"EEG 001": "EEG 002", "EEG 002": "EEG 002", "EEG 003": "EEG 005"}, - "epochs", - nullcontext(), + {"EEG 001": "EEG 002", "EEG 002": "EEG 002", "EEG 003": "EEG 005"}, + "epochs", + nullcontext(), ), ( - { - "EEG 001": ["EEG 002", "EEG 003"], - "EEG 002": "EEG 002", - "EEG 003": "EEG 005", - }, - "epochs", - nullcontext(), + { + "EEG 001": ["EEG 002", "EEG 003"], + "EEG 002": "EEG 002", + "EEG 003": "EEG 005", + }, + "epochs", + nullcontext(), ), ], ) @@ -554,15 +554,20 @@ def test_set_eeg_reference_dict(ref_channels, inst_type, expectation): for val in ref_channels.values(): if isinstance(val, str): val = [val] # pick_channels expects a list - ref_data.append(_data[..., pick_channels(inst.ch_names, val, ordered=True), :].mean( + ref_data.append( + _data[..., pick_channels(inst.ch_names, val, ordered=True), :].mean( -2, keepdims=True - )) + ) + ) if inst_type == "epochs": ref_data = np.concatenate(ref_data, axis=1) else: ref_data = np.squeeze(np.array(ref_data)) assert_allclose( - _data[..., ch_reref, :], _reref[..., ch_reref, :] + ref_data, 1e-6, atol=1e-15 + _data[..., ch_reref, :], + _reref[..., ch_reref, :] + ref_data, + 1e-6, + atol=1e-15, )