Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add ability to reject epochs using callables #12195

Merged
merged 54 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
81ef83e
Add ability to reject epochs using functions
withmywoessner Nov 11, 2023
d8dda07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 11, 2023
e9294de
Merge branch 'main' into epoch_reject
withmywoessner Nov 13, 2023
06e6770
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Nov 16, 2023
867496d
Update docs
withmywoessner Nov 19, 2023
1b4f5b3
Add ability to reject based on callables
withmywoessner Nov 19, 2023
2a66049
Add tutorial
withmywoessner Nov 19, 2023
e16f4a1
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Nov 19, 2023
e708465
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2023
fbdec77
Make flake8 compliant
withmywoessner Nov 20, 2023
6e23ecc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2023
cdd2843
Add docstrings and make flake8 compliant
withmywoessner Nov 20, 2023
cd9f7b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2023
3f5fd84
Update mne/epochs.py
withmywoessner Dec 1, 2023
a74ccf6
Update tutorials/preprocessing/20_rejecting_bad_data.py
withmywoessner Dec 1, 2023
f0cb1b8
Update tutorials/preprocessing/20_rejecting_bad_data.py
withmywoessner Dec 1, 2023
7d02fca
Update mne/utils/docs.py
withmywoessner Dec 1, 2023
1b836ad
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Dec 4, 2023
f3e8841
Make callable check more fine, doc, add noqa
withmywoessner Dec 6, 2023
984c604
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Dec 6, 2023
fbe4cd2
Update epochs so that adding refl tuple doesnt cause error
withmywoessner Jan 5, 2024
ee599a7
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 5, 2024
24e669c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2024
bce6486
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 9, 2024
8401c92
return callable/reasons
withmywoessner Jan 9, 2024
cf1facf
allow callables
withmywoessner Jan 9, 2024
e98bee2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
a579491
Delete mne/_version.py
withmywoessner Jan 9, 2024
c9da4db
Add None Check
withmywoessner Jan 9, 2024
5a7a618
Update mne/tests/test_epochs.py
withmywoessner Jan 10, 2024
98b92c4
Update mne/utils/mixin.py
withmywoessner Jan 10, 2024
44fdf8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
3ff8a9e
Update mne/epochs.py
withmywoessner Jan 10, 2024
e729e81
Update mne/tests/test_epochs.py
withmywoessner Jan 10, 2024
fd4c75f
Update mne/epochs.py
withmywoessner Jan 10, 2024
a685b89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
65778d1
Update mne/epochs.py
withmywoessner Jan 10, 2024
3ece314
Apply suggestions from code review
withmywoessner Jan 10, 2024
db0cf11
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 10, 2024
b2686c0
Apply reason to all dropped epochs
withmywoessner Jan 16, 2024
3ae37a4
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 16, 2024
45b30f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
e77ae63
Add check
withmywoessner Jan 16, 2024
f560ccd
Merge branch 'main' into epoch_reject
withmywoessner Jan 16, 2024
235f6d3
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 23, 2024
4e4b369
Apply suggestions from code review
withmywoessner Jan 23, 2024
b7c6a36
Add suggestions
withmywoessner Jan 23, 2024
ba45b5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2024
b5829b0
devel
withmywoessner Jan 23, 2024
a11a410
Remove support for callabes in constructor
withmywoessner Jan 24, 2024
b918ea8
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 24, 2024
ab009bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
8cf2e49
Merge branch 'main' into epoch_reject
withmywoessner Feb 1, 2024
fb6a5b1
Apply suggestions from code review
larsoner Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 57 additions & 27 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,14 @@ def _reject_setup(self, reject, flat):
# check for invalid values
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
if val is None or val < 0:
if callable(val):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would simplify to:

if callable(val):
    continue
else:
    name = f"{kind} dict value for {key}"
    _validate_type(val, "numeric", name, extra="or callable")
    if val < 0:
         raise ValueError("{kind} {name} must be >= 0, got {val}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to do it this way, but it breaks a test. I think its because _validate_type returns a TypeError and Pytest wants a ValueError:

   for val in (None, -1):  # protect against older MNE-C types
        for kwarg in ("reject", "flat"):
            pytest.raises(
                ValueError,
                Epochs,
                raw,
                events,
                event_id,
                tmin,
                tmax,
                picks=picks_meg,
                preload=False,
                **{kwarg: dict(grad=val)},
            )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just update the test, TypeError is a better error in the case where the type is wrong (i.e., _validate_type does something better than what we do in main)

continue
elif val is not None and val >= 0:
continue
else:
raise ValueError(
'%s value must be a number >= 0, not "%s"' % (kind, val)
"%s value must be a number >= 0 or a valid function,"
'not "%s"' % (kind, val)
)

# now check to see if our rejection and flat are getting more
larsoner marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -834,33 +839,48 @@ def _reject_setup(self, reject, flat):
"previous ones"
)

