diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 4007355..510c002 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -117,10 +117,12 @@ def get_local_shape(mesh_shape, sharding): else: pdims = gpu_mesh.devices.shape return [ - mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], *mesh_shape[2:] + mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], + *mesh_shape[2:] ] -def zeros(mesh_shape , sharding): + +def zeros(mesh_shape, sharding): gpu_mesh = sharding.mesh if sharding is not None else None if not gpu_mesh is None and not (gpu_mesh.empty): local_mesh_shape = get_local_shape(mesh_shape, sharding) @@ -132,6 +134,7 @@ def zeros(mesh_shape , sharding): else: return jnp.zeros(mesh_shape) + def normal_field(mesh_shape, seed, sharding): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None diff --git a/jaxpm/growth.py b/jaxpm/growth.py index 8194b06..80e6698 100644 --- a/jaxpm/growth.py +++ b/jaxpm/growth.py @@ -588,4 +588,5 @@ def dGf2a(cosmo, a): f2p = cache['h2'] / cache['a'] * cache['g2'] f2p = interp(np.log(a), np.log(cache['a']), f2p) E_a = E(cosmo, a) - return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E_a * D2f) + return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) + + 3 * a**2 * E_a * D2f) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index b41f261..f8059c7 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -5,7 +5,7 @@ from jax.sharding import PartitionSpec as P from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d, - normal_field,zeros) + normal_field, zeros) from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, growth_rate, growth_rate_second) from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, @@ -29,8 +29,10 @@ def pm_forces(positions, mesh_shape = delta.shape if paint_particles: - paint_fn = lambda x: cic_paint( - zeros(mesh_shape,sharding), x , halo_size=halo_size, sharding=sharding) + paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding), + x, + halo_size=halo_size, + sharding=sharding) read_fn = lambda x: cic_read( x, positions, halo_size=halo_size, sharding=sharding) else: diff --git a/notebooks/03-MultiHost_PM.py b/notebooks/03-MultiHost_PM.py index 03c61f6..eb39ac8 100644 --- a/notebooks/03-MultiHost_PM.py +++ b/notebooks/03-MultiHost_PM.py @@ -33,7 +33,8 @@ mesh_shape = [512, 512, 512] box_size = [500., 500., 1000.] halo_size = 64 -snapshots = jnp.linspace(0.1,1.,2) +snapshots = jnp.linspace(0.1, 1., 2) + @jax.jit def run_simulation(omega_c, sigma8): @@ -89,9 +90,11 @@ def run_simulation(omega_c, sigma8): # Gather the results -pm_dict = {"initial_conditions": all_gather(initial_conditions), - "lpt_displacements": all_gather(lpt_displacements), - "solver_stats": solver_stats} +pm_dict = { + "initial_conditions": all_gather(initial_conditions), + "lpt_displacements": all_gather(lpt_displacements), + "solver_stats": solver_stats +} for i in range(len(ode_solutions)): sol = ode_solutions[i] diff --git a/notebooks/visualize.py b/notebooks/visualize.py index 136c297..e586c48 100644 --- a/notebooks/visualize.py +++ b/notebooks/visualize.py @@ -62,11 +62,9 @@ def plot_fields_single_projection(fields_dict, sum_over=None): slicing = tuple(slicing) # Sum projection over axis 0 and plot - axes[i].imshow( - field[slicing].sum(axis=0) + 1, - cmap='magma', - extent=[0, field.shape[1], 0, field.shape[2]] - ) + axes[i].imshow(field[slicing].sum(axis=0) + 1, + cmap='magma', + extent=[0, field.shape[1], 0, field.shape[2]]) axes[i].set_xlabel('Mpc/h') axes[i].set_ylabel('Mpc/h') axes[i].set_title(f"{name} projection 0")