diff --git a/exponax/_repeated_stepper.py b/exponax/_repeated_stepper.py index 4914622..f17c033 100644 --- a/exponax/_repeated_stepper.py +++ b/exponax/_repeated_stepper.py @@ -2,6 +2,7 @@ from jaxtyping import Array, Complex, Float from ._base_stepper import BaseStepper +from ._spectral import fft, ifft, spatial_shape from ._utils import repeat @@ -25,6 +26,10 @@ def __init__( Sugarcoat the utility function `repeat` in a callable PyTree for easy composition with other equinox modules. + !!! info + Performs the substepping in Fourier space to avoid unnecessary + back-and-forth transformations. + One intended usage is to get "more accurate" or "more stable" time steppers that perform substeps. @@ -56,6 +61,10 @@ def step( Step the PDE forward in time by `self.num_sub_steps` time steps given the current state `u`. + !!! info + Performs the substepping in Fourier space to avoid unnecessary + back-and-forth transformations. + **Arguments:** - `u`: The current state. @@ -64,7 +73,14 @@ def step( - `u_next`: The state after `self.num_sub_steps` time steps. """ - return repeat(self.stepper.step, self.num_sub_steps)(u) + u_hat = fft(u, num_spatial_dims=self.num_spatial_dims) + u_hat_after_steps = self.step_fourier(u_hat) + u_after_steps = ifft( + u_hat_after_steps, + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + ) + return u_after_steps def step_fourier( self, @@ -93,6 +109,10 @@ def __call__( Step the PDE forward in time by self.num_sub_steps time steps given the current state `u`. + !!! info + Performs the substepping in Fourier space to avoid unnecessary + back-and-forth transformations. + **Arguments:** - `u`: The current state. @@ -100,5 +120,19 @@ def __call__( **Returns:** - `u_next`: The state after `self.num_sub_steps` time steps. + + !!! tip + Use this call method together with `exponax.rollout` to efficiently + produce temporal trajectories. + + !!! info + For batched operation, use `jax.vmap` on this function. """ - return repeat(self.stepper, self.num_sub_steps)(u) + expected_shape = (self.num_channels,) + spatial_shape( + self.num_spatial_dims, self.num_points + ) + if u.shape != expected_shape: + raise ValueError( + f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function." + ) + return self.step(u) diff --git a/tests/test_repeated_stepper.py b/tests/test_repeated_stepper.py new file mode 100644 index 0000000..4e486bc --- /dev/null +++ b/tests/test_repeated_stepper.py @@ -0,0 +1,28 @@ +import jax +import pytest + +import exponax as ex + + +def test_repeated_stepper(): + DOMAIN_EXTENT = 1.0 + NUM_POINTS = 81 + DT = 0.1 + NUM_REPEATS = 10 + + burgers_stepper = ex.stepper.Burgers(1, DOMAIN_EXTENT, NUM_POINTS, DT) + + burgers_stepper_repeated = ex.RepeatedStepper(burgers_stepper, NUM_REPEATS) + + burgers_stepper_repeated_manually = ex.repeat(burgers_stepper, NUM_REPEATS) + + u_0 = ex.ic.RandomTruncatedFourierSeries(1, max_one=True)( + NUM_POINTS, key=jax.random.PRNGKey(0) + ) + + u_final = burgers_stepper_repeated(u_0) + u_final_manually = burgers_stepper_repeated_manually(u_0) + + # Need a looser rel tolerance because Burgers is a decaying phenomenon, + # hence the expected/reference state has low magnitude after 10 steps. + assert u_final == pytest.approx(u_final_manually, rel=1e-3)