Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More validation #39

Merged
merged 21 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions exponax/_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,13 +932,24 @@ def power_in_bucket(p, k):
mask = (wavenumbers_norm[0] >= lower_limit) & (
wavenumbers_norm[0] < upper_limit
)
return jnp.sum(p[mask])
# return jnp.sum(p[mask])
return jnp.where(
mask,
p,
0.0,
).sum()

for k in wavenumbers_1d[0, :]:
spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k))
def scan_fn(_, k):
return None, jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k)

spectrum = jnp.stack(spectrum, axis=-1)
# spectrum /= jnp.sum(spectrum, axis=-1, keepdims=True)
_, spectrum = jax.lax.scan(scan_fn, None, wavenumbers_1d[0, :])

spectrum = jnp.moveaxis(spectrum, 0, -1)

# for k in wavenumbers_1d[0, :]:
# spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k))

# spectrum = jnp.stack(spectrum, axis=-1)

return spectrum

Expand Down
7 changes: 7 additions & 0 deletions exponax/viz/_animate_facet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Array, Float
from matplotlib.animation import FuncAnimation

Expand Down Expand Up @@ -73,6 +74,8 @@ def animate_state_1d_facet(

num_subplots = trj.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for j, ax in enumerate(ax_s.flatten()):
plot_state_1d(
trj[j, 0],
Expand Down Expand Up @@ -257,6 +260,8 @@ def animate_state_2d_facet(

fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize)

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for j, ax in enumerate(ax_s.flatten()):
plot_state_2d(
trj[j, 0],
Expand Down Expand Up @@ -412,6 +417,8 @@ def animate_state_3d_facet(

# num_subplots = trj.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for j, ax in enumerate(ax_s.flatten()):
ax.imshow(imgs[j, 0])
ax.axis("off")
Expand Down
11 changes: 11 additions & 0 deletions exponax/viz/_plot_facet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Array, Float

from ._plot import (
Expand Down Expand Up @@ -68,6 +69,8 @@ def plot_state_1d_facet(

num_batches = states.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
if i < num_batches:
plot_state_1d(
Expand Down Expand Up @@ -161,6 +164,8 @@ def plot_spatio_temporal_facet(

num_subplots = trjs.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
single_trj = trjs[i]
plot_spatio_temporal(
Expand Down Expand Up @@ -247,6 +252,8 @@ def plot_state_2d_facet(

num_subplots = states.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
plot_state_2d(
states[i],
Expand Down Expand Up @@ -343,6 +350,8 @@ def plot_state_3d_facet(

num_subplots = states.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
plot_state_3d(
states[i],
Expand Down Expand Up @@ -454,6 +463,8 @@ def plot_spatio_temporal_2d_facet(

num_subplots = trjs.shape[0]

if grid[0] * grid[1] == 1:
ax_s = np.array([[ax_s]])
for i, ax in enumerate(ax_s.flatten()):
single_trj = trjs[i]
plot_spatio_temporal_2d(
Expand Down
93 changes: 93 additions & 0 deletions tests/test_linear_components_of_nonlinear_solvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import jax
import pytest

import exponax as ex


@pytest.mark.parametrize("num_spatial_dims", [1, 2, 3])
def test_kdv(num_spatial_dims: int):
DOMAIN_EXTENT = 5.0
NUM_POINTS = 48
DT = 0.01
DIFFUSIVITY = 0.1
DISPERSIVITY = 0.001
HYPER_DIFFUSIVITY = 0.0001

u_0 = ex.ic.RandomTruncatedFourierSeries(
num_spatial_dims=num_spatial_dims, max_one=True
)(NUM_POINTS, key=jax.random.PRNGKey(0))

kdv_stepper_only_viscous = ex.stepper.KortewegDeVries(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
convection_scale=0.0,
diffusivity=DIFFUSIVITY,
dispersivity=0.0,
hyper_diffusivity=0.0,
advect_over_diffuse=False,
diffuse_over_diffuse=False,
single_channel=True,
)
diffusion_stepper = ex.stepper.Diffusion(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
diffusivity=DIFFUSIVITY,
)

assert kdv_stepper_only_viscous(u_0) == pytest.approx(
diffusion_stepper(u_0), abs=1e-6
)

kdv_stepper_only_dispersion = ex.stepper.KortewegDeVries(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
convection_scale=0.0,
diffusivity=0.0,
dispersivity=-DISPERSIVITY,
hyper_diffusivity=0.0,
advect_over_diffuse=False,
diffuse_over_diffuse=False,
single_channel=True,
)
dispersion_stepper = ex.stepper.Dispersion(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
dispersivity=DISPERSIVITY,
)

assert kdv_stepper_only_dispersion(u_0) == pytest.approx(
dispersion_stepper(u_0), abs=1e-6
)

kdv_stepper_only_hyper_diffusion = ex.stepper.KortewegDeVries(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
convection_scale=0.0,
diffusivity=0.0,
dispersivity=0.0,
hyper_diffusivity=HYPER_DIFFUSIVITY,
advect_over_diffuse=False,
diffuse_over_diffuse=False,
single_channel=True,
)
hyper_diffusion_stepper = ex.stepper.HyperDiffusion(
num_spatial_dims=num_spatial_dims,
domain_extent=DOMAIN_EXTENT,
num_points=NUM_POINTS,
dt=DT,
hyper_diffusivity=HYPER_DIFFUSIVITY,
)

assert kdv_stepper_only_hyper_diffusion(u_0) == pytest.approx(
hyper_diffusion_stepper(u_0), abs=1e-6
)
155 changes: 147 additions & 8 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,68 @@ def test_advection_1d():
assert u_1_pred == pytest.approx(u_1, rel=1e-4)


def test_advection_2d():
num_spatial_dims = 2
domain_extent = 10.0
num_points = 100
dt = 0.1
velocity = jnp.array([0.1, 0.2])

analytical_solution = lambda t, x: jnp.sin(
4 * 2 * jnp.pi * (x[0:1] - velocity[0] * t) / domain_extent
) * jnp.sin(6 * 2 * jnp.pi * (x[1:2] - velocity[1] * t) / domain_extent)

grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)
u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

stepper = ex.stepper.Advection(
num_spatial_dims,
domain_extent,
num_points,
dt,
velocity=velocity,
)

u_1_pred = stepper(u_0)

# Primarily only check the absolute difference
assert u_1_pred == pytest.approx(u_1, rel=1e-3)


def test_advection_3d():
num_spatial_dims = 3
domain_extent = 10.0
num_points = 40
dt = 0.1
velocity = jnp.array([0.1, 0.2, 0.3])

analytical_solution = (
lambda t, x: jnp.sin(
4 * 2 * jnp.pi * (x[0:1] - velocity[0] * t) / domain_extent
)
* jnp.sin(6 * 2 * jnp.pi * (x[1:2] - velocity[1] * t) / domain_extent)
* jnp.sin(8 * 2 * jnp.pi * (x[2:3] - velocity[2] * t) / domain_extent)
)

grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)
u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

stepper = ex.stepper.Advection(
num_spatial_dims,
domain_extent,
num_points,
dt,
velocity=velocity,
)

u_1_pred = stepper(u_0)

# Primarily only check the absolute difference
assert u_1_pred == pytest.approx(u_1, rel=1e-3)


def test_diffusion_1d():
num_spatial_dims = 1
domain_extent = 10.0
Expand Down Expand Up @@ -66,6 +128,89 @@ def test_diffusion_1d():
assert u_1_pred == pytest.approx(u_1, abs=1e-5)


def test_diffusion_2d():
num_spatial_dims = 2
domain_extent = 10.0
num_points = 100
dt = 0.1
diffusivity = 0.1

def analytical_solution(t, x):
# Third sine mode in x-direction and fourth sine mode in y-direction
third_sine_mode_x = jnp.sin(3 * 2 * jnp.pi * x[0:1] / domain_extent)
fourth_sine_mode_y = jnp.sin(4 * 2 * jnp.pi * x[1:2] / domain_extent)
exponent = (
-diffusivity
* (
(3 * 2 * jnp.pi / domain_extent) ** 2
+ (4 * 2 * jnp.pi / domain_extent) ** 2
)
* t
)
exp_term = jnp.exp(exponent)
return exp_term * third_sine_mode_x * fourth_sine_mode_y

grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)

u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

stepper = ex.stepper.Diffusion(
num_spatial_dims,
domain_extent,
num_points,
dt,
diffusivity=diffusivity,
)

u_1_pred = stepper(u_0)

assert u_1_pred == pytest.approx(u_1, abs=1e-5)


def test_diffusion_3d():
num_spatial_dims = 3
domain_extent = 10.0
num_points = 40
dt = 0.1
diffusivity = 0.1

def analytical_solution(t, x):
# Third sine mode in x-direction, fourth sine mode in y-direction, and
# fifth sine mode in z-direction
third_sine_mode_x = jnp.sin(3 * 2 * jnp.pi * x[0:1] / domain_extent)
fourth_sine_mode_y = jnp.sin(4 * 2 * jnp.pi * x[1:2] / domain_extent)
fifth_sine_mode_z = jnp.sin(5 * 2 * jnp.pi * x[2:3] / domain_extent)
exponent = (
-diffusivity
* (
(3 * 2 * jnp.pi / domain_extent) ** 2
+ (4 * 2 * jnp.pi / domain_extent) ** 2
+ (5 * 2 * jnp.pi / domain_extent) ** 2
)
* t
)
exp_term = jnp.exp(exponent)
return exp_term * third_sine_mode_x * fourth_sine_mode_y * fifth_sine_mode_z

grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)

u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

stepper = ex.stepper.Diffusion(
num_spatial_dims,
domain_extent,
num_points,
dt,
diffusivity=diffusivity,
)

u_1_pred = stepper(u_0)

assert u_1_pred == pytest.approx(u_1, abs=1e-5)


def test_validation_poisson_1d():
DOMAIN_EXTENT = 1.0
NUM_POINTS = 50
Expand Down Expand Up @@ -147,11 +292,5 @@ def test_validation_poisson_3d():
assert u == pytest.approx(analytical_solution, abs=1e-6)


# Nonlinear steppers

# Burgers can be test by comparing it with the solution obtained by Cole-Hopf
# transformation.


# The Korteveg-de Vries equation has an analytical solution, given the initial
# condition is a soliton.
# Find more validations that do not fit the format of a unit test in the
# `validation/` directory. For example, the following validations:
9 changes: 9 additions & 0 deletions validation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Validation

Since Fourier-pseudo spectral ETDRK methods are exact for linear bandlimited
problems (on periodic domains), this can be automatically validated to machine precision and is done in `tests/test_validation.py`. This folder contains additional validation notebooks for specific problems.

Additionally, run the script `qualitative rollouts` to produce a set
visualizations (1D -> spatio-temporal, 2D & 3D -> animations) of the
trajectories of the pre-built solvers. References to this can be found at:
https://github.com/Ceyron/exponax_qualitative_rollouts
Loading
Loading