Skip to content

Commit

Permalink
fix torch compile with FSDP (#1919)
Browse files Browse the repository at this point in the history
* fix torch compile with FSDP

* Update accelerator.py

* fixes

* resolve comments

* fix bug

* address comments

* addressing comments

* address comments
  • Loading branch information
pacman100 authored Sep 14, 2023
1 parent 40a73e0 commit e5452a6
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION
from .utils.other import is_compiled_module


if is_deepspeed_available():
Expand Down Expand Up @@ -1221,7 +1222,12 @@ def prepare(self, *args, device_placement=None):
for obj in args:
if isinstance(obj, torch.nn.Module):
model_count += 1
is_type_fsdp = type(obj) == FSDP
# if the model is compiled using PyTorch 2.0,
# check that the wrapped model is FSDP or not;
# else check if it is FSDP or not;
is_type_fsdp = isinstance(obj, FSDP) or (
is_compiled_module(obj) and isinstance(obj._orig_mod, FSDP)
)
if isinstance(obj, torch.optim.Optimizer):
optimizer_present = True
if model_count > 1 and optimizer_present:
Expand Down Expand Up @@ -1377,7 +1383,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)

if self.native_amp and self.distributed_type != DistributedType.FSDP:
if self.native_amp:
model._original_forward = model.forward
model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
Expand Down Expand Up @@ -1429,7 +1435,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e

# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
if type(model) != FSDP:
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
# is a FSDP model, don't wrap it again
is_type_fsdp = isinstance(model, FSDP) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
)

if not is_type_fsdp:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin
kwargs = {
Expand Down Expand Up @@ -1462,15 +1474,17 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)

# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
self._models[-1] = model
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif self.distributed_type == DistributedType.TPU and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# torch.compile should be called last.
if self.state.dynamo_plugin.backend != DynamoBackend.NO:
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
Expand Down

0 comments on commit e5452a6

Please sign in to comment.