Skip to content

Commit

Permalink
apply formating
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Oct 27, 2024
1 parent c93894f commit 19011d0
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 15 deletions.
7 changes: 5 additions & 2 deletions jaxpm/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,12 @@ def get_local_shape(mesh_shape, sharding):
else:
pdims = gpu_mesh.devices.shape
return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], *mesh_shape[2:]
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
*mesh_shape[2:]
]

def zeros(mesh_shape , sharding):

def zeros(mesh_shape, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
Expand All @@ -132,6 +134,7 @@ def zeros(mesh_shape , sharding):
else:
return jnp.zeros(mesh_shape)


def normal_field(mesh_shape, seed, sharding):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
Expand Down
3 changes: 2 additions & 1 deletion jaxpm/growth.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,4 +588,5 @@ def dGf2a(cosmo, a):
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E_a = E(cosmo, a)
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E_a * D2f)
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E_a * D2f)
8 changes: 5 additions & 3 deletions jaxpm/pm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax.sharding import PartitionSpec as P

from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
normal_field,zeros)
normal_field, zeros)
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
Expand All @@ -29,8 +29,10 @@ def pm_forces(positions,
mesh_shape = delta.shape

if paint_particles:
paint_fn = lambda x: cic_paint(
zeros(mesh_shape,sharding), x , halo_size=halo_size, sharding=sharding)
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
x,
halo_size=halo_size,
sharding=sharding)
read_fn = lambda x: cic_read(
x, positions, halo_size=halo_size, sharding=sharding)
else:
Expand Down
11 changes: 7 additions & 4 deletions notebooks/03-MultiHost_PM.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
mesh_shape = [512, 512, 512]
box_size = [500., 500., 1000.]
halo_size = 64
snapshots = jnp.linspace(0.1,1.,2)
snapshots = jnp.linspace(0.1, 1., 2)


@jax.jit
def run_simulation(omega_c, sigma8):
Expand Down Expand Up @@ -89,9 +90,11 @@ def run_simulation(omega_c, sigma8):

# Gather the results

pm_dict = {"initial_conditions": all_gather(initial_conditions),
"lpt_displacements": all_gather(lpt_displacements),
"solver_stats": solver_stats}
pm_dict = {
"initial_conditions": all_gather(initial_conditions),
"lpt_displacements": all_gather(lpt_displacements),
"solver_stats": solver_stats
}

for i in range(len(ode_solutions)):
sol = ode_solutions[i]
Expand Down
8 changes: 3 additions & 5 deletions notebooks/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,9 @@ def plot_fields_single_projection(fields_dict, sum_over=None):
slicing = tuple(slicing)

# Sum projection over axis 0 and plot
axes[i].imshow(
field[slicing].sum(axis=0) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]]
)
axes[i].imshow(field[slicing].sum(axis=0) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]])
axes[i].set_xlabel('Mpc/h')
axes[i].set_ylabel('Mpc/h')
axes[i].set_title(f"{name} projection 0")
Expand Down

0 comments on commit 19011d0

Please sign in to comment.