From afa3734cc6db103042f250053b6f22b017478f85 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 5 Aug 2024 21:52:59 -0400 Subject: [PATCH] Revert "Use gloo as part of DeviceGPU's process group backend (#3509)" This reverts commit cccc8a7c02a8bab52263365cec1da87a7e2c9bdc. reverting --- composer/devices/device_gpu.py | 5 ----- tests/checkpoint/test_state_dict.py | 6 +----- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/composer/devices/device_gpu.py b/composer/devices/device_gpu.py index c17dee3a3a..19cb0a774a 100644 --- a/composer/devices/device_gpu.py +++ b/composer/devices/device_gpu.py @@ -12,9 +12,7 @@ import torch.backends.cudnn import torch.cuda import torch.cuda.amp -import torch.distributed as torch_dist import torch.utils.data -from packaging import version from composer.devices.device import Device from composer.utils import dist @@ -44,9 +42,6 @@ def __init__( ): if not torch.cuda.is_available(): raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.') - if torch_dist.is_gloo_available() and version.parse(torch.__version__) >= version.parse('2.3.0'): - # Composer checkpoint load / save from before torch 2.3.0 is not compatible with gloo + nccl backends. - DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo' if device_id is None: device_id = dist.get_local_rank() self._device = torch.device(f'cuda:{device_id}') diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 99d9146aae..5a316ee286 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -6,7 +6,6 @@ import pytest import torch -import torch.distributed as torch_dist from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -447,10 +446,7 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz assert 'model_name' in metadata_sd assert 'dist_backend' in metadata_sd - if torch_dist.is_gloo_available() and version.parse(torch.__version__) >= version.parse('2.3.0'): - assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo' - else: - assert metadata_sd['dist_backend'] == 'nccl' + assert metadata_sd['dist_backend'] == 'nccl' @pytest.mark.filterwarnings('ignore:SWA has')