Skip to content

Commit

Permalink
Rework 2d animations
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Apr 8, 2024
1 parent 0dfe4b4 commit bf7331f
Showing 1 changed file with 64 additions and 24 deletions.
88 changes: 64 additions & 24 deletions exponax/viz/_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jaxtyping import Array, Float
from matplotlib.animation import FuncAnimation

from ._plot import plot_spatio_temporal, plot_state_1d
from ._plot import plot_spatio_temporal, plot_state_1d, plot_state_2d

N = TypeVar("N")

Expand Down Expand Up @@ -149,17 +149,39 @@ def animate_spatial_temporal_facet(
pass


def animate_state_2d(trj, *, vlim=(-1, 1)):
def animate_state_2d(
trj: Float[Array, "T 1 N N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
**kwargs,
):
fig, ax = plt.subplots()
im = ax.imshow(
trj[0].squeeze().T, vmin=vlim[0], vmax=vlim[1], cmap="RdBu_r", origin="lower"

if dt is not None:
time_range = (0, dt * trj.shape[0])
if not include_init:
time_range = (dt, time_range[1])
else:
time_range = (0, trj.shape[0] - 1)

plot_state_2d(
trj[0],
vlim=vlim,
domain_extent=domain_extent,
ax=ax,
)
im.set_data(jnp.zeros_like(trj[0]).squeeze())

def animate(i):
im.set_data(trj[i].squeeze().T)
fig.suptitle(f"t_i = {i:04d}")
return im
ax.clear()
plot_state_2d(
trj[i],
vlim=vlim,
domain_extent=domain_extent,
ax=ax,
)

plt.close(fig)

Expand All @@ -169,31 +191,49 @@ def animate(i):


def animate_state_2d_facet(
trj, *, vlim=(-1, 1), grid=(3, 3), figsize=(10, 10), titles=None
trj: Union[Float[Array, "T C N N"], Float[Array, "B T 1 N N"]],
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
grid: tuple[int, int] = (3, 3),
figsize: tuple[float, float] = (10, 10),
titles=None,
):
"""
trj.shape = (n_trjs, n_timesteps, ...)
"""
if facet_over_channels:
if trj.ndim != 4:
raise ValueError("trj must be a four-axis array.")
else:
if trj.ndim != 5:
raise ValueError("trj must be a five-axis array.")

if facet_over_channels:
trj = jnp.swapaxes(trj, 0, 1)
trj = trj[:, :, None]

fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize)
im_s = []
for i, ax in enumerate(ax_s.flatten()):
im = ax.imshow(
trj[i, 0].squeeze().T,
vmin=vlim[0],
vmax=vlim[1],
cmap="RdBu_r",
origin="lower",

for j, ax in enumerate(ax_s.flatten()):
plot_state_2d(
trj[j, 0],
vlim=vlim,
ax=ax,
)
im.set_data(jnp.zeros_like(trj[i, 0]).squeeze())
im_s.append(im)
if titles is not None:
ax.set_title(titles[j])

def animate(i):
for j, im in enumerate(im_s):
im.set_data(trj[j, i].squeeze().T)
for j, ax in enumerate(ax_s.flatten()):
ax.clear()
plot_state_2d(
trj[j, i],
vlim=vlim,
ax=ax,
)
if titles is not None:
ax_s.flatten()[j].set_title(titles[j])
fig.suptitle(f"t_i = {i:04d}")
return im_s
ax.set_title(titles[j])

plt.close(fig)

Expand Down

0 comments on commit bf7331f

Please sign in to comment.