diff --git a/tests/test_viz.py b/tests/test_viz.py index 045ac57..a06e747 100644 --- a/tests/test_viz.py +++ b/tests/test_viz.py @@ -28,8 +28,10 @@ def test_plot_state_2d(): plt.close(fig) -def test_plot_state_3d(): - state = jax.random.normal(jax.random.PRNGKey(0), (1, 32, 32, 32)) +# # Requires a GPU and therefore cannot easily be tested on GitHub Actions - fig = ex.viz.plot_state_3d(state) - plt.close(fig) +# def test_plot_state_3d(): +# state = jax.random.normal(jax.random.PRNGKey(0), (1, 32, 32, 32)) + +# fig = ex.viz.plot_state_3d(state) +# plt.close(fig)