diff --git a/tests/test_viz.py b/tests/test_viz.py index 045ac57..0440727 100644 --- a/tests/test_viz.py +++ b/tests/test_viz.py @@ -28,8 +28,8 @@ def test_plot_state_2d(): plt.close(fig) -def test_plot_state_3d(): - state = jax.random.normal(jax.random.PRNGKey(0), (1, 32, 32, 32)) +# def test_plot_state_3d(): +# state = jax.random.normal(jax.random.PRNGKey(0), (1, 32, 32, 32)) - fig = ex.viz.plot_state_3d(state) - plt.close(fig) +# fig = ex.viz.plot_state_3d(state) +# plt.close(fig)