Skip to content

Commit

Permalink
Allow changing cmap
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Jun 12, 2024
1 parent acf656b commit 30a5205
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
8 changes: 8 additions & 0 deletions exponax/viz/_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def animate_spatio_temporal(
trjs: Float[Array, "S T C N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
Expand All @@ -117,6 +118,7 @@ def animate_spatio_temporal(
- `trjs`: The trajectory of states to animate. Must be a four-axis array
with shape `(n_timesteps_outer, n_time_steps, n_channels, n_spatial)`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `cmap`: The colormap to use. Default is `"RdBu_r"`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
Expand All @@ -139,6 +141,7 @@ def animate_spatio_temporal(
plot_spatio_temporal(
trjs[0],
vlim=vlim,
cmap=cmap,
domain_extent=domain_extent,
dt=dt,
include_init=include_init,
Expand All @@ -151,6 +154,7 @@ def animate(i):
plot_spatio_temporal(
trjs[i],
vlim=vlim,
cmap=cmap,
domain_extent=domain_extent,
dt=dt,
include_init=include_init,
Expand All @@ -169,6 +173,7 @@ def animate_state_2d(
trj: Float[Array, "T 1 N N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
Expand All @@ -189,6 +194,7 @@ def animate_state_2d(
- `trj`: The trajectory of states to animate. Must be a four-axis array with
shape `(n_timesteps, 1, n_spatial, n_spatial)`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `cmap`: The colormap to use. Default is `"RdBu_r"`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x- and y-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
Expand Down Expand Up @@ -219,6 +225,7 @@ def animate_state_2d(
plot_state_2d(
trj[0],
vlim=vlim,
cmap=cmap,
domain_extent=domain_extent,
ax=ax,
)
Expand All @@ -229,6 +236,7 @@ def animate(i):
plot_state_2d(
trj[i],
vlim=vlim,
cmap=cmap,
domain_extent=domain_extent,
ax=ax,
)
Expand Down
10 changes: 8 additions & 2 deletions exponax/viz/_animate_facet.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def animate_spatial_temporal_facet(
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
Expand Down Expand Up @@ -154,6 +155,7 @@ def animate_spatial_temporal_facet(
- `facet_over_channels`: Whether to facet over the channel axis or the batch
axis. Default is `True`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `cmap`: The colormap to use. Default is `"RdBu_r"`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
Expand Down Expand Up @@ -184,10 +186,11 @@ def animate_state_2d_facet(
trj: Union[Float[Array, "T C N N"], Float[Array, "B T 1 N N"]],
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
vlim: tuple[float, float] = (-1.0, 1.0),
grid: tuple[int, int] = (3, 3),
figsize: tuple[float, float] = (10, 10),
titles=None,
Expand Down Expand Up @@ -217,13 +220,14 @@ def animate_state_2d_facet(
`facet_over_channels` is `False`.
- `facet_over_channels`: Whether to facet over the channel axis or the batch
axis. Default is `True`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `cmap`: The colormap to use. Default is `"RdBu_r"`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x-axis and y-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`.
- `include_init`: Whether to the state starts at an initial condition (t=0)
or at the first frame in the trajectory. This affects is the the time
range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `grid`: The grid of subplots. Default is `(3, 3)`.
- `figsize`: The size of the figure. Default is `(10, 10)`.
- `titles`: The titles for each subplot. Default is `None`.
Expand Down Expand Up @@ -257,6 +261,7 @@ def animate_state_2d_facet(
plot_state_2d(
trj[j, 0],
vlim=vlim,
cmap=cmap,
ax=ax,
domain_extent=domain_extent,
)
Expand All @@ -270,6 +275,7 @@ def animate(i):
plot_state_2d(
trj[j, i],
vlim=vlim,
cmap=cmap,
ax=ax,
)
if titles is not None:
Expand Down
8 changes: 6 additions & 2 deletions exponax/viz/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def plot_spatio_temporal(
trj: Float[Array, "T 1 N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
ax=None,
domain_extent: float = None,
dt: float = None,
Expand All @@ -110,6 +111,7 @@ def plot_spatio_temporal(
be the time axis, the second axis the channel axis, and the third axis
the spatial axis.
- `vlim`: The limits of the color scale.
- `cmap`: The colormap to use.
- `ax`: The axis to plot on. If not provided, a new figure will be created.
- `domain_extent`: The extent of the spatial domain. If not provided, the
domain extent will be the number of points in the spatial axis. This
Expand Down Expand Up @@ -153,7 +155,7 @@ def plot_spatio_temporal(
trj_wrapped[:, 0, :].T,
vmin=vlim[0],
vmax=vlim[1],
cmap="RdBu_r",
cmap=cmap,
origin="lower",
aspect="auto",
extent=(*time_range, *space_range),
Expand All @@ -173,6 +175,7 @@ def plot_state_2d(
state: Float[Array, "1 N N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
domain_extent: float = None,
ax=None,
**kwargs,
Expand All @@ -193,6 +196,7 @@ def plot_state_2d(
- `state`: The state to plot as a three axis array. The first axis should be
the channel axis, and the subsequent two axes the spatial axes.
- `vlim`: The limits of the color scale.
- `cmap`: The colormap to use.
- `domain_extent`: The extent of the spatial domain. If not provided, the
domain extent will be the number of points in the spatial axes. This
adjusts the x and y axes.
Expand Down Expand Up @@ -225,7 +229,7 @@ def plot_state_2d(
state_wrapped.T,
vmin=vlim[0],
vmax=vlim[1],
cmap="RdBu_r",
cmap=cmap,
origin="lower",
aspect="auto",
extent=(*space_range, *space_range),
Expand Down
6 changes: 6 additions & 0 deletions exponax/viz/_plot_facet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def plot_spatio_temporal_facet(
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
grid: tuple[int, int] = (3, 3),
figsize: tuple[float, float] = (10, 10),
titles: list[str] = None,
Expand Down Expand Up @@ -125,6 +126,7 @@ def plot_spatio_temporal_facet(
- `facet_over_channels`: Whether to facet over the channel axis (three axes)
or the batch axis (four axes).
- `vlim`: The limits of the color scale.
- `cmap`: The colormap to use.
- `grid`: The grid layout for the facet plot. This should be a tuple with
two integers. If the number of trajectories is less than the product of
the grid, the remaining axes will be removed.
Expand Down Expand Up @@ -164,6 +166,7 @@ def plot_spatio_temporal_facet(
plot_spatio_temporal(
single_trj,
vlim=vlim,
cmap=cmap,
ax=ax,
domain_extent=domain_extent,
dt=dt,
Expand All @@ -186,6 +189,7 @@ def plot_state_2d_facet(
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
grid: tuple[int, int] = (3, 3),
figsize: tuple[float, float] = (10, 10),
titles: list[str] = None,
Expand Down Expand Up @@ -215,6 +219,7 @@ def plot_state_2d_facet(
- `facet_over_channels`: Whether to facet over the channel axis (three axes)
or the batch axis (four axes).
- `vlim`: The limits of the color scale.
- `cmap`: The colormap to use.
- `grid`: The grid layout for the facet plot. This should be a tuple with
two integers. If the number of states is less than the product of the
grid, the remaining axes will be removed.
Expand Down Expand Up @@ -245,6 +250,7 @@ def plot_state_2d_facet(
plot_state_2d(
states[i],
vlim=vlim,
cmap=cmap,
ax=ax,
domain_extent=domain_extent,
**kwargs,
Expand Down

0 comments on commit 30a5205

Please sign in to comment.