Skip to content

Commit

Permalink
Start setting up qualitative rollouts
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 13, 2024
1 parent a1fcbfb commit 6dc8919
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Folder created by visual validation
validation/qualitative_rollouts/
69 changes: 69 additions & 0 deletions validation/qualitative_rollouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Work in Progress.
"""

import os
import sys
from pathlib import Path

import jax
import matplotlib.pyplot as plt
from tqdm import tqdm

sys.path.append(".")
import exponax as ex # noqa: E402

ic_key = jax.random.PRNGKey(0)

CONFIGURATIONS_1D = [
(
ex.stepper.Advection(1, 3.0, 110, 0.01, velocity=0.3),
"advection",
ex.ic.RandomTruncatedFourierSeries(1, cutoff=5),
100,
(-1.0, 1.0),
),
(
ex.stepper.Diffusion(1, 3.0, 110, 0.01, diffusivity=0.01),
"diffusion",
ex.ic.RandomTruncatedFourierSeries(1, cutoff=5),
100,
(-1.0, 1.0),
),
]

p_meter = tqdm(CONFIGURATIONS_1D, desc="", total=len(CONFIGURATIONS_1D))
dir_path = Path(os.path.dirname(os.path.realpath(__file__)))
img_folder = dir_path / Path("qualitative_rollouts")
img_folder.mkdir(exist_ok=True)


# 1d problems (produce spatio-temporal plots)
for stepper_1d, name, ic_distribution, steps, vlim in CONFIGURATIONS_1D:
p_meter.set_description(f"1d {name}")

ic = ic_distribution(stepper_1d.num_points, key=ic_key)
trj = ex.rollout(stepper_1d, steps, include_init=True)(ic)
num_channels = stepper_1d.num_channels
fig, ax_s = plt.subplots(num_channels, 1, figsize=(5, 5 * num_channels))
if num_channels == 1:
ax_s = [
ax_s,
]
for i, ax in enumerate(ax_s):
ax.imshow(
trj[:, i, :].T,
aspect="auto",
origin="lower",
vmin=vlim[0],
vmax=vlim[1],
cmap="RdBu_r",
)
ax.set_title(f"{name} channel {i}")
ax.set_xlabel("time")
ax.set_ylabel("space")

fig.savefig(img_folder / f"{name}_1d.png")
plt.close(fig)

p_meter.update(1)

0 comments on commit 6dc8919

Please sign in to comment.