diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 1829bb1bb98..52ebcd2c1f3 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -2785,6 +2785,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2901,10 +2902,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" - if state["data"].ndim != 3: - raise ValueError(f"RawTFR data should be 3D, got {state['data'].ndim}.") + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) # Set dims now since optional tapers makes it difficult to disentangle later - state["dims"] = ("channel", "freq", "time") + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -2948,6 +2954,7 @@ class AverageTFRArray(AverageTFR): The number of averaged TFRs. %(comment_averagetfr_attr)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -2960,6 +2967,7 @@ class AverageTFRArray(AverageTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2970,12 +2978,22 @@ class AverageTFRArray(AverageTFR): """ def __init__( - self, info, data, times, freqs, *, nave=None, comment=None, method=None + self, + info, + data, + times, + freqs, + *, + nave=None, + comment=None, + method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - for name, optional in dict(nave=nave, comment=comment, method=method).items(): - if optional is not None: - state[name] = optional + optional = dict(nave=nave, comment=comment, method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state)