Skip to content

Commit

Permalink
update notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Oct 27, 2024
1 parent b4fdb74 commit c93894f
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 169 deletions.
247 changes: 180 additions & 67 deletions notebooks/01-Introduction.ipynb

Large diffs are not rendered by default.

102 changes: 45 additions & 57 deletions notebooks/02-MultiGPU_PM.ipynb

Large diffs are not rendered by default.

94 changes: 64 additions & 30 deletions notebooks/03-MultiHost_PM.ipynb

Large diffs are not rendered by default.

29 changes: 14 additions & 15 deletions notebooks/03-MultiHost_PM.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))

mesh_shape = [2024, 1024, 1024]
box_size = [1024., 1024., 1024.]
halo_size = 512
snapshots = jnp.linspace(0.1, 1., 2)

mesh_shape = [512, 512, 512]
box_size = [500., 500., 1000.]
halo_size = 64
snapshots = jnp.linspace(0.1,1.,2)

@jax.jit
def run_simulation(omega_c, sigma8):
Expand All @@ -59,8 +58,7 @@ def run_simulation(omega_c, sigma8):
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
0.1,
a=0.1,
halo_size=halo_size,
sharding=sharding)

Expand Down Expand Up @@ -90,15 +88,16 @@ def run_simulation(omega_c, sigma8):
print(f"[{rank}] Solver stats: {solver_stats}")

# Gather the results
initial_conditions = all_gather(initial_conditions)
lpt_displacements = all_gather(lpt_displacements)
ode_solutions = [all_gather(sol) for sol in ode_solutions]

pm_dict = {"initial_conditions": all_gather(initial_conditions),
"lpt_displacements": all_gather(lpt_displacements),
"solver_stats": solver_stats}

for i in range(len(ode_solutions)):
sol = ode_solutions[i]
pm_dict[f"ode_solution_{i}"] = all_gather(sol)

if rank == 0:
np.savez("multihost_pm.npz",
initial_conditions=initial_conditions,
lpt_displacements=lpt_displacements,
ode_solutions=ode_solutions,
solver_stats=solver_stats)
np.savez("multihost_pm.npz", **pm_dict)

print(f"[{rank}] Simulation results saved")
33 changes: 33 additions & 0 deletions notebooks/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,36 @@ def plot_subplots(proj_axis, field, row, title):

plt.tight_layout()
plt.show()


def plot_fields_single_projection(fields_dict, sum_over=None):
"""
Plots a single projection (along axis 0) of 3D fields in one row,
summing over the first `sum_over` elements along the 0-axis.
Args:
- fields_dict: dictionary where keys are field names and values are 3D arrays
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_cols = len(fields_dict)
fig, axes = plt.subplots(1, nb_cols, figsize=(5 * nb_cols, 5))

for i, (name, field) in enumerate(fields_dict.items()):
# Define the slice for the 0-axis projection
slicing = [slice(None)] * field.ndim
slicing[0] = slice(None, sum_over)
slicing = tuple(slicing)

# Sum projection over axis 0 and plot
axes[i].imshow(
field[slicing].sum(axis=0) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]]
)
axes[i].set_xlabel('Mpc/h')
axes[i].set_ylabel('Mpc/h')
axes[i].set_title(f"{name} projection 0")

plt.tight_layout()
plt.show()

0 comments on commit c93894f

Please sign in to comment.