From cb2a7ab17fd4141b685780d75d3b70986dd51417 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Sun, 22 Dec 2024 18:47:42 +0100 Subject: [PATCH] Fix pfft gradients (#34) * update jaxdecomp version and test gradients * Prepare for DTO tests * format --- pyproject.toml | 2 +- tests/test_gradients.py | 87 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 tests/test_gradients.py diff --git a/pyproject.toml b/pyproject.toml index a41096d..a204633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ readme = "README.md" requires-python = ">=3.9" license = { file = "LICENSE" } urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" } -dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"] +dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"] [tool.setuptools] packages = ["jaxpm"] diff --git a/tests/test_gradients.py b/tests/test_gradients.py new file mode 100644 index 0000000..bb48920 --- /dev/null +++ b/tests/test_gradients.py @@ -0,0 +1,87 @@ +import jax +import pytest +from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController, + RecursiveCheckpointAdjoint, SaveAt, diffeqsolve) +from helpers import MSE +from jax import numpy as jnp + +from jaxpm.distributed import uniform_particles +from jaxpm.painting import cic_paint, cic_paint_dx +from jaxpm.pm import lpt, make_diffrax_ode + + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +@pytest.mark.parametrize("absolute_painting", [True, False]) +@pytest.mark.parametrize("adjoint", ['DTO', 'OTD']) +def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, + nbody_from_lpt1, nbody_from_lpt2, cosmo, order, + absolute_painting, adjoint): + + mesh_shape, _ = simulation_config + cosmo._workspace = {} + + if adjoint == 'OTD': + pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)") + + adjoint = RecursiveCheckpointAdjoint( + ) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5()) + + @jax.jit + @jax.grad + def forward_model(initial_conditions, cosmo): + + # Initial displacement + if absolute_painting: + particles = uniform_particles(mesh_shape) + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=lpt_scale_factor, + order=order) + ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + y0 = jnp.stack([particles + dx, p]) + + else: + dx, p, _ = lpt(cosmo, + initial_conditions, + a=lpt_scale_factor, + order=order) + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + y0 = jnp.stack([dx, p]) + + solver = Dopri5() + controller = PIDController(rtol=1e-7, + atol=1e-7, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + adjoint=adjoint, + stepsize_controller=controller, + saveat=saveat) + + if absolute_painting: + final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0]) + else: + final_field = cic_paint_dx(solutions.ys[-1, 0]) + + return MSE(final_field, + nbody_from_lpt1 if order == 1 else nbody_from_lpt2) + + bad_initial_conditions = initial_conditions + jax.random.normal( + jax.random.PRNGKey(0), initial_conditions.shape) * 0.5 + best_ic = forward_model(initial_conditions, cosmo) + bad_ic = forward_model(bad_initial_conditions, cosmo) + + assert jnp.max(best_ic) < 1e-5 + assert jnp.max(bad_ic) > 1e-5