Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Feb 22, 2024
1 parent 0a31299 commit 8aeeb8a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 92 deletions.
6 changes: 0 additions & 6 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,3 @@ def patch_pytorch():

from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state
_optim_utils._shard_orig_param_state = _shard_orig_param_state

# Monkeypatch dtensor wrapping
from torch.distributed.fsdp import _flat_param

from composer.trainer.mosaic_fsdp_utils import _use_unsharded_views
_flat_param._use_unsharded_views = _use_unsharded_views
146 changes: 60 additions & 86 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def _sharded_pre_load_state_dict_hook(

if version.parse(torch.__version__) > version.parse('2.2.9') and version.parse(
torch.__version__) < version.parse('2.3.1'):
import os
import copy

from torch.distributed._tensor import DeviceMesh, DTensor, Replicate
Expand All @@ -630,7 +631,7 @@ def _sharded_pre_load_state_dict_hook(
_is_valid_hybrid_shard_pg_type, _init_extension)
from torch.distributed.fsdp.fully_sharded_data_parallel import (_annotate_modules_for_dynamo, _auto_wrap,
_check_orig_params_flattened, _init_buffer_state,
_init_core_state, _init_device_handle,
_init_device_handle,
_init_ignored_module_states,
_init_param_handle_from_module,
_init_prefetching_state, _init_runtime_state,
Expand All @@ -640,6 +641,63 @@ def _sharded_pre_load_state_dict_hook(
from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy, _Policy
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions


from torch.distributed.fsdp._common_utils import _FSDPState, TrainingState
from torch.distributed.fsdp._flat_param import _FSDP_USE_FULL_PREC_IN_EVAL, FlatParameter, FlatParamHandle
import torch.distributed.fsdp._exec_order_utils as exec_order_utils
from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
def _init_core_state_t2p3p0(
state: _FSDPState,
sharding_strategy: Optional[ShardingStrategy],
mixed_precision: Optional[MixedPrecision],
cpu_offload: Optional[CPUOffload],
limit_all_gathers: bool,
use_orig_params: bool,
backward_prefetch_limit: int,
forward_prefetch_limit: int,
) -> _FSDPState:
if sharding_strategy == ShardingStrategy.NO_SHARD:
warnings.warn(
"The `NO_SHARD` sharding strategy is deprecated. If having issues, "
"please use DistributedDataParallel instead.",
# Level 1 is here, level 2 is from `FullyShardedDataParallel`, and
# level 3 is from the true caller
stacklevel=3,
)
state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
state.mixed_precision = mixed_precision or MixedPrecision()
if mixed_precision is not None:
torch._C._log_api_usage_once(
f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}"
)
state._use_full_prec_in_eval = (
os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
)
state.cpu_offload = cpu_offload or CPUOffload()
state.limit_all_gathers = limit_all_gathers
state._use_orig_params = use_orig_params
state.training_state = TrainingState.IDLE
state._is_root = None
state._free_event_queue = _FreeEventQueue()
state._debug_level = dist.get_debug_level()
state._exec_order_data = exec_order_utils._ExecOrderData(
state._debug_level,
backward_prefetch_limit,
forward_prefetch_limit,
)
# Mapping from fully sharded module to the handles it is responsible to
# unshard and reshard (see [Note: Fully Sharded Module])
_fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = dict()
state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
# Invariant: `state.params` contains exactly the `FlatParameter`s of the
# handles in `state._handle`
_handle: FlatParamHandle = None
state._handle = _handle
params: List[FlatParameter] = []
state.params = params
return state


def all_gather_dtensor_t2p3p0(
self,
tensor: DTensor,
Expand Down Expand Up @@ -883,7 +941,7 @@ def init_fn_t2p3p0(

backward_prefetch_limit = 1
forward_prefetch_limit = 1
_init_core_state(
_init_core_state_t2p3p0(
self,
sharding_strategy,
mixed_precision,
Expand Down Expand Up @@ -1033,87 +1091,3 @@ def _shard_orig_param_state(
new_optim_state[state_name] = value
torch.cuda.synchronize()
return new_optim_state

from torch.distributed.utils import _p_assert
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened, HandleTrainingState
@torch.enable_grad()
def _use_unsharded_views(self, as_params: bool) -> None:
"""Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it."""
flat_param = self.flat_param
self._check_unsharded(flat_param)
views = self._get_unflat_views()
from torch.distributed._tensor import DTensor

for i, (view, (param_name, module, _)) in enumerate(
zip(views, flat_param._param_infos)
):
if self._use_orig_params and as_params:
if type(view) is DTensor:
# A `DTensor` `view` is not compatible with assigning
# `param.data = view`, so we cannot preserve the parameter
# variable.
self._setattr_param(
module,
param_name,
nn.Parameter(view, requires_grad=flat_param.requires_grad),
)
# Since wrapping with `nn.Parameter` constructs a new
# object, we re-set that it is FSDP-flattened
_set_fsdp_flattened(getattr(module, param_name))
continue
param = self.flat_param._params[i]
self._setattr_param(module, param_name, param)
param.data = view
elif as_params:
self._setattr_param(
module,
param_name,
nn.Parameter(view, requires_grad=flat_param.requires_grad),
)
else: # `as_params=False`
param_var = view
if self._use_orig_params:
if self._training_state == HandleTrainingState.FORWARD:
# Save the `Tensor` for the pre-backward
self.flat_param._tensors[i] = view # save for pre-backward
elif self._training_state == HandleTrainingState.BACKWARD_PRE:
# Use the saved `Tensor` variable from the forward to
# preserve the autograd graph so that the post-backward
# hook fires (e.g. for reentrant AC)
tensor = self.flat_param._tensors[i]
tensor.data = view
param_var = tensor
self._setattr_tensor(module, param_name, param_var)
if (
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
module._parameters[param_name] = param_var
for i, (
param_name,
module,
_,
prim_param_name,
prim_module,
_,
) in enumerate(self.flat_param._shared_param_infos):
prim_param = getattr(
prim_module, prim_param_name
)
_p_assert(
not as_params or isinstance(prim_param, nn.Parameter),
f'as_params={as_params} type(prim_param)={type(prim_param)}',
)
if self._use_orig_params and as_params:
shared_param = self.flat_param._shared_params[i]
self._setattr_param(module, param_name, shared_param)
shared_param.data = prim_param
elif as_params:
self._setattr_param(module, param_name, prim_param)
else:
self._setattr_tensor(module, param_name, prim_param)
if (
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
module._parameters[param_name] = prim_param

0 comments on commit 8aeeb8a

Please sign in to comment.