# Skip this check if old_reject, reject, old_flat, and flat are
# callables
is_callable = False
for rej in (reject, flat, old_reject, old_flat):
for key, val in rej.items():
if callable(val):
is_callable = True

# copy thresholds for channel types that were used previously, but not
# passed this time
for key in set(old_reject) - set(reject):
reject[key] = old_reject[key]
# make sure new thresholds are at least as stringent as the old ones
for key in reject:
if key in old_reject and reject[key] > old_reject[key]:
raise ValueError(
bad_msg.format(
kind="reject",
key=key,
new=reject[key],
old=old_reject[key],
op=">",

if not is_callable:
# make sure new thresholds are at least as stringent
# as the old ones
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... but I think you can make these checks more fine-grained without much more work. Instead of if any are callable, skip all checks (which is what you're doing now), you should within this loop be able to fairly easily make it so that if this type is or was callable, skip this check. I think it almost amounts to moving the conditional from above into this loop with a continue when callables are or were present. As a bonus, your diff will probably be smaller, too!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I edit my test to check for this?

for key in reject:
if key in old_reject and reject[key] > old_reject[key]:
raise ValueError(
bad_msg.format(
kind="reject",
key=key,
new=reject[key],
old=old_reject[key],
op=">",
)
)
)

# same for flat thresholds
for key in set(old_flat) - set(flat):
flat[key] = old_flat[key]
for key in flat:
if key in old_flat and flat[key] < old_flat[key]:
raise ValueError(
bad_msg.format(
kind="flat", key=key, new=flat[key], old=old_flat[key], op="<"
# same for flat thresholds
for key in set(old_flat) - set(flat):
flat[key] = old_flat[key]
for key in flat:
if key in old_flat and flat[key] < old_flat[key]:
raise ValueError(
bad_msg.format(
kind="flat",
key=key,
new=flat[key],
old=old_flat[key],
op="<",
)
)
)

# after validation, set parameters
self._bad_dropped = False
Expand Down Expand Up @@ -3621,25 +3641,35 @@ def _is_good(
):
"""Test if data segment e is good according to reject and flat.

The reject and flat dictionaries can accept functions as values.
withmywoessner marked this conversation as resolved.
Show resolved Hide resolved

If full_report=True, it will give True/False as well as a list of all
offending channels.
"""
bad_tuple = tuple()
has_printed = False
checkable = np.ones(len(ch_names), dtype=bool)
checkable[np.array([c in ignore_chs for c in ch_names], dtype=bool)] = False

for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]):
if refl is not None:
for key, thresh in refl.items():
for key, criterion in refl.items():
idx = channel_type_idx[key]
name = key.upper()
if len(idx) > 0:
e_idx = e[idx]
deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
checkable_idx = checkable[idx]
idx_deltas = np.where(
np.logical_and(f(deltas, thresh), checkable_idx)
)[0]

# Check if criterion is a function and apply it
if callable(criterion):
idx_deltas = np.where(
np.logical_and(criterion(e_idx), checkable_idx)
)[0]
else:
deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
idx_deltas = np.where(
np.logical_and(f(deltas, criterion), checkable_idx)
)[0]

if len(idx_deltas) > 0:
bad_names = [ch_names[idx[i]] for i in idx_deltas]
Expand Down
61 changes: 61 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,67 @@ def test_reject_epochs(tmp_path):
assert epochs_cleaned.flat == dict(grad=new_flat["grad"], mag=flat["mag"])


@testing.requires_testing_data
def test_callable_reject():
"""Test using a callable for rejection."""
raw = read_raw_fif(fname_raw_testing, preload=True)
raw.crop(0, 5)
raw.del_proj()
chans = raw.info["ch_names"][-6:-1]
raw.pick(chans)
data = raw.get_data()

# Add some artifacts
new_data = data
new_data[0, 180:200] *= 1e7
new_data[0, 610:880] += 1e-3
edit_raw = mne.io.RawArray(new_data, raw.info)

events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0)
epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True)

assert len(epochs) == 5
epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
reject=dict(
eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False
),
preload=True,
)
assert epochs.drop_log[2] != ()

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1).any() else False),
preload=True,
)
assert epochs.drop_log[0] != ()

def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
return True if max_condition.any() or median_condition.any() else False

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=reject_criteria),
preload=True,
)
assert epochs.drop_log[0] != () and epochs.drop_log[2] != ()
withmywoessner marked this conversation as resolved.
Show resolved Hide resolved


def test_preload_epochs():
"""Test preload of epochs."""
raw, events, picks = _get_data()
Expand Down
37 changes: 26 additions & 11 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,11 +1702,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

