Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check notebook launcher for 3090+ #2212

Merged
merged 8 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/accelerate/launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
import torch

from .state import AcceleratorState, PartialState
from .utils import PrecisionType, PrepareForLaunch, are_libraries_initialized, is_mps_available, patch_environment
from .utils import (
PrecisionType,
PrepareForLaunch,
are_libraries_initialized,
check_cuda_p2p_ib_support,
is_mps_available,
patch_environment,
)


def test_launch():
Expand Down Expand Up @@ -153,16 +160,23 @@ def train(*args):
err += f"\n\t* `{lib_name}`"
raise RuntimeError(err)

# torch.distributed will expect a few environment variable to be here. We set the ones common to each
# process here (the other ones will be set be the launcher).
with patch_environment(
patched_env = dict(
nproc=num_processes,
node_rank=node_rank,
world_size=num_nodes * num_processes,
master_addr=master_addr,
master_port=use_port,
mixed_precision=mixed_precision,
):
)

# Check for CUDA P2P and IB issues
if not check_cuda_p2p_ib_support():
patched_env["nccl_p2p_disable"] = "1"
patched_env["nccl_ib_disable"] = "1"

# torch.distributed will expect a few environment variable to be here. We set the ones common to each
# process here (the other ones will be set be the launcher).
with patch_environment(**patched_env):
# First dummy launch
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
Expand Down
15 changes: 7 additions & 8 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,6 @@ def __init__(self, cpu: bool = False, **kwargs):
self.backend = "nccl"
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)

if not check_cuda_p2p_ib_support():
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
raise NotImplementedError(
"Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. "
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
"will do this automatically."
)

self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
Expand All @@ -206,6 +198,13 @@ def __init__(self, cpu: bool = False, **kwargs):
self.device = torch.device("cuda", self.local_process_index)
if self.device is not None:
torch.cuda.set_device(self.device)
if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
raise NotImplementedError(
"Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. "
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
"will do this automatically."
)
self._mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu and torch.cuda.is_available():
self.distributed_type = DistributedType.MULTI_GPU
Expand Down
49 changes: 41 additions & 8 deletions src/accelerate/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

import os
import platform
import subprocess
import sys
from distutils import spawn
from typing import Dict

import torch


def str_to_bool(value) -> int:
"""
Expand Down Expand Up @@ -61,17 +62,49 @@ def are_libraries_initialized(*library_names: str) -> Dict[str, bool]:
return [lib_name for lib_name in library_names if lib_name in sys.modules]


def get_gpu_info():
"""
Gets GPU count and names using `nvidia-smi` instead of torch to not initialize CUDA.

Largely based on the `gputil` library.
"""
if platform.system() == "Windows":
# If platform is Windows and nvidia-smi can't be found in path
# try from systemd rive with default installation path
command = spawn.find_executable("nvidia-smi")
if command is None:
command = "%s\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe" % os.environ["systemdrive"]
else:
command = "nvidia-smi"
# Returns as list of `n` GPUs and their names
output = subprocess.check_output(
[command, "--query-gpu=count,name", "--format=csv,noheader"], universal_newlines=True
)
output = output.strip()
gpus = output.split(os.linesep)
# Get names from output
gpu_count = len(gpus)
gpu_names = [gpu.split(",")[1].strip() for gpu in gpus]
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
return gpu_names, gpu_count


def check_cuda_p2p_ib_support():
"""
Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
the 3090.

Noteably uses `nvidia-smi` instead of torch to not initialize CUDA.
"""
if torch.cuda.is_available():
# Get the first device/default
device_name = torch.cuda.get_device_name()
device_count = torch.cuda.device_count()
unsupported_devices = ["RTX 3090", "RTX 40"]
try:
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
device_names, device_count = get_gpu_info()
unsupported_devices = {"RTX 3090", "RTX 40"}
if device_count > 1:
if any(device in device_name for device in unsupported_devices):
if any(
unsupported_device in device_name
for device_name in device_names
for unsupported_device in unsupported_devices
):
return False
except Exception:
pass
return True
Loading