From 9da9d81dcc7a3f26acd092bd2987d87656dc7212 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 29 Apr 2024 19:34:26 -0400 Subject: [PATCH] Fix torch 2.3 GPU tests (#3218) --- composer/core/state.py | 71 +++++++++++++--------- composer/trainer/mosaic_fsdp.py | 6 ++ composer/trainer/mosaic_fsdp_utils.py | 86 +++++++++++++++++++++++++++ composer/utils/checkpoint.py | 21 +++---- tests/trainer/test_fsdp.py | 12 ++-- tests/trainer/test_fsdp_checkpoint.py | 20 ++++--- 6 files changed, 157 insertions(+), 59 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 4ff658cf5b..28ec5df1e2 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -11,6 +11,7 @@ from collections import OrderedDict from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast +from unittest.mock import MagicMock import numpy as np import torch @@ -888,7 +889,7 @@ def get_model_state_dict(self) -> Dict[str, Any]: model=self.model, submodules=None, options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type != 'sharded', + full_state_dict=self.fsdp_state_dict_type == 'full', cpu_offload=True, ), ) @@ -928,7 +929,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]: optimizers=optimizer, submodules=None, options=StateDictOptions( - full_state_dict=self.fsdp_state_dict_type != 'sharded', + full_state_dict=self.fsdp_state_dict_type == 'full', cpu_offload=True, ), ) @@ -1238,7 +1239,11 @@ def load_model_state( set_model_state_dict( model=self.model, model_state_dict=state_dict['model'], - options=StateDictOptions(strict=strict, cpu_offload=True), + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + strict=strict, + cpu_offload=True, + ), ) else: missing_keys, unexpected_keys = [], [] @@ -1297,35 +1302,43 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True): strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer state dict should perfectly match the keys in the optimizer instance. """ - if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized(): - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict - optimizer = self.optimizers[0] - set_optimizer_state_dict( - model=self.model, - optimizers=optimizer, - optim_state_dict=state_dict['optimizers'].get(type(optimizer).__qualname__, {}), - options=StateDictOptions(strict=strict, cpu_offload=True), + serialized_value = state_dict['optimizers'] + for optimizer in ensure_tuple(self.optimizers): + # Broadcast compatibility check as monolith rank 0 only loads won't have optimizer on all ranks + skip_optimizer_load = 1 if serialized_value is not None and type( + optimizer, + ).__qualname__ not in serialized_value else 0 + skip_optimizer_load_tensor = self.device.tensor_to_device( + torch.tensor([skip_optimizer_load], dtype=torch.uint8), ) - else: - serialized_value = state_dict['optimizers'] - for optimizer in ensure_tuple(self.optimizers): - # Broadcast compatibility check as monolith rank 0 only loads won't have optimizer on all ranks - skip_optimizer_load = 1 if serialized_value is not None and type( - optimizer, - ).__qualname__ not in serialized_value else 0 - skip_optimizer_load_tensor = self.device.tensor_to_device( - torch.tensor([skip_optimizer_load], dtype=torch.uint8), + dist.all_reduce(skip_optimizer_load_tensor, reduce_operation='MAX') + if skip_optimizer_load_tensor.item() == 1: + warnings.warn( + f'{type(optimizer).__qualname__} is not in the state_dict. Its state will not be restored.', + category=UserWarning, ) - dist.all_reduce(skip_optimizer_load_tensor, reduce_operation='MAX') - if skip_optimizer_load_tensor.item() == 1: - warnings.warn( - f'{type(optimizer).__qualname__} is not in the state_dict. Its state will not be restored.', - category=UserWarning, - ) - continue + continue - optim_state_dict = serialized_value[type(optimizer).__qualname__ - ] if serialized_value is not None else None + optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None + if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized(): + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict + + # optim_state_dict is `None` on non-zero ranks when loading FSDP monolith + # checkpoint on rank 0 only. However, PyTorch modifies the state_dict (producing + # errors) before discarding the output. Accordingly, we mock the state dict. + # See: https://github.com/pytorch/pytorch/issues/125177 + optim_state_dict = MagicMock() if optim_state_dict is None else optim_state_dict + set_optimizer_state_dict( + model=self.model, + optimizers=optimizer, + optim_state_dict=optim_state_dict, + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type == 'full', + strict=strict, + cpu_offload=True, + ), + ) + else: if self.fsdp_enabled: assert self.fsdp_state_dict_type is not None # pyright log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}') diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 3271f627ae..97af3ddaa4 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -66,3 +66,9 @@ def patch_pytorch(): from composer.trainer.mosaic_fsdp_utils import _same_storage _flat_param._same_storage = _same_storage + + from torch.distributed.checkpoint import state_dict + + from composer.trainer.mosaic_fsdp_utils import set_model_state_dict, set_optimizer_state_dict + state_dict.set_model_state_dict = set_model_state_dict + state_dict.set_optimizer_state_dict = set_optimizer_state_dict diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 20658da6f0..8b90634516 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -565,3 +565,89 @@ def _same_storage(a, b): if isinstance(b, DTensor): b = b._local_tensor return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr() + + from torch.distributed.checkpoint.state_dict import _unflatten_model_state_dict, _verify_options, _load_model_state_dict, gc_context, _verify_state_dict, _load_optim_state_dict + + def set_model_state_dict( + model: nn.Module, + model_state_dict, + *, + options = None, + ): + """Load the model state_dict. + + The counterpart of ``get_model_state_dict`` to set the state_dict to the + model. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + model_state_dict: (Dict[str, ValueType]): + the model state_dict to load. If the key of the ``model_state_dict`` + is nn.Module, the key is a submodule of ``model`` and the value should + be the state_dict of the submodule. When loading the state_dict, + the prefix of the submodule will be append to the state_dict. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + :type model_state_dict: typing.Dict[str, ValueType] + """ + from torch.distributed.fsdp._runtime_utils import _lazy_init + for module in model.modules(): + if isinstance(module, FullyShardedDataParallel): + _lazy_init(module, module) + model_state_dict = _unflatten_model_state_dict( + model, model_state_dict, + ) + with gc_context(): + info = _verify_options(model, tuple(), optim_only=False, options=options) + + _verify_state_dict(model_state_dict, {}, info) + return _load_model_state_dict(model, model_state_dict, info) + + def set_optimizer_state_dict( + model: nn.Module, + optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], + *, + optim_state_dict, + options = None, + ) -> None: + """Load the optimizers state_dict. + + The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the + optimizers. See ``set_state_dict`` for the detail usage. + + Args: + model (nn.Module): the nn.Module to the model. + optimizers (Union[Optimizer, Iterable[Optimizer]]): + The optimizers that are used to optimize ``model``. + optim_state_dict: OptimizerStateType: + the optimizer state_dict to load. + options (StateDictOptions): the options to control how + model state_dict and optimizer state_dict should be loaded. See + `StateDictOptions` for the details. + + Returns: + None + + :type optim_state_dict: typing.OptimizerStateType + """ + from torch.distributed.fsdp._runtime_utils import _lazy_init + for module in model.modules(): + if isinstance(module, FullyShardedDataParallel): + _lazy_init(module, module) + with gc_context(): + optimizers = ( + (optimizers,) + if isinstance(optimizers, torch.optim.Optimizer) + else tuple(optimizers) + ) + info = _verify_options(model, optimizers, optim_only=True, options=options) + + _verify_state_dict({}, optim_state_dict, info) + _load_optim_state_dict(model, optimizers, optim_state_dict, info) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 97070e5b18..590b4f1619 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -643,19 +643,14 @@ def load_sharded_checkpoint( # Ensure state exists state_dict['state'] = state_dict.get('state', {}) - if version.parse(torch.__version__) >= version.parse('2.3.0'): - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=state.fsdp_config['load_planner'], - ) - else: - dist_cp.load_state_dict( - state_dict=state_dict, - storage_reader=storage_reader, - planner=state.fsdp_config['load_planner'], - no_dist=(not dist.is_initialized()), - ) + # dist_cp.load breaks unless the specified state_dict supports `load_state_dict` + # See: https://github.com/pytorch/pytorch/issues/125096 + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=storage_reader, + planner=state.fsdp_config['load_planner'], + no_dist=(not dist.is_initialized()), + ) log.info(f'Loaded state dict') state.load_state_dict( diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index d8b2b053b0..b766b4d1ed 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -95,10 +95,6 @@ def test_fsdp_device_initialization( @pytest.mark.parametrize('device', _INIT_DEVICES) @world_size(2) @pytest.mark.gpu -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.1.0'), - reason='This has only been fixed and tested starting with torch 2.1.0', -) def test_fsdp_inits_params_once(model: ComposerClassifier, device: str, world_size: int, expected_param_inits: int): resolved_device = device if device == 'mixed': @@ -132,7 +128,7 @@ def dummy_param_init_fn(module: torch.nn.Module): train_dataloader=dataloader, fsdp_config={ 'mixed_precision': 'PURE', - 'sharding_strategy': 'NO_SHARD', + 'sharding_strategy': 'SHARD_GRAD_OP', 'sync_module_states': True if device == 'mixed' else False, }, max_duration='3ba', @@ -173,7 +169,7 @@ def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precisio train_dataloader=dataloader, fsdp_config={ 'mixed_precision': mixed_precision, - 'sharding_strategy': 'NO_SHARD', + 'sharding_strategy': 'SHARD_GRAD_OP', }, max_duration='3ba', ) @@ -235,12 +231,12 @@ def test_fsdp_process_group(world_size: int): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.2.0'), reason='Device mesh requires Torch 2.2') @pytest.mark.parametrize( 'sharding_strategy', - ['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'], + ['SHARD_GRAD_OP', 'FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'], ) @pytest.mark.parametrize('device_mesh', [[2], [1, 2]]) def test_wrong_size_device_mesh_error(world_size: int, sharding_strategy: str, device_mesh: list[int]): context = contextlib.nullcontext() - if sharding_strategy in ['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD'] and len(device_mesh) != 1: + if sharding_strategy in ['SHARD_GRAD_OP', 'FULL_SHARD'] and len(device_mesh) != 1: context = pytest.raises(ValueError, match='.*requires a device mesh of size 1.*') if sharding_strategy in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] and len(device_mesh) != 2: context = pytest.raises(ValueError, match='.*requires a device mesh of size 2.*') diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 2fa86df8c0..672cf803da 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -18,6 +18,7 @@ import pytest import torch from packaging import version +from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.utils.data import DataLoader from torchmetrics import Metric, MetricCollection from torchmetrics.classification import MulticlassAccuracy @@ -184,11 +185,11 @@ def _compare_optims_between_state_dicts(state_dict1, state_dict2): for moment_name in state_dict2_param_moment_dict.keys(): state_dict1_moment = state_dict1_param_moment_dict[moment_name].cpu() state_dict2_moment = state_dict2_param_moment_dict[moment_name].cpu() - assert torch.equal(state_dict1_moment, state_dict2_moment), ( - f'Moment {moment_name} for parameter {param_name} not the same ' - 'between state dicts,\n\t{state_dict1_moment}\n\t' - '{state_dict2_moment}' - ) + if isinstance(state_dict1_moment, ShardedTensor): + state_dict1_moment = state_dict1_moment.local_tensor() + if isinstance(state_dict2_moment, ShardedTensor): + state_dict2_moment = state_dict2_moment.local_tensor() + torch.testing.assert_close(state_dict1_moment, state_dict2_moment) def _compare_model_params_between_state_dicts(state_dict1, state_dict2): @@ -207,10 +208,11 @@ def _compare_model_params_between_state_dicts(state_dict1, state_dict2): for param_name in state_dict2_model_params.keys(): state_dict1_model_tensor = state_dict1_model_params[param_name].cpu() state_dict2_model_tensor = state_dict2_model_params[param_name].cpu() - assert torch.equal( - state_dict1_model_tensor, - state_dict2_model_tensor, - ), f'Weight named {param_name} not the same between state_dicts' + if isinstance(state_dict1_model_tensor, ShardedTensor): + state_dict1_model_tensor = state_dict1_model_tensor.local_tensor() + if isinstance(state_dict2_model_tensor, ShardedTensor): + state_dict2_model_tensor = state_dict2_model_tensor.local_tensor() + torch.testing.assert_close(state_dict1_model_tensor, state_dict2_model_tensor) def _compare_rng_states_between_trainers(rng_state1, rng_state2):