From 44eabf08f71f9cd080433e63627b2d5eaf56a4fe Mon Sep 17 00:00:00 2001 From: samadpls Date: Thu, 10 Oct 2024 23:59:36 +0500 Subject: [PATCH] Refactor dipole visualization to invert y-axis when min_freq > max_freq Signed-off-by: samadpls --- hnn_core/tests/test_viz.py | 7 +++++++ hnn_core/viz.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 081f2416e..41038bb3d 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -187,6 +187,13 @@ def test_dipole_visualization(setup_net): # multiple TFRs get averaged fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3, show=False) + # when min_freq > max_freq (y-axis inversion) + fig = plot_tfr_morlet(dpls, freqs=np.array([30, 20, 10]), + n_cycles=3, show=False) + ax = fig.get_axes()[0] + y_limits = ax.get_ylim() + assert y_limits[0] > y_limits[1], \ + "Y-axis should be inverted when min_freq > max_freq" with pytest.raises(RuntimeError, match="All dipoles must be scaled equally!"): diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 45169e9f3..879e94433 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -740,6 +740,11 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, power = np.mean(trial_power, axis=0) im = ax.pcolormesh(times, freqs, power[0, 0, ...], cmap=colormap, shading='auto') + + if freqs[0] > freqs[-1]: + freqs = freqs[::-1] + ax.invert_yaxis() + ax.set_xlabel('Time (ms)') ax.set_ylabel('Frequency (Hz)')