-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement logic of fourier substepping * Forward .step() to call signature with shape check * Add hints from Base Stepper * Add hints on cost-saving * Also add hint to constructor * Add a test for the new repeated stepper
- Loading branch information
Showing
2 changed files
with
64 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import jax | ||
import pytest | ||
|
||
import exponax as ex | ||
|
||
|
||
def test_repeated_stepper(): | ||
DOMAIN_EXTENT = 1.0 | ||
NUM_POINTS = 81 | ||
DT = 0.1 | ||
NUM_REPEATS = 10 | ||
|
||
burgers_stepper = ex.stepper.Burgers(1, DOMAIN_EXTENT, NUM_POINTS, DT) | ||
|
||
burgers_stepper_repeated = ex.RepeatedStepper(burgers_stepper, NUM_REPEATS) | ||
|
||
burgers_stepper_repeated_manually = ex.repeat(burgers_stepper, NUM_REPEATS) | ||
|
||
u_0 = ex.ic.RandomTruncatedFourierSeries(1, max_one=True)( | ||
NUM_POINTS, key=jax.random.PRNGKey(0) | ||
) | ||
|
||
u_final = burgers_stepper_repeated(u_0) | ||
u_final_manually = burgers_stepper_repeated_manually(u_0) | ||
|
||
# Need a looser rel tolerance because Burgers is a decaying phenomenon, | ||
# hence the expected/reference state has low magnitude after 10 steps. | ||
assert u_final == pytest.approx(u_final_manually, rel=1e-3) |