Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Parallelize parameter weight computation using PyTorch Distributed (#22)
## Description This PR introduces parallelization to the `create_parameter_weights.py` script using PyTorch Distributed. The main changes include: 1. Added functions `get_rank()`, `get_world_size()`, `setup()`, and `cleanup()` to initialize and manage the distributed process group. - `get_rank()` retrieves the rank of the current process in the distributed group. - `get_world_size()` retrieves the total number of processes in the distributed group. - `setup()` initializes the distributed process group using NCCL (for GPU) or gloo (for CPU) backend. - `cleanup()` destroys the distributed process group. 2. Modified the `main()` function to take `rank` and `world_size` as arguments and set up the distributed environment. - The device is set based on the rank and available GPUs. - The dataset is adjusted to ensure its size is divisible by `(world_size * batch_size)` using the `adjust_dataset_size()` function. - A `DistributedSampler` is used to partition the dataset among the processes. 3. Parallelized the computation of means and squared values across the dataset. - Each process computes the means and squared values for its assigned portion of the dataset. - The results are gathered from all processes using `dist.all_gather_object()`. - The root process (rank 0) computes the final mean, standard deviation, and flux statistics using the gathered results. 4. Parallelized the computation of one-step difference means and squared values. - Similar to step 3, each process computes the difference means and squared values for its assigned portion of the dataset. - The results are gathered from all processes using `dist.all_gather_object()`. - The final difference mean and standard deviation are computed using the gathered results. These changes enable the script to leverage multiple processes/GPUs to speed up the computation of parameter weights, means, and standard deviations. The dataset is partitioned among the processes, and the results are gathered and aggregated by the root process. To run the script in a distributed manner, it can be launched using Slurm. Please review the changes and provide any feedback or suggestions. --------- Co-authored-by: Simon Adamov <[email protected]>
- Loading branch information