From 2e065838b82c0c28a39c5186f6ba6f3872a9fff1 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Fri, 7 Jun 2024 18:38:37 +0200 Subject: [PATCH 1/2] Add equalization of epochs randomly (#12649) --- doc/changes/devel/12649.newfeature.rst | 1 + mne/epochs.py | 50 ++++++++++++++------------ mne/tests/test_epochs.py | 15 ++++++++ mne/utils/docs.py | 10 ++++++ 4 files changed, 54 insertions(+), 22 deletions(-) create mode 100644 doc/changes/devel/12649.newfeature.rst diff --git a/doc/changes/devel/12649.newfeature.rst b/doc/changes/devel/12649.newfeature.rst new file mode 100644 index 00000000000..33908f8dc89 --- /dev/null +++ b/doc/changes/devel/12649.newfeature.rst @@ -0,0 +1 @@ +Adding argument ``'random'`` to :func:`~mne.epochs.equalize_epoch_counts` and to :meth:`~mne.Epochs.equalize_event_counts` to randomly select epochs or events. By `Mathieu Scheltienne`_. diff --git a/mne/epochs.py b/mne/epochs.py index c8ab7f5a440..cb31491f151 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -2341,7 +2341,10 @@ def export(self, fname, fmt="auto", *, overwrite=False, verbose=None): export_epochs(fname, self, fmt, overwrite=overwrite, verbose=verbose) - def equalize_event_counts(self, event_ids=None, method="mintime"): + @fill_doc + def equalize_event_counts( + self, event_ids=None, method="mintime", *, random_state=None + ): """Equalize the number of trials in each condition. It tries to make the remaining epochs occurring as close as possible in @@ -2381,10 +2384,8 @@ def equalize_event_counts(self, event_ids=None, method="mintime"): matched by the provided tags had been supplied instead. The ``event_ids`` must identify non-overlapping subsets of the epochs. - method : str - If ``'truncate'``, events will be truncated from the end of each - type of events. If ``'mintime'``, timing differences between each - event type will be minimized. + %(equalize_events_method)s + %(random_state)s Used only if ``method='random'``. Returns ------- @@ -2486,7 +2487,7 @@ def equalize_event_counts(self, event_ids=None, method="mintime"): eq_inds.append(self._keys_to_idx(eq)) sample_nums = [self.events[e, 0] for e in eq_inds] - indices = _get_drop_indices(sample_nums, method) + indices = _get_drop_indices(sample_nums, method, random_state) # need to re-index indices indices = np.concatenate([e[idx] for e, idx in zip(eq_inds, indices)]) self.drop(indices, reason="EQUALIZED_COUNT") @@ -3830,23 +3831,22 @@ def combine_event_ids(epochs, old_event_ids, new_event_id, copy=True): return epochs -def equalize_epoch_counts(epochs_list, method="mintime"): +@fill_doc +def equalize_epoch_counts(epochs_list, method="mintime", *, random_state=None): """Equalize the number of trials in multiple Epochs or EpochsTFR instances. Parameters ---------- epochs_list : list of Epochs instances The Epochs instances to equalize trial counts for. - method : str - If 'truncate', events will be truncated from the end of each event - list. If 'mintime', timing differences between each event list will be - minimized. + %(equalize_events_method)s + %(random_state)s Used only if ``method='random'``. Notes ----- - This tries to make the remaining epochs occurring as close as possible in - time. This method works based on the idea that if there happened to be some - time-varying (like on the scale of minutes) noise characteristics during + The method ``'mintime'`` tries to make the remaining epochs occurring as close as + possible in time. This method is motivated by the possibility that if there happened + to be some time-varying (like on the scale of minutes) noise characteristics during a recording, they could be compensated for (to some extent) in the equalization process. This method thus seeks to reduce any of those effects by minimizing the differences in the times of the events in the two sets of @@ -3860,29 +3860,35 @@ def equalize_epoch_counts(epochs_list, method="mintime"): """ if not all(isinstance(epoch, (BaseEpochs, EpochsTFR)) for epoch in epochs_list): raise ValueError("All inputs must be Epochs instances") - # make sure bad epochs are dropped for epoch in epochs_list: if not epoch._bad_dropped: epoch.drop_bad() sample_nums = [epoch.events[:, 0] for epoch in epochs_list] - indices = _get_drop_indices(sample_nums, method) + indices = _get_drop_indices(sample_nums, method, random_state) for epoch, inds in zip(epochs_list, indices): epoch.drop(inds, reason="EQUALIZED_COUNT") -def _get_drop_indices(sample_nums, method): +def _get_drop_indices(sample_nums, method, random_state): """Get indices to drop from multiple event timing lists.""" - small_idx = np.argmin([e.shape[0] for e in sample_nums]) + small_idx = np.argmin([e.size for e in sample_nums]) small_epoch_indices = sample_nums[small_idx] - _check_option("method", method, ["mintime", "truncate"]) + _check_option("method", method, ["mintime", "truncate", "random"]) indices = list() for event in sample_nums: if method == "mintime": mask = _minimize_time_diff(small_epoch_indices, event) - else: - mask = np.ones(event.shape[0], dtype=bool) - mask[small_epoch_indices.shape[0] :] = False + elif method == "truncate": + mask = np.ones(event.size, dtype=bool) + mask[small_epoch_indices.size :] = False + elif method == "random": + rng = check_random_state(random_state) + mask = np.zeros(event.size, dtype=bool) + idx = rng.choice( + np.arange(event.size), size=small_epoch_indices.size, replace=False + ) + mask[idx] = True indices.append(np.where(np.logical_not(mask))[0]) return indices diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 85076c0ee6d..bba0955d40e 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -3001,6 +3001,21 @@ def test_epoch_eq(): epochs.equalize_event_counts(1.5) +def test_equalize_epoch_counts_random(): + """Test random equalization of epochs.""" + raw, events, picks = _get_data() + # create epochs with unequal counts + events_1 = events[events[:, 2] == event_id] + epochs_1 = Epochs(raw, events_1, event_id, tmin, tmax, picks=picks) + events_2 = events[events[:, 2] == event_id_2] + epochs_2 = Epochs(raw, events_2, event_id_2, tmin, tmax, picks=picks) + epochs_1.drop_bad() + epochs_2.drop_bad() + assert len(epochs_1) != len(epochs_2) + equalize_epoch_counts([epochs_1, epochs_2], method="random") + assert len(epochs_1) == len(epochs_2) + + def test_access_by_name(tmp_path): """Test accessing epochs by event name and on_missing for rare events.""" raw, events, picks = _get_data() diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 9aabafb90b6..492cd97a052 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1289,6 +1289,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): time are included. Defaults to ``-0.2`` and ``0.5``, respectively. """ +docdict["equalize_events_method"] = """ +method : ``'truncate'`` | ``'mintime'`` | ``'random'`` + If ``'truncate'``, events will be truncated from the end of each event + list. If ``'mintime'``, timing differences between each event list will be + minimized. If ``'random'``, events will be randomly selected from each event + list. + + .. versionadded:: 1.8 +""" + docdict["estimate_plot_psd"] = """\ estimate : str, {'power', 'amplitude'} Can be "power" for power spectral density (PSD; default), "amplitude" for From 6fb9aaef6380033ad2fdeeaf38f0724fdb1aff70 Mon Sep 17 00:00:00 2001 From: Xabier de Zuazo Date: Fri, 7 Jun 2024 23:19:23 +0200 Subject: [PATCH 2/2] Enhance documentation on decimation filtering to prevent aliasing (#12650) --- doc/changes/devel/12650.other.rst | 1 + doc/changes/names.inc | 2 ++ doc/help/faq.rst | 8 +++++--- mne/utils/docs.py | 1 + 4 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 doc/changes/devel/12650.other.rst diff --git a/doc/changes/devel/12650.other.rst b/doc/changes/devel/12650.other.rst new file mode 100644 index 00000000000..b97a204dff4 --- /dev/null +++ b/doc/changes/devel/12650.other.rst @@ -0,0 +1 @@ +Enhance documentation on decimation filtering to prevent aliasing, by :newcontrib:`Xabier de Zuazo`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index cdb9a62b855..ebd6775ad29 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -602,6 +602,8 @@ .. _Victoria Peterson: https://github.com/vpeterson +.. _Xabier de Zuazo: https://github.com/zuazo + .. _Xiaokai Xia: https://github.com/dddd1007 .. _Will Turner: https://bootstrapbill.github.io diff --git a/doc/help/faq.rst b/doc/help/faq.rst index 14d85f4e038..7720885d643 100644 --- a/doc/help/faq.rst +++ b/doc/help/faq.rst @@ -234,9 +234,11 @@ of data. We'll discuss some major ones here, with some of their implications: - :func:`mne.Epochs.decimate`, which does the same thing as the ``decim`` parameter in the :class:`mne.Epochs` constructor, sub-selects every - :math:`N^{th}` sample before and after each event. This should only be - used when the raw data have been sufficiently low-passed e.g. by - :func:`mne.io.Raw.filter` to avoid aliasing artifacts. + :math:`N^{th}` sample before and after each event. To avoid aliasing + artifacts, the raw data should be sufficiently low-passed before decimation. + It is recommended to use :func:`mne.io.Raw.filter` with ``h_freq`` set to + half the new sampling rate (fs/2N) or lower, as per the Nyquist criterion, to + ensure effective attenuation of frequency content above this threshold. - :func:`mne.Epochs.resample`, :func:`mne.Evoked.resample`, and :func:`mne.SourceEstimate.resample` all resample data. diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 492cd97a052..7de7eb2dd69 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1058,6 +1058,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``decim``), i.e., it compresses the signal (see Notes). If the data are not properly filtered, aliasing artifacts may occur. + See :ref:`resampling-and-decimating` for more information. """ docdict["decim_notes"] = """