diff --git a/mne/beamformer/tests/test_dics.py b/mne/beamformer/tests/test_dics.py
index 1daaaf17eb0..3f2d9f939cb 100644
--- a/mne/beamformer/tests/test_dics.py
+++ b/mne/beamformer/tests/test_dics.py
@@ -727,7 +727,7 @@ def test_apply_dics_tfr(return_generator):
data = rng.random((n_epochs, n_chans, len(freqs), n_times))
data *= 1e-6
data = data + data * 1j # add imag. component to simulate phase
- epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs)
+ epochs_tfr = EpochsTFR(dict(info=info, data=data, times=times, freqs=freqs))
# Create a DICS beamformer and convert the EpochsTFR to source space.
csd = csd_tfr(epochs_tfr)
diff --git a/mne/channels/channels.py b/mne/channels/channels.py
index 8db7acfbc61..5c67fef4272 100644
--- a/mne/channels/channels.py
+++ b/mne/channels/channels.py
@@ -607,7 +607,6 @@ def drop_channels(self, ch_names, on_missing="raise"):
def _pick_drop_channels(self, idx, *, verbose=None):
# avoid circular imports
from ..io import BaseRaw
- from ..time_frequency import AverageTFR, EpochsTFR
msg = "adding, dropping, or reordering channels"
if isinstance(self, BaseRaw):
@@ -634,8 +633,6 @@ def _pick_drop_channels(self, idx, *, verbose=None):
if hasattr(self, "_dims"): # Spectrum and "new-style" TFRs
axis = self._dims.index("channel")
- elif isinstance(self, (AverageTFR, EpochsTFR)): # "old-style" TFRs
- axis = -3
else: # All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -2
if hasattr(self, "_data"): # skip non-preloaded Raw
diff --git a/mne/epochs.py b/mne/epochs.py
index 756a9799c86..4c7cb1732fd 100644
--- a/mne/epochs.py
+++ b/mne/epochs.py
@@ -2602,8 +2602,9 @@ def compute_tfr(
state["data"] = out._itc
state["data_type"] = "Inter-trial coherence"
itc = AverageTFR(state, method=None, freqs=None)
+ del out._itc
return out, itc
- del out._itc # if it's None, don't keep it around
+ del out._itc
return out
# now handle average=False
return EpochsTFR(
diff --git a/mne/html_templates/repr/tfr.html.jinja b/mne/html_templates/repr/tfr.html.jinja
index ae3fba06bcd..616a1899f80 100644
--- a/mne/html_templates/repr/tfr.html.jinja
+++ b/mne/html_templates/repr/tfr.html.jinja
@@ -21,6 +21,12 @@
{{ tfr.shape[0] }} |
{% endif -%}
+ {%- inst_type == "Evoked" %}
+
+ Number of averaged trials |
+ {{ nave }} |
+
+ {% endif -%}
Dims |
{{ tfr._dims | join(", ") }} |
diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py
index 58722a19fd5..23e3822d524 100644
--- a/mne/minimum_norm/tests/test_inverse.py
+++ b/mne/minimum_norm/tests/test_inverse.py
@@ -1377,7 +1377,7 @@ def test_apply_inverse_tfr(return_generator):
times = np.arange(sfreq) / sfreq # make epochs 1s long
data = rng.random((n_epochs, len(info.ch_names), freqs.size, times.size))
data = data + 1j * data # make complex to simulate amplitude + phase
- epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs)
+ epochs_tfr = EpochsTFR(dict(info=info, data=data, times=times, freqs=freqs))
epochs_tfr.apply_baseline((0, 0.5))
pick_ori = "vector"
diff --git a/mne/time_frequency/_stockwell.py b/mne/time_frequency/_stockwell.py
index 708cfd34e53..002c9627bee 100644
--- a/mne/time_frequency/_stockwell.py
+++ b/mne/time_frequency/_stockwell.py
@@ -302,17 +302,28 @@ def tfr_stockwell(
)
times = inst.times[decim].copy()
nave = len(data)
- out = AverageTFR(info, power, times, freqs, nave, method="stockwell-power")
+ out = AverageTFR(
+ dict(
+ info=info,
+ data=power,
+ times=times,
+ freqs=freqs,
+ nave=nave,
+ method="stockwell-power",
+ )
+ )
if return_itc:
out = (
out,
AverageTFR(
- deepcopy(info),
- itc,
- times.copy(),
- freqs.copy(),
- nave,
- method="stockwell-itc",
+ dict(
+ info=deepcopy(info),
+ data=itc,
+ times=times.copy(),
+ freqs=freqs.copy(),
+ nave=nave,
+ method="stockwell-itc",
+ )
),
)
return out
diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py
index 095199163d2..47154e70c71 100644
--- a/mne/time_frequency/spectrum.py
+++ b/mne/time_frequency/spectrum.py
@@ -50,6 +50,7 @@
from ..viz.utils import (
_format_units_psd,
_get_plot_ch_type,
+ _make_combine_callable,
_plot_psd,
_prepare_sensor_names,
plt_show,
@@ -1395,12 +1396,11 @@ def average(self, method="mean"):
spectrum : instance of Spectrum
The aggregated spectrum object.
"""
- # TODO: we probably should avoid `np.median` here when data are complex
- # (like we do in EpochsTFR.average)
# TODO: probably should have a `.nave` attribute?
- if isinstance(method, str):
- method = getattr(np, method) # mean, median, std, etc
- method = partial(method, axis=0)
+ _validate_type(method, ("str", "callable"))
+ method = _make_combine_callable(
+ method, axis=0, valid=("mean", "median"), keepdims=False
+ )
if not callable(method):
raise ValueError(
'"method" must be a valid string or callable, '
diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py
index b030124dfa0..846e0b6b061 100644
--- a/mne/time_frequency/tests/test_tfr.py
+++ b/mne/time_frequency/tests/test_tfr.py
@@ -1727,6 +1727,7 @@ def test_tfr_arithmetic(epochs):
@parametrize_tfr_inst
def test_tfr_save_load(inst, average_tfr, request, tmp_path):
"""Test TFR I/O."""
+ pytest.importorskip("h5io")
tfr = _get_inst(inst, request, average_tfr=average_tfr)
fname = tmp_path / "temp_tfr.hdf5"
tfr.save(fname, overwrite=True)
diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py
index c303bedb4db..cdae28afa95 100644
--- a/mne/time_frequency/tfr.py
+++ b/mne/time_frequency/tfr.py
@@ -1173,9 +1173,19 @@ def __init__(
for k, v in dict(method=method, freqs=freqs).items()
if v is None
]
- class_name = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0]
+ # TODO when py3.11 is min version, replace if/elif/else block with
+ # classname = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0]
+ _varnames = inspect.currentframe().f_back.f_code.co_varnames
+ if "BaseRaw" in _varnames:
+ classname = "RawTFR"
+ elif "Evoked" in _varnames:
+ classname = "AverageTFR"
+ else:
+ assert "BaseEpochs" in _varnames and "Evoked" not in _varnames
+ classname = "EpochsTFR"
+ # end TODO
raise ValueError(
- f'{class_name} got unsupported parameter value{_pl(problem)} '
+ f'{classname} got unsupported parameter value{_pl(problem)} '
f'{" and ".join(problem)}.'
)
# check method
@@ -1218,23 +1228,16 @@ def __init__(
self.preload = True # needed for __getitem__, never False
self._method = method
# self._dims may also get updated by child classes
- self._dims = (
- "channel",
- "freq",
- "time",
- )
- # get the instance data. `_time_mask` is later deleted inside _compute_tfr,
- # when decim is applied.
- self._time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq)
- get_instance_data_kw = (
- dict()
- if reject_by_annotation is None
- else dict(reject_by_annotation=reject_by_annotation)
- )
+ self._dims = ("channel", "freq", "time")
+ # get the instance data.
+ time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq)
+ get_instance_data_kw = dict(time_mask=time_mask)
+ if reject_by_annotation is not None:
+ get_instance_data_kw.update(reject_by_annotation=reject_by_annotation)
data = self._get_instance_data(**get_instance_data_kw)
# compute the TFR
self._decim = _check_decim(decim)
- self._raw_times = inst.times[self._time_mask]
+ self._raw_times = inst.times[time_mask]
self._compute_tfr(data, n_jobs, verbose)
self._update_epoch_attributes()
# "apply" decim to the rest of the object (data is decimated in _compute_tfr)
@@ -1250,7 +1253,6 @@ def __init__(
# we don't need these anymore, and they make save/load harder
del self._picks
del self._tfr_func
- del self._time_mask
del self._shape # calculated from self._data henceforth
del self.inst # save memory
@@ -1392,18 +1394,7 @@ def __setstate__(self, state):
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class)
self._inst_type = inst_types[defaults["inst_type_str"]]
# sanity check data/freqs/times/info agreement
- msg = "{} axis of data ({}) doesn't match {} attribute ({})"
- n_chan_info = len(self.info["chs"])
- n_chan, n_freq, n_time = self._data.shape[self._dims.index("channel") :]
- if n_chan_info != n_chan:
- msg = msg.format("Channel", n_chan, "info", n_chan_info)
- elif n_freq != self.freqs.size:
- msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size)
- elif n_time != self.times.size:
- msg = msg.format("Time", n_time, "times", self.times.size)
- else:
- return
- raise ValueError(msg)
+ self._check_state()
def __repr__(self):
"""Build string representation of the TFR object."""
@@ -1427,9 +1418,10 @@ def _repr_html_(self, caption=None):
from ..html_templates import repr_templates_env
inst_type_str = _get_instance_type_string(self)
+ nave = getattr(self, "nave", 0)
units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()]
t = repr_templates_env.get_template("tfr.html.jinja")
- t = t.render(tfr=self, inst_type=inst_type_str, units=units)
+ t = t.render(tfr=self, inst_type=inst_type_str, units=units, nave=nave)
return t
def _check_compatibility(self, other):
@@ -1454,6 +1446,21 @@ def _check_compatibility(self, other):
return
raise RuntimeError(msg.format(problem, extra))
+ def _check_state(self):
+ """Check data/freqs/times/info agreement during __setstate__."""
+ msg = "{} axis of data ({}) doesn't match {} attribute ({})"
+ n_chan_info = len(self.info["chs"])
+ n_chan, n_freq, n_time = self._data.shape[self._dims.index("channel") :]
+ if n_chan_info != n_chan:
+ msg = msg.format("Channel", n_chan, "info", n_chan_info)
+ elif n_freq != self.freqs.size:
+ msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size)
+ elif n_time != self.times.size:
+ msg = msg.format("Time", n_time, "times", self.times.size)
+ else:
+ return
+ raise ValueError(msg)
+
def _check_values(self, wants_complex=False):
"""Check TFR results for correct shape and bad values."""
assert len(self._dims) == self._data.ndim
@@ -1897,13 +1904,6 @@ def plot(
figs : list of instances of matplotlib.figure.Figure
A list of figures containing the time-frequency power.
"""
- # triage EpochsTFR
- if isinstance(self, EpochsTFR) and self.shape[0] > 1:
- raise NotImplementedError(
- "Plotting EpochsTFR objects containing more than one epoch is not "
- "supported; either plot an average `EpochsTFR.average().plot()` or "
- "plot individual epochs `EpochsTFR[0].plot()`."
- )
# the rectangle selector plots topomaps, which needs all channels uncombined,
# so we keep a reference to that state here, and update it with `comment` and
# `nave` values in case we started out with a singleton EpochsTFR or RawTFR
@@ -2784,11 +2784,11 @@ def nave(self):
def nave(self, nave):
self._nave = nave
- def _get_instance_data(self):
+ def _get_instance_data(self, time_mask):
# AverageTFRs can be constructed from Epochs data, so we triage shape here.
# Evoked data get a fake singleton "epoch" axis prepended
dim = slice(None) if _get_instance_type_string(self) == "Epochs" else np.newaxis
- data = self.inst.get_data(picks=self._picks)[dim, :, self._time_mask]
+ data = self.inst.get_data(picks=self._picks)[dim, :, time_mask]
self._nave = getattr(self.inst, "nave", data.shape[0])
return data
@@ -2957,14 +2957,15 @@ def __next__(self, return_event_id=False):
def _check_singleton(self):
"""Check if self contains only one Epoch, and return it as an AverageTFR."""
if self.shape[0] > 1:
+ calling_func = inspect.currentframe().f_back.f_code.co_name
raise NotImplementedError(
- "Cannot plot topomap for multiple EpochsTFR epochs; please subselect a"
- "single epoch before plotting."
+ f"Cannot call {calling_func}() from EpochsTFR with multiple epochs; "
+ "please subselect a single epoch before plotting."
)
return list(self.iter_evoked())[0]
- def _get_instance_data(self):
- return self.inst.get_data(picks=self._picks)[:, :, self._time_mask]
+ def _get_instance_data(self, time_mask):
+ return self.inst.get_data(picks=self._picks)[:, :, time_mask]
def _update_epoch_attributes(self):
# adjust dims and shape
@@ -3554,8 +3555,8 @@ def __getitem__(self, item):
self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self)
return BaseRaw._getitem(self, item, return_times=False)
- def _get_instance_data(self, reject_by_annotation):
- start, stop = np.where(self._time_mask)[0][[0, -1]]
+ def _get_instance_data(self, time_mask, reject_by_annotation):
+ start, stop = np.where(time_mask)[0][[0, -1]]
rba = "NaN" if reject_by_annotation else None
data = self.inst.get_data(
self._picks, start, stop + 1, reject_by_annotation=rba
diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py
index 97622dabde4..24306fafdd0 100644
--- a/mne/utils/mixin.py
+++ b/mne/utils/mixin.py
@@ -764,6 +764,7 @@ def _prepare_read_metadata(metadata):
assert isinstance(metadata, list)
if pd:
metadata = pd.DataFrame.from_records(metadata)
- metadata.set_index("index", inplace=True)
+ if "index" in metadata.columns:
+ metadata.set_index("index", inplace=True)
assert isinstance(metadata, pd.DataFrame)
return metadata
diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py
index 5830c647edb..fa99424f261 100644
--- a/mne/viz/tests/test_topo.py
+++ b/mne/viz/tests/test_topo.py
@@ -309,7 +309,15 @@ def test_plot_tfr_topo():
data = np.random.RandomState(0).randn(
len(epochs.ch_names), n_freqs, len(epochs.times)
)
- tfr = AverageTFR(epochs.info, data, epochs.times, np.arange(n_freqs), nave)
+ tfr = AverageTFR(
+ dict(
+ info=epochs.info,
+ data=data,
+ times=epochs.times,
+ freqs=np.arange(n_freqs),
+ nave=nave,
+ )
+ )
plt.close("all")
fig = tfr.plot_topo(
baseline=(None, 0), mode="ratio", title="Average power", vmin=0.0, vmax=14.0
diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py
index 2774e198fe8..d816f2b3c92 100644
--- a/mne/viz/tests/test_topomap.py
+++ b/mne/viz/tests/test_topomap.py
@@ -578,7 +578,15 @@ def test_plot_tfr_topomap():
data = rng.randn(len(picks), n_freqs, len(times))
# test complex numbers
- tfr = AverageTFR(info, data * (1 + 1j), times, np.arange(n_freqs), nave)
+ tfr = AverageTFR(
+ dict(
+ info=info,
+ data=data * (1 + 1j),
+ times=times,
+ freqs=np.arange(n_freqs),
+ nave=nave,
+ )
+ )
tfr.plot_topomap(
ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0
)
diff --git a/mne/viz/topo.py b/mne/viz/topo.py
index 1751c7efa57..ae95b294df4 100644
--- a/mne/viz/topo.py
+++ b/mne/viz/topo.py
@@ -428,7 +428,6 @@ def _imshow_tfr(
cnorm=None,
):
"""Show time-frequency map as two-dimensional image."""
- from matplotlib import pyplot as plt
from matplotlib.widgets import RectangleSelector
_check_option("yscale", yscale, ["auto", "linear", "log"])
@@ -460,7 +459,7 @@ def _imshow_tfr(
if isinstance(colorbar, DraggableColorbar):
cbar = colorbar.cbar # this happens with multiaxes case
else:
- cbar = plt.colorbar(mappable=img, ax=ax)
+ cbar = ax.get_figure().colorbar(mappable=img, ax=ax)
if interactive_cmap:
ax.CB = DraggableColorbar(cbar, img, kind="tfr_image", ch_type=None)
ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept
diff --git a/mne/viz/utils.py b/mne/viz/utils.py
index a83dcf2823f..7dec170b621 100644
--- a/mne/viz/utils.py
+++ b/mne/viz/utils.py
@@ -2355,7 +2355,7 @@ def _make_combine_callable(
"""
kwargs = dict(axis=axis, keepdims=keepdims)
if combine is None:
- combine = _identity_function
+ combine = _identity_function if keepdims else partial(np.squeeze, axis=axis)
elif isinstance(combine, str):
combine_dict = {
key: partial(getattr(np, key), **kwargs)
@@ -2366,9 +2366,10 @@ def _make_combine_callable(
if "median" in valid:
combine_dict["median"] = partial(_median_complex, axis=axis)
# RMS and GFP are computed the same way
- for key in ("gfp", "rms"):
- if key in valid:
- combine_dict[key] = lambda data: np.sqrt((data**2).mean(**kwargs))
+ if "rms" in valid:
+ combine_dict["rms"] = lambda data: np.sqrt((data**2).mean(**kwargs))
+ if "gfp" in valid:
+ combine_dict["gfp"] = lambda data: data.std(axis=axis, ddof=0)
try:
combine = combine_dict[combine]
except KeyError: