Skip to content

Commit

Permalink
Add weights to AverageTFR
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Dec 10, 2024
1 parent b14a100 commit 972aba2
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,7 @@ class AverageTFR(BaseTFR):
%(nave_tfr_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
%(weights_tfr_attr)s
See Also
--------
Expand Down Expand Up @@ -2901,10 +2902,15 @@ def __getstate__(self):

def __setstate__(self, state):
"""Unpack AverageTFR from serialized format."""
if state["data"].ndim != 3:
raise ValueError(f"RawTFR data should be 3D, got {state['data'].ndim}.")
if state["data"].ndim not in [3, 4]:
raise ValueError(
f"RawTFR data should be 3D or 4D, got {state['data'].ndim}."
)
# Set dims now since optional tapers makes it difficult to disentangle later
state["dims"] = ("channel", "freq", "time")
state["dims"] = ("channel",)
if state["data"].ndim == 4:
state["dims"] += ("taper",)
state["dims"] += ("freq", "time")
super().__setstate__(state)
self._comment = state.get("comment", "")
self._nave = state.get("nave", 1)
Expand Down Expand Up @@ -2948,6 +2954,7 @@ class AverageTFRArray(AverageTFR):
The number of averaged TFRs.
%(comment_averagetfr_attr)s
%(method_tfr_array)s
%(weights_tfr_array)s
Attributes
----------
Expand All @@ -2960,6 +2967,7 @@ class AverageTFRArray(AverageTFR):
%(nave_tfr_attr)s
%(sfreq_tfr_attr)s
%(shape_tfr_attr)s
%(weights_tfr_attr)s
See Also
--------
Expand All @@ -2970,12 +2978,22 @@ class AverageTFRArray(AverageTFR):
"""

def __init__(
self, info, data, times, freqs, *, nave=None, comment=None, method=None
self,
info,
data,
times,
freqs,
*,
nave=None,
comment=None,
method=None,
weights=None,
):
state = dict(info=info, data=data, times=times, freqs=freqs)
for name, optional in dict(nave=nave, comment=comment, method=method).items():
if optional is not None:
state[name] = optional
optional = dict(nave=nave, comment=comment, method=method, weights=weights)
for name, value in optional.items():
if value is not None:
state[name] = value
self.__setstate__(state)


Expand Down

0 comments on commit 972aba2

Please sign in to comment.