_flat_common = """\
Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP).
Valid **keys** can be any channel type present in the object. The
**values** are floats that set the minimum acceptable PTP. If the PTP
is smaller than this threshold, the epoch will be dropped. If ``None``
then no rejection is performed based on flatness of the signal."""
Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP)
or a custom function. Valid **keys** can be any channel type present
in the object. If using PTP, **values** are floats that set the minimum
acceptable PTP. If the PTP is smaller than this threshold, the epoch
will be dropped. If ``None`` then no rejection is performed based on
flatness of the signal. If a custom function is used than ``flat`` can be
used to reject epochs based on any criteria (including maxima and
minima)."""

docdict[
"flat"
Expand Down Expand Up @@ -3794,8 +3797,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
)

_reject_common = """\
Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP),
i.e. the absolute difference between the lowest and the highest signal
Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP)
or custom functions. Peak-to-peak signal amplitude is defined as
the absolute difference between the lowest and the highest signal
value. In each individual epoch, the PTP is calculated for every channel.
If the PTP of any one channel exceeds the rejection threshold, the
respective epoch will be dropped.
Expand All @@ -3811,10 +3815,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
eog=250e-6 # unit: V (EOG channels)
)

.. note:: Since rejection is based on a signal **difference**
calculated for each channel separately, applying baseline
correction does not affect the rejection procedure, as the
difference will be preserved.
Custom rejection criteria can be also be used by passing a callable
to the dictionary.

Example::

reject = dict(eeg=lambda x: True if (np.max(x, axis=1) >
1e-3).any() else False))
withmywoessner marked this conversation as resolved.
Show resolved Hide resolved

.. note:: If rejection is based on a signal **difference**
calculated for each channel separately, applying baseline
correction does not affect the rejection procedure, as the
difference will be preserved.

.. note:: If ``reject`` is a callable, than **any** criteria can be
used to reject epochs (including maxima and minima).
"""

docdict[
Expand Down
99 changes: 97 additions & 2 deletions tutorials/preprocessing/20_rejecting_bad_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import os

import numpy as np

import mne

sample_data_folder = mne.datasets.sample.data_path()
Expand Down Expand Up @@ -205,8 +207,8 @@
# %%
# .. _`tut-reject-epochs-section`:
#
# Rejecting Epochs based on channel amplitude
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Rejecting Epochs based on peak-to-peak channel amplitude
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Besides "bad" annotations, the :class:`mne.Epochs` class constructor has
# another means of rejecting epochs, based on signal amplitude thresholds for
Expand Down Expand Up @@ -328,6 +330,99 @@
epochs.drop_bad(reject=stronger_reject_criteria)
print(epochs.drop_log)

# %%
# .. _`tut-reject-epochs-func-section`:
#
# Rejecting Epochs using callables (functions)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Sometimes it is useful to reject epochs based criteria other than
# peak-to-peak amplitudes. For example, we might want to reject epochs
# based on the maximum or minimum amplitude of a channel.
# In this case, the :class:`mne.Epochs` class constructor also accepts
# callables (functions) in the ``reject`` and ``flat`` parameters. This
# allows us to define functions to reject epochs based on our desired criteria.
#
# Let's begin by generating Epoch data with large artifacts in one eeg channel
# in order to demonstrate the versatility of this approach.

raw.crop(0, 5)
raw.del_proj()
chans = raw.info["ch_names"][-5:-1]
raw.pick(chans)
data = raw.get_data()

new_data = data
new_data[0, 180:200] *= 1e3
new_data[0, 460:580] += 1e-3
edit_raw = mne.io.RawArray(new_data, raw.info)

# Create fixed length epochs of 1 second
events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0)
epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None)
epochs.plot(scalings=dict(eeg=50e-5))

# %%
# As you can see, we have two large artifacts in the first channel. One large
# spike in amplitude and one large increase in amplitude.

# Let's try to reject the epoch containing the spike in amplitude based on the
# maximum amplitude of the first channel.

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1e-2).any() else False),
withmywoessner marked this conversation as resolved.
Show resolved Hide resolved
preload=True,
)
epochs.plot(scalings=dict(eeg=50e-5))

# %%
# Here, the epoch containing the spike in amplitude was rejected for having a
# maximum amplitude greater than 1e-2 Volts. Notice the use of the ``any()``
# function to check if any of the channels exceeded the threshold. We could
# have also used the ``all()`` function to check if all channels exceeded the
# threshold.

# Next, let's try to reject the epoch containing the increase in amplitude
# using the median.

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: True if (np.median(x, axis=1) > 1e-4).any() else False),
withmywoessner marked this conversation as resolved.
Show resolved Hide resolved
preload=True,
)
epochs.plot(scalings=dict(eeg=50e-5))

# %%
# Finally, let's try to reject both epochs using a combination of the maximum
# and median. We'll define a custom function and use boolean operators to
# combine the two criteria.


def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
return True if max_condition.any() or median_condition.any() else False


epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=reject_criteria),
preload=True,
)
epochs.plot(events=True)

# %%
# Note that a complementary Python module, the `autoreject package`_, uses
# machine learning to find optimal rejection criteria, and is designed to
Expand Down
Loading