Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix torch compile with FSDP #1919

Merged
merged 9 commits into from
Sep 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION
from .utils.other import is_compiled_module


if is_deepspeed_available():
Expand Down Expand Up @@ -1215,7 +1216,12 @@ def prepare(self, *args, device_placement=None):
for obj in args:
if isinstance(obj, torch.nn.Module):
model_count += 1
is_type_fsdp = type(obj) == FSDP
# if the model is compiled using PyTorch 2.0,
# check that the wrapped model is FSDP or not;
# else check if it is FSDP or not;
is_type_fsdp = isinstance(obj, FSDP) or (
is_compiled_module(obj) and isinstance(obj._orig_mod, FSDP)
)
if isinstance(obj, torch.optim.Optimizer):
optimizer_present = True
if model_count > 1 and optimizer_present:
Expand Down Expand Up @@ -1371,7 +1377,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)

if self.native_amp and self.distributed_type != DistributedType.FSDP:
if self.native_amp:
model._original_forward = model.forward
model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
Expand Down Expand Up @@ -1423,7 +1429,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e

# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
if type(model) != FSDP:
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
# is a FSDP model, don't wrap it again
is_type_fsdp = isinstance(model, FSDP) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
)

if not is_type_fsdp:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin
kwargs = {
Expand Down Expand Up @@ -1456,15 +1468,17 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)

# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
self._models[-1] = model
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif self.distributed_type == DistributedType.TPU and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# torch.compile should be called last.
if self.state.dynamo_plugin.backend != DynamoBackend.NO:
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
Expand Down
Loading