Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add option to store and return TFR taper weights #12910

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9fe1fb6
Add option to store and return tfr taper weights
tsbinns Oct 22, 2024
45c6a0b
Merge remote-tracking branch 'upstream/main' into add_tfr_weights
tsbinns Oct 22, 2024
82fc2f7
Update docstrings
tsbinns Oct 22, 2024
9f30a59
Merge branch 'main' into add_tfr_weights
tsbinns Oct 22, 2024
a49f934
Remove whitespace
tsbinns Oct 22, 2024
48afced
Merge branch 'add_tfr_weights' of https://github.com/tsbinns/mne-pyth…
tsbinns Oct 22, 2024
7c3dcfa
Add PR num
tsbinns Oct 22, 2024
8c16716
Revert "Update docstrings"
tsbinns Oct 22, 2024
51b8cd0
Remove outdated default setting
tsbinns Oct 22, 2024
b4537b2
Update docstrings
tsbinns Oct 22, 2024
f155238
Merge branch 'main' into add_tfr_weights
tsbinns Oct 24, 2024
2a03e9b
Merge branch 'main' into add_tfr_weights
tsbinns Oct 28, 2024
045d9a2
Merge branch 'main' into add_tfr_weights
tsbinns Oct 29, 2024
8d645bb
Enforce return_weights as named param
tsbinns Oct 29, 2024
5ad9bd5
Merge branch 'main' into add_tfr_weights
tsbinns Dec 9, 2024
1c02b40
Add missing test coverage
tsbinns Dec 9, 2024
54f2a32
Add changelog entry
tsbinns Dec 9, 2024
01c486c
Begin add support for tapers in array objs
tsbinns Dec 9, 2024
ca27179
Fix docstring entries
tsbinns Dec 9, 2024
b14a100
Fix faulty state check
tsbinns Dec 10, 2024
972aba2
Add weights to AverageTFR
tsbinns Dec 10, 2024
e11fa2b
Expand test coverage
tsbinns Dec 10, 2024
aaef4b7
Merge branch 'main' into add_tfr_weights
tsbinns Dec 10, 2024
999d122
Disallow aggregating tapers in combine_tfr
tsbinns Dec 10, 2024
e12b09a
Updated docstrings
tsbinns Dec 10, 2024
dd61955
Merge branch 'main' into add_tfr_weights
tsbinns Dec 10, 2024
728701e
Add placeholder versionadded tags
tsbinns Dec 10, 2024
6af3310
Merge branch 'add_tfr_weights' of https://github.com/tsbinns/mne-pyth…
tsbinns Dec 10, 2024
e3a3c4b
Merge remote-tracking branch 'upstream/main' into add_tfr_weights
tsbinns Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/changes/devel/12910.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Added the option to return taper weights from
:func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the
:class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_.
11 changes: 11 additions & 0 deletions mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def tfr_array_multitaper(
output="complex",
n_jobs=None,
*,
return_weights=False,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using DPSS tapers.
Expand Down Expand Up @@ -502,8 +503,14 @@ def tfr_array_multitaper(
* ``'itc'`` : inter-trial coherence.
* ``'avg_power_itc'`` : average of single trial power and inter-trial
coherence across trials.

%(n_jobs)s
The parallelization is implemented across channels.
return_weights : bool, default False
If True, return the taper weights. Only applies if ``output='complex'`` or
``'phase'``.

.. versionadded:: 1.X.0
%(verbose)s

Returns
Expand All @@ -520,6 +527,9 @@ def tfr_array_multitaper(
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
contain the average power and the imaginary values contain the
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
weights : array of shape (n_tapers, n_freqs)
The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and
``return_weights=True``.

See Also
--------
Expand Down Expand Up @@ -550,6 +560,7 @@ def tfr_array_multitaper(
use_fft=use_fft,
decim=decim,
output=output,
return_weights=return_weights,
n_jobs=n_jobs,
verbose=verbose,
)
139 changes: 125 additions & 14 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,21 @@ def test_tfr_morlet():
def test_dpsswavelet():
"""Test DPSS tapers."""
freqs = np.arange(5, 25, 3)
Ws = _make_dpss(
1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True
Ws, weights = _make_dpss(
1000,
freqs=freqs,
n_cycles=freqs / 2.0,
time_bandwidth=4.0,
zero_mean=True,
return_weights=True,
)

assert len(Ws) == 3 # 3 tapers expected
assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected
assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs)

# Check that zero mean is true
assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5

assert len(Ws[0]) == len(freqs) # As many wavelets as asked for


@pytest.mark.slowtest
def test_tfr_multitaper():
Expand Down Expand Up @@ -664,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):
with tfr.info._unlock():
tfr.info["meas_date"] = want
assert tfr_loaded == tfr
# test with taper dimension and weights
n_tapers = 3 # anything >= 1 should do
weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs
state = tfr.__getstate__()
state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim
state["weights"] = weights # add weights
state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims
tfr = EpochsTFR(inst=state)
tfr.save(fname, overwrite=True)
tfr_loaded = read_tfrs(fname)
assert tfr_loaded == tfr
# test overwrite
with pytest.raises(OSError, match="Destination file exists."):
tfr.save(fname, overwrite=False)
Expand Down Expand Up @@ -722,17 +737,31 @@ def test_average_tfr_init(full_evoked):
AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace)


def test_epochstfr_init_errors(epochs_tfr):
"""Test __init__ for EpochsTFR."""
state = epochs_tfr.__getstate__()
with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"):
EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0]))
@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr"))
def test_tfr_init_errors(inst, request, average_tfr):
"""Test __init__ for {Raw,Epochs,Average}TFR."""
# Load data
inst = _get_inst(inst, request, average_tfr=average_tfr)
state = inst.__getstate__()
# Prepare for TFRArray object instantiation
inst_name = inst.__class__.__name__
class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR)
ndims_mapping = dict(
RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D")
)
TFR = class_mapping[inst_name]
allowed_ndims = ndims_mapping[inst_name]
# Check errors caught
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"):
TFR(inst=state | dict(data=inst.data[..., 0]))
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"):
TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1))))
with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"):
EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1]))
TFR(inst=state | dict(data=inst.data[..., :-1, :, :]))
with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"):
EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1]))
TFR(inst=state | dict(times=inst.times[:-1]))
with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"):
EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1]))
TFR(inst=state | dict(freqs=inst.freqs[:-1]))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1154,6 +1183,15 @@ def test_averaging_epochsTFR():
):
power.average(method=np.mean)

