Skip to content

Commit

Permalink
Adds DTensor Support (#2821)
Browse files Browse the repository at this point in the history
* fixes to get dtensor to work

* more fixes

* Change state dict materialization for new version of torch

* get load working for new set_state_dict api

* use device_mesh

* Add fsdp init monkeypatch for DTensor

* Add checkpoint profiling logs

* attempt

* working single node

* fix optimizer

* allow 3d device mesh

* attempt to use different pg during 3d mesh save

* undo 3d mesh changes

* load_state_dict -> load

* allow parent mesh in FSDP init

* allow override of force_sync_module_states

* remove unnecessary exit

* ignore _validate_and_get_shard_state()

* save/load hsdp-moe working

* remove prints

* v1

* v2

* lint

* add more tests

* switch to PRs

* ignore warning

* fix lint

* version error

* fix version

* fix state dict

* update versions

* lint

* lint

* disable lint for mosaic fsdp utils

* remove bad line

* move around for legacy

* device mesh

* ignore warning

* fix import

* always init

* fix error

* fix load planner

* remove

* fix lint

* lint

* delay state dict

* test checkpoint

* checkpoint

* fix cpu tests

* fix rotate tests

* fix precision

* lint

* fix alibi

* cleanup

* cleanup

* remove force sync

* fix type

* merge

* lint

* fix gpt

* comment

* fix test

* lint

* minor optimizations

* Update composer/core/state.py

Co-authored-by: Evan Racah <[email protected]>

* revert tests

---------

Co-authored-by: Evan Racah <[email protected]>
Co-authored-by: Abhinav Venigalla <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Abhinav Venigalla <[email protected]>
Co-authored-by: Your Name <[email protected]>
Co-authored-by: Evan Racah <[email protected]>
  • Loading branch information
7 people authored Jan 9, 2024
1 parent 7b70dde commit 94e0386
Show file tree
Hide file tree
Showing 14 changed files with 565 additions and 112 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ jobs:
markers: 'not daily and not remote and not gpu and not vision and not doctest'
pytest_command: 'coverage run -m pytest'
composer_package_name: 'mosaicml'
# - name: 'cpu-3.10-2.2'
# container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04
# markers: 'not daily and not remote and not gpu and not vision and not doctest'
# pytest_command: 'coverage run -m pytest'
# composer_package_name: 'mosaicml'
- name: 'cpu-vision'
container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04
markers: 'not daily and not remote and not gpu and vision and not doctest'
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ jobs:
markers: 'not daily and not remote and gpu and (doctest or not doctest)'
pytest_command: 'coverage run -m pytest'
composer_package_name: 'mosaicml'
# - name: 'gpu-3.10-2.2'
# container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04
# markers: 'not daily and not remote and gpu and (doctest or not doctest)'
# pytest_command: 'coverage run -m pytest'
# composer_package_name: 'mosaicml'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand Down
10 changes: 6 additions & 4 deletions composer/algorithms/alibi/attention_surgery_functions/_bert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import copy
import math
from types import MethodType
from typing import Optional, Tuple
Expand All @@ -20,13 +21,14 @@ def bert_embedding_converter(module: torch.nn.Module, module_index: int, max_seq
"""
assert isinstance(module, (BertEmbeddings, RobertaEmbeddings))
del module_index # unused
zero_and_freeze_expand_position_embeddings(module,
new_module = copy.deepcopy(module)
zero_and_freeze_expand_position_embeddings(new_module,
max_sequence_length,
position_embedding_attribute='position_embeddings')

module_device = next(module.parameters()).device
module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device))
return module
module_device = next(new_module.parameters()).device
new_module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device))
return new_module


@policy_registry.register(BertSelfAttention, RobertaSelfAttention)
Expand Down
179 changes: 121 additions & 58 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -792,6 +792,15 @@ def fsdp_state_dict_type(self):
def fsdp_sharded_state_dict_enabled(self):
return self.fsdp_config is not None and self.fsdp_enabled and self.fsdp_state_dict_type in ['sharded', 'local']

@property
def fsdp_device_mesh(self):
if self.fsdp_enabled:
if not hasattr(self.model, 'model'):
return None
return self.model.model._device_mesh
else:
return None

@property
def load_fsdp_monolith_rank0_only(self):
return self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config[
Expand Down Expand Up @@ -864,6 +873,9 @@ def get_model_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the model.
"""
return self.get_model_and_optimizer_state_dict(model_only=True)[0]

def _legacy_get_model_state_dict(self) -> Dict[str, Any]:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
model_state_dict = self.model.state_dict()
Expand All @@ -876,30 +888,60 @@ def get_model_state_dict(self) -> Dict[str, Any]:
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')
return model_state_dict

def _legacy_get_optim_state_dict(self) -> Dict[str, Any]:
optimizer = ensure_tuple(self.optimizers)[0] # Let's stop pretending. We don't support more than one optimizer.
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
optim_state_dict = {
type(optimizer).__qualname__:
fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type)
}
else:
optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()}
return optim_state_dict

