Skip to content

Commit

Permalink
Fix EpochsTFR.add_channels() (#12616)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrnr authored May 21, 2024
1 parent b12396d commit 4c9a176
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12616.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix adding channels to :class:`~mne.time_frequency.EpochsTFR` objects, by `Clemens Brunner`_.
4 changes: 4 additions & 0 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ def add_channels(self, add_list, force_update_info=False):
# avoid circular imports
from ..epochs import BaseEpochs
from ..io import BaseRaw
from ..time_frequency import EpochsTFR

_validate_type(add_list, (list, tuple), "Input")

Expand All @@ -708,6 +709,9 @@ def add_channels(self, add_list, force_update_info=False):
elif isinstance(self, BaseEpochs):
con_axis = 1
comp_class = BaseEpochs
elif isinstance(self, EpochsTFR):
con_axis = 1
comp_class = EpochsTFR
else:
con_axis = 0
comp_class = type(self)
Expand Down
17 changes: 17 additions & 0 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,23 @@ def test_add_channels():
pytest.raises(ValueError, tfr_meg.add_channels, [tfr_meg])
pytest.raises(TypeError, tfr_meg.add_channels, tfr_badsf)

# Test for EpochsTFR(Array)
tfr1 = EpochsTFRArray(
info=mne.create_info(["EEG 001"], 1000, "eeg"),
data=np.zeros((5, 1, 2, 3)), # epochs, channels, freqs, times
times=[0.1, 0.2, 0.3],
freqs=[0.1, 0.2],
)
tfr2 = EpochsTFRArray(
info=mne.create_info(["EEG 002", "EEG 003"], 1000, "eeg"),
data=np.zeros((5, 2, 2, 3)), # epochs, channels, freqs, times
times=[0.1, 0.2, 0.3],
freqs=[0.1, 0.2],
)
tfr1.add_channels([tfr2])
assert tfr1.ch_names == ["EEG 001", "EEG 002", "EEG 003"]
assert tfr1.data.shape == (5, 3, 2, 3)


def test_compute_tfr():
"""Test _compute_tfr function."""
Expand Down

0 comments on commit 4c9a176

Please sign in to comment.