Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/mne-tools/mne-python into d…
Browse files Browse the repository at this point in the history
…evcontainer
  • Loading branch information
hoechenberger committed Jun 8, 2024
2 parents d317c0f + 6fb9aae commit 9043f31
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 25 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12649.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adding argument ``'random'`` to :func:`~mne.epochs.equalize_epoch_counts` and to :meth:`~mne.Epochs.equalize_event_counts` to randomly select epochs or events. By `Mathieu Scheltienne`_.
1 change: 1 addition & 0 deletions doc/changes/devel/12650.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enhance documentation on decimation filtering to prevent aliasing, by :newcontrib:`Xabier de Zuazo`.
2 changes: 2 additions & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@

.. _Victoria Peterson: https://github.com/vpeterson

.. _Xabier de Zuazo: https://github.com/zuazo

.. _Xiaokai Xia: https://github.com/dddd1007

.. _Will Turner: https://bootstrapbill.github.io
Expand Down
8 changes: 5 additions & 3 deletions doc/help/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,11 @@ of data. We'll discuss some major ones here, with some of their implications:

- :func:`mne.Epochs.decimate`, which does the same thing as the
``decim`` parameter in the :class:`mne.Epochs` constructor, sub-selects every
:math:`N^{th}` sample before and after each event. This should only be
used when the raw data have been sufficiently low-passed e.g. by
:func:`mne.io.Raw.filter` to avoid aliasing artifacts.
:math:`N^{th}` sample before and after each event. To avoid aliasing
artifacts, the raw data should be sufficiently low-passed before decimation.
It is recommended to use :func:`mne.io.Raw.filter` with ``h_freq`` set to
half the new sampling rate (fs/2N) or lower, as per the Nyquist criterion, to
ensure effective attenuation of frequency content above this threshold.

- :func:`mne.Epochs.resample`, :func:`mne.Evoked.resample`, and
:func:`mne.SourceEstimate.resample` all resample data.
Expand Down
50 changes: 28 additions & 22 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,7 +2341,10 @@ def export(self, fname, fmt="auto", *, overwrite=False, verbose=None):

export_epochs(fname, self, fmt, overwrite=overwrite, verbose=verbose)

