Skip to content

Commit

Permalink
By default use absoulute painting with
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Dec 5, 2024
1 parent c1b276d commit e0c118a
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions jaxpm/pm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def pm_forces(positions,
mesh_shape=None,
delta=None,
r_split=0,
paint_particles=False,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
"""
Expand All @@ -28,7 +28,7 @@ def pm_forces(positions,
"If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape

if paint_particles:
if paint_absolute_pos:
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
x,
halo_size=halo_size,
Expand Down Expand Up @@ -72,14 +72,14 @@ def lpt(cosmo,
Computes first and second order LPT displacement and momentum,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
"""
paint_particles = particles is not None
paint_absolute_pos = particles is not None

a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = fft3d(initial_conditions)
initial_force = pm_forces(particles,
delta=delta_k,
paint_particles=paint_particles,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force
Expand Down Expand Up @@ -111,7 +111,7 @@ def lpt(cosmo,
delta_k2 = fft3d(delta2)
init_force2 = pm_forces(particles,
delta=delta_k2,
paint_particles=paint_particles,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
Expand Down Expand Up @@ -144,18 +144,20 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
return field


def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None):
def make_ode_fn(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):

def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
paint_particles = particles is not None

forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_particles=paint_particles,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m

Expand All @@ -165,6 +167,8 @@ def nbody_ode(state, a, cosmo):
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces

#dpos = dpos if not paint_absolute_pos else dpos + pos

return dpos, dvel

return nbody_ode
Expand Down

0 comments on commit e0c118a

Please sign in to comment.