From 0ba3e9bb50589acc35aadc4db792e617070bdd01 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 29 Nov 2023 12:10:01 -0500 Subject: [PATCH] Explicitly disable P2P using `launch`, and pick up in `state` if a user will face issues. (#2195) * Disable P2P automatically * Clean * Right check * Set better * Check if just cuda * Spacing * replace str int for int as str --- src/accelerate/commands/launch.py | 25 ++++++++++++++++++++++++- src/accelerate/state.py | 16 ++++++++++++++++ src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/environment.py | 18 ++++++++++++++++++ 4 files changed, 59 insertions(+), 1 deletion(-) diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 8e44919b23d..66bbea21707 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -34,6 +34,7 @@ DistributedType, PrepareForLaunch, _filter_args, + check_cuda_p2p_ib_support, is_bf16_available, is_deepspeed_available, is_npu_available, @@ -642,6 +643,17 @@ def multi_gpu_launcher(args): import torch.distributed.run as distrib_run current_env = prepare_multi_gpu_env(args) + if not check_cuda_p2p_ib_support(): + message = "Using RTX 3090 or 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled." + warn = False + if "NCCL_P2P_DISABLE" not in current_env: + current_env["NCCL_P2P_DISABLE"] = "1" + warn = True + if "NCCL_IB_DISABLE" not in current_env: + current_env["NCCL_IB_DISABLE"] = "1" + warn = True + if warn: + logger.warning(message) debug = getattr(args, "debug", False) args = _filter_args( @@ -668,6 +680,17 @@ def deepspeed_launcher(args): raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.") cmd, current_env = prepare_deepspeed_cmd_env(args) + if not check_cuda_p2p_ib_support(): + message = "Using RTX 3090 or 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled." + warn = False + if "NCCL_P2P_DISABLE" not in current_env: + current_env["NCCL_P2P_DISABLE"] = "1" + warn = True + if "NCCL_IB_DISABLE" not in current_env: + current_env["NCCL_IB_DISABLE"] = "1" + warn = True + if warn: + logger.warning(message) if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: with open(".deepspeed_env", "a") as f: @@ -756,7 +779,7 @@ def tpu_pod_launcher(args): "--tpu", "--no_tpu_cluster", "--num_machines", - str(1), + "1", "--mixed_precision", "no", "--dynamo_backend", diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 80f291fadb2..889a816ba47 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -28,6 +28,7 @@ DistributedType, DynamoBackend, GradientAccumulationPlugin, + check_cuda_p2p_ib_support, get_ccl_version, get_int_from_env, is_ccl_available, @@ -181,6 +182,14 @@ def __init__(self, cpu: bool = False, **kwargs): self.backend = "nccl" dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs) + if not check_cuda_p2p_ib_support(): + if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ: + raise NotImplementedError( + "Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. " + 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which ' + "will do this automatically." + ) + self.num_processes = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) @@ -206,6 +215,13 @@ def __init__(self, cpu: bool = False, **kwargs): if self.backend is None: self.backend = "nccl" torch.distributed.init_process_group(backend=self.backend, **kwargs) + if not check_cuda_p2p_ib_support(): + if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ: + raise NotImplementedError( + "Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. " + 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which ' + "will do this automatically." + ) self.num_processes = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 702d9697acd..83bb19502e5 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -39,6 +39,7 @@ ) from .environment import ( are_libraries_initialized, + check_cuda_p2p_ib_support, get_int_from_env, parse_choice_from_env, parse_flag_from_env, diff --git a/src/accelerate/utils/environment.py b/src/accelerate/utils/environment.py index cff6e73f380..99e153bf04e 100644 --- a/src/accelerate/utils/environment.py +++ b/src/accelerate/utils/environment.py @@ -16,6 +16,8 @@ import sys from typing import Dict +import torch + def str_to_bool(value) -> int: """ @@ -57,3 +59,19 @@ def are_libraries_initialized(*library_names: str) -> Dict[str, bool]: Checks if any of `library_names` are imported in the environment. Will return results as a `key:bool` pair. """ return [lib_name for lib_name in library_names if lib_name in sys.modules] + + +def check_cuda_p2p_ib_support(): + """ + Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after + the 3090. + """ + if torch.cuda.is_available(): + # Get the first device/default + device_name = torch.cuda.get_device_name() + device_count = torch.cuda.device_count() + unsupported_devices = ["RTX 3090", "RTX 40"] + if device_count > 1: + if any(device in device_name for device in unsupported_devices): + return False + return True