Skip to content

Commit

Permalink
Start writing tests for viz routines
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Jun 12, 2024
1 parent 30a5205 commit 452a4d0
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import jax
import matplotlib.pyplot as plt

import exponax as ex


def test_plot_state_1d():
state = jax.random.normal(
jax.random.PRNGKey(0),
(
10,
100,
),
)

fig = ex.viz.plot_state_1d(state)
plt.close(fig)


def test_plot_spatio_temporal():
trj = jax.random.normal(jax.random.PRNGKey(0), (100, 1, 64))

fig = ex.viz.plot_spatio_temporal(trj)
plt.close(fig)


def test_plot_state_2d():
state = jax.random.normal(jax.random.PRNGKey(0), (1, 100, 100))

fig = ex.viz.plot_state_2d(state)
plt.close(fig)


def test_plot_state_3d():
state = jax.random.normal(jax.random.PRNGKey(0), (1, 32, 32, 32))

fig = ex.viz.plot_state_3d(state)
plt.close(fig)

0 comments on commit 452a4d0

Please sign in to comment.