diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index a0e9a38656..6f9010bebe 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -303,7 +303,9 @@ def sync_hook(*args): if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'): if 'device_mesh' in fsdp_config: device_mesh_size = len(fsdp_config['device_mesh']) - if sharding_strategy in [ShardingStrategy.FULL_SHARD, ShardingStrategy.NO_SHARD] and device_mesh_size != 1: + if sharding_strategy in [ + ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD + ] and device_mesh_size != 1: raise ValueError(f'FSDP sharding strategy {sharding_map_key.upper()} requires a device mesh ' f'of size 1 but got device mesh size of {device_mesh_size}.') elif sharding_strategy in [ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2 diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 76c59c0cc1..dfa4c1f3ee 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import contextlib from unittest.mock import MagicMock import pytest @@ -219,10 +220,19 @@ def test_fsdp_process_group(world_size: int): @pytest.mark.gpu @world_size(2) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.2.0'), reason='Device mesh requires Torch 2.2') -def test_wrong_size_device_mesh_error(world_size: int): - with pytest.raises(ValueError, match='.*requires a device mesh of size 1.*'): +@pytest.mark.parametrize('sharding_strategy', + ['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2']) +@pytest.mark.parametrize('device_mesh', [[2], [1, 2]]) +def test_wrong_size_device_mesh_error(world_size: int, sharding_strategy: str, device_mesh: list[int]): + context = contextlib.nullcontext() + if sharding_strategy in ['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD'] and len(device_mesh) != 1: + context = pytest.raises(ValueError, match='.*requires a device mesh of size 1.*') + if sharding_strategy in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] and len(device_mesh) != 2: + context = pytest.raises(ValueError, match='.*requires a device mesh of size 2.*') + with context: Trainer(model=SimpleModel(), fsdp_config={ - 'device_mesh': [1, 2], + 'sharding_strategy': sharding_strategy, + 'device_mesh': device_mesh, })