def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if version.parse(torch.__version__) > version.parse('2.1.3'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
textwrap.dedent(f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for '
f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set '
'fsdp_state_dict_type to None, "full", or "sharded".'))

optimizer = ensure_tuple(self.optimizers)[0]
model_state_dict, optim_state_dict = get_state_dict(
model=self.model,
optimizers=([] if model_only else optimizer),
submodules=None,
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type != 'sharded',
cpu_offload=True,
),
)
optim_state_dict = {type(optimizer).__qualname__: optim_state_dict}
else:
model_state_dict = self._legacy_get_model_state_dict()
optim_state_dict = self._legacy_get_optim_state_dict()

return model_state_dict, optim_state_dict

def state_dict(self) -> Dict[str, Any]:
"""Collect the state dicts of our serializable attributes.
Returns:
Dict[str, Any]: The state dict.
"""
state_dict = {}

model_state_dict, optim_state_dict = None, None
if 'model' in self.serialized_attributes or 'optimizers' in self.serialized_attributes:
model_state_dict, optim_state_dict = self.get_model_and_optimizer_state_dict()
for attribute_name in self.serialized_attributes:
attribute_value = getattr(self, attribute_name)
if attribute_name == 'dataset_state':
serialized_value = self._dataset_state_dict()
elif attribute_name == 'model':
serialized_value = self.get_model_state_dict()
serialized_value = model_state_dict
elif attribute_name == 'optimizers':
optimizer = ensure_tuple(attribute_value)[
0] # Let's stop pretending. We don't support more than one optimizer.
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
optim_state_dict = {
type(optimizer).__qualname__:
fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type)
}
else:
optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()}
serialized_value = optim_state_dict
elif attribute_name == 'algorithms':
# Store as list to preserve order in which algorithms were applied
Expand Down Expand Up @@ -1058,49 +1100,34 @@ def _apply_required_algorithms(
'have undergone surgery, the following algorithms may be excluded using '
f'`load_exclude_algorithms`, e.g. `load_exclude_algorithms=[{missing_algo_names}]`.')) from e

def load_model_state(
def _legacy_load_model_state(
self,
state_dict: Dict[str, Any],
logger: Logger,
strict: bool,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
):
"""Loads the model's state from a ``state_dict``.
Args:
state_dict (Dict[str, Any]): The state dict, generated from a previous call to :meth:`state_dict`.
logger (Logger): The logger.
strict (bool): Whether the keys (i.e., model parameter names) in the model state dict should
perfectly match the keys in the model instance.
exclude_algorithms (List[str], optional): List of algorithm names to exclude from autoloading. (default: ``None``)
algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms
to sort them into the correct order. (default: ``None``)
"""
if 'algorithms' in state_dict:
self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes)

if state_dict.get('is_model_ddp', False) and not self.is_model_ddp:
# This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state
# with the `module.` prefix
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.')

# For FSDP monolith checkpoints, the model does not exist on ranks > 0
model_on_rank = state_dict['model'] is not None
if state_dict['model'] is None:
return

