Skip to content

Commit

Permalink
Patch fqns for torch 2.3.0 (#3210)
Browse files Browse the repository at this point in the history
* fqn patch

* fsdp import

* patch fqns

* patch torch

* linting
  • Loading branch information
snarayan21 authored Apr 30, 2024
1 parent fbe533a commit ffb40c0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
5 changes: 4 additions & 1 deletion composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ def patch_pytorch():
from composer.trainer.mosaic_fsdp_utils import _same_storage
_flat_param._same_storage = _same_storage

# Monkeypatch state_dict to get FQNs correctly.
# Issue: https://github.com/pytorch/pytorch/pull/124698
from torch.distributed.checkpoint import state_dict

from composer.trainer.mosaic_fsdp_utils import set_model_state_dict, set_optimizer_state_dict
from composer.trainer.mosaic_fsdp_utils import _get_fqns, set_model_state_dict, set_optimizer_state_dict
state_dict.set_model_state_dict = set_model_state_dict
state_dict.set_optimizer_state_dict = set_optimizer_state_dict
state_dict._get_fqns = _get_fqns
68 changes: 67 additions & 1 deletion composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,73 @@ def _same_storage(a, b):
b = b._local_tensor
return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()

from torch.distributed.checkpoint.state_dict import _unflatten_model_state_dict, _verify_options, _load_model_state_dict, gc_context, _verify_state_dict, _load_optim_state_dict
from torch.distributed.checkpoint.state_dict import (_unflatten_model_state_dict, _verify_options,
_load_model_state_dict, gc_context,
_verify_state_dict, _load_optim_state_dict,
FQNS_T)

@no_type_check
def _get_fqns(
model: nn.Module,
name: str,
skip_ddp_prefix: bool = True,
skip_compiler_prefix: bool = True,
) -> FQNS_T:
"""Used to convert the name of a parameter to the FQNs.
For FSDP without `use_orig_params`, the name of FlatParameter can be mapped to
multiple original parameters. As a result, the return type of this function
is `Set[str]`.
Args:
module (nn.Module): the root model.
name (str): the name
skip_ddp_prefix (bool): whether to skip DDP's `module` prefix
Returns:
The canonical FQNs based on the model traversal.
"""
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import _CHECKPOINT_PREFIX
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import FLAT_PARAM
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE

# Remove the checkpoint prefix, if it exists.
name = name.replace(_CHECKPOINT_PREFIX, '')
if '.' not in name:
return {name}

obj_names = name.split('.')
fqn_obj_names = []
curr_obj = model
for i, curr_obj_name in enumerate(obj_names):
if isinstance(curr_obj, DDP):
assert curr_obj_name == 'module'
curr_obj = curr_obj.module
if not skip_ddp_prefix:
fqn_obj_names.append(curr_obj_name)
elif isinstance(curr_obj, FSDP):
if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM:
prefix = '.'.join(fqn_obj_names)
flat_param = getattr(curr_obj, FLAT_PARAM)
if prefix:
prefix = f'{prefix}.'
return {f'{prefix}{fqn}' for fqn in flat_param._fqns}
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
if curr_obj_name != FSDP_WRAPPED_MODULE:
fqn_obj_names.append(curr_obj_name)
curr_obj = getattr(curr_obj, curr_obj_name)
elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
assert curr_obj_name == '_orig_mod'
curr_obj = curr_obj._orig_mod
if not skip_compiler_prefix:
fqn_obj_names.append(curr_obj_name)
else:
fqn_obj_names.append(curr_obj_name)
curr_obj = getattr(curr_obj, curr_obj_name)

return {'.'.join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, '')}

def set_model_state_dict(
model: nn.Module,
Expand Down

0 comments on commit ffb40c0

Please sign in to comment.