From 10eb855f77f4e262dfb618b1a22f25a3efe3d633 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:08:02 +0530 Subject: [PATCH 1/8] fix torch compile with FSDP --- src/accelerate/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 236d0189360..56fb4ebdee2 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1215,7 +1215,7 @@ def prepare(self, *args, device_placement=None): for obj in args: if isinstance(obj, torch.nn.Module): model_count += 1 - is_type_fsdp = type(obj) == FSDP + is_type_fsdp = (type(obj) == FSDP) or isinstance(getattr(obj, "_orig_mod", None), FSDP) if isinstance(obj, torch.optim.Optimizer): optimizer_present = True if model_count > 1 and optimizer_present: From 379f5b106e55e6f6bf8866d65e162e6b5a5e25b0 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:29:41 +0530 Subject: [PATCH 2/8] Update accelerator.py --- src/accelerate/accelerator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 56fb4ebdee2..c1abfaae75a 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1213,9 +1213,10 @@ def prepare(self, *args, device_placement=None): optimizer_present = False is_type_fsdp = False for obj in args: - if isinstance(obj, torch.nn.Module): + is_torch_compiled = getattr(obj, "_orig_mod", None) + if isinstance(obj, torch.nn.Module) or is_torch_compiled is not None: model_count += 1 - is_type_fsdp = (type(obj) == FSDP) or isinstance(getattr(obj, "_orig_mod", None), FSDP) + is_type_fsdp = (type(obj) == FSDP) or isinstance(is_torch_compiled, FSDP) if isinstance(obj, torch.optim.Optimizer): optimizer_present = True if model_count > 1 and optimizer_present: From cca37712ccab34858b670451edca14aae255f802 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 1 Sep 2023 15:12:16 +0530 Subject: [PATCH 3/8] fixes --- src/accelerate/accelerator.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index c1abfaae75a..60a5939dd90 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1213,10 +1213,9 @@ def prepare(self, *args, device_placement=None): optimizer_present = False is_type_fsdp = False for obj in args: - is_torch_compiled = getattr(obj, "_orig_mod", None) - if isinstance(obj, torch.nn.Module) or is_torch_compiled is not None: + if isinstance(obj, torch.nn.Module): model_count += 1 - is_type_fsdp = (type(obj) == FSDP) or isinstance(is_torch_compiled, FSDP) + is_type_fsdp = (type(obj) == FSDP) or isinstance(getattr(obj, "_orig_mod", None), FSDP) if isinstance(obj, torch.optim.Optimizer): optimizer_present = True if model_count > 1 and optimizer_present: @@ -1372,7 +1371,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e elif device_placement and not self.verify_device_map(model): model = model.to(self.device) - if self.native_amp and self.distributed_type != DistributedType.FSDP: + if self.native_amp: model._original_forward = model.forward model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) @@ -1424,7 +1423,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, # don't wrap it again - if type(model) != FSDP: + is_torch_compiled = getattr(model, "_orig_mod", None) + is_type_fsdp = (type(model) == FSDP) or isinstance(is_torch_compiled, FSDP) + + if not is_type_fsdp: self.state.fsdp_plugin.set_auto_wrap_policy(model) fsdp_plugin = self.state.fsdp_plugin kwargs = { @@ -1457,15 +1459,17 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ), auto_wrap_policy=fsdp_plugin.auto_wrap_policy, ) - + # if the previous and current models are same, delete the previous one + if len(self._models) > 1 and (self._models[-2] is self._models[-1]): + del self._models[-2] self._models[-1] = model elif self.distributed_type == DistributedType.MULTI_CPU: kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) elif self.distributed_type == DistributedType.TPU and self.state.fork_launched: model = xmp.MpModelWrapper(model).to(self.device) - # torch.compile should be called last. - if self.state.dynamo_plugin.backend != DynamoBackend.NO: + # torch.compile should be called last and only if the model isn't already compiled. + if self.state.dynamo_plugin.backend != DynamoBackend.NO and getattr(model, "_orig_mod", None) is None: if not is_torch_version(">=", "2.0"): raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.") model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) From bc320ce8ae30285d860bd733a79b6218e5c4dd1b Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 6 Sep 2023 11:31:58 +0530 Subject: [PATCH 4/8] resolve comments --- src/accelerate/accelerator.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 3402ba8d3f9..5ca7247d047 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -96,6 +96,7 @@ wait_for_everyone, ) from .utils.constants import FSDP_PYTORCH_VERSION +from .utils.other import is_compiled_module if is_deepspeed_available(): @@ -1215,7 +1216,9 @@ def prepare(self, *args, device_placement=None): for obj in args: if isinstance(obj, torch.nn.Module): model_count += 1 - is_type_fsdp = (type(obj) == FSDP) or isinstance(getattr(obj, "_orig_mod", None), FSDP) + if is_compiled_module(obj): + unwrapped_model = obj._orig_mod + is_type_fsdp = (type(obj) == FSDP) or isinstance(unwrapped_model, FSDP) if isinstance(obj, torch.optim.Optimizer): optimizer_present = True if model_count > 1 and optimizer_present: @@ -1423,8 +1426,9 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, # don't wrap it again - is_torch_compiled = getattr(model, "_orig_mod", None) - is_type_fsdp = (type(model) == FSDP) or isinstance(is_torch_compiled, FSDP) + if is_compiled_module(model): + unwrapped_model = model._orig_mod + is_type_fsdp = (type(model) == FSDP) or isinstance(unwrapped_model, FSDP) if not is_type_fsdp: self.state.fsdp_plugin.set_auto_wrap_policy(model) @@ -1469,7 +1473,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e elif self.distributed_type == DistributedType.TPU and self.state.fork_launched: model = xmp.MpModelWrapper(model).to(self.device) # torch.compile should be called last and only if the model isn't already compiled. - if self.state.dynamo_plugin.backend != DynamoBackend.NO and getattr(model, "_orig_mod", None) is None: + if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): if not is_torch_version(">=", "2.0"): raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.") model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) From efe9c111babe1cd2c0db30c4fb8bd8fa21e6548f Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 7 Sep 2023 12:07:36 +0530 Subject: [PATCH 5/8] fix bug --- src/accelerate/accelerator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 5ca7247d047..d01e0042fb1 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1216,8 +1216,7 @@ def prepare(self, *args, device_placement=None): for obj in args: if isinstance(obj, torch.nn.Module): model_count += 1 - if is_compiled_module(obj): - unwrapped_model = obj._orig_mod + unwrapped_model = obj._orig_mod if is_compiled_module(obj) else obj is_type_fsdp = (type(obj) == FSDP) or isinstance(unwrapped_model, FSDP) if isinstance(obj, torch.optim.Optimizer): optimizer_present = True @@ -1426,8 +1425,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, # don't wrap it again - if is_compiled_module(model): - unwrapped_model = model._orig_mod + unwrapped_model = model._orig_mod if is_compiled_module(model) else model is_type_fsdp = (type(model) == FSDP) or isinstance(unwrapped_model, FSDP) if not is_type_fsdp: From e46056e445254bed3fa13ce8aaa18d0303faef8f Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 8 Sep 2023 22:48:18 +0530 Subject: [PATCH 6/8] address comments --- src/accelerate/accelerator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index d01e0042fb1..8005a02efab 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1216,6 +1216,9 @@ def prepare(self, *args, device_placement=None): for obj in args: if isinstance(obj, torch.nn.Module): model_count += 1 + # if the model is compiled using PyTorch 2.0, + # check that the wrapped model is FSDP or not; + # else check if it is FSDP or not; unwrapped_model = obj._orig_mod if is_compiled_module(obj) else obj is_type_fsdp = (type(obj) == FSDP) or isinstance(unwrapped_model, FSDP) if isinstance(obj, torch.optim.Optimizer): @@ -1425,6 +1428,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, # don't wrap it again + # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it + # is a FSDP model, don't wrap it again unwrapped_model = model._orig_mod if is_compiled_module(model) else model is_type_fsdp = (type(model) == FSDP) or isinstance(unwrapped_model, FSDP) From 5ec951b516a095f1a1dafc8e6e88e32bc95ca3d0 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:19:07 +0530 Subject: [PATCH 7/8] addressing comments --- src/accelerate/accelerator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 8005a02efab..f9f9b9fa2ab 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1220,7 +1220,7 @@ def prepare(self, *args, device_placement=None): # check that the wrapped model is FSDP or not; # else check if it is FSDP or not; unwrapped_model = obj._orig_mod if is_compiled_module(obj) else obj - is_type_fsdp = (type(obj) == FSDP) or isinstance(unwrapped_model, FSDP) + is_type_fsdp = isinstance(obj, FSDP) or isinstance(unwrapped_model, FSDP) if isinstance(obj, torch.optim.Optimizer): optimizer_present = True if model_count > 1 and optimizer_present: @@ -1431,7 +1431,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it # is a FSDP model, don't wrap it again unwrapped_model = model._orig_mod if is_compiled_module(model) else model - is_type_fsdp = (type(model) == FSDP) or isinstance(unwrapped_model, FSDP) + is_type_fsdp = isinstance(model, FSDP) or isinstance(unwrapped_model, FSDP) if not is_type_fsdp: self.state.fsdp_plugin.set_auto_wrap_policy(model) From 2041fa7cea30e4c67431eb7950fb83300524b8af Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 14 Sep 2023 12:57:58 +0530 Subject: [PATCH 8/8] address comments --- src/accelerate/accelerator.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index f9f9b9fa2ab..58e294c7efc 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1219,8 +1219,9 @@ def prepare(self, *args, device_placement=None): # if the model is compiled using PyTorch 2.0, # check that the wrapped model is FSDP or not; # else check if it is FSDP or not; - unwrapped_model = obj._orig_mod if is_compiled_module(obj) else obj - is_type_fsdp = isinstance(obj, FSDP) or isinstance(unwrapped_model, FSDP) + is_type_fsdp = isinstance(obj, FSDP) or ( + is_compiled_module(obj) and isinstance(obj._orig_mod, FSDP) + ) if isinstance(obj, torch.optim.Optimizer): optimizer_present = True if model_count > 1 and optimizer_present: @@ -1430,8 +1431,9 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # don't wrap it again # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it # is a FSDP model, don't wrap it again - unwrapped_model = model._orig_mod if is_compiled_module(model) else model - is_type_fsdp = isinstance(model, FSDP) or isinstance(unwrapped_model, FSDP) + is_type_fsdp = isinstance(model, FSDP) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDP) + ) if not is_type_fsdp: self.state.fsdp_plugin.set_auto_wrap_policy(model)