missing_keys, unexpected_keys = [], []
try:
# Load model if it exists. For FSDP monolith checkpoints, the model does not exist on ranks > 0
if model_on_rank:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only:
log.debug(
f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}'
)
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
else:
log.debug(f'Loading model state dict with strict={strict}')
# Load model if it exists
if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only:
log.debug(
f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}'
)
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
else:
log.debug(f'Loading model state dict with strict={strict}')
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict)
except RuntimeError as e:
if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e):
raise RuntimeError(
Expand All @@ -1110,9 +1137,9 @@ def load_model_state(
else:
raise e

if model_on_rank and len(missing_keys) > 0:
if len(missing_keys) > 0:
log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
if model_on_rank and len(unexpected_keys) > 0:
if len(unexpected_keys) > 0:
if self.fsdp_config is not None and self.fsdp_config[
'use_orig_params'] and self.fsdp_state_dict_type == 'local':
log.warning(
Expand All @@ -1122,16 +1149,7 @@ def load_model_state(
'was still loaded correctly.')
log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")

# If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading
if self.load_fsdp_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device,
self.auto_microbatching)
log.debug('Finished wrapping model with FSDP.')

def load_optim_state(self, state_dict: Dict[str, Any]):
def _legacy_load_optim_state(self, state_dict: Dict[str, Any]):
"""Load the optimizer state.
Args:
Expand Down Expand Up @@ -1205,6 +1223,55 @@ def _load_dataset_state(self, obj: Dict[str, Any]) -> None:
# starts. This avoids "CUDA error: initialization error" -- its not clear why.
# self.dataset_resumption['eval'][evaluator.label] = True

def load_model_and_optimizer_state(
self,
state_dict: Dict[str, Any],
logger: Logger,
strict: bool,
exclude_algorithms: Optional[List[str]] = None,
algorithm_passes: Optional[List[AlgorithmPass]] = None,
load_model_only: bool = False,
):
if 'algorithms' in state_dict:
self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes)

if state_dict.get('is_model_ddp', False) and not self.is_model_ddp:
# This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state
# with the `module.` prefix
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.')

# Load model and optimizer state
use_state_dict_fns = version.parse(torch.__version__) > version.parse('2.1.3')
if use_state_dict_fns:
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_state_dict
model_state_dict = state_dict.get('model', {})
optimizer, optim_state_dict = [], {}
if not load_model_only:
optimizer = ensure_tuple(self.optimizers)[0]
optim_state_dict = state_dict['optimizers'].get(type(optimizer).__qualname__, {})
set_state_dict(
self.model,
optimizers=optimizer,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
options=StateDictOptions(strict=strict, cpu_offload=True),
)
else:
self._legacy_load_model_state(state_dict, strict)

# If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading
if self.load_fsdp_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device,
self.auto_microbatching)
log.debug('Finished wrapping model with FSDP.')

# Legacy optimizer state load must happen after FSDP monolith
if not use_state_dict_fns and not load_model_only:
self._legacy_load_optim_state(state_dict)

def load_state_dict(
self,
state: Dict[str, Any],
Expand All @@ -1228,28 +1295,26 @@ def load_state_dict(

# Call load_model_state first since it applies required algorithms
if 'model' in state:
self.load_model_state(
self.load_model_and_optimizer_state(
state,
logger,
strict=strict,
exclude_algorithms=exclude_algorithms,
algorithm_passes=algorithm_passes,
load_model_only=(not 'optimizers' in state),
)

for attribute_name in sorted(state.keys()): # Sort so all ranks load in the same order
serialized_value = state[attribute_name]
# Skip removed attributes as well as algorithms and model, which was already loaded
if attribute_name not in self.serialized_attributes or attribute_name == 'model':
if attribute_name not in self.serialized_attributes or attribute_name in ['model', 'optimizers']:
continue

# Integrations are extra information about other libraries (e.g. huggingface) and not attributes to be loaded here
if attribute_name == 'integrations':
continue

# Skip metadata, which is not an attribute on State
if attribute_name == 'metadata':
continue

log.debug(f'Loading {attribute_name} into state.')

# Restructure algorithms serialized_value from list to dict
Expand All @@ -1258,8 +1323,6 @@ def load_state_dict(

if attribute_name == 'dataset_state':
self._load_dataset_state(serialized_value)
elif attribute_name == 'optimizers':
self.load_optim_state(state)
elif attribute_name == 'train_metrics':
# Get current metrics object and populate each metric present
# in serialization with serialized data via load_state_dict()
Expand Down
16 changes: 12 additions & 4 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,11 @@ def prepare_fsdp_module(
'gpu and some ranks are on meta. Either keep all ranks on the same '
"device or set fsdp_config['sync_module_states'] = True. Otherwise, "
'some weights may be randomly initialized when loading a checkpoint.')
if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'):
raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires '
'fsdp_config["sync_module_states"] = True or different replicas will '
'have different weights.')
# Comment out while we debug deadlock
# if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'):
# raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires '
# 'fsdp_config["sync_module_states"] = True or different replicas will '
# 'have different weights.')

# Check if other ranks OOMed after forward/backward pass when using auto microbatching. This
# may happen when close to memory limit or with uneven memory usage across ranks. Since we
Expand All @@ -273,6 +274,13 @@ def sync_hook(*args):
# `nn.Module.named_parameters`.
# Setting it to `True` is mandatory when using `torch.compile()`.
kwargs['use_orig_params'] = fsdp_config['use_orig_params']
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'):
if 'device_mesh' in fsdp_config:
from torch.distributed._tensor import init_device_mesh
kwargs['device_mesh'] = init_device_mesh(
'cuda',
tuple([int(x) for x in fsdp_config['device_mesh']]),
)

# necessary variables for optimizers with multiple param groups in FSDP
num_param_groups = None
Expand Down
Loading

0 comments on commit 94e0386

Please sign in to comment.