def equalize_event_counts(self, event_ids=None, method="mintime"):
@fill_doc
def equalize_event_counts(
self, event_ids=None, method="mintime", *, random_state=None
):
"""Equalize the number of trials in each condition.
It tries to make the remaining epochs occurring as close as possible in
Expand Down Expand Up @@ -2381,10 +2384,8 @@ def equalize_event_counts(self, event_ids=None, method="mintime"):
matched by the provided tags had been supplied instead.
The ``event_ids`` must identify non-overlapping subsets of the
epochs.
method : str
If ``'truncate'``, events will be truncated from the end of each
type of events. If ``'mintime'``, timing differences between each
event type will be minimized.
%(equalize_events_method)s
%(random_state)s Used only if ``method='random'``.
Returns
-------
Expand Down Expand Up @@ -2486,7 +2487,7 @@ def equalize_event_counts(self, event_ids=None, method="mintime"):
eq_inds.append(self._keys_to_idx(eq))

sample_nums = [self.events[e, 0] for e in eq_inds]
indices = _get_drop_indices(sample_nums, method)
indices = _get_drop_indices(sample_nums, method, random_state)
# need to re-index indices
indices = np.concatenate([e[idx] for e, idx in zip(eq_inds, indices)])
self.drop(indices, reason="EQUALIZED_COUNT")
Expand Down Expand Up @@ -3830,23 +3831,22 @@ def combine_event_ids(epochs, old_event_ids, new_event_id, copy=True):
return epochs


def equalize_epoch_counts(epochs_list, method="mintime"):
@fill_doc
def equalize_epoch_counts(epochs_list, method="mintime", *, random_state=None):
"""Equalize the number of trials in multiple Epochs or EpochsTFR instances.
Parameters
----------
epochs_list : list of Epochs instances
The Epochs instances to equalize trial counts for.
method : str
If 'truncate', events will be truncated from the end of each event
list. If 'mintime', timing differences between each event list will be
minimized.
%(equalize_events_method)s
%(random_state)s Used only if ``method='random'``.
Notes
-----
This tries to make the remaining epochs occurring as close as possible in
time. This method works based on the idea that if there happened to be some
time-varying (like on the scale of minutes) noise characteristics during
The method ``'mintime'`` tries to make the remaining epochs occurring as close as
possible in time. This method is motivated by the possibility that if there happened
to be some time-varying (like on the scale of minutes) noise characteristics during
a recording, they could be compensated for (to some extent) in the
equalization process. This method thus seeks to reduce any of those effects
by minimizing the differences in the times of the events in the two sets of
Expand All @@ -3860,29 +3860,35 @@ def equalize_epoch_counts(epochs_list, method="mintime"):
"""
if not all(isinstance(epoch, (BaseEpochs, EpochsTFR)) for epoch in epochs_list):
raise ValueError("All inputs must be Epochs instances")

# make sure bad epochs are dropped
for epoch in epochs_list:
if not epoch._bad_dropped:
epoch.drop_bad()
sample_nums = [epoch.events[:, 0] for epoch in epochs_list]
indices = _get_drop_indices(sample_nums, method)
indices = _get_drop_indices(sample_nums, method, random_state)
for epoch, inds in zip(epochs_list, indices):
epoch.drop(inds, reason="EQUALIZED_COUNT")


def _get_drop_indices(sample_nums, method):
def _get_drop_indices(sample_nums, method, random_state):
"""Get indices to drop from multiple event timing lists."""
small_idx = np.argmin([e.shape[0] for e in sample_nums])
small_idx = np.argmin([e.size for e in sample_nums])
small_epoch_indices = sample_nums[small_idx]
_check_option("method", method, ["mintime", "truncate"])
_check_option("method", method, ["mintime", "truncate", "random"])
indices = list()
for event in sample_nums:
if method == "mintime":
mask = _minimize_time_diff(small_epoch_indices, event)
else:
mask = np.ones(event.shape[0], dtype=bool)
mask[small_epoch_indices.shape[0] :] = False
elif method == "truncate":
mask = np.ones(event.size, dtype=bool)
mask[small_epoch_indices.size :] = False
elif method == "random":
rng = check_random_state(random_state)
mask = np.zeros(event.size, dtype=bool)
idx = rng.choice(
np.arange(event.size), size=small_epoch_indices.size, replace=False
)
mask[idx] = True
indices.append(np.where(np.logical_not(mask))[0])
return indices

Expand Down
15 changes: 15 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3001,6 +3001,21 @@ def test_epoch_eq():
epochs.equalize_event_counts(1.5)


def test_equalize_epoch_counts_random():
"""Test random equalization of epochs."""
raw, events, picks = _get_data()
# create epochs with unequal counts
events_1 = events[events[:, 2] == event_id]
epochs_1 = Epochs(raw, events_1, event_id, tmin, tmax, picks=picks)
events_2 = events[events[:, 2] == event_id_2]
epochs_2 = Epochs(raw, events_2, event_id_2, tmin, tmax, picks=picks)
epochs_1.drop_bad()
epochs_2.drop_bad()
assert len(epochs_1) != len(epochs_2)
equalize_epoch_counts([epochs_1, epochs_2], method="random")
assert len(epochs_1) == len(epochs_2)


def test_access_by_name(tmp_path):
"""Test accessing epochs by event name and on_missing for rare events."""
raw, events, picks = _get_data()
Expand Down
11 changes: 11 additions & 0 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
``decim``), i.e., it compresses the signal (see Notes).
If the data are not properly filtered, aliasing artifacts
may occur.
See :ref:`resampling-and-decimating` for more information.
"""

docdict["decim_notes"] = """
Expand Down Expand Up @@ -1289,6 +1290,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
time are included. Defaults to ``-0.2`` and ``0.5``, respectively.
"""

docdict["equalize_events_method"] = """
method : ``'truncate'`` | ``'mintime'`` | ``'random'``
If ``'truncate'``, events will be truncated from the end of each event
list. If ``'mintime'``, timing differences between each event list will be
minimized. If ``'random'``, events will be randomly selected from each event
list.
.. versionadded:: 1.8
"""

docdict["estimate_plot_psd"] = """\
estimate : str, {'power', 'amplitude'}
Can be "power" for power spectral density (PSD; default), "amplitude" for
Expand Down

0 comments on commit 9043f31

Please sign in to comment.