Skip to content

Commit

Permalink
fix all the bugs [circle full]
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Jan 31, 2024
1 parent f1d66f1 commit ca729a1
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 72 deletions.
2 changes: 1 addition & 1 deletion mne/beamformer/tests/test_dics.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def test_apply_dics_tfr(return_generator):
data = rng.random((n_epochs, n_chans, len(freqs), n_times))
data *= 1e-6
data = data + data * 1j # add imag. component to simulate phase
epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs)
epochs_tfr = EpochsTFR(dict(info=info, data=data, times=times, freqs=freqs))

# Create a DICS beamformer and convert the EpochsTFR to source space.
csd = csd_tfr(epochs_tfr)
Expand Down
3 changes: 0 additions & 3 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,6 @@ def drop_channels(self, ch_names, on_missing="raise"):
def _pick_drop_channels(self, idx, *, verbose=None):
# avoid circular imports
from ..io import BaseRaw
from ..time_frequency import AverageTFR, EpochsTFR

msg = "adding, dropping, or reordering channels"
if isinstance(self, BaseRaw):
Expand All @@ -634,8 +633,6 @@ def _pick_drop_channels(self, idx, *, verbose=None):

if hasattr(self, "_dims"): # Spectrum and "new-style" TFRs
axis = self._dims.index("channel")
elif isinstance(self, (AverageTFR, EpochsTFR)): # "old-style" TFRs
axis = -3
else: # All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -2
if hasattr(self, "_data"): # skip non-preloaded Raw
Expand Down
3 changes: 2 additions & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2602,8 +2602,9 @@ def compute_tfr(
state["data"] = out._itc
state["data_type"] = "Inter-trial coherence"
itc = AverageTFR(state, method=None, freqs=None)
del out._itc
return out, itc
del out._itc # if it's None, don't keep it around
del out._itc
return out
# now handle average=False
return EpochsTFR(
Expand Down
6 changes: 6 additions & 0 deletions mne/html_templates/repr/tfr.html.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
<td>{{ tfr.shape[0] }}</td>
</tr>
{% endif -%}
{%- inst_type == "Evoked" %}
<tr>
<th>Number of averaged trials</th>
<td>{{ nave }}</td>
</tr>
{% endif -%}
<tr>
<th>Dims</th>
<td>{{ tfr._dims | join(", ") }}</td>
Expand Down
2 changes: 1 addition & 1 deletion mne/minimum_norm/tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ def test_apply_inverse_tfr(return_generator):
times = np.arange(sfreq) / sfreq # make epochs 1s long
data = rng.random((n_epochs, len(info.ch_names), freqs.size, times.size))
data = data + 1j * data # make complex to simulate amplitude + phase
epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs)
epochs_tfr = EpochsTFR(dict(info=info, data=data, times=times, freqs=freqs))
epochs_tfr.apply_baseline((0, 0.5))
pick_ori = "vector"

Expand Down
25 changes: 18 additions & 7 deletions mne/time_frequency/_stockwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,28 @@ def tfr_stockwell(
)
times = inst.times[decim].copy()
nave = len(data)
out = AverageTFR(info, power, times, freqs, nave, method="stockwell-power")
out = AverageTFR(
dict(
info=info,
data=power,
times=times,
freqs=freqs,
nave=nave,
method="stockwell-power",
)
)
if return_itc:
out = (
out,
AverageTFR(
deepcopy(info),
itc,
times.copy(),
freqs.copy(),
nave,
method="stockwell-itc",
dict(
info=deepcopy(info),
data=itc,
times=times.copy(),
freqs=freqs.copy(),
nave=nave,
method="stockwell-itc",
)
),
)
return out
10 changes: 5 additions & 5 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ..viz.utils import (
_format_units_psd,
_get_plot_ch_type,
_make_combine_callable,
_plot_psd,
_prepare_sensor_names,
plt_show,
Expand Down Expand Up @@ -1395,12 +1396,11 @@ def average(self, method="mean"):
spectrum : instance of Spectrum
The aggregated spectrum object.
"""
# TODO: we probably should avoid `np.median` here when data are complex
# (like we do in EpochsTFR.average)
# TODO: probably should have a `.nave` attribute?
if isinstance(method, str):
method = getattr(np, method) # mean, median, std, etc
method = partial(method, axis=0)
_validate_type(method, ("str", "callable"))
method = _make_combine_callable(
method, axis=0, valid=("mean", "median"), keepdims=False
)
if not callable(method):
raise ValueError(
'"method" must be a valid string or callable, '
Expand Down
1 change: 1 addition & 0 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,7 @@ def test_tfr_arithmetic(epochs):
@parametrize_tfr_inst
def test_tfr_save_load(inst, average_tfr, request, tmp_path):
"""Test TFR I/O."""
pytest.importorskip("h5io")
tfr = _get_inst(inst, request, average_tfr=average_tfr)
fname = tmp_path / "temp_tfr.hdf5"
tfr.save(fname, overwrite=True)
Expand Down
91 changes: 46 additions & 45 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,19 @@ def __init__(
for k, v in dict(method=method, freqs=freqs).items()
if v is None
]
class_name = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0]
# TODO when py3.11 is min version, replace if/elif/else block with
# classname = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0]
_varnames = inspect.currentframe().f_back.f_code.co_varnames
if "BaseRaw" in _varnames:
classname = "RawTFR"
elif "Evoked" in _varnames:
classname = "AverageTFR"
else:
assert "BaseEpochs" in _varnames and "Evoked" not in _varnames
classname = "EpochsTFR"
# end TODO
raise ValueError(
f'{class_name} got unsupported parameter value{_pl(problem)} '
f'{classname} got unsupported parameter value{_pl(problem)} '
f'{" and ".join(problem)}.'
)
# check method
Expand Down Expand Up @@ -1218,23 +1228,16 @@ def __init__(
self.preload = True # needed for __getitem__, never False
self._method = method
# self._dims may also get updated by child classes
self._dims = (
"channel",
"freq",
"time",
)
# get the instance data. `_time_mask` is later deleted inside _compute_tfr,
# when decim is applied.
self._time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq)
get_instance_data_kw = (
dict()
if reject_by_annotation is None
else dict(reject_by_annotation=reject_by_annotation)
)
self._dims = ("channel", "freq", "time")
# get the instance data.
time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq)
get_instance_data_kw = dict(time_mask=time_mask)
if reject_by_annotation is not None:
get_instance_data_kw.update(reject_by_annotation=reject_by_annotation)
data = self._get_instance_data(**get_instance_data_kw)
# compute the TFR
self._decim = _check_decim(decim)
self._raw_times = inst.times[self._time_mask]
self._raw_times = inst.times[time_mask]
self._compute_tfr(data, n_jobs, verbose)
self._update_epoch_attributes()
# "apply" decim to the rest of the object (data is decimated in _compute_tfr)
Expand All @@ -1250,7 +1253,6 @@ def __init__(
# we don't need these anymore, and they make save/load harder
del self._picks
del self._tfr_func
del self._time_mask
del self._shape # calculated from self._data henceforth
del self.inst # save memory

Expand Down Expand Up @@ -1392,18 +1394,7 @@ def __setstate__(self, state):
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class)
self._inst_type = inst_types[defaults["inst_type_str"]]
# sanity check data/freqs/times/info agreement
msg = "{} axis of data ({}) doesn't match {} attribute ({})"
n_chan_info = len(self.info["chs"])
n_chan, n_freq, n_time = self._data.shape[self._dims.index("channel") :]
if n_chan_info != n_chan:
msg = msg.format("Channel", n_chan, "info", n_chan_info)
elif n_freq != self.freqs.size:
msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size)
elif n_time != self.times.size:
msg = msg.format("Time", n_time, "times", self.times.size)
else:
return
raise ValueError(msg)
self._check_state()

def __repr__(self):
"""Build string representation of the TFR object."""
Expand All @@ -1427,9 +1418,10 @@ def _repr_html_(self, caption=None):
from ..html_templates import repr_templates_env

inst_type_str = _get_instance_type_string(self)
nave = getattr(self, "nave", 0)
units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()]
t = repr_templates_env.get_template("tfr.html.jinja")
t = t.render(tfr=self, inst_type=inst_type_str, units=units)
t = t.render(tfr=self, inst_type=inst_type_str, units=units, nave=nave)
return t

def _check_compatibility(self, other):
Expand All @@ -1454,6 +1446,21 @@ def _check_compatibility(self, other):
return
raise RuntimeError(msg.format(problem, extra))

def _check_state(self):
"""Check data/freqs/times/info agreement during __setstate__."""
msg = "{} axis of data ({}) doesn't match {} attribute ({})"
n_chan_info = len(self.info["chs"])
n_chan, n_freq, n_time = self._data.shape[self._dims.index("channel") :]
if n_chan_info != n_chan:
msg = msg.format("Channel", n_chan, "info", n_chan_info)
elif n_freq != self.freqs.size:
msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size)
elif n_time != self.times.size:
msg = msg.format("Time", n_time, "times", self.times.size)
else:
return
raise ValueError(msg)

def _check_values(self, wants_complex=False):
"""Check TFR results for correct shape and bad values."""
assert len(self._dims) == self._data.ndim
Expand Down Expand Up @@ -1897,13 +1904,6 @@ def plot(
figs : list of instances of matplotlib.figure.Figure
A list of figures containing the time-frequency power.
"""
# triage EpochsTFR
if isinstance(self, EpochsTFR) and self.shape[0] > 1:
raise NotImplementedError(
"Plotting EpochsTFR objects containing more than one epoch is not "
"supported; either plot an average `EpochsTFR.average().plot()` or "
"plot individual epochs `EpochsTFR[0].plot()`."
)
# the rectangle selector plots topomaps, which needs all channels uncombined,
# so we keep a reference to that state here, and update it with `comment` and
# `nave` values in case we started out with a singleton EpochsTFR or RawTFR
Expand Down Expand Up @@ -2784,11 +2784,11 @@ def nave(self):
def nave(self, nave):
self._nave = nave

def _get_instance_data(self):
def _get_instance_data(self, time_mask):
# AverageTFRs can be constructed from Epochs data, so we triage shape here.
# Evoked data get a fake singleton "epoch" axis prepended
dim = slice(None) if _get_instance_type_string(self) == "Epochs" else np.newaxis
data = self.inst.get_data(picks=self._picks)[dim, :, self._time_mask]
data = self.inst.get_data(picks=self._picks)[dim, :, time_mask]
self._nave = getattr(self.inst, "nave", data.shape[0])
return data

Expand Down Expand Up @@ -2957,14 +2957,15 @@ def __next__(self, return_event_id=False):
def _check_singleton(self):
"""Check if self contains only one Epoch, and return it as an AverageTFR."""
if self.shape[0] > 1:
calling_func = inspect.currentframe().f_back.f_code.co_name
raise NotImplementedError(
"Cannot plot topomap for multiple EpochsTFR epochs; please subselect a"
"single epoch before plotting."
f"Cannot call {calling_func}() from EpochsTFR with multiple epochs; "
"please subselect a single epoch before plotting."
)
return list(self.iter_evoked())[0]

def _get_instance_data(self):
return self.inst.get_data(picks=self._picks)[:, :, self._time_mask]
def _get_instance_data(self, time_mask):
return self.inst.get_data(picks=self._picks)[:, :, time_mask]

def _update_epoch_attributes(self):
# adjust dims and shape
Expand Down Expand Up @@ -3554,8 +3555,8 @@ def __getitem__(self, item):
self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self)
return BaseRaw._getitem(self, item, return_times=False)

def _get_instance_data(self, reject_by_annotation):
start, stop = np.where(self._time_mask)[0][[0, -1]]
def _get_instance_data(self, time_mask, reject_by_annotation):
start, stop = np.where(time_mask)[0][[0, -1]]
rba = "NaN" if reject_by_annotation else None
data = self.inst.get_data(
self._picks, start, stop + 1, reject_by_annotation=rba
Expand Down
3 changes: 2 additions & 1 deletion mne/utils/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ def _prepare_read_metadata(metadata):
assert isinstance(metadata, list)
if pd:
metadata = pd.DataFrame.from_records(metadata)
metadata.set_index("index", inplace=True)
if "index" in metadata.columns:
metadata.set_index("index", inplace=True)
assert isinstance(metadata, pd.DataFrame)
return metadata
10 changes: 9 additions & 1 deletion mne/viz/tests/test_topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,15 @@ def test_plot_tfr_topo():
data = np.random.RandomState(0).randn(
len(epochs.ch_names), n_freqs, len(epochs.times)
)
tfr = AverageTFR(epochs.info, data, epochs.times, np.arange(n_freqs), nave)
tfr = AverageTFR(
dict(
info=epochs.info,
data=data,
times=epochs.times,
freqs=np.arange(n_freqs),
nave=nave,
)
)
plt.close("all")
fig = tfr.plot_topo(
baseline=(None, 0), mode="ratio", title="Average power", vmin=0.0, vmax=14.0
Expand Down
10 changes: 9 additions & 1 deletion mne/viz/tests/test_topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,15 @@ def test_plot_tfr_topomap():
data = rng.randn(len(picks), n_freqs, len(times))

# test complex numbers
tfr = AverageTFR(info, data * (1 + 1j), times, np.arange(n_freqs), nave)
tfr = AverageTFR(
dict(
info=info,
data=data * (1 + 1j),
times=times,
freqs=np.arange(n_freqs),
nave=nave,
)
)
tfr.plot_topomap(
ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0
)
Expand Down
3 changes: 1 addition & 2 deletions mne/viz/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ def _imshow_tfr(
cnorm=None,
):
"""Show time-frequency map as two-dimensional image."""
from matplotlib import pyplot as plt
from matplotlib.widgets import RectangleSelector

_check_option("yscale", yscale, ["auto", "linear", "log"])
Expand Down Expand Up @@ -460,7 +459,7 @@ def _imshow_tfr(
if isinstance(colorbar, DraggableColorbar):
cbar = colorbar.cbar # this happens with multiaxes case
else:
cbar = plt.colorbar(mappable=img, ax=ax)
cbar = ax.get_figure().colorbar(mappable=img, ax=ax)
if interactive_cmap:
ax.CB = DraggableColorbar(cbar, img, kind="tfr_image", ch_type=None)
ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept
Expand Down
9 changes: 5 additions & 4 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2355,7 +2355,7 @@ def _make_combine_callable(
"""
kwargs = dict(axis=axis, keepdims=keepdims)
if combine is None:
combine = _identity_function
combine = _identity_function if keepdims else partial(np.squeeze, axis=axis)
elif isinstance(combine, str):
combine_dict = {
key: partial(getattr(np, key), **kwargs)
Expand All @@ -2366,9 +2366,10 @@ def _make_combine_callable(
if "median" in valid:
combine_dict["median"] = partial(_median_complex, axis=axis)
# RMS and GFP are computed the same way
for key in ("gfp", "rms"):
if key in valid:
combine_dict[key] = lambda data: np.sqrt((data**2).mean(**kwargs))
if "rms" in valid:
combine_dict["rms"] = lambda data: np.sqrt((data**2).mean(**kwargs))
if "gfp" in valid:
combine_dict["gfp"] = lambda data: data.std(axis=axis, ddof=0)
try:
combine = combine_dict[combine]
except KeyError:
Expand Down

0 comments on commit ca729a1

Please sign in to comment.