Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Apr 9, 2024
1 parent 68a78fe commit 48bd412
Showing 1 changed file with 210 additions and 3 deletions.
213 changes: 210 additions & 3 deletions exponax/viz/_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,37 @@ def animate_state_1d(
include_init: bool = False,
**kwargs,
):
"""
Animate a trajectory of 1d states.
Requires the input to be a three-axis array with a leading time axis, a
channel axis, and a spatial axis. If there is more than one dimension in the
channel axis, this will be plotted in a different color.
Periodic boundary conditions will be applied to the spatial axis (the state
is wrapped around).
**Arguments**:
- `trj`: The trajectory of states to animate. Must be a three-axis array
with shape `(n_timesteps, n_channels, n_spatial)`. If the channel axis
has more than one dimension, the different channels will be plotted in
different colors.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
a title will be displayed with the current time. If not provided, just
the frames are counted.
- `include_init`: Whether to the state starts at an initial condition (t=0)
or at the first frame in the trajectory. This affects is the the time
range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
- `**kwargs`: Additional keyword arguments to pass to the plotting function.
**Returns**:
- `ani`: The animation object.
"""
fig, ax = plt.subplots()

plot_state_1d(
Expand Down Expand Up @@ -57,7 +88,7 @@ def animate(i):


def animate_state_1d_facet(
trj: Float[Array, "T B C N"],
trj: Float[Array, "B T C N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
labels: list[str] = None,
Expand All @@ -67,6 +98,38 @@ def animate_state_1d_facet(
figsize: tuple[float, float] = (10, 10),
**kwargs,
):
"""
Animate a trajectory of faceted 1d states.
Requires the input to be a four-axis array with a leading batch axis, a time
axis, a channel axis, and a spatial axis. If there is more than one
dimension in the channel axis, this will be plotted in a different color.
Hence, there are two ways to display multiple states: either via the batch
axis (resulting in faceted subplots) or via the channel axis (resulting in
different colors).
Periodic boundary conditions will be applied to the spatial axis (the state
is wrapped around).
**Arguments**:
- `trj`: The trajectory of states to animate. Must be a four-axis array with
shape `(n_batches, n_timesteps, n_channels, n_spatial)`. If the channel
axis has more than one dimension, the different channels will be plotted
in different colors.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `labels`: The labels for each channel. Default is `None`.
- `titles`: The titles for each subplot. Default is `None`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x-axis limits of the plot.
- `grid`: The grid of subplots. Default is `(3, 3)`.
- `figsize`: The size of the figure. Default is `(10, 10)`.
- `**kwargs`: Additional keyword arguments to pass to the plotting function.
**Returns**:
- `ani`: The animation object.
"""
if trj.ndim != 4:
raise ValueError("states must be a four-axis array.")

Expand Down Expand Up @@ -120,6 +183,41 @@ def animate_spatio_temporal(
include_init: bool = False,
**kwargs,
):
"""
Animate a trajectory of spatio-temporal states. Allows to visualize "two
time dimensions". One time dimension is the x-axis. The other is via the
animation. For instance, this can be used to present how neural predictors
learn spatio-temporal dynamics over time.
Requires the input to be a four-axis array with a leading spatial axis, a
time axis, a channel axis, and a batch axis. Only the zeroth dimension in
the channel axis is plotted.
Periodic boundary conditions will be applied to the spatial axis (the state
is wrapped around).
**Arguments**:
- `trjs`: The trajectory of states to animate. Must be a four-axis array
with shape `(n_timesteps_outer, n_time_steps, n_channels, n_spatial)`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
a title will be displayed with the current time. If not provided, just
the frames are counted.
- `include_init`: Whether to the state starts at an initial condition (t=0)
or at the first frame in the trajectory. This affects is the the time
range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
- `**kwargs`: Additional keyword arguments to pass to the plotting function.
**Returns**:
- `ani`: The animation object.
"""
if trjs.ndim != 4:
raise ValueError("trjs must be a four-axis array.")

fig, ax = plt.subplots()

plot_spatio_temporal(
Expand Down Expand Up @@ -163,14 +261,60 @@ def animate_spatial_temporal_facet(
figsize: tuple[float, float] = (10, 10),
**kwargs,
):
"""
Animate a facet of trajectories of spatio-temporal states. Allows to
visualize "two time dimensions". One time dimension is the x-axis. The other
is via the animation. For instance, this can be used to present how neural
predictors learn spatio-temporal dynamics over time. The additional faceting
dimension can be used two compare multiple networks with one another.
Requires the input to be either a four-axis array or a five-axis array:
- If `facet_over_channels` is `True`, the input must be a four-axis array
with a leading outer time axis, a time axis, a channel axis, and a
spatial axis. Each faceted subplot displays a different channel.
- If `facet_over_channels` is `False`, the input must be a five-axis array
with a leading batch axis, an outer time axis, a time axis, a channel
axis, and a spatial axis. Each faceted subplot displays a different
batch, only the zeroth dimension in the channel axis is plotted.
Periodic boundary conditions will be applied to the spatial axis (the state
is wrapped around).
**Arguments**:
- `trjs`: The trajectory of states to animate. Must be a four-axis array
with shape `(n_timesteps_outer, n_time_steps, n_channels, n_spatial)` if
`facet_over_channels` is `True`, or a five-axis array with shape
`(n_batches, n_timesteps_outer, n_time_steps, n_channels, n_spatial)` if
`facet_over_channels` is `False`.
- `facet_over_channels`: Whether to facet over the channel axis or the batch
axis. Default is `True`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
a title will be displayed with the current time. If not provided, just
the frames are counted.
- `include_init`: Whether to the state starts at an initial condition (t=0)
or at the first frame in the trajectory. This affects is the the time
range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
- `grid`: The grid of subplots. Default is `(3, 3)`.
- `figsize`: The size of the figure. Default is `(10, 10)`.
- `**kwargs`: Additional keyword arguments to pass to the plotting function.
**Returns**:
- `ani`: The animation object.
"""
if facet_over_channels:
if trjs.ndim != 4:
raise ValueError("trjs must be a four-axis array.")
else:
if trjs.ndim != 5:
raise ValueError("states must be a five-axis array.")
# TODO
pass
raise NotImplementedError("Not implemented yet.")


def animate_state_2d(
Expand All @@ -182,6 +326,38 @@ def animate_state_2d(
include_init: bool = False,
**kwargs,
):
"""
Animate a trajectory of 2d states.
Requires the input to be a four-axis array with a leading time axis, a
channel axis, and two spatial axes. Only the zeroth dimension in the channel
axis is plotted.
Periodic boundary conditions will be applied to the spatial axes (the state
is wrapped around).
**Arguments**:
- `trj`: The trajectory of states to animate. Must be a four-axis array with
shape `(n_timesteps, 1, n_spatial, n_spatial)`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x- and y-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
a title will be displayed with the current time. If not provided, just
the frames are counted.
- `include_init`: Whether to the state starts at an initial condition (t=0)
or at the first frame in the trajectory. This affects is the the time
range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
- `**kwargs`: Additional keyword arguments to pass to the plotting function.
**Returns**:
- `ani`: The animation object.
"""
if trj.ndim != 4:
raise ValueError("trj must be a four-axis array.")

fig, ax = plt.subplots()

if dt is not None:
Expand Down Expand Up @@ -224,7 +400,38 @@ def animate_state_2d_facet(
titles=None,
):
"""
trj.shape = (n_trjs, n_timesteps, ...)
Animate a facet of trajectories of 2d states.
Requires the input to be either a four-axis array or a five-axis array:
- If `facet_over_channels` is `True`, the input must be a four-axis array
with a leading time axis, a channel axis, and two spatial axes. Each
faceted subplot displays a different channel.
- If `facet_over_channels` is `False`, the input must be a five-axis array
with a leading batch axis, a time axis, a channel axis, and two spatial
axes. Each faceted subplot displays a different batch. Only the zeroth
dimension in the channel axis is plotted.
Periodic boundary conditions will be applied to the spatial axes (the state
is wrapped around).
**Arguments**:
- `trj`: The trajectory of states to animate. Must be a four-axis array with
shape `(n_timesteps, n_channels, n_spatial, n_spatial)` if
`facet_over_channels` is `True`, or a five-axis array with shape
`(n_batches, n_timesteps, n_channels, n_spatial, n_spatial)` if
`facet_over_channels` is `False`.
- `facet_over_channels`: Whether to facet over the channel axis or the batch
axis. Default is `True`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `grid`: The grid of subplots. Default is `(3, 3)`.
- `figsize`: The size of the figure. Default is `(10, 10)`.
- `titles`: The titles for each subplot. Default is `None`.
**Returns**:
- `ani`: The animation object.
"""
if facet_over_channels:
if trj.ndim != 4:
Expand Down

0 comments on commit 48bd412

Please sign in to comment.