diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 97af3ddaa4..44f42650eb 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -67,8 +67,11 @@ def patch_pytorch(): from composer.trainer.mosaic_fsdp_utils import _same_storage _flat_param._same_storage = _same_storage + # Monkeypatch state_dict to get FQNs correctly. + # Issue: https://github.com/pytorch/pytorch/pull/124698 from torch.distributed.checkpoint import state_dict - from composer.trainer.mosaic_fsdp_utils import set_model_state_dict, set_optimizer_state_dict + from composer.trainer.mosaic_fsdp_utils import _get_fqns, set_model_state_dict, set_optimizer_state_dict state_dict.set_model_state_dict = set_model_state_dict state_dict.set_optimizer_state_dict = set_optimizer_state_dict + state_dict._get_fqns = _get_fqns diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 8b90634516..394bbce1bb 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -566,7 +566,73 @@ def _same_storage(a, b): b = b._local_tensor return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr() - from torch.distributed.checkpoint.state_dict import _unflatten_model_state_dict, _verify_options, _load_model_state_dict, gc_context, _verify_state_dict, _load_optim_state_dict + from torch.distributed.checkpoint.state_dict import (_unflatten_model_state_dict, _verify_options, + _load_model_state_dict, gc_context, + _verify_state_dict, _load_optim_state_dict, + FQNS_T) + + @no_type_check + def _get_fqns( + model: nn.Module, + name: str, + skip_ddp_prefix: bool = True, + skip_compiler_prefix: bool = True, + ) -> FQNS_T: + """Used to convert the name of a parameter to the FQNs. + + For FSDP without `use_orig_params`, the name of FlatParameter can be mapped to + multiple original parameters. As a result, the return type of this function + is `Set[str]`. + + Args: + module (nn.Module): the root model. + name (str): the name + skip_ddp_prefix (bool): whether to skip DDP's `module` prefix + + Returns: + The canonical FQNs based on the model traversal. + """ + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import _CHECKPOINT_PREFIX + from torch.nn.parallel import DistributedDataParallel as DDP + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import FLAT_PARAM + from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, '') + if '.' not in name: + return {name} + + obj_names = name.split('.') + fqn_obj_names = [] + curr_obj = model + for i, curr_obj_name in enumerate(obj_names): + if isinstance(curr_obj, DDP): + assert curr_obj_name == 'module' + curr_obj = curr_obj.module + if not skip_ddp_prefix: + fqn_obj_names.append(curr_obj_name) + elif isinstance(curr_obj, FSDP): + if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM: + prefix = '.'.join(fqn_obj_names) + flat_param = getattr(curr_obj, FLAT_PARAM) + if prefix: + prefix = f'{prefix}.' + return {f'{prefix}{fqn}' for fqn in flat_param._fqns} + curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) + if curr_obj_name != FSDP_WRAPPED_MODULE: + fqn_obj_names.append(curr_obj_name) + curr_obj = getattr(curr_obj, curr_obj_name) + elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): + assert curr_obj_name == '_orig_mod' + curr_obj = curr_obj._orig_mod + if not skip_compiler_prefix: + fqn_obj_names.append(curr_obj_name) + else: + fqn_obj_names.append(curr_obj_name) + curr_obj = getattr(curr_obj, curr_obj_name) + + return {'.'.join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, '')} def set_model_state_dict( model: nn.Module,