-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Add option to store and return TFR taper weights #12910
base: main
Are you sure you want to change the base?
Conversation
@@ -302,12 +306,15 @@ def _make_dpss( | |||
real_offset = Wk.mean() | |||
Wk -= real_offset | |||
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) | |||
Ck = np.sqrt(conc[m]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This I am somewhat unsure on. The existing implementation is to just use conc
as-is, however in the MNE-Connectivity implementation that sqrt is taken: https://github.com/mne-tools/mne-connectivity/blob/97147a57eefb36a5c9680e539fdc6343a1183f20/mne_connectivity/spectral/time.py#L825
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also somewhat confused about the design of the mne-python/mne/time_frequency/tfr.py Lines 285 to 315 in 82fc2f7
It is looping over tapers, and then over frequencies. However, the Would it not be more efficient to only loop over frequencies and take advantage of the fact that this will also return information for each taper? |
I also have a question regarding testing: for the I/O tests, we're reading Apart from this there are still some tests I need to expand. |
@@ -302,12 +306,15 @@ def _make_dpss( | |||
real_offset = Wk.mean() | |||
Wk -= real_offset | |||
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) | |||
Ck = np.sqrt(conc[m]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @drammock! I will sort out those remaining tests, although I'm in the process of moving at the moment so it might not be for some days. Regarding those issues I came across with TFR multitapers and converting to dataframes / plotting: would you like me to incorporate that into this PR? |
Sorry for the lack of work on this, had to organise things for my PhD defence. Everything new added here has test coverage now. @drammock, just a couple points I would appreciate your input on:
Also tagging @larsoner and @ruuskas in case they can help clarify an outstanding point: #12910 (comment) |
Currently working on support for a tapers dimension in |
Yes I think we should. most (all?) of them are created by pytest fixtures at present. I see 3 options:
To really test thoroughly, option (2) is probably best, because then you can also patch in things that are expected to fail, and test that they do fail in the expected way. |
@@ -1392,7 +1421,6 @@ def __setstate__(self, state): | |||
|
|||
defaults = dict( | |||
method="unknown", | |||
dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have removed dims
being set in BaseTFR
since the possibility of the optional epoch
and taper
dimensions makes it really difficult to disentangle here. It's much easier to handle this in the individual RawTFR
, EpochsTFR
, and AverageTFR
classes.
# Set dims now since optional tapers makes it difficult to disentangle later | ||
state["dims"] = ("channel",) | ||
if state["data"].ndim == 4: | ||
state["dims"] += ("taper",) | ||
state["dims"] += ("freq", "time") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example of handling dims in the AverageTFR
class where only one dimension (taper
) is optional.
|
||
Averaging is not supported for data containing a taper dimension. | ||
""" | ||
if "taper" in self._dims: | ||
raise NotImplementedError( | ||
"Averaging multitaper tapers across epochs, frequencies, or times is " | ||
"not supported. If averaging across epochs, consider averaging the " | ||
"epochs before computing the complex/phase spectrum." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In terms of averaging for data with tapers, I went for the same approach we're using for Spectrum
and just disallowing this.
I don't think this is an API change requiring a deprecation cycle since:
- the docstring expects the data to not have a taper dimension, e.g.
If callable, must take a NumPy array of shape (n_epochs, n_channels, n_freqs, n_times)
. - trying to call this method on an object with a taper dimension would raise an uncaught error:
n_epochs, n_channels, n_freqs, n_times = self.data.shape
(wouldn't be able to unpack this properly).
So explicitly preventing this method being called with a taper dimension doesn't change current behaviour, it just gives a nicer error as to why this can't be done.
Notes | ||
----- | ||
Aggregating multitaper TFR datasets with a taper dimension such as for complex or | ||
phase data is not supported. | ||
|
||
.. versionadded:: 0.11.0 | ||
""" | ||
if any("taper" in tfr._dims for tfr in all_tfr): | ||
raise NotImplementedError( | ||
"Aggregating multitaper tapers across TFR datasets is not supported." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a similar case to averaging for the time_frequency.combine_tfr()
function (which also gets called by the grand_average()
function).
However, unlike the EpochsTFR.average()
method, this could be considered an API change since combine_tfr()
should currently run with taper data. Does preventing this use case require a deprecation cycle?
On a side note, I noticed that while a public function, combine_tfr()
is not listed in the API (the equivalent combine_evoked()
is). Is this an oversight or an intended omission?
Reference issue (if any)
PR for #12851
What does this implement/fix?
Adds an option to return taper weights for complex and phase outputs of the multitaper method in
tfr_array_multitaper()
, and also ensures taper weights are stored inTFR
objects.Additional information
When working on this, I discovered a couple of other issues with the per-taper TFR implementations (#12851 (comment)), including the fact that the
TFR
object plotting methods andto_data_frame
methods do not account for a taper dimension, leading to errors. Wasn't sure if people want me to also address these here or in a separate PR.