diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 081f2416e..41038bb3d 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -187,6 +187,13 @@ def test_dipole_visualization(setup_net): # multiple TFRs get averaged fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3, show=False) + # when min_freq > max_freq (y-axis inversion) + fig = plot_tfr_morlet(dpls, freqs=np.array([30, 20, 10]), + n_cycles=3, show=False) + ax = fig.get_axes()[0] + y_limits = ax.get_ylim() + assert y_limits[0] > y_limits[1], \ + "Y-axis should be inverted when min_freq > max_freq" with pytest.raises(RuntimeError, match="All dipoles must be scaled equally!"): diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 45169e9f3..879e94433 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -740,6 +740,11 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, power = np.mean(trial_power, axis=0) im = ax.pcolormesh(times, freqs, power[0, 0, ...], cmap=colormap, shading='auto') + + if freqs[0] > freqs[-1]: + freqs = freqs[::-1] + ax.invert_yaxis() + ax.set_xlabel('Time (ms)') ax.set_ylabel('Frequency (Hz)')