Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix torch compile with FSDP #1919

Merged
merged 9 commits into from
Sep 14, 2023
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION
from .utils.other import is_compiled_module


if is_deepspeed_available():
Expand Down Expand Up @@ -1215,7 +1216,9 @@ def prepare(self, *args, device_placement=None):
for obj in args:
if isinstance(obj, torch.nn.Module):
model_count += 1
is_type_fsdp = type(obj) == FSDP
if is_compiled_module(obj):
unwrapped_model = obj._orig_mod
is_type_fsdp = (type(obj) == FSDP) or isinstance(unwrapped_model, FSDP)
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(obj, torch.optim.Optimizer):
optimizer_present = True
if model_count > 1 and optimizer_present:
Expand Down Expand Up @@ -1371,7 +1374,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)

if self.native_amp and self.distributed_type != DistributedType.FSDP:
if self.native_amp:
model._original_forward = model.forward
model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
Expand Down Expand Up @@ -1423,7 +1426,11 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e

# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
if type(model) != FSDP:
if is_compiled_module(model):
unwrapped_model = model._orig_mod
is_type_fsdp = (type(model) == FSDP) or isinstance(unwrapped_model, FSDP)
pacman100 marked this conversation as resolved.
Show resolved Hide resolved

if not is_type_fsdp:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin
kwargs = {
Expand Down Expand Up @@ -1456,15 +1463,17 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)

# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
self._models[-1] = model
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif self.distributed_type == DistributedType.TPU and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# torch.compile should be called last.
if self.state.dynamo_plugin.backend != DynamoBackend.NO:
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
Expand Down
Loading