-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Former-commit-id: da6fdea
- Loading branch information
Showing
7 changed files
with
2,418 additions
and
14 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# **Animating Particle Mesh density fields**\n", | ||
"\n", | ||
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n", | ||
"\n", | ||
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n", | ||
"\n", | ||
"just like in notebook [**05-MultiHost_PM.ipynb**]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**A few hours later**\n", | ||
"\n", | ||
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n", | ||
"\n", | ||
"In this example, we’ve been allocated **32 GPUs split across 4 nodes**.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!squeue -u $USER -o \"%i %D %b\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"del os.environ['VSCODE_PROXY_URI']\n", | ||
"del os.environ['NO_PROXY']\n", | ||
"del os.environ['no_proxy']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Checking Available Compute Resources\n", | ||
"\n", | ||
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Multi-Host Simulation Script with Arguments (reminder)\n", | ||
"\n", | ||
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:\n", | ||
"\n", | ||
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n", | ||
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n", | ||
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n", | ||
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n", | ||
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n", | ||
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n", | ||
"\n", | ||
"### Running the Multi-Host Simulation Script\n", | ||
"\n", | ||
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n", | ||
"\n", | ||
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n", | ||
"\n", | ||
"The command to run the multi-host simulation with these settings will look something like this:\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import subprocess\n", | ||
"\n", | ||
"# Define parameters as variables\n", | ||
"jobid = \"467745\"\n", | ||
"num_processes = 32\n", | ||
"script_name = \"05-MultiHost_PM.py\"\n", | ||
"mesh_shape = (1024, 1024, 1024)\n", | ||
"box_size = (1000., 1000., 1000.)\n", | ||
"halo_size = 128\n", | ||
"solver = \"leapfrog\"\n", | ||
"pdims = (16, 2)\n", | ||
"snapshots = 8\n", | ||
"\n", | ||
"# Build the command as a list, incorporating variables\n", | ||
"command = [\n", | ||
" \"srun\",\n", | ||
" f\"--jobid={jobid}\",\n", | ||
" \"-n\", str(num_processes),\n", | ||
" \"python\", script_name,\n", | ||
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n", | ||
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n", | ||
" \"--halo_size\", str(halo_size),\n", | ||
" \"-s\", solver,\n", | ||
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n", | ||
" \"--snapshots\", str(snapshots)\n", | ||
"]\n", | ||
"\n", | ||
"# Execute the command as a subprocess\n", | ||
"subprocess.run(command)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Projecting the 3D Density Fields to 2D\n", | ||
"\n", | ||
"To visualize the 3D density fields in 2D, we need to create a projection:\n", | ||
"\n", | ||
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n", | ||
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n", | ||
"\n", | ||
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n", | ||
"\n", | ||
"### Applying the Magma Colormap\n", | ||
"\n", | ||
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n", | ||
"\n", | ||
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n", | ||
" - First, normalize the array to the `[0, 1]` range.\n", | ||
" - Apply the colormap to create RGB images, which will be used for the animation.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from matplotlib import colormaps\n", | ||
"\n", | ||
"# Define a function to project the 3D field to 2D\n", | ||
"def project_to_2d(field):\n", | ||
" sum_over = field.shape[0] // 8\n", | ||
" slicing = [slice(None)] * field.ndim\n", | ||
" slicing[0] = slice(None, sum_over)\n", | ||
" slicing = tuple(slicing)\n", | ||
"\n", | ||
" return field[slicing].sum(axis=0)\n", | ||
"\n", | ||
"\n", | ||
"def apply_colormap(array, cmap_name=\"magma\"):\n", | ||
" cmap = colormaps[cmap_name]\n", | ||
" normalized_array = (array - array.min()) / (array.max() - array.min())\n", | ||
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n", | ||
" return (colored_image * 255).astype(np.uint8)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Loading and Visualizing Results\n", | ||
"\n", | ||
"After running the multi-host simulation, we load the saved results from disk:\n", | ||
"\n", | ||
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n", | ||
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n", | ||
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n", | ||
"\n", | ||
"We will now project the fields to 2D maps and apply the color map\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n", | ||
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n", | ||
"ode_solutions = []\n", | ||
"for i in range(8):\n", | ||
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Animating with Manim\n", | ||
"\n", | ||
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from manim import *\n", | ||
"config.media_width = \"100%\"\n", | ||
"config.verbosity = \"WARNING\"\n", | ||
"config.background_color = \"#00000000\" # Transparent background" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Defining the Animation in Manim\n", | ||
"\n", | ||
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n", | ||
"\n", | ||
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n", | ||
"- **Animation Sequence**:\n", | ||
" - The animation begins with a fade-in of the initial conditions.\n", | ||
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n", | ||
"\n", | ||
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define the animation in Manim\n", | ||
"class FieldTransition(Scene):\n", | ||
" def construct(self):\n", | ||
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n", | ||
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n", | ||
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n", | ||
"\n", | ||
"\n", | ||
" # Place the images on top of each other initially\n", | ||
" lpt_img.move_to(init_conditions_img)\n", | ||
" for img in snapshots_imgs:\n", | ||
" img.move_to(init_conditions_img)\n", | ||
"\n", | ||
" # Show initial field and then transform between fields\n", | ||
" self.play(FadeIn(init_conditions_img))\n", | ||
" self.wait(0.2)\n", | ||
" self.play(Transform(init_conditions_img, lpt_img))\n", | ||
" self.wait(0.2)\n", | ||
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n", | ||
" self.wait(0.2)\n", | ||
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n", | ||
" self.play(Transform(img1, img2))\n", | ||
" self.wait(0.2)\n", | ||
"\n", | ||
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition " | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |