diff --git a/doc/changes/devel/12910.newfeature.rst b/doc/changes/devel/12910.newfeature.rst new file mode 100644 index 00000000000..d4af832923a --- /dev/null +++ b/doc/changes/devel/12910.newfeature.rst @@ -0,0 +1,3 @@ +Added the option to return taper weights from +:func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the +:class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 73a3308685d..0fa48db49d7 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -471,6 +471,7 @@ def tfr_array_multitaper( output="complex", n_jobs=None, *, + return_weights=False, verbose=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. @@ -502,8 +503,14 @@ def tfr_array_multitaper( * ``'itc'`` : inter-trial coherence. * ``'avg_power_itc'`` : average of single trial power and inter-trial coherence across trials. + %(n_jobs)s The parallelization is implemented across channels. + return_weights : bool, default False + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. + + .. versionadded:: 1.X.0 %(verbose)s Returns @@ -520,6 +527,9 @@ def tfr_array_multitaper( If ``output`` is ``'avg_power_itc'``, the real values in ``out`` contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and + ``return_weights=True``. See Also -------- @@ -550,6 +560,7 @@ def tfr_array_multitaper( use_fft=use_fft, decim=decim, output=output, + return_weights=return_weights, n_jobs=n_jobs, verbose=verbose, ) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index cd3a97ab90a..62db87f3a83 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -432,17 +432,21 @@ def test_tfr_morlet(): def test_dpsswavelet(): """Test DPSS tapers.""" freqs = np.arange(5, 25, 3) - Ws = _make_dpss( - 1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True + Ws, weights = _make_dpss( + 1000, + freqs=freqs, + n_cycles=freqs / 2.0, + time_bandwidth=4.0, + zero_mean=True, + return_weights=True, ) - assert len(Ws) == 3 # 3 tapers expected + assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected + assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs) # Check that zero mean is true assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5 - assert len(Ws[0]) == len(freqs) # As many wavelets as asked for - @pytest.mark.slowtest def test_tfr_multitaper(): @@ -664,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): with tfr.info._unlock(): tfr.info["meas_date"] = want assert tfr_loaded == tfr + # test with taper dimension and weights + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs + state = tfr.__getstate__() + state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim + state["weights"] = weights # add weights + state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims + tfr = EpochsTFR(inst=state) + tfr.save(fname, overwrite=True) + tfr_loaded = read_tfrs(fname) + assert tfr_loaded == tfr # test overwrite with pytest.raises(OSError, match="Destination file exists."): tfr.save(fname, overwrite=False) @@ -722,17 +737,31 @@ def test_average_tfr_init(full_evoked): AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) -def test_epochstfr_init_errors(epochs_tfr): - """Test __init__ for EpochsTFR.""" - state = epochs_tfr.__getstate__() - with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) +@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) +def test_tfr_init_errors(inst, request, average_tfr): + """Test __init__ for {Raw,Epochs,Average}TFR.""" + # Load data + inst = _get_inst(inst, request, average_tfr=average_tfr) + state = inst.__getstate__() + # Prepare for TFRArray object instantiation + inst_name = inst.__class__.__name__ + class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR) + ndims_mapping = dict( + RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D") + ) + TFR = class_mapping[inst_name] + allowed_ndims = ndims_mapping[inst_name] + # Check errors caught + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=inst.data[..., 0])) + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1)))) with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) + TFR(inst=state | dict(data=inst.data[..., :-1, :, :])) with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): - EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) + TFR(inst=state | dict(times=inst.times[:-1])) with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): - EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) + TFR(inst=state | dict(freqs=inst.freqs[:-1])) @pytest.mark.parametrize( @@ -1154,6 +1183,15 @@ def test_averaging_epochsTFR(): ): power.average(method=np.mean) + # Check it doesn't run for taper spectra + tapered = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex" + ) + with pytest.raises( + NotImplementedError, match=r"Averaging multitaper tapers .* is not supported." + ): + tapered.average() + def test_averaging_freqsandtimes_epochsTFR(): """Test that EpochsTFR averaging freqs methods work.""" @@ -1254,12 +1292,15 @@ def test_to_data_frame(): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) srate = 1000.0 - freqs = np.arange(5) + freqs = np.arange(n_freqs) + tapers = np.arange(n_tapers) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 5 + n_epos) @@ -1272,6 +1313,7 @@ def test_to_data_frame(): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) # test index checking with pytest.raises(ValueError, match="options. Valid index options are"): @@ -1283,10 +1325,21 @@ def test_to_data_frame(): # test wide format df_wide = tfr.to_data_frame() assert all(np.isin(tfr.ch_names, df_wide.columns)) - assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns)) + assert all( + np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns) + ) # test long format df_long = tfr.to_data_frame(long_format=True) - expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value") + expected = ( + "condition", + "epoch", + "freq", + "time", + "channel", + "ch_type", + "value", + "taper", + ) assert set(expected) == set(df_long.columns) assert set(tfr.ch_names) == set(df_long["channel"]) assert len(df_long) == tfr.data.size @@ -1294,21 +1347,29 @@ def test_to_data_frame(): df_long = tfr.to_data_frame(long_format=True, index=["freq"]) del df_wide, df_long # test whether data is in correct shape - df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"]) + df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"]) data = tfr.data assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze()) # compare arbitrary observation: assert ( - df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0] - == data[1, 3, 1, 2] + df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0] + == data[1, 3, 1, 1, 2] ) # Check also for AverageTFR: + # (remove taper dimension before averaging) + state = tfr.__getstate__() + state["data"] = state["data"][:, :, 0] + state["dims"] = ("epoch", "channel", "freq", "time") + state["weights"] = None + tfr = EpochsTFR(inst=state) tfr = tfr.average() with pytest.raises(ValueError, match="options. Valid index options are"): tfr.to_data_frame(index=["epoch", "condition"]) with pytest.raises(ValueError, match='"epoch" is not a valid option'): tfr.to_data_frame(index="epoch") + with pytest.raises(ValueError, match='"taper" is not a valid option'): + tfr.to_data_frame(index="taper") with pytest.raises(TypeError, match="index must be `None` or a string "): tfr.to_data_frame(index=np.arange(400)) # test wide format @@ -1344,11 +1405,13 @@ def test_to_data_frame_index(index): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) - freqs = np.arange(5) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 8) @@ -1361,6 +1424,7 @@ def test_to_data_frame_index(index): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) df = tfr.to_data_frame(picks=[0, 2, 3], index=index) # test index order/hierarchy preservation @@ -1368,7 +1432,7 @@ def test_to_data_frame_index(index): index = [index] assert list(df.index.names) == index # test that non-indexed data were present as columns - non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index)) + non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index)) if len(non_index): assert all(np.isin(non_index, df.columns)) @@ -1534,7 +1598,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): """Test Epochs.compute_tfr(output="complex"/"phase").""" tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) - assert len(tfr.shape) == 5 + assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time + assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match @pytest.mark.parametrize("copy", (False, True)) @@ -1546,6 +1611,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): assert avgs[0].comment == str(epochs_tfr.events[0, -1]) +@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) +def test_tfrarray_tapered_spectra(inst, evoked, request): + """Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra.""" + # Load data object + inst = _get_inst(inst, request, evoked=evoked) + inst.pick("mag") + # Compute TFR with taper dimension (can be complex or phase output) + tfr = inst.compute_tfr( + method="multitaper", freqs=freqs_linspace, n_cycles=4, output="complex" + ) + tfr_array, weights = tfr.get_data(), tfr.weights + # Prepare for TFRArray object instantiation + defaults = dict( + info=inst.info, data=tfr_array, times=inst.times, freqs=freqs_linspace + ) + class_mapping = dict(Raw=RawTFRArray, Epochs=EpochsTFRArray, Evoked=AverageTFRArray) + TFRArray = class_mapping[inst.__class__.__name__] + # Check TFRArray instantiation runs with good data + TFRArray(**defaults, weights=weights) + # Check taper dimension but no weights caught + with pytest.raises( + ValueError, match="Taper dimension in data, but no weights found." + ): + TFRArray(**defaults) + # Check mismatching n_taper in weights caught + with pytest.raises( + ValueError, match=r"Taper axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:-1]) + # Check mismatching n_freq in weights caught + with pytest.raises( + ValueError, match=r"Frequency axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:, :-1]) + + def test_tfr_proj(epochs): """Test `compute_tfr(proj=True)`.""" epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) @@ -1727,3 +1828,39 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): assert re.match( rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() ) + + +def test_combine_tfr_error_catch(request, average_tfr): + """Test combine_tfr() catches errors.""" + # check unrecognised weights string caught + with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'): + combine_tfr([average_tfr, average_tfr], weights="foo") + # check bad weights size caught + with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"): + combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1]) + # check different channel names caught + state = average_tfr.__getstate__() + new_info = average_tfr.info.copy() + average_tfr_bad = AverageTFR( + inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"})) + ) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_tfr([average_tfr, average_tfr_bad]) + # check different times caught + average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1)) + with pytest.raises( + AssertionError, match=".* do not contain the same time instants" + ): + combine_tfr([average_tfr, average_tfr_bad]) + # check taper dim caught + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs + state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1) + state["weights"] = weights + state["dims"] = ("channel", "taper", "freq", "time") + average_tfr_taper = AverageTFR(inst=state) + with pytest.raises( + NotImplementedError, + match="Aggregating multitaper tapers across TFR datasets is not supported.", + ): + combine_tfr([average_tfr_taper, average_tfr_taper]) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index eaf173092bb..49bc15d8833 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -264,8 +264,11 @@ def _make_dpss( ------- Ws : list of array The wavelets time series. + Cs : list of array + The concentration weights. Only returned if return_weights=True. """ Ws = list() + Cs = list() freqs = np.array(freqs) if np.any(freqs <= 0): @@ -281,6 +284,7 @@ def _make_dpss( for m in range(n_taps): Wm = list() + Cm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -302,12 +306,15 @@ def _make_dpss( real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) + Ck = np.sqrt(conc[m]) Wm.append(Wk) + Cm.append(Ck) Ws.append(Wm) + Cs.append(Cm) if return_weights: - return Ws, conc + return Ws, Cs return Ws @@ -428,6 +435,7 @@ def _compute_tfr( use_fft=True, decim=1, output="complex", + return_weights=False, n_jobs=None, *, verbose=None, @@ -479,6 +487,9 @@ def _compute_tfr( * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. + return_weights : bool, default False + Whether to return the taper weights. Only applies if method='multitaper' and + output='complex' or 'phase'. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. @@ -495,6 +506,9 @@ def _compute_tfr( n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if method='multitaper', output='complex' or + 'phase', and return_weights=True. """ # Check data epoch_data = np.asarray(epoch_data) @@ -516,6 +530,9 @@ def _compute_tfr( decim, output, ) + return_weights = ( + return_weights and method == "multitaper" and output in ["complex", "phase"] + ) decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): @@ -531,13 +548,18 @@ def _compute_tfr( Ws = [W] # to have same dimensionality as the 'multitaper' case elif method == "multitaper": - Ws = _make_dpss( + out = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, + return_weights=return_weights, ) + if return_weights: + Ws, weights = out + else: + Ws = out # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -561,6 +583,8 @@ def _compute_tfr( out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) + if return_weights: + weights = np.array(weights) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -585,6 +609,9 @@ def _compute_tfr( out = out.transpose(2, 0, 1, 3, 4) else: out = out.transpose(1, 0, 2, 3) + + if return_weights: + return out, weights return out @@ -1187,9 +1214,6 @@ def __init__( f'{classname} got unsupported parameter value{_pl(problem)} ' f'{" and ".join(problem)}.' ) - # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) - if method == "morlet": - method_kw.setdefault("zero_mean", True) # check method valid_methods = ["morlet", "multitaper"] if isinstance(inst, BaseEpochs): @@ -1203,6 +1227,9 @@ def __init__( method_kw.setdefault("output", "power") self._freqs = np.asarray(freqs, dtype=np.float64) del freqs + # always store weights for per-taper outputs + if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: + method_kw["return_weights"] = True # check validity of kwargs manually to save compute time if any are invalid tfr_funcs = dict( morlet=tfr_array_morlet, @@ -1224,6 +1251,7 @@ def __init__( self._method = method self._inst_type = type(inst) self._baseline = None + self._weights = None self.preload = True # needed for __getitem__, never False for TFRs # self._dims may also get updated by child classes self._dims = ["channel", "freq", "time"] @@ -1382,6 +1410,7 @@ def __getstate__(self): info=self.info, baseline=self._baseline, decim=self._decim, + weights=self._weights, ) def __setstate__(self, state): @@ -1392,7 +1421,6 @@ def __setstate__(self, state): defaults = dict( method="unknown", - dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], baseline=None, decim=1, data_type="TFR", @@ -1410,12 +1438,13 @@ def __setstate__(self, state): self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) + self._weights = state.get("weights") # objs saved before #12910 won't have # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] - # sanity check data/freqs/times/info agreement + # sanity check data/freqs/times/info/weights agreement self._check_state() def __repr__(self): @@ -1468,18 +1497,29 @@ def _check_compatibility(self, other): raise RuntimeError(msg.format(problem, extra)) def _check_state(self): - """Check data/freqs/times/info agreement during __setstate__.""" + """Check data/freqs/times/info/weights agreement during __setstate__.""" msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] n_freq = self._data.shape[self._dims.index("freq")] n_time = self._data.shape[self._dims.index("time")] + n_taper = ( + self._data.shape[self._dims.index("taper")] + if "taper" in self._dims + else None + ) + if n_taper is not None and self._weights is None: + raise ValueError("Taper dimension in data, but no weights found.") if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): msg = msg.format("Time", n_time, "times", self.times.size) + elif n_taper is not None and n_taper != self._weights.shape[0]: + msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) + elif n_taper is not None and n_freq != self._weights.shape[1]: + msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) else: return raise ValueError(msg) @@ -1516,6 +1556,10 @@ def _compute_tfr(self, data, n_jobs, verbose): if self.method == "stockwell": self._data, self._itc, freqs = result assert np.array_equal(self._freqs, freqs) + elif self.method == "multitaper" and self._tfr_func.keywords.get( + "output", "" + ) in ["complex", "phase"]: + self._data, self._weights = result elif self._tfr_func.keywords.get("output", "").endswith("_itc"): self._data, self._itc = result.real, result.imag else: @@ -1694,6 +1738,11 @@ def times(self): """The time points present in the data (in seconds).""" return self._times_readonly + @property + def weights(self): + """The weights used for each taper in the time-frequency estimates.""" + return self._weights + @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): """Crop data to a given time interval in place. @@ -1788,6 +1837,7 @@ def get_data( tmax=None, return_times=False, return_freqs=False, + return_tapers=False, ): """Get time-frequency data in NumPy array format. @@ -1803,6 +1853,10 @@ def get_data( return_freqs : bool Whether to return the frequency bin values for the requested frequency range. Default is ``False``. + return_tapers : bool + Whether to return the taper numbers. Default is ``False``. + + .. versionadded:: 1.X.0 Returns ------- @@ -1814,6 +1868,9 @@ def get_data( freqs : array The frequency values for the requested data range. Only returned if ``return_freqs`` is ``True``. + tapers : array | None + The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be + ``None`` if a taper dimension is not present in the data. Notes ----- @@ -1851,7 +1908,13 @@ def get_data( if return_freqs: freqs = self._freqs[fmin_idx:fmax_idx] out.append(freqs) - if not return_times and not return_freqs: + if return_tapers: + if "taper" in self._dims: + tapers = np.arange(self.shape[self._dims.index("taper")]) + else: + tapers = None + out.append(tapers) + if not return_times and not return_freqs and not return_tapers: return out[0] return tuple(out) @@ -2627,21 +2690,21 @@ def to_data_frame( ): """Export data in tabular structure as a pandas DataFrame. - Channels are converted to columns in the DataFrame. By default, - additional columns ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` (epoch event description) are added, unless ``index`` - is not ``None`` (in which case the columns specified in ``index`` will - be used to form the DataFrame's index instead). ``'epoch'``, and - ``'condition'`` are not supported for ``AverageTFR``. + Channels are converted to columns in the DataFrame. By default, additional + columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'`` + (epoch event description) are added, unless ``index`` is not ``None`` (in which + case the columns specified in ``index`` will be used to form the DataFrame's + index instead). ``'epoch'``, and ``'condition'`` are not supported for + ``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is + present, such as for complex or phase multitaper data. Parameters ---------- %(picks_all)s %(index_df_epo)s - Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` - for ``AverageTFR``. - Defaults to ``None``. + Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, + and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and + ``'taper'`` for ``AverageTFR``. Defaults to ``None``. %(long_format_df_epo)s %(time_format_df)s @@ -2654,42 +2717,58 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = isinstance(self, EpochsTFR) + unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): + if from_epo: valid_index_args.extend(["epoch", "condition"]) + if unagg_mt: + valid_index_args.append("taper") valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) - data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - axis = self._dims.index("channel") - if not isinstance(self, EpochsTFR): + data, times, freqs, tapers = self.get_data( + picks, return_times=True, return_freqs=True, return_tapers=True + ) + ch_axis = self._dims.index("channel") + if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis - axis += 1 - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, axis, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + ch_axis += 1 + if not unagg_mt: + data = np.expand_dims(data, -3) # add singleton "tapers" axis + n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape + # reshape to (epochs*tapers*freqs*times) x signals + data = np.moveaxis(data, ch_axis, -1) + data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() + default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs) - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + times = np.tile(times, n_epochs * n_freqs * n_tapers) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + if from_epo: + mindex.append( + ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) + ) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + mindex.append( + ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) + ) + default_index.extend(["condition", "epoch"]) + if unagg_mt: + tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times) + mindex.append(("taper", tapers)) + default_index.append("taper") + default_index.extend(["freq", "time"]) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) @@ -2736,6 +2815,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2852,6 +2932,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -2895,6 +2984,7 @@ class AverageTFRArray(AverageTFR): The number of averaged TFRs. %(comment_averagetfr_attr)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -2907,6 +2997,7 @@ class AverageTFRArray(AverageTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2917,12 +3008,22 @@ class AverageTFRArray(AverageTFR): """ def __init__( - self, info, data, times, freqs, *, nave=None, comment=None, method=None + self, + info, + data, + times, + freqs, + *, + nave=None, + comment=None, + method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - for name, optional in dict(nave=nave, comment=comment, method=method).items(): - if optional is not None: - state[name] = optional + optional = dict(nave=nave, comment=comment, method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) @@ -3007,6 +3108,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3091,8 +3193,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack EpochsTFR from serialized format.""" - if state["data"].ndim != 4: - raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + if state["data"].ndim not in [4, 5]: + raise ValueError( + f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("epoch", "channel") + if state["data"].ndim == 5: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._metadata = state.get("metadata", None) n_epochs = self.shape[0] @@ -3196,7 +3305,16 @@ def average(self, method="mean", *, dim="epochs", copy=False): See discussion here: https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + + Averaging is not supported for data containing a taper dimension. """ + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs, frequencies, or times is " + "not supported. If averaging across epochs, consider averaging the " + "epochs before computing the complex/phase spectrum." + ) + _check_option("dim", dim, ("epochs", "freqs", "times")) axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural @@ -3568,6 +3686,7 @@ class EpochsTFRArray(EpochsTFR): %(selection)s %(drop_log)s %(metadata_epochstfr)s + %(weights_tfr_array)s Attributes ---------- @@ -3584,6 +3703,7 @@ class EpochsTFRArray(EpochsTFR): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3606,6 +3726,7 @@ def __init__( selection=None, drop_log=None, metadata=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) optional = dict( @@ -3616,6 +3737,7 @@ def __init__( selection=selection, drop_log=drop_log, metadata=metadata, + weights=weights, ) for name, value in optional.items(): if value is not None: @@ -3658,6 +3780,7 @@ class RawTFR(BaseTFR): method : str The method used to compute the spectra (``'morlet'``, ``'multitaper'`` or ``'stockwell'``). + %(weights_tfr_attr)s See Also -------- @@ -3707,6 +3830,19 @@ def __init__( **method_kw, ) + def __setstate__(self, state): + """Unpack RawTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") + super().__setstate__(state) + def __getitem__(self, item): """Get RawTFR data. @@ -3772,6 +3908,7 @@ class RawTFRArray(RawTFR): %(times)s %(freqs_tfr_array)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -3782,6 +3919,7 @@ class RawTFRArray(RawTFR): %(method_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3799,10 +3937,13 @@ def __init__( freqs, *, method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - if method is not None: - state["method"] = method + optional = dict(method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) @@ -3830,8 +3971,16 @@ def combine_tfr(all_tfr, weights="nave"): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ + if any("taper" in tfr._dims for tfr in all_tfr): + raise NotImplementedError( + "Aggregating multitaper tapers across TFR datasets is not supported." + ) + tfr = all_tfr[0].copy() if isinstance(weights, str): if weights not in ("nave", "equal"): diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 2ff6984dce9..5ff7b480ddc 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5008,6 +5008,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ +docdict["weights_tfr_array"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights for each taper. Must be provided if ``data`` has a taper dimension, such + as for complex or phase multitaper data. + + .. versionadded:: 1.X.0 +""" +docdict["weights_tfr_attr"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights used for each taper in the time-frequency estimates. +""" + docdict["window_psd"] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index c287fb42305..4bf8d094f81 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -550,6 +550,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ # check if all elements in the given list are evoked data