Skip to content

Commit

Permalink
Explicitly disable P2P using launch, and pick up in state if a us…
Browse files Browse the repository at this point in the history
…er will face issues. (#2195)

* Disable P2P automatically

* Clean

* Right check

* Set better

* Check if just cuda

* Spacing

* replace str int for int as str
  • Loading branch information
muellerzr authored Nov 29, 2023
1 parent b04d36c commit 0ba3e9b
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 1 deletion.
25 changes: 24 additions & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DistributedType,
PrepareForLaunch,
_filter_args,
check_cuda_p2p_ib_support,
is_bf16_available,
is_deepspeed_available,
is_npu_available,
Expand Down Expand Up @@ -642,6 +643,17 @@ def multi_gpu_launcher(args):
import torch.distributed.run as distrib_run

current_env = prepare_multi_gpu_env(args)
if not check_cuda_p2p_ib_support():
message = "Using RTX 3090 or 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
warn = False
if "NCCL_P2P_DISABLE" not in current_env:
current_env["NCCL_P2P_DISABLE"] = "1"
warn = True
if "NCCL_IB_DISABLE" not in current_env:
current_env["NCCL_IB_DISABLE"] = "1"
warn = True
if warn:
logger.warning(message)

debug = getattr(args, "debug", False)
args = _filter_args(
Expand All @@ -668,6 +680,17 @@ def deepspeed_launcher(args):
raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")

cmd, current_env = prepare_deepspeed_cmd_env(args)
if not check_cuda_p2p_ib_support():
message = "Using RTX 3090 or 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
warn = False
if "NCCL_P2P_DISABLE" not in current_env:
current_env["NCCL_P2P_DISABLE"] = "1"
warn = True
if "NCCL_IB_DISABLE" not in current_env:
current_env["NCCL_IB_DISABLE"] = "1"
warn = True
if warn:
logger.warning(message)

if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
with open(".deepspeed_env", "a") as f:
Expand Down Expand Up @@ -756,7 +779,7 @@ def tpu_pod_launcher(args):
"--tpu",
"--no_tpu_cluster",
"--num_machines",
str(1),
"1",
"--mixed_precision",
"no",
"--dynamo_backend",
Expand Down
16 changes: 16 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DistributedType,
DynamoBackend,
GradientAccumulationPlugin,
check_cuda_p2p_ib_support,
get_ccl_version,
get_int_from_env,
is_ccl_available,
Expand Down Expand Up @@ -181,6 +182,14 @@ def __init__(self, cpu: bool = False, **kwargs):
self.backend = "nccl"
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)

if not check_cuda_p2p_ib_support():
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
raise NotImplementedError(
"Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. "
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
"will do this automatically."
)

self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
Expand All @@ -206,6 +215,13 @@ def __init__(self, cpu: bool = False, **kwargs):
if self.backend is None:
self.backend = "nccl"
torch.distributed.init_process_group(backend=self.backend, **kwargs)
if not check_cuda_p2p_ib_support():
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
raise NotImplementedError(
"Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. "
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
"will do this automatically."
)
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from .environment import (
are_libraries_initialized,
check_cuda_p2p_ib_support,
get_int_from_env,
parse_choice_from_env,
parse_flag_from_env,
Expand Down
18 changes: 18 additions & 0 deletions src/accelerate/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import sys
from typing import Dict

import torch


def str_to_bool(value) -> int:
"""
Expand Down Expand Up @@ -57,3 +59,19 @@ def are_libraries_initialized(*library_names: str) -> Dict[str, bool]:
Checks if any of `library_names` are imported in the environment. Will return results as a `key:bool` pair.
"""
return [lib_name for lib_name in library_names if lib_name in sys.modules]


def check_cuda_p2p_ib_support():
"""
Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
the 3090.
"""
if torch.cuda.is_available():
# Get the first device/default
device_name = torch.cuda.get_device_name()
device_count = torch.cuda.device_count()
unsupported_devices = ["RTX 3090", "RTX 40"]
if device_count > 1:
if any(device in device_name for device in unsupported_devices):
return False
return True

0 comments on commit 0ba3e9b

Please sign in to comment.