Skip to content

Commit

Permalink
Merge branch 'dev' into spectral-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Sep 5, 2024
2 parents 375623e + 1b2f310 commit f639e14
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
run-test:
strategy:
matrix:
python-version: [ 3.10, 3.11, 3.12 ]
python-version: [ "3.10", "3.11", "3.12" ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
4 changes: 3 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
jaxlib
jaxlib
pytest
vape4d
4 changes: 3 additions & 1 deletion tests/test_builtin_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,4 +366,6 @@ def test_nonlinear_normalized_stepper():
regular_burgers_pred = regular_burgers_stepper(u_0)
normalized_burgers_pred = normalized_burgers_stepper(u_0)

assert regular_burgers_pred == pytest.approx(normalized_burgers_pred)
assert regular_burgers_pred == pytest.approx(
normalized_burgers_pred, rel=1e-5, abs=1e-5
)
90 changes: 56 additions & 34 deletions tests/test_mode_slices_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,95 @@

def test_mode_slices_generation():
# 1D
assert ex.spectral.get_modes_slices(1, 10) == [
[slice(None), slice(None, (10 // 2) + 1)]
]
assert ex.spectral.get_modes_slices(1, 11) == [
[slice(None), slice(None, (11 // 2) + 1)]
]
assert ex.spectral.get_modes_slices(1, 10) == (
(
slice(None),
slice(None, (10 // 2) + 1),
),
)
assert ex.spectral.get_modes_slices(1, 11) == (
(
slice(None),
slice(None, (11 // 2) + 1),
),
)

# 2D
assert ex.spectral.get_modes_slices(2, 10) == [
[slice(None), slice(None, (10 // 2)), slice(None, (10 // 2) + 1)],
[slice(None), slice(-(10 // 2), None), slice(None, (10 // 2) + 1)],
]
assert ex.spectral.get_modes_slices(2, 11) == [
[slice(None), slice(None, (11 // 2) + 1), slice(None, (11 // 2) + 1)],
[slice(None), slice(-(11 // 2), None), slice(None, (11 // 2) + 1)],
]
assert ex.spectral.get_modes_slices(2, 10) == (
(
slice(None),
slice(None, (10 // 2)),
slice(None, (10 // 2) + 1),
),
(
slice(None),
slice(-(10 // 2), None),
slice(None, (10 // 2) + 1),
),
)
assert ex.spectral.get_modes_slices(2, 11) == (
(
slice(None),
slice(None, (11 // 2) + 1),
slice(None, (11 // 2) + 1),
),
(
slice(None),
slice(-(11 // 2), None),
slice(None, (11 // 2) + 1),
),
)

# 3D
assert ex.spectral.get_modes_slices(3, 10) == [
[
assert ex.spectral.get_modes_slices(3, 10) == (
(
slice(None),
slice(None, (10 // 2)),
slice(None, (10 // 2)),
slice(None, (10 // 2) + 1),
],
[
),
(
slice(None),
slice(-(10 // 2), None),
slice(None, (10 // 2)),
slice(None, (10 // 2) + 1),
],
[
),
(
slice(None),
slice(None, (10 // 2)),
slice(-(10 // 2), None),
slice(None, (10 // 2) + 1),
],
[
),
(
slice(None),
slice(-(10 // 2), None),
slice(-(10 // 2), None),
slice(None, (10 // 2) + 1),
],
]
assert ex.spectral.get_modes_slices(3, 11) == [
[
),
)
assert ex.spectral.get_modes_slices(3, 11) == (
(
slice(None),
slice(None, (11 // 2) + 1),
slice(None, (11 // 2) + 1),
slice(None, (11 // 2) + 1),
],
[
),
(
slice(None),
slice(-(11 // 2), None),
slice(None, (11 // 2) + 1),
slice(None, (11 // 2) + 1),
],
[
),
(
slice(None),
slice(None, (11 // 2) + 1),
slice(-(11 // 2), None),
slice(None, (11 // 2) + 1),
],
[
),
(
slice(None),
slice(-(11 // 2), None),
slice(-(11 // 2), None),
slice(None, (11 // 2) + 1),
],
]
),
)
10 changes: 6 additions & 4 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def test_plot_state_2d():
plt.close(fig)


def test_plot_state_3d():
state = jax.random.normal(jax.random.PRNGKey(0), (1, 32, 32, 32))
# # Requires a GPU and therefore cannot easily be tested on GitHub Actions

fig = ex.viz.plot_state_3d(state)
plt.close(fig)
# def test_plot_state_3d():
# state = jax.random.normal(jax.random.PRNGKey(0), (1, 32, 32, 32))

# fig = ex.viz.plot_state_3d(state)
# plt.close(fig)

0 comments on commit f639e14

Please sign in to comment.