Skip to content

Commit

Permalink
Expand test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Dec 10, 2024
1 parent 972aba2 commit e11fa2b
Showing 1 changed file with 78 additions and 8 deletions.
86 changes: 78 additions & 8 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):
with tfr.info._unlock():
tfr.info["meas_date"] = want
assert tfr_loaded == tfr
# test with taper dimension and weights
n_tapers = 3 # anything >= 1 should do
weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs
state = tfr.__getstate__()
state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim
state["weights"] = weights # add weights
state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims
tfr = EpochsTFR(inst=state)
tfr.save(fname, overwrite=True)
tfr_loaded = read_tfrs(fname)
assert tfr_loaded == tfr
# test overwrite
with pytest.raises(OSError, match="Destination file exists."):
tfr.save(fname, overwrite=False)
Expand Down Expand Up @@ -726,17 +737,31 @@ def test_average_tfr_init(full_evoked):
AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace)


def test_epochstfr_init_errors(epochs_tfr):
"""Test __init__ for EpochsTFR."""
state = epochs_tfr.__getstate__()
with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"):
EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0]))
@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr"))
def test_tfr_init_errors(inst, request, average_tfr):
"""Test __init__ for Raw/Epochs/AverageTFR."""
# Load data
inst = _get_inst(inst, request, average_tfr=average_tfr)
state = inst.__getstate__()
# Prepare for TFRArray object instantiation
inst_name = inst.__class__.__name__
class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR)
ndims_mapping = dict(
RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D")
)
TFR = class_mapping[inst_name]
allowed_ndims = ndims_mapping[inst_name]
# Check errors caught
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"):
TFR(inst=state | dict(data=inst.data[..., 0]))
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"):
TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1))))
with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"):
EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1]))
TFR(inst=state | dict(data=inst.data[..., :-1, :, :]))
with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"):
EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1]))
TFR(inst=state | dict(times=inst.times[:-1]))
with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"):
EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1]))
TFR(inst=state | dict(freqs=inst.freqs[:-1]))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1158,6 +1183,15 @@ def test_averaging_epochsTFR():
):
power.average(method=np.mean)

# Check it doesn't run for taper spectra
tapered = epochs.compute_tfr(
method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex"
)
with pytest.raises(
NotImplementedError, match=r"Averaging multitaper tapers .* is not supported."
):
tapered.average()


def test_averaging_freqsandtimes_epochsTFR():
"""Test that EpochsTFR averaging freqs methods work."""
Expand Down Expand Up @@ -1551,6 +1585,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):
assert avgs[0].comment == str(epochs_tfr.events[0, -1])


@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked"))
def test_tfrarray_tapered_spectra(inst, evoked, request):
"""Test Raw/Epochs/AverageTFRArray instantiation with tapered spectra."""
# Load data object
inst = _get_inst(inst, request, evoked=evoked)
inst.pick("mag")
# Compute TFR with taper dimension (can be complex or phase output)
tfr = inst.compute_tfr(
method="multitaper", freqs=freqs_linspace, n_cycles=4, output="complex"
)
tfr_array, weights = tfr.get_data(), tfr.weights
# Prepare for TFRArray object instantiation
defaults = dict(
info=inst.info, data=tfr_array, times=inst.times, freqs=freqs_linspace
)
class_mapping = dict(Raw=RawTFRArray, Epochs=EpochsTFRArray, Evoked=AverageTFRArray)
TFRArray = class_mapping[inst.__class__.__name__]
# Check TFRArray instantiation runs with good data
TFRArray(**defaults, weights=weights)
# Check taper dimension but no weights caught
with pytest.raises(
ValueError, match="Taper dimension in data, but no weights found."
):
TFRArray(**defaults)
# Check mismatching n_taper in weights caught
with pytest.raises(
ValueError, match=r"Taper axis .* doesn't match weights attribute"
):
TFRArray(**defaults, weights=weights[:-1])
# Check mismatching n_freq in weights caught
with pytest.raises(
ValueError, match=r"Frequency axis .* doesn't match weights attribute"
):
TFRArray(**defaults, weights=weights[:, :-1])


def test_tfr_proj(epochs):
"""Test `compute_tfr(proj=True)`."""
epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True)
Expand Down

0 comments on commit e11fa2b

Please sign in to comment.