-
-
Notifications
You must be signed in to change notification settings - Fork 894
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fixes NCCL_P2P_LEVEL=NVL #429 * adding more insights into verious values of NCCL_P2P_LEVEL
- Loading branch information
1 parent
e30f1e3
commit 5e2d8a4
Showing
2 changed files
with
50 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# NCCL | ||
|
||
NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort: | ||
|
||
```text | ||
Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out. | ||
``` | ||
|
||
Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you. | ||
|
||
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command: | ||
|
||
```shell | ||
nvidia-smi nvlink --status | ||
``` | ||
|
||
To force NCCL to use NVLink, simply set this in the environment: | ||
|
||
```shell | ||
export NCCL_P2P_LEVEL=NVL | ||
``` | ||
|
||
If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below: | ||
|
||
| NCCL_P2P_LEVEL | Description | | ||
| -------------- | ----------- | | ||
| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. | | ||
| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. | | ||
| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) | | ||
|
||
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example: | ||
|
||
```shell | ||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3 | ||
``` | ||
|
||
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL: | ||
|
||
```shell | ||
export NCCL_DEBUG=INFO | ||
export NCCL_DEBUG_SUBSYS=ALL | ||
export TORCH_DISTRIBUTED_DEBUG=INFO | ||
export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log | ||
``` | ||
|
||
Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value. |