Skip to content

Commit

Permalink
Fix torch 2.3 GPU tests (#3218)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Apr 29, 2024
1 parent ba84d89 commit 9da9d81
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 59 deletions.
71 changes: 42 additions & 29 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections import OrderedDict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast
from unittest.mock import MagicMock

import numpy as np
import torch
Expand Down Expand Up @@ -888,7 +889,7 @@ def get_model_state_dict(self) -> Dict[str, Any]:
model=self.model,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type != 'sharded',
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=True,
),
)
Expand Down Expand Up @@ -928,7 +929,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]:
optimizers=optimizer,
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type != 'sharded',
full_state_dict=self.fsdp_state_dict_type == 'full',
cpu_offload=True,
),
)
Expand Down Expand Up @@ -1238,7 +1239,11 @@ def load_model_state(
set_model_state_dict(
model=self.model,
model_state_dict=state_dict['model'],
options=StateDictOptions(strict=strict, cpu_offload=True),
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=True,
),
)
else:
missing_keys, unexpected_keys = [], []
Expand Down Expand Up @@ -1297,35 +1302,43 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True):
strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer
state dict should perfectly match the keys in the optimizer instance.
"""
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict
optimizer = self.optimizers[0]
set_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
optim_state_dict=state_dict['optimizers'].get(type(optimizer).__qualname__, {}),
options=StateDictOptions(strict=strict, cpu_offload=True),
serialized_value = state_dict['optimizers']
for optimizer in ensure_tuple(self.optimizers):
# Broadcast compatibility check as monolith rank 0 only loads won't have optimizer on all ranks
skip_optimizer_load = 1 if serialized_value is not None and type(
optimizer,
).__qualname__ not in serialized_value else 0
skip_optimizer_load_tensor = self.device.tensor_to_device(
torch.tensor([skip_optimizer_load], dtype=torch.uint8),
)
else:
serialized_value = state_dict['optimizers']
for optimizer in ensure_tuple(self.optimizers):
# Broadcast compatibility check as monolith rank 0 only loads won't have optimizer on all ranks
skip_optimizer_load = 1 if serialized_value is not None and type(
optimizer,
).__qualname__ not in serialized_value else 0
skip_optimizer_load_tensor = self.device.tensor_to_device(
torch.tensor([skip_optimizer_load], dtype=torch.uint8),
dist.all_reduce(skip_optimizer_load_tensor, reduce_operation='MAX')
if skip_optimizer_load_tensor.item() == 1:
warnings.warn(
f'{type(optimizer).__qualname__} is not in the state_dict. Its state will not be restored.',
category=UserWarning,
)
dist.all_reduce(skip_optimizer_load_tensor, reduce_operation='MAX')
if skip_optimizer_load_tensor.item() == 1:
warnings.warn(
f'{type(optimizer).__qualname__} is not in the state_dict. Its state will not be restored.',
category=UserWarning,
)
continue
continue

optim_state_dict = serialized_value[type(optimizer).__qualname__
] if serialized_value is not None else None
optim_state_dict = serialized_value[type(optimizer).__qualname__] if serialized_value is not None else None
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict

# optim_state_dict is `None` on non-zero ranks when loading FSDP monolith
# checkpoint on rank 0 only. However, PyTorch modifies the state_dict (producing
# errors) before discarding the output. Accordingly, we mock the state dict.
# See: https://github.com/pytorch/pytorch/issues/125177
optim_state_dict = MagicMock() if optim_state_dict is None else optim_state_dict
set_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
optim_state_dict=optim_state_dict,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type == 'full',
strict=strict,
cpu_offload=True,
),
)
else:
if self.fsdp_enabled:
assert self.fsdp_state_dict_type is not None # pyright
log.debug(f'Loading FSDP optimizer with fsdp_state_dict_type={self.fsdp_state_dict_type}')
Expand Down
6 changes: 6 additions & 0 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,9 @@ def patch_pytorch():

from composer.trainer.mosaic_fsdp_utils import _same_storage
_flat_param._same_storage = _same_storage

from torch.distributed.checkpoint import state_dict

from composer.trainer.mosaic_fsdp_utils import set_model_state_dict, set_optimizer_state_dict
state_dict.set_model_state_dict = set_model_state_dict
state_dict.set_optimizer_state_dict = set_optimizer_state_dict
86 changes: 86 additions & 0 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,89 @@ def _same_storage(a, b):
if isinstance(b, DTensor):
b = b._local_tensor
return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()

from torch.distributed.checkpoint.state_dict import _unflatten_model_state_dict, _verify_options, _load_model_state_dict, gc_context, _verify_state_dict, _load_optim_state_dict

def set_model_state_dict(
model: nn.Module,
model_state_dict,
*,
options = None,
):
"""Load the model state_dict.
The counterpart of ``get_model_state_dict`` to set the state_dict to the
model. See ``set_state_dict`` for the detail usage.
Args:
model (nn.Module): the nn.Module to the model.
model_state_dict: (Dict[str, ValueType]):
the model state_dict to load. If the key of the ``model_state_dict``
is nn.Module, the key is a submodule of ``model`` and the value should
be the state_dict of the submodule. When loading the state_dict,
the prefix of the submodule will be append to the state_dict.
options (StateDictOptions): the options to control how
model state_dict and optimizer state_dict should be loaded. See
`StateDictOptions` for the details.
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
:type model_state_dict: typing.Dict[str, ValueType]
"""
from torch.distributed.fsdp._runtime_utils import _lazy_init
for module in model.modules():
if isinstance(module, FullyShardedDataParallel):
_lazy_init(module, module)
model_state_dict = _unflatten_model_state_dict(
model, model_state_dict,
)
with gc_context():
info = _verify_options(model, tuple(), optim_only=False, options=options)

_verify_state_dict(model_state_dict, {}, info)
return _load_model_state_dict(model, model_state_dict, info)

def set_optimizer_state_dict(
model: nn.Module,
optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
*,
optim_state_dict,
options = None,
) -> None:
"""Load the optimizers state_dict.
The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the
optimizers. See ``set_state_dict`` for the detail usage.
Args:
model (nn.Module): the nn.Module to the model.
optimizers (Union[Optimizer, Iterable[Optimizer]]):
The optimizers that are used to optimize ``model``.
optim_state_dict: OptimizerStateType:
the optimizer state_dict to load.
options (StateDictOptions): the options to control how
model state_dict and optimizer state_dict should be loaded. See
`StateDictOptions` for the details.
Returns:
None
:type optim_state_dict: typing.OptimizerStateType
"""
from torch.distributed.fsdp._runtime_utils import _lazy_init
for module in model.modules():
if isinstance(module, FullyShardedDataParallel):
_lazy_init(module, module)
with gc_context():
optimizers = (
(optimizers,)
if isinstance(optimizers, torch.optim.Optimizer)
else tuple(optimizers)
)
info = _verify_options(model, optimizers, optim_only=True, options=options)

_verify_state_dict({}, optim_state_dict, info)
_load_optim_state_dict(model, optimizers, optim_state_dict, info)
21 changes: 8 additions & 13 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,19 +643,14 @@ def load_sharded_checkpoint(
# Ensure state exists
state_dict['state'] = state_dict.get('state', {})

if version.parse(torch.__version__) >= version.parse('2.3.0'):
dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=state.fsdp_config['load_planner'],
)
else:
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=state.fsdp_config['load_planner'],
no_dist=(not dist.is_initialized()),
)
# dist_cp.load breaks unless the specified state_dict supports `load_state_dict`
# See: https://github.com/pytorch/pytorch/issues/125096
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=state.fsdp_config['load_planner'],
no_dist=(not dist.is_initialized()),
)

log.info(f'Loaded state dict')
state.load_state_dict(
Expand Down
12 changes: 4 additions & 8 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ def test_fsdp_device_initialization(
@pytest.mark.parametrize('device', _INIT_DEVICES)
@world_size(2)
@pytest.mark.gpu
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse('2.1.0'),
reason='This has only been fixed and tested starting with torch 2.1.0',
)
def test_fsdp_inits_params_once(model: ComposerClassifier, device: str, world_size: int, expected_param_inits: int):
resolved_device = device
if device == 'mixed':
Expand Down Expand Up @@ -132,7 +128,7 @@ def dummy_param_init_fn(module: torch.nn.Module):
train_dataloader=dataloader,
fsdp_config={
'mixed_precision': 'PURE',
'sharding_strategy': 'NO_SHARD',
'sharding_strategy': 'SHARD_GRAD_OP',
'sync_module_states': True if device == 'mixed' else False,
},
max_duration='3ba',
Expand Down Expand Up @@ -173,7 +169,7 @@ def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precisio
train_dataloader=dataloader,
fsdp_config={
'mixed_precision': mixed_precision,
'sharding_strategy': 'NO_SHARD',
'sharding_strategy': 'SHARD_GRAD_OP',
},
max_duration='3ba',
)
Expand Down Expand Up @@ -235,12 +231,12 @@ def test_fsdp_process_group(world_size: int):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.2.0'), reason='Device mesh requires Torch 2.2')
@pytest.mark.parametrize(
'sharding_strategy',
['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'],
['SHARD_GRAD_OP', 'FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'],
)
@pytest.mark.parametrize('device_mesh', [[2], [1, 2]])
def test_wrong_size_device_mesh_error(world_size: int, sharding_strategy: str, device_mesh: list[int]):
context = contextlib.nullcontext()
if sharding_strategy in ['NO_SHARD', 'SHARD_GRAD_OP', 'FULL_SHARD'] and len(device_mesh) != 1:
if sharding_strategy in ['SHARD_GRAD_OP', 'FULL_SHARD'] and len(device_mesh) != 1:
context = pytest.raises(ValueError, match='.*requires a device mesh of size 1.*')
if sharding_strategy in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] and len(device_mesh) != 2:
context = pytest.raises(ValueError, match='.*requires a device mesh of size 2.*')
Expand Down
20 changes: 11 additions & 9 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
import torch
from packaging import version
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.utils.data import DataLoader
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import MulticlassAccuracy
Expand Down Expand Up @@ -184,11 +185,11 @@ def _compare_optims_between_state_dicts(state_dict1, state_dict2):
for moment_name in state_dict2_param_moment_dict.keys():
state_dict1_moment = state_dict1_param_moment_dict[moment_name].cpu()
state_dict2_moment = state_dict2_param_moment_dict[moment_name].cpu()
assert torch.equal(state_dict1_moment, state_dict2_moment), (
f'Moment {moment_name} for parameter {param_name} not the same '
'between state dicts,\n\t{state_dict1_moment}\n\t'
'{state_dict2_moment}'
)
if isinstance(state_dict1_moment, ShardedTensor):
state_dict1_moment = state_dict1_moment.local_tensor()
if isinstance(state_dict2_moment, ShardedTensor):
state_dict2_moment = state_dict2_moment.local_tensor()
torch.testing.assert_close(state_dict1_moment, state_dict2_moment)


def _compare_model_params_between_state_dicts(state_dict1, state_dict2):
Expand All @@ -207,10 +208,11 @@ def _compare_model_params_between_state_dicts(state_dict1, state_dict2):
for param_name in state_dict2_model_params.keys():
state_dict1_model_tensor = state_dict1_model_params[param_name].cpu()
state_dict2_model_tensor = state_dict2_model_params[param_name].cpu()
assert torch.equal(
state_dict1_model_tensor,
state_dict2_model_tensor,
), f'Weight named {param_name} not the same between state_dicts'
if isinstance(state_dict1_model_tensor, ShardedTensor):
state_dict1_model_tensor = state_dict1_model_tensor.local_tensor()
if isinstance(state_dict2_model_tensor, ShardedTensor):
state_dict2_model_tensor = state_dict2_model_tensor.local_tensor()
torch.testing.assert_close(state_dict1_model_tensor, state_dict2_model_tensor)


def _compare_rng_states_between_trainers(rng_state1, rng_state2):
Expand Down

0 comments on commit 9da9d81

Please sign in to comment.