Skip to content

Commit

Permalink
More validation (#39)
Browse files Browse the repository at this point in the history
* Add note

* Add 2d and 3d advection

* Add diffusion analytical solution in 2D

* Add diffusion test in 3D

* Ensure all presets are working

* Employ a fix for single facets

* Use Exponax viz routines

* Also employ fix for faceted animations

* Fix defaults for Kolmogorov

* Change to exponax viz routines

* Add 3D dynamics

* Add hint on where to find the reference qualitative rollouts

* Re-Execute Notebook

* Re-Execute Notebook

* Re-Execute notebook

* Start with a validation notebook for KdV

* Test if KdV without convection is purely linear

* Finished KdV soliton comparison

* Faster way to compute the spectrum binning in higher dimensions

* Improved version of kolmogorov comparison
  • Loading branch information
Ceyron authored Sep 12, 2024
1 parent 034b4c5 commit a17af96
Show file tree
Hide file tree
Showing 12 changed files with 1,184 additions and 339 deletions.
21 changes: 16 additions & 5 deletions exponax/_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,13 +932,24 @@ def power_in_bucket(p, k):
mask = (wavenumbers_norm[0] >= lower_limit) & (
wavenumbers_norm[0] < upper_limit
)
return jnp.sum(p[mask])
# return jnp.sum(p[mask])
return jnp.where(
mask,
p,
0.0,
).sum()

for k in wavenumbers_1d[0, :]:
spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k))
def scan_fn(_, k):
return None, jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k)

spectrum = jnp.stack(spectrum, axis=-1)
# spectrum /= jnp.sum(spectrum, axis=-1, keepdims=True)
_, spectrum = jax.lax.scan(scan_fn, None, wavenumbers_1d[0, :])

spectrum = jnp.moveaxis(spectrum, 0, -1)

# for k in wavenumbers_1d[0, :]:
# spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k))

# spectrum = jnp.stack(spectrum, axis=-1)

return spectrum

Expand Down
7 changes: 7 additions & 0 deletions exponax/viz/_animate_facet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Array, Float
from matplotlib.animation import FuncAnimation

