Skip to content

Commit

Permalink
Parallelize parameter weight computation using PyTorch Distributed (#22)
Browse files Browse the repository at this point in the history
## Description
This PR introduces parallelization to the `create_parameter_weights.py`
script using PyTorch Distributed. The main changes include:

1. Added functions `get_rank()`, `get_world_size()`, `setup()`, and
`cleanup()` to initialize and manage the distributed process group.
- `get_rank()` retrieves the rank of the current process in the
distributed group.
- `get_world_size()` retrieves the total number of processes in the
distributed group.
- `setup()` initializes the distributed process group using NCCL (for
GPU) or gloo (for CPU) backend.
   - `cleanup()` destroys the distributed process group.

2. Modified the `main()` function to take `rank` and `world_size` as
arguments and set up the distributed environment.
   - The device is set based on the rank and available GPUs.
- The dataset is adjusted to ensure its size is divisible by
`(world_size * batch_size)` using the `adjust_dataset_size()` function.
- A `DistributedSampler` is used to partition the dataset among the
processes.

3. Parallelized the computation of means and squared values across the
dataset.
- Each process computes the means and squared values for its assigned
portion of the dataset.
- The results are gathered from all processes using
`dist.all_gather_object()`.
- The root process (rank 0) computes the final mean, standard deviation,
and flux statistics using the gathered results.

4. Parallelized the computation of one-step difference means and squared
values.
- Similar to step 3, each process computes the difference means and
squared values for its assigned portion of the dataset.
- The results are gathered from all processes using
`dist.all_gather_object()`.
- The final difference mean and standard deviation are computed using
the gathered results.

These changes enable the script to leverage multiple processes/GPUs to
speed up the computation of parameter weights, means, and standard
deviations. The dataset is partitioned among the processes, and the
results are gathered and aggregated by the root process.

To run the script in a distributed manner, it can be launched using
Slurm.

Please review the changes and provide any feedback or suggestions.

---------

Co-authored-by: Simon Adamov <[email protected]>
  • Loading branch information
sadamov and Simon Adamov authored Jun 3, 2024
1 parent e5400bb commit 743c07a
Show file tree
Hide file tree
Showing 2 changed files with 302 additions and 74 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
wandb
slurm_log*
saved_models
lightning_logs
data
graphs
*.sif
sweeps
test_*.sh
.vscode
*slurm*

### Python ###
# Byte-compiled / optimized / DLL files
Expand Down
Loading

0 comments on commit 743c07a

Please sign in to comment.