# Check it doesn't run for taper spectra
tapered = epochs.compute_tfr(
method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex"
)
with pytest.raises(
NotImplementedError, match=r"Averaging multitaper tapers .* is not supported."
):
tapered.average()


def test_averaging_freqsandtimes_epochsTFR():
"""Test that EpochsTFR averaging freqs methods work."""
Expand Down Expand Up @@ -1534,7 +1572,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc):
def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output):
"""Test Epochs.compute_tfr(output="complex"/"phase")."""
tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output)
assert len(tfr.shape) == 5
assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time
assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match


@pytest.mark.parametrize("copy", (False, True))
Expand All @@ -1546,6 +1585,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):
assert avgs[0].comment == str(epochs_tfr.events[0, -1])


@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked"))
def test_tfrarray_tapered_spectra(inst, evoked, request):
"""Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra."""
# Load data object
inst = _get_inst(inst, request, evoked=evoked)
inst.pick("mag")
# Compute TFR with taper dimension (can be complex or phase output)
tfr = inst.compute_tfr(
method="multitaper", freqs=freqs_linspace, n_cycles=4, output="complex"
)
tfr_array, weights = tfr.get_data(), tfr.weights
# Prepare for TFRArray object instantiation
defaults = dict(
info=inst.info, data=tfr_array, times=inst.times, freqs=freqs_linspace
)
class_mapping = dict(Raw=RawTFRArray, Epochs=EpochsTFRArray, Evoked=AverageTFRArray)
TFRArray = class_mapping[inst.__class__.__name__]
# Check TFRArray instantiation runs with good data
TFRArray(**defaults, weights=weights)
# Check taper dimension but no weights caught
with pytest.raises(
ValueError, match="Taper dimension in data, but no weights found."
):
TFRArray(**defaults)
# Check mismatching n_taper in weights caught
with pytest.raises(
ValueError, match=r"Taper axis .* doesn't match weights attribute"
):
TFRArray(**defaults, weights=weights[:-1])
# Check mismatching n_freq in weights caught
with pytest.raises(
ValueError, match=r"Frequency axis .* doesn't match weights attribute"
):
TFRArray(**defaults, weights=weights[:, :-1])


def test_tfr_proj(epochs):
"""Test `compute_tfr(proj=True)`."""
epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True)
Expand Down Expand Up @@ -1727,3 +1802,39 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request):
assert re.match(
rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title()
)


def test_combine_tfr_error_catch(request, average_tfr):
"""Test combine_tfr() catches errors."""
# check unrecognised weights string caught
with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'):
combine_tfr([average_tfr, average_tfr], weights="foo")
# check bad weights size caught
with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"):
combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1])
# check different channel names caught
state = average_tfr.__getstate__()
new_info = average_tfr.info.copy()
average_tfr_bad = AverageTFR(
inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"}))
)
with pytest.raises(AssertionError, match=".* do not contain the same channels"):
combine_tfr([average_tfr, average_tfr_bad])
# check different times caught
average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1))
with pytest.raises(
AssertionError, match=".* do not contain the same time instants"
):
combine_tfr([average_tfr, average_tfr_bad])
# check taper dim caught
n_tapers = 3 # anything >= 1 should do
weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs
state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1)
state["weights"] = weights
state["dims"] = ("channel", "taper", "freq", "time")
average_tfr_taper = AverageTFR(inst=state)
with pytest.raises(
NotImplementedError,
match="Aggregating multitaper tapers across TFR datasets is not supported.",
):
combine_tfr([average_tfr_taper, average_tfr_taper])
Loading
Loading