diff --git a/exponax/viz/_animate.py b/exponax/viz/_animate.py index 62ae9b4..948d798 100644 --- a/exponax/viz/_animate.py +++ b/exponax/viz/_animate.py @@ -72,19 +72,43 @@ def animate_state_1d_facet( fig, ax_s = plt.subplots(*grid, figsize=figsize) - for i, ax in enumerate(ax_s.flatten()): + num_subplots = trj.shape[0] + + for j, ax in enumerate(ax_s.flatten()): plot_state_1d( - trj[i], + trj[j, 0], vlim=vlim, domain_extent=domain_extent, labels=labels, ax=ax, **kwargs, ) - if titles is not None: - ax.set_title(titles[i]) + if j >= num_subplots: + ax.remove() + else: + if titles is not None: + ax.set_title(titles[j]) + + def animate(i): + for j, ax in enumerate(ax_s.flatten()): + ax.clear() + plot_state_1d( + trj[j, i], + vlim=vlim, + domain_extent=domain_extent, + labels=labels, + ax=ax, + **kwargs, + ) + if j >= num_subplots: + ax.remove() + else: + if titles is not None: + ax.set_title(titles[j]) + + ani = FuncAnimation(fig, animate, frames=trj.shape[1], interval=100, blit=False) - return fig + return ani def animate_spatio_temporal(