Skip to content

Commit

Permalink
remove duplicate get_ode_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Oct 22, 2024
1 parent 82f2987 commit 31ca41b
Showing 1 changed file with 0 additions and 22 deletions.
22 changes: 0 additions & 22 deletions jaxpm/pm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,28 +158,6 @@ def nbody_ode(state, a, cosmo):
return nbody_ode


def get_ode_fn(cosmo, mesh_shape):

def nbody_ode(a, state, args):
"""
State is an array [position, velocities]
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m

# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel

# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces

return jnp.stack([dpos, dvel])

return nbody_ode


def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):

def nbody_ode(a, state, args):
Expand Down

0 comments on commit 31ca41b

Please sign in to comment.