diff --git a/tests/test_builtin_solvers.py b/tests/test_builtin_solvers.py index 37f45b0..88919a7 100644 --- a/tests/test_builtin_solvers.py +++ b/tests/test_builtin_solvers.py @@ -46,7 +46,20 @@ def test_instantiate(): normalized_simulator(num_spatial_dims, num_points) -def test_linear_normalized_stepper(): +@pytest.mark.parametrize( + "coefficients", + [ + [ + 0.5, + ], # drag + [0.0, -0.3], # advection + [0.0, 0.0, 0.01], # diffusion + [0.0, -0.2, 0.01], # advection-diffusion + [0.0, 0.0, 0.0, 0.001], # dispersion + [0.0, 0.0, 0.0, 0.0, -0.0001], # hyperdiffusion + ] +) +def test_linear_normalized_stepper(coefficients): num_spatial_dims = 1 domain_extent = 3.0 num_points = 50 @@ -58,37 +71,27 @@ def test_linear_normalized_stepper(): cutoff=5, )(num_points, key=jax.random.PRNGKey(0)) - for coefficients in ( - [ - 0.5, - ], # drag - [0.0, -0.3], # advection - [0.0, 0.0, 0.01], # diffusion - [0.0, -0.2, 0.01], # advection-diffusion - [0.0, 0.0, 0.0, 0.001], # dispersion - [0.0, 0.0, 0.0, 0.0, -0.0001], # hyperdiffusion - ): - regular_linear_stepper = ex.GeneralLinearStepper( - num_spatial_dims, + regular_linear_stepper = ex.GeneralLinearStepper( + num_spatial_dims, + domain_extent, + num_points, + dt, + coefficients=coefficients, + ) + normalized_linear_stepper = ex.NormalizedLinearStepper( + num_spatial_dims, + num_points, + normalized_coefficients=ex.normalize_coefficients( domain_extent, - num_points, - dt, - coefficients=coefficients, - ) - normalized_linear_stepper = ex.NormalizedLinearStepper( - num_spatial_dims, - num_points, - normalized_coefficients=ex.normalize_coefficients( - domain_extent, - dt, - coefficients, - ), - ) - - regular_linear_pred = regular_linear_stepper(u_0) - normalized_linear_pred = normalized_linear_stepper(u_0) - - assert regular_linear_pred == pytest.approx(normalized_linear_pred) + coefficients, + ), + dt=dt, + ) + + regular_linear_pred = regular_linear_stepper(u_0) + normalized_linear_pred = normalized_linear_stepper(u_0) + + assert regular_linear_pred == pytest.approx(normalized_linear_pred, rel=1e-4) def test_nonlinear_normalized_stepper(): @@ -97,7 +100,7 @@ def test_nonlinear_normalized_stepper(): num_points = 50 dt = 0.1 diffusivity = 0.1 - convection_scale = 0.5 + convection_scale = 1.0 grid = ex.get_grid(num_spatial_dims, domain_extent, num_points) u_0 = jnp.sin(2 * jnp.pi * grid / domain_extent) + 0.3 @@ -116,7 +119,6 @@ def test_nonlinear_normalized_stepper(): dt=dt, normalized_coefficients=ex.normalize_coefficients( domain_extent, - dt, [0.0, 0.0, diffusivity], ), normalized_convection_scale=ex.normalize_convection_scale( diff --git a/tests/test_validation.py b/tests/test_validation.py index e617cb9..d1ae329 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -46,7 +46,7 @@ def test_diffusion_1d(): diffusivity = 0.1 analytical_solution = lambda t, x: jnp.exp( - -4 * (2 * jnp.pi / domain_extent) ** 2 * diffusivity * t + -(4 * 2 * jnp.pi / domain_extent) ** 2 * diffusivity * t ) * jnp.sin(4 * 2 * jnp.pi * x / domain_extent) grid = ex.get_grid(num_spatial_dims, domain_extent, num_points)