Skip to content

Commit

Permalink
Change default to batch rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Jun 11, 2024
1 parent 8c6dcf5 commit beaa035
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
8 changes: 6 additions & 2 deletions exponax/viz/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def plot_state_3d(
one_channel_state = state[0:1]
one_channel_state_wrapped = wrap_bc(one_channel_state)

img = volume_render_state_3d(
imgs = volume_render_state_3d(
one_channel_state_wrapped,
vlim=vlim,
domain_extent=domain_extent,
Expand All @@ -262,6 +262,8 @@ def plot_state_3d(
**kwargs,
)

img = imgs[0]

if ax is None:
fig, ax = plt.subplots()

Expand Down Expand Up @@ -298,7 +300,7 @@ def plot_spatio_temporal_2d(
jnp.array(trj_one_channel_wrapped.transpose(1, 2, 3, 0)), 3
)

img = volume_render_state_3d(
imgs = volume_render_state_3d(
trj_reshaped_to_3d,
vlim=vlim,
bg_color=bg_color,
Expand All @@ -310,6 +312,8 @@ def plot_spatio_temporal_2d(
**kwargs,
)

img = imgs[0]

if ax is None:
fig, ax = plt.subplots()

Expand Down
51 changes: 39 additions & 12 deletions exponax/viz/_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def zigzag_alpha(cmap, min_alpha=0.0):
)


def chunk_list(lst, n):
for i in range(0, len(lst), n):
yield lst[i : i + n]


def volume_render_state_3d(
states: Float[Array, "B N N N"],
*,
Expand All @@ -66,6 +71,7 @@ def volume_render_state_3d(
transfer_function: callable = zigzag_alpha,
distance_scale: float = 10.0,
gamma_correction: float = 2.4,
chunk_size: int = 64,
**kwargs,
) -> Float[Array, "B resolution resolution 3"]:
"""
Expand All @@ -88,18 +94,39 @@ def volume_render_state_3d(

cmap_with_alpha_transfer = transfer_function(plt.get_cmap(cmap))

imgs = vape.render(
states,
cmap=cmap_with_alpha_transfer,
time=0.0 if states.shape[0] == 1 else np.arange(states.shape[0]),
width=resolution,
height=resolution,
background=bg_color,
vmin=vlim[0],
vmax=vlim[1],
distance_scale=distance_scale,
)

num_images = states.shape[0]

imgs = []
for time_steps in chunk_list(range(num_images), chunk_size):
if num_images == 1:
sub_time_steps = [0.0]
else:
sub_time_steps = [i / (num_images - 1) for i in time_steps]
imgs_this_batch = vape.render(
states,
cmap=cmap_with_alpha_transfer,
time=sub_time_steps,
width=resolution,
height=resolution,
background=bg_color,
vmin=vlim[0],
vmax=vlim[1],
distance_scale=distance_scale,
)
# imgs = vape.render(
# states,
# cmap=cmap_with_alpha_transfer,
# time=[0.0,],
# width=resolution,
# height=resolution,
# background=bg_color,
# vmin=vlim[0],
# vmax=vlim[1],
# distance_scale=distance_scale,
# )
imgs.append(imgs_this_batch)

imgs = np.concatenate(imgs, axis=0)
imgs = ((imgs / 255.0) ** (gamma_correction) * 255).astype(np.uint8)

return imgs

0 comments on commit beaa035

Please sign in to comment.