Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Apr 8, 2024
1 parent c7105ec commit 063e227
Showing 1 changed file with 145 additions and 1 deletion.
146 changes: 145 additions & 1 deletion exponax/viz/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,39 @@ def plot_spatio_temporal(
include_init: bool = False,
**kwargs,
):
"""
Plot a trajectory of a 1d state as a spatio-temporal plot (space in y-axis,
and time in x-axis).
Requires the input to be a real array with three axis: a leading time axis,
a channel axis, and a spatial axis. Only the leading dimension in the
channel axis will be plotted. See `plot_spatio_temporal_facet` for plotting
multiple trajectories.
Periodic boundary conditions will be applied to the spatial axis (the state
is wrapped around).
**Arguments:**
- `trj`: The trajectory to plot as a three axis array. The first axis should
be the time axis, the second axis the channel axis, and the third axis
the spatial axis.
- `vlim`: The limits of the color scale.
- `ax`: The axis to plot on. If not provided, a new figure will be created.
- `domain_extent`: The extent of the spatial domain. If not provided, the
domain extent will be the number of points in the spatial axis. This
adjusts the y-axis.
- `dt`: The time step. This adjust the extent of the x-axis. If not
provided, the time axis will be the number of time steps.
- `include_init`: Will affect the ticks of the time axis. If `True`, they
will start at zero. If `False`, they will start at the time step.
- `**kwargs`: Additional arguments to pass to the imshow function.
**Returns:**
- If `ax` is not provided, returns a tuple with the figure, axis, and image
object. Otherwise, returns the image object.
"""
if trj.ndim != 3:
raise ValueError("trj must be a two-axis array.")

Expand Down Expand Up @@ -207,6 +240,49 @@ def plot_spatio_temporal_facet(
include_init: bool = False,
**kwargs,
):
"""
Plot a facet of spatio-temporal trajectories.
Requires the input to be a real array with either three or four axes:
* Three axes: a leading time axis, a channel axis, and a spatial axis. The
faceting is performed over the channel axis. Requires the
`facet_over_channels` argument to be `True` (default).
* Four axes: a leading batch axis, a time axis, a channel axis, and a
spatial
axis. The faceting is performed over the batch axis. Requires the
`facet_over_channels` argument to be `False`. Only the zeroth channel
for each trajectory will be plotted.
Periodic boundary conditions will be applied to the spatial axis (the state
is wrapped around).
**Arguments:**
- `trjs`: The trajectories to plot as a three or four axis array. See above
for the requirements.
- `facet_over_channels`: Whether to facet over the channel axis (three axes)
or the batch axis (four axes).
- `vlim`: The limits of the color scale.
- `grid`: The grid layout for the facet plot. This should be a tuple with
two integers. If the number of trajectories is less than the product of
the grid, the remaining axes will be removed.
- `figsize`: The size of the figure.
- `titles`: The titles for each plot. This should be a list of strings with
the same length as the number of trajectories.
- `domain_extent`: The extent of the spatial domain. If not provided, the
domain extent will be the number of points in the spatial axis. This
adjusts the y-axis.
- `dt`: The time step. This adjust the extent of the x-axis. If not
provided, the time axis will be the number of time steps.
- `include_init`: Will affect the ticks of the time axis. If `True`, they
will start at zero. If `False`, they will start at the time step.
- `**kwargs`: Additional arguments to pass to the imshow function.
**Returns:**
- The figure.
"""
if facet_over_channels:
if trjs.ndim != 3:
raise ValueError("trjs must be a three-axis array.")
Expand Down Expand Up @@ -250,6 +326,36 @@ def plot_state_2d(
ax=None,
**kwargs,
):
"""
Visualizes a two-dimensional state as an image.
Requires the input to be a real array with three axes: a leading channel
axis, and two subsequent spatial axes. This function will visualize the
zeroth channel. For plotting multiple channels at the same time, see
`plot_state_2d_facet`.
Periodic boundary conditions will be applied to the spatial axes (the state
is wrapped around).
**Arguments:**
- `state`: The state to plot as a three axis array. The first axis should be
the channel axis, and the subsequent two axes the spatial axes.
- `vlim`: The limits of the color scale.
- `domain_extent`: The extent of the spatial domain. If not provided, the
domain extent will be the number of points in the spatial axes. This
adjusts the x and y axes.
- `ax`: The axis to plot on. If not provided, a new figure will be created.
- `**kwargs`: Additional arguments to pass to the imshow function.
**Returns:**
- If `ax` is not provided, returns a tuple with the figure, axis, and image
object. Otherwise, returns the image object.
"""
if state.ndim != 3:
raise ValueError("state must be a three-axis array.")

if domain_extent is not None:
space_range = (0, domain_extent)
else:
Expand Down Expand Up @@ -279,7 +385,7 @@ def plot_state_2d(


def plot_state_2d_facet(
states: Union[Float[Array, "B N N"], Float[Array, "B 1 N N"]],
states: Union[Float[Array, "C N N"], Float[Array, "B 1 N N"]],
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
Expand All @@ -289,6 +395,44 @@ def plot_state_2d_facet(
domain_extent: float = None,
**kwargs,
):
"""
Plot a facet of 2d states.
Requires the input to be a real array with three or four axes:
* Three axes: a leading channel axis, and two subsequent spatial axes. The
facet will be done over the channel axis, requires the
`facet_over_channels` argument to be `True` (default).
* Four axes: a leading batch axis, a channel axis, and two subsequent
spatial axes. The facet will be done over the batch axis, requires the
`facet_over_channels` argument to be `False`. Only the zeroth channel
for each state will be plotted.
Periodic boundary conditions will be applied to the spatial axes (the state
is wrapped around).
**Arguments:**
- `states`: The states to plot as a three or four axis array. See above for
the requirements.
- `facet_over_channels`: Whether to facet over the channel axis (three axes)
or the batch axis (four axes).
- `vlim`: The limits of the color scale.
- `grid`: The grid layout for the facet plot. This should be a tuple with
two integers. If the number of states is less than the product of the
grid, the remaining axes will be removed.
- `figsize`: The size of the figure.
- `titles`: The titles for each plot. This should be a list of strings with
the same length as the number of states.
- `domain_extent`: The extent of the spatial domain. If not provided, the
domain extent will be the number of points in the spatial axes. This
adjusts the x and y axes.
- `**kwargs`: Additional arguments to pass to the imshow function.
**Returns:**
- The figure.
"""
if facet_over_channels:
if states.ndim != 3:
raise ValueError("states must be a three-axis array.")
Expand Down

0 comments on commit 063e227

Please sign in to comment.