Skip to content

Commit

Permalink
overhaul TFR classes [circle full]
Browse files Browse the repository at this point in the history
refactor method_kw checking

cleaner imports

refactor _get_instance_type_string

many changes; raw.compute_tfr(multitaper, freqs) works!

add verbose to tfr_array_stockwell

add get_data() method

DRY / fixes for method=stockwell

partially handle ITC; comments

forgotten (temporary) import

rework evoked fixtures to provide uncropped evoked option

override plot() method for EpochsTFR

add nave to save/load roundtrip

move copy to base class, fix docstrings, work on iter_evoked

work on ITC handling; add ValueErrors for unsupported param combos; cleanup TODOs

move arg gymnastics into EpochsTFR init; get iter_evoked working

get TFR.average() working

working on making stockwell/ITC behave sanely

reorg for saner class behavior

cleanup/move util funcs

make plotting work for singleton EpochsTFR

docstring fixes; support comments for AverageTFR

fix stockwell bugs

fix attributes, fix docstring format, don't store method_kw in an attribute

fix import nesting

add forgotten utils file

add arithmetic methods

fix check_option for stockwell

get baseline and crop working

add loader func; allow both h5 and hdf5 extensions

get plot_topomap working

better decim support for stockwell

make read_tfrs always return new class

refactor plot_joint

get plot_joint interaction working for grads; propogate topomap_kw to popup figs

small improvements to util func set_title_multi_electrodes

get onselect topomaps to look like plot_joint topomaps

copy everything to tfr.py; delete spectrogram.py; mark legacy

plot_joint debugging
  • Loading branch information
drammock committed Jan 30, 2024
1 parent 195a2cc commit ae76b38
Show file tree
Hide file tree
Showing 22 changed files with 3,096 additions and 1,609 deletions.
5 changes: 2 additions & 3 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,6 @@ def _pick_drop_channels(self, idx, *, verbose=None):
# avoid circular imports
from ..io import BaseRaw
from ..time_frequency import AverageTFR, EpochsTFR
from ..time_frequency.spectrum import BaseSpectrum

msg = "adding, dropping, or reordering channels"
if isinstance(self, BaseRaw):
Expand All @@ -633,9 +632,9 @@ def _pick_drop_channels(self, idx, *, verbose=None):
if mat is not None:
setattr(self, key, mat[idx][:, idx])

if isinstance(self, BaseSpectrum):
if hasattr(self, "_dims"): # Spectrum and "new-style" TFRs
axis = self._dims.index("channel")
elif isinstance(self, (AverageTFR, EpochsTFR)):
elif isinstance(self, (AverageTFR, EpochsTFR)): # "old-style" TFRs
axis = -3
else: # All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -2
Expand Down
46 changes: 37 additions & 9 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,25 @@ def epochs_spectrum():
return _get_epochs().load_data().compute_psd()


@pytest.fixture()
def epochs_tfr():
"""Get an EpochsTFR computed from mne.io.tests.data."""
epochs = _get_epochs().load_data()
return epochs.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5))


@pytest.fixture()
def average_tfr(full_evoked):
"""Get an AverageTFR computed from mne.io.tests.data."""
return full_evoked.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5))


@pytest.fixture()
def raw_tfr(raw):
"""Get a RawTFR computed from mne.io.tests.data."""
return raw.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5))


@pytest.fixture()
def epochs_empty():
"""Get empty epochs from mne.io.tests.data."""
Expand All @@ -406,22 +425,31 @@ def epochs_empty():


@pytest.fixture(scope="session", params=[testing._pytest_param()])
def _evoked():
# This one is session scoped, so be sure not to modify it (use evoked
# instead)
evoked = mne.read_evokeds(
fname_evoked, condition="Left Auditory", baseline=(None, 0)
)
evoked.crop(0, 0.2)
return evoked
def _full_evoked():
# This is session scoped, so be sure not to modify its return value (use
# `full_evoked` fixture instead)
return mne.read_evokeds(fname_evoked, condition="Left Auditory", baseline=(None, 0))


@pytest.fixture(scope="session", params=[testing._pytest_param()])
def _evoked(_full_evoked):
# This is session scoped, so be sure not to modify its return value (use `evoked`
# fixture instead)
return _full_evoked.copy().crop(0, 0.2)


@pytest.fixture()
def evoked(_evoked):
"""Get evoked data."""
"""Get truncated evoked data."""
return _evoked.copy()


@pytest.fixture()
def full_evoked(_full_evoked):
"""Get full-duration evoked data (needed for, e.g., testing TFR)."""
return _full_evoked.copy()


