Skip to content

Commit

Permalink
Fourier repeated stepper (#29)
Browse files Browse the repository at this point in the history
* Implement logic of fourier substepping

* Forward .step() to call signature with shape check

* Add hints from Base Stepper

* Add hints on cost-saving

* Also add hint to constructor

* Add a test for the new repeated stepper
  • Loading branch information
Ceyron authored Sep 4, 2024
1 parent 96a1116 commit 1d52578
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
38 changes: 36 additions & 2 deletions exponax/_repeated_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from jaxtyping import Array, Complex, Float

from ._base_stepper import BaseStepper
from ._spectral import fft, ifft, spatial_shape
from ._utils import repeat


Expand All @@ -25,6 +26,10 @@ def __init__(
Sugarcoat the utility function `repeat` in a callable PyTree for easy
composition with other equinox modules.
!!! info
Performs the substepping in Fourier space to avoid unnecessary
back-and-forth transformations.
One intended usage is to get "more accurate" or "more stable" time steppers
that perform substeps.
Expand Down Expand Up @@ -56,6 +61,10 @@ def step(
Step the PDE forward in time by `self.num_sub_steps` time steps given the
current state `u`.
!!! info
Performs the substepping in Fourier space to avoid unnecessary
back-and-forth transformations.
**Arguments:**
- `u`: The current state.
Expand All @@ -64,7 +73,14 @@ def step(
- `u_next`: The state after `self.num_sub_steps` time steps.
"""
return repeat(self.stepper.step, self.num_sub_steps)(u)
u_hat = fft(u, num_spatial_dims=self.num_spatial_dims)
u_hat_after_steps = self.step_fourier(u_hat)
u_after_steps = ifft(
u_hat_after_steps,
num_spatial_dims=self.num_spatial_dims,
num_points=self.num_points,
)
return u_after_steps

def step_fourier(
self,
Expand Down Expand Up @@ -93,12 +109,30 @@ def __call__(
Step the PDE forward in time by self.num_sub_steps time steps given the
current state `u`.
!!! info
Performs the substepping in Fourier space to avoid unnecessary
back-and-forth transformations.
**Arguments:**
- `u`: The current state.
**Returns:**
- `u_next`: The state after `self.num_sub_steps` time steps.
!!! tip
Use this call method together with `exponax.rollout` to efficiently
produce temporal trajectories.
!!! info
For batched operation, use `jax.vmap` on this function.
"""
return repeat(self.stepper, self.num_sub_steps)(u)
expected_shape = (self.num_channels,) + spatial_shape(
self.num_spatial_dims, self.num_points
)
if u.shape != expected_shape:
raise ValueError(
f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
)
return self.step(u)
28 changes: 28 additions & 0 deletions tests/test_repeated_stepper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import jax
import pytest

import exponax as ex


def test_repeated_stepper():
DOMAIN_EXTENT = 1.0
NUM_POINTS = 81
DT = 0.1
NUM_REPEATS = 10

burgers_stepper = ex.stepper.Burgers(1, DOMAIN_EXTENT, NUM_POINTS, DT)

burgers_stepper_repeated = ex.RepeatedStepper(burgers_stepper, NUM_REPEATS)

burgers_stepper_repeated_manually = ex.repeat(burgers_stepper, NUM_REPEATS)

u_0 = ex.ic.RandomTruncatedFourierSeries(1, max_one=True)(
NUM_POINTS, key=jax.random.PRNGKey(0)
)

u_final = burgers_stepper_repeated(u_0)
u_final_manually = burgers_stepper_repeated_manually(u_0)

# Need a looser rel tolerance because Burgers is a decaying phenomenon,
# hence the expected/reference state has low magnitude after 10 steps.
assert u_final == pytest.approx(u_final_manually, rel=1e-3)

0 comments on commit 1d52578

Please sign in to comment.