diff --git a/mne/beamformer/tests/test_dics.py b/mne/beamformer/tests/test_dics.py index 1daaaf17eb0..3f2d9f939cb 100644 --- a/mne/beamformer/tests/test_dics.py +++ b/mne/beamformer/tests/test_dics.py @@ -727,7 +727,7 @@ def test_apply_dics_tfr(return_generator): data = rng.random((n_epochs, n_chans, len(freqs), n_times)) data *= 1e-6 data = data + data * 1j # add imag. component to simulate phase - epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs) + epochs_tfr = EpochsTFR(dict(info=info, data=data, times=times, freqs=freqs)) # Create a DICS beamformer and convert the EpochsTFR to source space. csd = csd_tfr(epochs_tfr) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 8db7acfbc61..5c67fef4272 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -607,7 +607,6 @@ def drop_channels(self, ch_names, on_missing="raise"): def _pick_drop_channels(self, idx, *, verbose=None): # avoid circular imports from ..io import BaseRaw - from ..time_frequency import AverageTFR, EpochsTFR msg = "adding, dropping, or reordering channels" if isinstance(self, BaseRaw): @@ -634,8 +633,6 @@ def _pick_drop_channels(self, idx, *, verbose=None): if hasattr(self, "_dims"): # Spectrum and "new-style" TFRs axis = self._dims.index("channel") - elif isinstance(self, (AverageTFR, EpochsTFR)): # "old-style" TFRs - axis = -3 else: # All others (Evoked, Epochs, Raw) have chs axis=-2 axis = -2 if hasattr(self, "_data"): # skip non-preloaded Raw diff --git a/mne/epochs.py b/mne/epochs.py index 756a9799c86..4c7cb1732fd 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -2602,8 +2602,9 @@ def compute_tfr( state["data"] = out._itc state["data_type"] = "Inter-trial coherence" itc = AverageTFR(state, method=None, freqs=None) + del out._itc return out, itc - del out._itc # if it's None, don't keep it around + del out._itc return out # now handle average=False return EpochsTFR( diff --git a/mne/html_templates/repr/tfr.html.jinja b/mne/html_templates/repr/tfr.html.jinja index ae3fba06bcd..616a1899f80 100644 --- a/mne/html_templates/repr/tfr.html.jinja +++ b/mne/html_templates/repr/tfr.html.jinja @@ -21,6 +21,12 @@ {{ tfr.shape[0] }} {% endif -%} + {%- inst_type == "Evoked" %} + + Number of averaged trials + {{ nave }} + + {% endif -%} Dims {{ tfr._dims | join(", ") }} diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index 58722a19fd5..23e3822d524 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -1377,7 +1377,7 @@ def test_apply_inverse_tfr(return_generator): times = np.arange(sfreq) / sfreq # make epochs 1s long data = rng.random((n_epochs, len(info.ch_names), freqs.size, times.size)) data = data + 1j * data # make complex to simulate amplitude + phase - epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs) + epochs_tfr = EpochsTFR(dict(info=info, data=data, times=times, freqs=freqs)) epochs_tfr.apply_baseline((0, 0.5)) pick_ori = "vector" diff --git a/mne/time_frequency/_stockwell.py b/mne/time_frequency/_stockwell.py index 708cfd34e53..002c9627bee 100644 --- a/mne/time_frequency/_stockwell.py +++ b/mne/time_frequency/_stockwell.py @@ -302,17 +302,28 @@ def tfr_stockwell( ) times = inst.times[decim].copy() nave = len(data) - out = AverageTFR(info, power, times, freqs, nave, method="stockwell-power") + out = AverageTFR( + dict( + info=info, + data=power, + times=times, + freqs=freqs, + nave=nave, + method="stockwell-power", + ) + ) if return_itc: out = ( out, AverageTFR( - deepcopy(info), - itc, - times.copy(), - freqs.copy(), - nave, - method="stockwell-itc", + dict( + info=deepcopy(info), + data=itc, + times=times.copy(), + freqs=freqs.copy(), + nave=nave, + method="stockwell-itc", + ) ), ) return out diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 095199163d2..47154e70c71 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -50,6 +50,7 @@ from ..viz.utils import ( _format_units_psd, _get_plot_ch_type, + _make_combine_callable, _plot_psd, _prepare_sensor_names, plt_show, @@ -1395,12 +1396,11 @@ def average(self, method="mean"): spectrum : instance of Spectrum The aggregated spectrum object. """ - # TODO: we probably should avoid `np.median` here when data are complex - # (like we do in EpochsTFR.average) # TODO: probably should have a `.nave` attribute? - if isinstance(method, str): - method = getattr(np, method) # mean, median, std, etc - method = partial(method, axis=0) + _validate_type(method, ("str", "callable")) + method = _make_combine_callable( + method, axis=0, valid=("mean", "median"), keepdims=False + ) if not callable(method): raise ValueError( '"method" must be a valid string or callable, ' diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index b030124dfa0..846e0b6b061 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -1727,6 +1727,7 @@ def test_tfr_arithmetic(epochs): @parametrize_tfr_inst def test_tfr_save_load(inst, average_tfr, request, tmp_path): """Test TFR I/O.""" + pytest.importorskip("h5io") tfr = _get_inst(inst, request, average_tfr=average_tfr) fname = tmp_path / "temp_tfr.hdf5" tfr.save(fname, overwrite=True) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index c303bedb4db..cdae28afa95 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1173,9 +1173,19 @@ def __init__( for k, v in dict(method=method, freqs=freqs).items() if v is None ] - class_name = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0] + # TODO when py3.11 is min version, replace if/elif/else block with + # classname = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0] + _varnames = inspect.currentframe().f_back.f_code.co_varnames + if "BaseRaw" in _varnames: + classname = "RawTFR" + elif "Evoked" in _varnames: + classname = "AverageTFR" + else: + assert "BaseEpochs" in _varnames and "Evoked" not in _varnames + classname = "EpochsTFR" + # end TODO raise ValueError( - f'{class_name} got unsupported parameter value{_pl(problem)} ' + f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) # check method @@ -1218,23 +1228,16 @@ def __init__( self.preload = True # needed for __getitem__, never False self._method = method # self._dims may also get updated by child classes - self._dims = ( - "channel", - "freq", - "time", - ) - # get the instance data. `_time_mask` is later deleted inside _compute_tfr, - # when decim is applied. - self._time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) - get_instance_data_kw = ( - dict() - if reject_by_annotation is None - else dict(reject_by_annotation=reject_by_annotation) - ) + self._dims = ("channel", "freq", "time") + # get the instance data. + time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) + get_instance_data_kw = dict(time_mask=time_mask) + if reject_by_annotation is not None: + get_instance_data_kw.update(reject_by_annotation=reject_by_annotation) data = self._get_instance_data(**get_instance_data_kw) # compute the TFR self._decim = _check_decim(decim) - self._raw_times = inst.times[self._time_mask] + self._raw_times = inst.times[time_mask] self._compute_tfr(data, n_jobs, verbose) self._update_epoch_attributes() # "apply" decim to the rest of the object (data is decimated in _compute_tfr) @@ -1250,7 +1253,6 @@ def __init__( # we don't need these anymore, and they make save/load harder del self._picks del self._tfr_func - del self._time_mask del self._shape # calculated from self._data henceforth del self.inst # save memory @@ -1392,18 +1394,7 @@ def __setstate__(self, state): inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] # sanity check data/freqs/times/info agreement - msg = "{} axis of data ({}) doesn't match {} attribute ({})" - n_chan_info = len(self.info["chs"]) - n_chan, n_freq, n_time = self._data.shape[self._dims.index("channel") :] - if n_chan_info != n_chan: - msg = msg.format("Channel", n_chan, "info", n_chan_info) - elif n_freq != self.freqs.size: - msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) - elif n_time != self.times.size: - msg = msg.format("Time", n_time, "times", self.times.size) - else: - return - raise ValueError(msg) + self._check_state() def __repr__(self): """Build string representation of the TFR object.""" @@ -1427,9 +1418,10 @@ def _repr_html_(self, caption=None): from ..html_templates import repr_templates_env inst_type_str = _get_instance_type_string(self) + nave = getattr(self, "nave", 0) units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()] t = repr_templates_env.get_template("tfr.html.jinja") - t = t.render(tfr=self, inst_type=inst_type_str, units=units) + t = t.render(tfr=self, inst_type=inst_type_str, units=units, nave=nave) return t def _check_compatibility(self, other): @@ -1454,6 +1446,21 @@ def _check_compatibility(self, other): return raise RuntimeError(msg.format(problem, extra)) + def _check_state(self): + """Check data/freqs/times/info agreement during __setstate__.""" + msg = "{} axis of data ({}) doesn't match {} attribute ({})" + n_chan_info = len(self.info["chs"]) + n_chan, n_freq, n_time = self._data.shape[self._dims.index("channel") :] + if n_chan_info != n_chan: + msg = msg.format("Channel", n_chan, "info", n_chan_info) + elif n_freq != self.freqs.size: + msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) + elif n_time != self.times.size: + msg = msg.format("Time", n_time, "times", self.times.size) + else: + return + raise ValueError(msg) + def _check_values(self, wants_complex=False): """Check TFR results for correct shape and bad values.""" assert len(self._dims) == self._data.ndim @@ -1897,13 +1904,6 @@ def plot( figs : list of instances of matplotlib.figure.Figure A list of figures containing the time-frequency power. """ - # triage EpochsTFR - if isinstance(self, EpochsTFR) and self.shape[0] > 1: - raise NotImplementedError( - "Plotting EpochsTFR objects containing more than one epoch is not " - "supported; either plot an average `EpochsTFR.average().plot()` or " - "plot individual epochs `EpochsTFR[0].plot()`." - ) # the rectangle selector plots topomaps, which needs all channels uncombined, # so we keep a reference to that state here, and update it with `comment` and # `nave` values in case we started out with a singleton EpochsTFR or RawTFR @@ -2784,11 +2784,11 @@ def nave(self): def nave(self, nave): self._nave = nave - def _get_instance_data(self): + def _get_instance_data(self, time_mask): # AverageTFRs can be constructed from Epochs data, so we triage shape here. # Evoked data get a fake singleton "epoch" axis prepended dim = slice(None) if _get_instance_type_string(self) == "Epochs" else np.newaxis - data = self.inst.get_data(picks=self._picks)[dim, :, self._time_mask] + data = self.inst.get_data(picks=self._picks)[dim, :, time_mask] self._nave = getattr(self.inst, "nave", data.shape[0]) return data @@ -2957,14 +2957,15 @@ def __next__(self, return_event_id=False): def _check_singleton(self): """Check if self contains only one Epoch, and return it as an AverageTFR.""" if self.shape[0] > 1: + calling_func = inspect.currentframe().f_back.f_code.co_name raise NotImplementedError( - "Cannot plot topomap for multiple EpochsTFR epochs; please subselect a" - "single epoch before plotting." + f"Cannot call {calling_func}() from EpochsTFR with multiple epochs; " + "please subselect a single epoch before plotting." ) return list(self.iter_evoked())[0] - def _get_instance_data(self): - return self.inst.get_data(picks=self._picks)[:, :, self._time_mask] + def _get_instance_data(self, time_mask): + return self.inst.get_data(picks=self._picks)[:, :, time_mask] def _update_epoch_attributes(self): # adjust dims and shape @@ -3554,8 +3555,8 @@ def __getitem__(self, item): self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self) return BaseRaw._getitem(self, item, return_times=False) - def _get_instance_data(self, reject_by_annotation): - start, stop = np.where(self._time_mask)[0][[0, -1]] + def _get_instance_data(self, time_mask, reject_by_annotation): + start, stop = np.where(time_mask)[0][[0, -1]] rba = "NaN" if reject_by_annotation else None data = self.inst.get_data( self._picks, start, stop + 1, reject_by_annotation=rba diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index 97622dabde4..24306fafdd0 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -764,6 +764,7 @@ def _prepare_read_metadata(metadata): assert isinstance(metadata, list) if pd: metadata = pd.DataFrame.from_records(metadata) - metadata.set_index("index", inplace=True) + if "index" in metadata.columns: + metadata.set_index("index", inplace=True) assert isinstance(metadata, pd.DataFrame) return metadata diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 5830c647edb..fa99424f261 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -309,7 +309,15 @@ def test_plot_tfr_topo(): data = np.random.RandomState(0).randn( len(epochs.ch_names), n_freqs, len(epochs.times) ) - tfr = AverageTFR(epochs.info, data, epochs.times, np.arange(n_freqs), nave) + tfr = AverageTFR( + dict( + info=epochs.info, + data=data, + times=epochs.times, + freqs=np.arange(n_freqs), + nave=nave, + ) + ) plt.close("all") fig = tfr.plot_topo( baseline=(None, 0), mode="ratio", title="Average power", vmin=0.0, vmax=14.0 diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index 2774e198fe8..d816f2b3c92 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -578,7 +578,15 @@ def test_plot_tfr_topomap(): data = rng.randn(len(picks), n_freqs, len(times)) # test complex numbers - tfr = AverageTFR(info, data * (1 + 1j), times, np.arange(n_freqs), nave) + tfr = AverageTFR( + dict( + info=info, + data=data * (1 + 1j), + times=times, + freqs=np.arange(n_freqs), + nave=nave, + ) + ) tfr.plot_topomap( ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 1751c7efa57..ae95b294df4 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -428,7 +428,6 @@ def _imshow_tfr( cnorm=None, ): """Show time-frequency map as two-dimensional image.""" - from matplotlib import pyplot as plt from matplotlib.widgets import RectangleSelector _check_option("yscale", yscale, ["auto", "linear", "log"]) @@ -460,7 +459,7 @@ def _imshow_tfr( if isinstance(colorbar, DraggableColorbar): cbar = colorbar.cbar # this happens with multiaxes case else: - cbar = plt.colorbar(mappable=img, ax=ax) + cbar = ax.get_figure().colorbar(mappable=img, ax=ax) if interactive_cmap: ax.CB = DraggableColorbar(cbar, img, kind="tfr_image", ch_type=None) ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept diff --git a/mne/viz/utils.py b/mne/viz/utils.py index a83dcf2823f..7dec170b621 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -2355,7 +2355,7 @@ def _make_combine_callable( """ kwargs = dict(axis=axis, keepdims=keepdims) if combine is None: - combine = _identity_function + combine = _identity_function if keepdims else partial(np.squeeze, axis=axis) elif isinstance(combine, str): combine_dict = { key: partial(getattr(np, key), **kwargs) @@ -2366,9 +2366,10 @@ def _make_combine_callable( if "median" in valid: combine_dict["median"] = partial(_median_complex, axis=axis) # RMS and GFP are computed the same way - for key in ("gfp", "rms"): - if key in valid: - combine_dict[key] = lambda data: np.sqrt((data**2).mean(**kwargs)) + if "rms" in valid: + combine_dict["rms"] = lambda data: np.sqrt((data**2).mean(**kwargs)) + if "gfp" in valid: + combine_dict["gfp"] = lambda data: data.std(axis=axis, ddof=0) try: combine = combine_dict[combine] except KeyError: