Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use gloo as part of DeviceGPU's process group backend #3509

Merged
merged 14 commits into from
Aug 5, 2024
5 changes: 5 additions & 0 deletions composer/devices/device_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import torch.backends.cudnn
import torch.cuda
import torch.cuda.amp
import torch.distributed as torch_dist
import torch.utils.data
from packaging import version

from composer.devices.device import Device
from composer.utils import dist
Expand Down Expand Up @@ -42,6 +44,9 @@ def __init__(
):
if not torch.cuda.is_available():
raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.')
if torch_dist.is_gloo_available() and version.parse(torch.__version__) >= version.parse('2.3.0'):
# Composer checkpoint load / save from before torch 2.3.0 is not compatible with gloo + nccl backends.
DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo'
if device_id is None:
device_id = dist.get_local_rank()
self._device = torch.device(f'cuda:{device_id}')
Expand Down
6 changes: 5 additions & 1 deletion tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
import torch.distributed as torch_dist
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

Expand Down Expand Up @@ -446,7 +447,10 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz
assert 'model_name' in metadata_sd

assert 'dist_backend' in metadata_sd
assert metadata_sd['dist_backend'] == 'nccl'
if torch_dist.is_gloo_available() and version.parse(torch.__version__) >= version.parse('2.3.0'):
assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo'
else:
assert metadata_sd['dist_backend'] == 'nccl'


@pytest.mark.filterwarnings('ignore:SWA has')
Expand Down
Loading