Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Feb 27, 2024
1 parent 3edbb2b commit 77133ea
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 34 deletions.
68 changes: 35 additions & 33 deletions tests/test_builtin_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,20 @@ def test_instantiate():
normalized_simulator(num_spatial_dims, num_points)


def test_linear_normalized_stepper():
@pytest.mark.parametrize(
"coefficients",
[
[
0.5,
], # drag
[0.0, -0.3], # advection
[0.0, 0.0, 0.01], # diffusion
[0.0, -0.2, 0.01], # advection-diffusion
[0.0, 0.0, 0.0, 0.001], # dispersion
[0.0, 0.0, 0.0, 0.0, -0.0001], # hyperdiffusion
]
)
def test_linear_normalized_stepper(coefficients):
num_spatial_dims = 1
domain_extent = 3.0
num_points = 50
Expand All @@ -58,37 +71,27 @@ def test_linear_normalized_stepper():
cutoff=5,
)(num_points, key=jax.random.PRNGKey(0))

for coefficients in (
[
0.5,
], # drag
[0.0, -0.3], # advection
[0.0, 0.0, 0.01], # diffusion
[0.0, -0.2, 0.01], # advection-diffusion
[0.0, 0.0, 0.0, 0.001], # dispersion
[0.0, 0.0, 0.0, 0.0, -0.0001], # hyperdiffusion
):
regular_linear_stepper = ex.GeneralLinearStepper(
num_spatial_dims,
regular_linear_stepper = ex.GeneralLinearStepper(
num_spatial_dims,
domain_extent,
num_points,
dt,
coefficients=coefficients,
)
normalized_linear_stepper = ex.NormalizedLinearStepper(
num_spatial_dims,
num_points,
normalized_coefficients=ex.normalize_coefficients(
domain_extent,
num_points,
dt,
coefficients=coefficients,
)
normalized_linear_stepper = ex.NormalizedLinearStepper(
num_spatial_dims,
num_points,
normalized_coefficients=ex.normalize_coefficients(
domain_extent,
dt,
coefficients,
),
)

regular_linear_pred = regular_linear_stepper(u_0)
normalized_linear_pred = normalized_linear_stepper(u_0)

assert regular_linear_pred == pytest.approx(normalized_linear_pred)
coefficients,
),
dt=dt,
)

regular_linear_pred = regular_linear_stepper(u_0)
normalized_linear_pred = normalized_linear_stepper(u_0)

assert regular_linear_pred == pytest.approx(normalized_linear_pred, rel=1e-4)


def test_nonlinear_normalized_stepper():
Expand All @@ -97,7 +100,7 @@ def test_nonlinear_normalized_stepper():
num_points = 50
dt = 0.1
diffusivity = 0.1
convection_scale = 0.5
convection_scale = 1.0

grid = ex.get_grid(num_spatial_dims, domain_extent, num_points)
u_0 = jnp.sin(2 * jnp.pi * grid / domain_extent) + 0.3
Expand All @@ -116,7 +119,6 @@ def test_nonlinear_normalized_stepper():
dt=dt,
normalized_coefficients=ex.normalize_coefficients(
domain_extent,
dt,
[0.0, 0.0, diffusivity],
),
normalized_convection_scale=ex.normalize_convection_scale(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_diffusion_1d():
diffusivity = 0.1

analytical_solution = lambda t, x: jnp.exp(
-4 * (2 * jnp.pi / domain_extent) ** 2 * diffusivity * t
-(4 * 2 * jnp.pi / domain_extent) ** 2 * diffusivity * t
) * jnp.sin(4 * 2 * jnp.pi * x / domain_extent)

grid = ex.get_grid(num_spatial_dims, domain_extent, num_points)
Expand Down

0 comments on commit 77133ea

Please sign in to comment.