Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Dec 21, 2024
1 parent a924458 commit bbacd45
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax
import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint
from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController,
RecursiveCheckpointAdjoint, SaveAt, diffeqsolve)
from helpers import MSE
from jax import numpy as jnp

Expand All @@ -15,15 +16,16 @@
@pytest.mark.parametrize("adjoint", ['DTO', 'OTD'])
def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor,
nbody_from_lpt1, nbody_from_lpt2, cosmo, order,
absolute_painting , adjoint):
absolute_painting, adjoint):

mesh_shape, _ = simulation_config
cosmo._workspace = {}

if adjoint == 'OTD':
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")
pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)")

adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())
adjoint = RecursiveCheckpointAdjoint(
) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5())

@jax.jit
@jax.grad
Expand Down

0 comments on commit bbacd45

Please sign in to comment.