Expand Down Expand Up @@ -73,6 +74,8 @@ def animate_state_1d_facet(

num_subplots = trj.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for j, ax in enumerate(ax_s.flatten()):
plot_state_1d(
trj[j, 0],
Expand Down Expand Up @@ -257,6 +260,8 @@ def animate_state_2d_facet(

fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize)

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for j, ax in enumerate(ax_s.flatten()):
plot_state_2d(
trj[j, 0],
Expand Down Expand Up @@ -412,6 +417,8 @@ def animate_state_3d_facet(

# num_subplots = trj.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for j, ax in enumerate(ax_s.flatten()):
ax.imshow(imgs[j, 0])
ax.axis("off")
Expand Down
11 changes: 11 additions & 0 deletions exponax/viz/_plot_facet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Array, Float

from ._plot import (
Expand Down Expand Up @@ -68,6 +69,8 @@ def plot_state_1d_facet(

num_batches = states.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
if i < num_batches:
plot_state_1d(
Expand Down Expand Up @@ -161,6 +164,8 @@ def plot_spatio_temporal_facet(

num_subplots = trjs.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
single_trj = trjs[i]
plot_spatio_temporal(
Expand Down Expand Up @@ -247,6 +252,8 @@ def plot_state_2d_facet(

num_subplots = states.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
plot_state_2d(
states[i],
Expand Down Expand Up @@ -343,6 +350,8 @@ def plot_state_3d_facet(

num_subplots = states.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
plot_state_3d(
states[i],
Expand Down Expand Up @@ -454,6 +463,8 @@ def plot_spatio_temporal_2d_facet(

num_subplots = trjs.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
single_trj = trjs[i]
plot_spatio_temporal_2d(
Expand Down
93 changes: 93 additions & 0 deletions tests/test_linear_components_of_nonlinear_solvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import jax
import pytest

import exponax as ex


@pytest.mark.parametrize("num_spatial_dims", [1, 2, 3])
def test_kdv(num_spatial_dims: int):
DOMAIN_EXTENT = 5.0
NUM_POINTS = 48
DT = 0.01
DIFFUSIVITY = 0.1
DISPERSIVITY = 0.001
HYPER_DIFFUSIVITY = 0.0001

u_0 = ex.ic.RandomTruncatedFourierSeries(
num_spatial_dims=num_spatial_dims, max_one=True
)(NUM_POINTS, key=jax.random.PRNGKey(0))

kdv_stepper_only_viscous = ex.stepper.KortewegDeVries(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
convection_scale=0.0,
diffusivity=DIFFUSIVITY,
dispersivity=0.0,
hyper_diffusivity=0.0,
advect_over_diffuse=False,
diffuse_over_diffuse=False,
single_channel=True,
)
diffusion_stepper = ex.stepper.Diffusion(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
diffusivity=DIFFUSIVITY,
)

assert kdv_stepper_only_viscous(u_0) == pytest.approx(
diffusion_stepper(u_0), abs=1e-6
)

kdv_stepper_only_dispersion = ex.stepper.KortewegDeVries(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
convection_scale=0.0,
diffusivity=0.0,
dispersivity=-DISPERSIVITY,
hyper_diffusivity=0.0,
advect_over_diffuse=False,
diffuse_over_diffuse=False,
single_channel=True,
)
dispersion_stepper = ex.stepper.Dispersion(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
dispersivity=DISPERSIVITY,
)

assert kdv_stepper_only_dispersion(u_0) == pytest.approx(
dispersion_stepper(u_0), abs=1e-6
)

kdv_stepper_only_hyper_diffusion = ex.stepper.KortewegDeVries(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
convection_scale=0.0,
diffusivity=0.0,
dispersivity=0.0,
hyper_diffusivity=HYPER_DIFFUSIVITY,
advect_over_diffuse=False,
diffuse_over_diffuse=False,
single_channel=True,
)
hyper_diffusion_stepper = ex.stepper.HyperDiffusion(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
hyper_diffusivity=HYPER_DIFFUSIVITY,
)

assert kdv_stepper_only_hyper_diffusion(u_0) == pytest.approx(
hyper_diffusion_stepper(u_0), abs=1e-6
)
155 changes: 147 additions & 8 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,68 @@ def test_advection_1d():
assert u_1_pred == pytest.approx(u_1, rel=1e-4)


def test_advection_2d():
num_spatial_dims = 2
domain_extent = 10.0
num_points = 100
dt = 0.1
velocity = jnp.array([0.1, 0.2])

analytical_solution = lambda t, x: jnp.sin(
4 * 2 * jnp.pi * (x[0:1] - velocity[0] * t) / domain_extent
) * jnp.sin(6 * 2 * jnp.pi * (x[1:2] - velocity[1] * t) / domain_extent)

grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)
u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

stepper = ex.stepper.Advection(
num_spatial_dims,
domain_extent,
num_points,
dt,
velocity=velocity,
)

u_1_pred = stepper(u_0)

# Primarily only check the absolute difference
assert u_1_pred == pytest.approx(u_1, rel=1e-3)


def test_advection_3d():
num_spatial_dims = 3
domain_extent = 10.0
num_points = 40
dt = 0.1
velocity = jnp.array([0.1, 0.2, 0.3])

analytical_solution = (
lambda t, x: jnp.sin(
4 * 2 * jnp.pi * (x[0:1] - velocity[0] * t) / domain_extent
)
* jnp.sin(6 * 2 * jnp.pi * (x[1:2] - velocity[1] * t) / domain_extent)
* jnp.sin(8 * 2 * jnp.pi * (x[2:3] - velocity[2] * t) / domain_extent)
)

grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)
u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

stepper = ex.stepper.Advection(
num_spatial_dims,
domain_extent,
num_points,
dt,
velocity=velocity,
)

u_1_pred = stepper(u_0)

# Primarily only check the absolute difference
assert u_1_pred == pytest.approx(u_1, rel=1e-3)


def test_diffusion_1d():
num_spatial_dims = 1
domain_extent = 10.0
Expand Down Expand Up @@ -66,6 +128,89 @@ def test_diffusion_1d():
assert u_1_pred == pytest.approx(u_1, abs=1e-5)


def test_diffusion_2d():
num_spatial_dims = 2
domain_extent = 10.0
num_points = 100
dt = 0.1
diffusivity = 0.1

def analytical_solution(t, x):
# Third sine mode in x-direction and fourth sine mode in y-direction
third_sine_mode_x = jnp.sin(3 * 2 * jnp.pi * x[0:1] / domain_extent)
fourth_sine_mode_y = jnp.sin(4 * 2 * jnp.pi * x[1:2] / domain_extent)
exponent = (
-diffusivity
* (
(3 * 2 * jnp.pi / domain_extent) ** 2
+ (4 * 2 * jnp.pi / domain_extent) ** 2
)
* t
)
exp_term = jnp.exp(exponent)
return exp_term * third_sine_mode_x * fourth_sine_mode_y

grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)

u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

stepper = ex.stepper.Diffusion(
num_spatial_dims,
domain_extent,
num_points,
dt,
diffusivity=diffusivity,
)

u_1_pred = stepper(u_0)

assert u_1_pred == pytest.approx(u_1, abs=1e-5)


def test_diffusion_3d():
num_spatial_dims = 3
domain_extent = 10.0
num_points = 40
dt = 0.1
diffusivity = 0.1

def analytical_solution(t, x):
# Third sine mode in x-direction, fourth sine mode in y-direction, and
# fifth sine mode in z-direction
third_sine_mode_x = jnp.sin(3 * 2 * jnp.pi * x[0:1] / domain_extent)
fourth_sine_mode_y = jnp.sin(4 * 2 * jnp.pi * x[1:2] / domain_extent)
fifth_sine_mode_z = jnp.sin(5 * 2 * jnp.pi * x[2:3] / domain_extent)
exponent = (
-diffusivity
* (
(3 * 2 * jnp.pi / domain_extent) ** 2
+ (4 * 2 * jnp.pi / domain_extent) ** 2
+ (5 * 2 * jnp.pi / domain_extent) ** 2
)
* t
)
exp_term = jnp.exp(exponent)
return exp_term * third_sine_mode_x * fourth_sine_mode_y * fifth_sine_mode_z

grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)

u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

stepper = ex.stepper.Diffusion(
num_spatial_dims,
domain_extent,
num_points,
dt,
diffusivity=diffusivity,
)

u_1_pred = stepper(u_0)

assert u_1_pred == pytest.approx(u_1, abs=1e-5)


def test_validation_poisson_1d():
DOMAIN_EXTENT = 1.0
NUM_POINTS = 50
Expand Down Expand Up @@ -147,11 +292,5 @@ def test_validation_poisson_3d():
assert u == pytest.approx(analytical_solution, abs=1e-6)


# Nonlinear steppers

# Burgers can be test by comparing it with the solution obtained by Cole-Hopf
# transformation.


# The Korteveg-de Vries equation has an analytical solution, given the initial
# condition is a soliton.
# Find more validations that do not fit the format of a unit test in the
# `validation/` directory. For example, the following validations:
9 changes: 9 additions & 0 deletions validation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Validation

Since Fourier-pseudo spectral ETDRK methods are exact for linear bandlimited
problems (on periodic domains), this can be automatically validated to machine precision and is done in `tests/test_validation.py`. This folder contains additional validation notebooks for specific problems.

Additionally, run the script `qualitative rollouts` to produce a set
visualizations (1D -> spatio-temporal, 2D & 3D -> animations) of the
trajectories of the pre-built solvers. References to this can be found at:
https://github.com/Ceyron/exponax_qualitative_rollouts
Loading

0 comments on commit a17af96

Please sign in to comment.