Skip to content

Commit

Permalink
Refactor dipole visualization to invert y-axis when min_freq > max_freq
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Oct 25, 2024
1 parent b616356 commit dca5a42
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def test_dipole_visualization(setup_net):
# multiple TFRs get averaged
fig = plot_tfr_morlet(dpls, freqs=np.arange(23, 26, 1.), n_cycles=3,
show=False)
# when min_freq > max_freq (y-axis inversion)
fig = plot_tfr_morlet(dpls, freqs=np.array([30, 20, 10]),
n_cycles=3, show=False)
ax = fig.get_axes()[0]
y_limits = ax.get_ylim()
assert y_limits[0] > y_limits[1], \
"Y-axis should be inverted when min_freq > max_freq"

with pytest.raises(RuntimeError,
match="All dipoles must be scaled equally!"):
Expand Down
5 changes: 5 additions & 0 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None,
power = np.mean(trial_power, axis=0)
im = ax.pcolormesh(times, freqs, power[0, 0, ...], cmap=colormap,
shading='auto')

if freqs[0] > freqs[-1]:
freqs = freqs[::-1]
ax.invert_yaxis()

ax.set_xlabel('Time (ms)')
ax.set_ylabel('Frequency (Hz)')

Expand Down

0 comments on commit dca5a42

Please sign in to comment.