Skip to content

Commit

Permalink
ENH: Add ability to reject epochs using callables (mne-tools#12195)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
3 people authored and snwnde committed Mar 20, 2024
1 parent 7eb9e6f commit d64694e
Show file tree
Hide file tree
Showing 6 changed files with 399 additions and 38 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12195.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ability reject :class:`mne.Epochs` using callables, by `Jacob Woessner`_.
102 changes: 78 additions & 24 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None):
self.baseline = baseline
return self

def _reject_setup(self, reject, flat):
def _reject_setup(self, reject, flat, *, allow_callable=False):
"""Set self._reject_time and self._channel_type_idx."""
idx = channel_indices_by_type(self.info)
reject = deepcopy(reject) if reject is not None else dict()
Expand All @@ -814,11 +814,21 @@ def _reject_setup(self, reject, flat):
f"{key.upper()}."
)

# check for invalid values
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
if val is None or val < 0:
raise ValueError(f'{kind} value must be a number >= 0, not "{val}"')
# check for invalid values
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
name = f"{kind} dict value for {key}"
if callable(val) and allow_callable:
continue
extra_str = ""
if allow_callable:
extra_str = "or callable"
_validate_type(val, "numeric", name, extra=extra_str)
if val is None or val < 0:
raise ValueError(
f"If using numerical {name} criteria, the value "
f"must be >= 0, not {repr(val)}"
)

# now check to see if our rejection and flat are getting more
# restrictive
Expand All @@ -836,6 +846,9 @@ def _reject_setup(self, reject, flat):
reject[key] = old_reject[key]
# make sure new thresholds are at least as stringent as the old ones
for key in reject:
# Skip this check if old_reject and reject are callables
if callable(reject[key]) and allow_callable:
continue
if key in old_reject and reject[key] > old_reject[key]:
raise ValueError(
bad_msg.format(
Expand All @@ -851,6 +864,8 @@ def _reject_setup(self, reject, flat):
for key in set(old_flat) - set(flat):
flat[key] = old_flat[key]
for key in flat:
if callable(flat[key]) and allow_callable:
continue
if key in old_flat and flat[key] < old_flat[key]:
raise ValueError(
bad_msg.format(
Expand Down Expand Up @@ -1404,7 +1419,7 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None):
flat = self.flat
if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)):
raise ValueError('reject and flat, if strings, must be "existing"')
self._reject_setup(reject, flat)
self._reject_setup(reject, flat, allow_callable=True)
self._get_data(out=False, verbose=verbose)
return self

Expand Down Expand Up @@ -1520,8 +1535,9 @@ def drop(self, indices, reason="USER", verbose=None):
Set epochs to remove by specifying indices to remove or a boolean
mask to apply (where True values get removed). Events are
correspondingly modified.
reason : str
Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc).
reason : list | tuple | str
Reason(s) for dropping the epochs ('ECG', 'timeout', 'blink' etc).
Reason(s) are applied to all indices specified.
Default: 'USER'.
%(verbose)s
Expand All @@ -1533,7 +1549,9 @@ def drop(self, indices, reason="USER", verbose=None):
indices = np.atleast_1d(indices)

if indices.ndim > 1:
raise ValueError("indices must be a scalar or a 1-d array")
raise TypeError("indices must be a scalar or a 1-d array")
# Check if indices and reasons are of the same length
# if using collection to drop epochs

if indices.dtype == np.dtype(bool):
indices = np.where(indices)[0]
Expand Down Expand Up @@ -3199,6 +3217,10 @@ class Epochs(BaseEpochs):
See :meth:`~mne.Epochs.equalize_event_counts`
- 'USER'
For user-defined reasons (see :meth:`~mne.Epochs.drop`).
When dropping based on flat or reject parameters the tuple of
reasons contains a tuple of channels that satisfied the rejection
criteria.
filename : str
The filename of the object.
times : ndarray
Expand Down Expand Up @@ -3667,37 +3689,69 @@ def _is_good(
):
"""Test if data segment e is good according to reject and flat.
The reject and flat parameters can accept functions as values.
If full_report=True, it will give True/False as well as a list of all
offending channels.
"""
bad_tuple = tuple()
has_printed = False
checkable = np.ones(len(ch_names), dtype=bool)
checkable[np.array([c in ignore_chs for c in ch_names], dtype=bool)] = False

for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]):
if refl is not None:
for key, thresh in refl.items():
for key, refl in refl.items():
criterion = refl
idx = channel_type_idx[key]
name = key.upper()
if len(idx) > 0:
e_idx = e[idx]
deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
checkable_idx = checkable[idx]
idx_deltas = np.where(
np.logical_and(f(deltas, thresh), checkable_idx)
)[0]
# Check if criterion is a function and apply it
if callable(criterion):
result = criterion(e_idx)
_validate_type(result, tuple, "reject/flat output")
if len(result) != 2:
raise TypeError(
"Function criterion must return a tuple of length 2"
)
cri_truth, reasons = result
_validate_type(cri_truth, (bool, np.bool_), cri_truth, "bool")
_validate_type(
reasons, (str, list, tuple), reasons, "str, list, or tuple"
)
idx_deltas = np.where(np.logical_and(cri_truth, checkable_idx))[
0
]
else:
deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
idx_deltas = np.where(
np.logical_and(f(deltas, criterion), checkable_idx)
)[0]

