From bbacd45dcfa5d3921abb295e51a1cda592d92caf Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sat, 21 Dec 2024 23:27:05 +0100 Subject: [PATCH] format --- tests/test_gradients.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index b35d656..bb48920 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,6 +1,7 @@ import jax import pytest -from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint +from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController, + RecursiveCheckpointAdjoint, SaveAt, diffeqsolve) from helpers import MSE from jax import numpy as jnp @@ -15,15 +16,16 @@ @pytest.mark.parametrize("adjoint", ['DTO', 'OTD']) def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, cosmo, order, - absolute_painting , adjoint): + absolute_painting, adjoint): mesh_shape, _ = simulation_config cosmo._workspace = {} if adjoint == 'OTD': - pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)") + pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)") - adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5()) + adjoint = RecursiveCheckpointAdjoint( + ) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5()) @jax.jit @jax.grad