Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fourier Interpolation and Up/Downsampling #33

Merged
merged 21 commits into from
Sep 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add tests for mapping between resolutions
Ceyron committed Sep 4, 2024
commit baa0511947a1d8de465425d818e06dfc66024aa9
138 changes: 138 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -163,3 +163,141 @@ def test_fourier_interpolator_3d(

# Looser rel and abs tol because JAX runs in single precision by default
assert interpolated_u == pytest.approx(correct_val, rel=1e-5, abs=1e-5)


@pytest.mark.parametrize(
"fn_name, domain_extent, num_points_old, num_points_new,",
[
("constant", 1.0, 20, 30),
("constant", 1.0, 20, 15),
("simple_sine", 1.0, 20, 30),
("simple_sine", 1.0, 20, 15),
("simple_cosine", 1.0, 20, 30),
("simple_cosine", 1.0, 20, 15),
("complicated_fn", 10.0, 100, 200),
("complicated_fn", 10.0, 100, 51),
],
)
def test_map_between_resolutions_1d(
fn_name: str,
domain_extent: float,
num_points_old: int,
num_points_new: int,
):
fn = FN_DICT_1D[fn_name]
grid_old = ex.make_grid(1, domain_extent, num_points_old)
grid_new = ex.make_grid(1, domain_extent, num_points_new)

u_old = fn(grid_old)

u_new = ex.map_between_resolutions(u_old, num_points_new)

u_new_correct = fn(grid_new)

# Looser rel tol because JAX runs in single precision by default, and the
# FFT incorse some rounding errors
assert u_new == pytest.approx(u_new_correct, rel=10.0, abs=1e-6)


@pytest.mark.parametrize(
"fn_name, domain_extent, num_points_old, num_points_new,",
[
("constant", 1.0, 20, 30),
("constant", 1.0, 20, 15),
("simple_sine_x", 1.0, 20, 30),
("simple_sine_x", 1.0, 20, 15),
("simple_cosine_x", 1.0, 20, 30),
("simple_cosine_x", 1.0, 20, 15),
("simple_sine_y", 1.0, 20, 30),
("simple_sine_y", 1.0, 20, 15),
("simple_cosine_y", 1.0, 20, 30),
("simple_cosine_y", 1.0, 20, 15),
("mixed_sine", 1.0, 20, 30),
("mixed_sine", 1.0, 20, 15),
("mixed_cosine", 1.0, 20, 30),
("mixed_cosine", 1.0, 20, 15),
("complicated_fn_x", 10.0, 100, 200),
("complicated_fn_x", 10.0, 100, 51),
("complicated_fn_y", 10.0, 100, 200),
("complicated_fn_y", 10.0, 100, 51),
("complicated_fn_xy", 10.0, 100, 200),
("complicated_fn_xy", 10.0, 100, 51),
],
)
def test_map_between_resolutions_2d(
fn_name: str,
domain_extent: float,
num_points_old: int,
num_points_new: int,
):
fn = FN_DICT_2D[fn_name]
grid_old = ex.make_grid(2, domain_extent, num_points_old)
grid_new = ex.make_grid(2, domain_extent, num_points_new)

u_old = fn(grid_old)

u_new = ex.map_between_resolutions(u_old, num_points_new)

u_new_correct = fn(grid_new)

# Looser rel tol because JAX runs in single precision by default, and the
# FFT incorse some rounding errors
assert u_new == pytest.approx(u_new_correct, rel=10.0, abs=1e-6)


@pytest.mark.parametrize(
"fn_name, domain_extent, num_points_old, num_points_new,",
[
("constant", 1.0, 20, 30),
("constant", 1.0, 20, 15),
("simple_sine_x", 1.0, 20, 30),
("simple_sine_x", 1.0, 20, 15),
("simple_cosine_x", 1.0, 20, 30),
("simple_cosine_x", 1.0, 20, 15),
("simple_sine_y", 1.0, 20, 30),
("simple_sine_y", 1.0, 20, 15),
("simple_cosine_y", 1.0, 20, 30),
("simple_cosine_y", 1.0, 20, 15),
("simple_sine_z", 1.0, 20, 30),
("simple_sine_z", 1.0, 20, 15),
("simple_cosine_z", 1.0, 20, 30),
("simple_cosine_z", 1.0, 20, 15),
("mixed_sine", 1.0, 20, 30),
("mixed_sine", 1.0, 20, 15),
("mixed_cosine", 1.0, 20, 30),
("mixed_cosine", 1.0, 20, 15),
("complicated_fn_x", 10.0, 40, 50),
("complicated_fn_x", 10.0, 40, 33),
("complicated_fn_y", 10.0, 40, 50),
("complicated_fn_y", 10.0, 40, 33),
("complicated_fn_z", 10.0, 40, 50),
("complicated_fn_z", 10.0, 40, 33),
("complicated_fn_xy", 10.0, 40, 50),
("complicated_fn_xy", 10.0, 40, 33),
("complicated_fn_xz", 10.0, 40, 50),
("complicated_fn_xz", 10.0, 40, 33),
("complicated_fn_yz", 10.0, 40, 50),
("complicated_fn_yz", 10.0, 40, 33),
("complicated_fn_xyz", 10.0, 40, 50),
("complicated_fn_xyz", 10.0, 40, 33),
],
)
def test_map_between_resolutions_3d(
fn_name: str,
domain_extent: float,
num_points_old: int,
num_points_new: int,
):
fn = FN_DICT_3D[fn_name]
grid_old = ex.make_grid(3, domain_extent, num_points_old)
grid_new = ex.make_grid(3, domain_extent, num_points_new)

u_old = fn(grid_old)

u_new = ex.map_between_resolutions(u_old, num_points_new)

u_new_correct = fn(grid_new)

# Looser rel tol because JAX runs in single precision by default, and the
# FFT incorse some rounding errors
assert u_new == pytest.approx(u_new_correct, rel=10.0, abs=3e-4)