if len(idx_deltas) > 0:
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if not has_printed:
logger.info(
f" Rejecting {t} epoch based on {name} : {bad_names}"
)
has_printed = True
if not full_report:
return False
# Check to verify that refl is a callable that returns
# (bool, reason). Reason must be a str/list/tuple.
# If using tuple
if callable(refl):
if isinstance(reasons, str):
reasons = (reasons,)
for idx, reason in enumerate(reasons):
_validate_type(reason, str, reason)
bad_tuple += tuple(reasons)
else:
bad_tuple += tuple(bad_names)
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if not has_printed:
logger.info(
" Rejecting %s epoch based on %s : "
"%s" % (t, name, bad_names)
)
has_printed = True
if not full_report:
return False
else:
bad_tuple += tuple(bad_names)

if not full_report:
return True
Expand Down
165 changes: 161 additions & 4 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,20 @@ def test_reject():
preload=False,
reject=dict(eeg=np.inf),
)
for val in (None, -1): # protect against older MNE-C types

# Good function
def my_reject_1(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35)
reasons = "a" * len(bad_idxs[0])
return len(bad_idxs) > 0, reasons

# Bad function
def my_reject_2(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35)
reasons = "a" * len(bad_idxs[0])
return len(bad_idxs), reasons

for val in (-1, -2): # protect against older MNE-C types
for kwarg in ("reject", "flat"):
pytest.raises(
ValueError,
Expand All @@ -564,6 +577,44 @@ def test_reject():
preload=False,
**{kwarg: dict(grad=val)},
)

# Check that reject and flat in constructor are not callables
val = my_reject_1
for kwarg in ("reject", "flat"):
with pytest.raises(
TypeError,
match=r".* must be an instance of numeric, got <class 'function'> instead.",
):
Epochs(
raw,
events,
event_id,
tmin,
tmax,
picks=picks_meg,
preload=False,
**{kwarg: dict(grad=val)},
)

# Check if callable returns a tuple with reasons
bad_types = [my_reject_2, ("Hi" "Hi"), (1, 1), None]
for val in bad_types: # protect against bad types
for kwarg in ("reject", "flat"):
with pytest.raises(
TypeError,
match=r".* must be an instance of .* got <class '.*'> instead.",
):
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
picks=picks_meg,
preload=True,
)
epochs.drop_bad(**{kwarg: dict(grad=val)})

pytest.raises(
KeyError,
Epochs,
Expand Down Expand Up @@ -2149,6 +2200,93 @@ def test_reject_epochs(tmp_path):
assert epochs_cleaned.flat == dict(grad=new_flat["grad"], mag=flat["mag"])


@testing.requires_testing_data
def test_callable_reject():
"""Test using a callable for rejection."""
raw = read_raw_fif(fname_raw_testing, preload=True)
raw.crop(0, 5)
raw.del_proj()
chans = raw.info["ch_names"][-6:-1]
raw.pick(chans)
data = raw.get_data()

# Add some artifacts
new_data = data
new_data[0, 180:200] *= 1e7
new_data[0, 610:880] += 1e-3
edit_raw = mne.io.RawArray(new_data, raw.info)

events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0)
epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True)
assert len(epochs) == 5

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(
reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median"))
)

assert epochs.drop_log[2] == ("eeg median",)

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(
reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",)))
)

assert epochs.drop_log[0] == ("eeg max",)

def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
return (max_condition.any() or median_condition.any()), "eeg max or median"

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(reject=dict(eeg=reject_criteria))

assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == (
"eeg max or median",
)

# Test reasons must be str or tuple of str
with pytest.raises(
TypeError,
match=r".* must be an instance of str, got <class 'int'> instead.",
):
epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(
reject=dict(
eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2))
)
)


def test_preload_epochs():
"""Test preload of epochs."""
raw, events, picks = _get_data()
Expand Down Expand Up @@ -3180,9 +3318,16 @@ def test_drop_epochs():
events1 = events[events[:, 2] == event_id]

# Bound checks
pytest.raises(IndexError, epochs.drop, [len(epochs.events)])
pytest.raises(IndexError, epochs.drop, [-len(epochs.events) - 1])
pytest.raises(ValueError, epochs.drop, [[1, 2], [3, 4]])
with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"):
epochs.drop([len(epochs.events)])
with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"):
epochs.drop([-len(epochs.events) - 1])
with pytest.raises(TypeError, match="indices must be a scalar or a 1-d array"):
epochs.drop([[1, 2], [3, 4]])
with pytest.raises(
TypeError, match=r".* must be an instance of .* got <class '.*'> instead."
):
epochs.drop([1], reason=("a", "b", 2))

# Test selection attribute
assert_array_equal(epochs.selection, np.where(events[:, 2] == event_id)[0])
Expand All @@ -3202,6 +3347,18 @@ def test_drop_epochs():
assert_array_equal(events[epochs[3:].selection], events1[[5, 6]])
assert_array_equal(events[epochs["1"].selection], events1[[0, 1, 3, 5, 6]])

# Test using tuple to drop epochs
raw, events, picks = _get_data()
epochs_tuple = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True)
selection_tuple = epochs_tuple.selection.copy()
epochs_tuple.drop((2, 3, 4), reason=("a", "b"))
n_events = len(epochs.events)
assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [
("a", "b"),
("a", "b"),
("a", "b"),
]


@pytest.mark.parametrize("preload", (True, False))
def test_drop_epochs_mult(preload):
Expand Down
Loading

0 comments on commit d64694e

Please sign in to comment.