@pytest.fixture(scope="function", params=[testing._pytest_param()])
def noise_cov():
"""Get a noise cov from the testing dataset."""
Expand Down
117 changes: 117 additions & 0 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@
from .fixes import rng_uniform
from .html_templates import _get_html_template
from .parallel import parallel_func
from .time_frequency._utils import _ensure_output_not_in_method_kw
from .time_frequency.spectrum import EpochsSpectrum, SpectrumMixin, _validate_method
from .time_frequency.tfr import AverageTFR, EpochsTFR
from .utils import (
ExtendedTimeMixin,
GetEpochsMixin,
Expand Down Expand Up @@ -2503,6 +2505,121 @@ def compute_psd(
**method_kw,
)

@verbose
def compute_tfr(
self,
method,
freqs,
*,
tmin=None,
tmax=None,
picks=None,
proj=False,
average="auto",
return_itc=False,
decim=1,
n_jobs=None,
verbose=None,
**method_kw,
):
"""Compute a time-frequency representation of epoched data.
Parameters
----------
%(method_tfr_epochs)s
%(freqs_tfr)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
average : bool | "auto"
Whether to return average power across epochs (instead of single-trial
power). Default is "auto" which means ``True`` if method="stockwell" and
``False`` otherwise.
return_itc : bool
Whether to return inter-trial coherence (ITC) as well as power estimates.
Default is ``False``.
%(decim_tfr)s
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Returns
-------
tfr : instance of EpochsTFR or AverageTFR
The time-frequency-resolved power estimates.
itc : instance of AverageTFR
The inter-trial coherence (ITC). Only returned if ``return_itc=True``.
Notes
-----
.. versionadded:: 1.6
References
----------
.. footbibliography::
"""
# construct `output` value from `average` and `return_itc`
method_kw = _ensure_output_not_in_method_kw(self, method_kw)
if average == "auto": # stockwell method *must* average
average = method == "stockwell"
if average:
method_kw["output"] = "avg_power_itc" if return_itc else "avg_power"
else:
msg = (
"compute_tfr() got incompatible parameters `average=False` and `{}` "
"({} requires averaging over epochs)."
)
if return_itc:
raise ValueError(msg.format("return_itc=True", "computing ITC"))
if method == "stockwell":
raise ValueError(msg.format('method="stockwell"', "Stockwell method"))
if method == "stockwell":
method_kw["return_itc"] = return_itc
method_kw.pop("output")
if isinstance(freqs, str):
_check_option("freqs", freqs, "auto")
else:
_validate_type(freqs, "array-like")
_check_option("freqs", np.array(freqs).shape, ((2,),))
if average:
out = AverageTFR(
self,
method=method,
freqs=freqs,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)
# tfr_array_stockwell always returns ITC (but sometimes it's None)
if hasattr(out, "_itc"):
if out._itc is not None:
state = out.__getstate__()
state["data"] = out._itc
state["data_type"] = "Inter-trial coherence"
itc = AverageTFR(state, method=None, freqs=None)
return out, itc
del out._itc # if it's None, don't keep it around
return out
# now handle average=False
return EpochsTFR(
self,
method=method,
freqs=freqs,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)

@verbose
def plot_psd(
self,
Expand Down
59 changes: 59 additions & 0 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
from .filter import FilterMixin, _check_fun, detrend
from .html_templates import _get_html_template
from .parallel import parallel_func
from .time_frequency._utils import _ensure_output_not_in_method_kw
from .time_frequency.spectrum import Spectrum, SpectrumMixin, _validate_method
from .time_frequency.tfr import AverageTFR
from .utils import (
ExtendedTimeMixin,
SizeMixin,
Expand Down Expand Up @@ -1122,6 +1124,63 @@ def compute_psd(
**method_kw,
)

@verbose
def compute_tfr(
self,
method,
freqs,
*,
tmin=None,
tmax=None,
picks=None,
proj=False,
decim=1,
n_jobs=None,
verbose=None,
**method_kw,
):
"""Compute a time-frequency representation of evoked data.
Parameters
----------
%(method_tfr)s
%(freqs_tfr)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(decim_tfr)s
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Returns
-------
tfr : instance of AverageTFR
The time-frequency-resolved power estimates of the data.
Notes
-----
.. versionadded:: 1.6
References
----------
.. footbibliography::
"""
_ensure_output_not_in_method_kw(self, method_kw)
return AverageTFR(
self,
method=method,
freqs=freqs,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)

@verbose
def plot_psd(
self,
Expand Down
54 changes: 54 additions & 0 deletions mne/html_templates/repr/tfr.html.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<table class="table table-hover table-striped table-sm table-responsive small">
<tr>
<th>Data type</th>
<td>{{ tfr._data_type }}</td>
</tr>
{%- for unit in units %}
<tr>
{%- if loop.index == 1 %}
<th rowspan={{ units | length }}>Units</th>
{%- endif %}
<td class="justify">{{ unit }}</td>
</tr>
{%- endfor %}
<tr>
<th>Data source</th>
<td>{{ inst_type }}</td>
</tr>
{%- if inst_type == "Epochs" %}
<tr>
<th>Number of epochs</th>
<td>{{ tfr.shape[0] }}</td>
</tr>
{% endif -%}
<tr>
<th>Dims</th>
<td>{{ tfr._dims | join(", ") }}</td>
</tr>
<tr>
<th>Estimation method</th>
<td>{{ tfr.method }}</td>
</tr>
{% if "taper" in tfr._dims %}
<tr>
<th>Number of tapers</th>
<td>{{ tfr._mt_weights.size }}</td>
</tr>
{% endif %}
<tr>
<th>Number of channels</th>
<td>{{ tfr.ch_names|length }}</td>
</tr>
<tr>
<th>Number of timepoints</th>
<td>{{ tfr.times|length }}</td>
</tr>
<tr>
<th>Number of frequency bins</th>
<td>{{ tfr.freqs|length }}</td>
</tr>
<tr>
<th>Frequency range</th>
<td>{{ '%.2f'|format(tfr.freqs[0]) }} – {{ '%.2f'|format(tfr.freqs[-1]) }} Hz</td>
</tr>
</table>
Loading

0 comments on commit ae76b38

Please sign in to comment.