From 89574d2d5c4cb11a84f1f649319ea45cb2203123 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Wed, 6 Sep 2023 10:00:40 -0700 Subject: [PATCH] Fix FSDP for Models with Frozen Weights (#5484) (#5539) * Fix fsdp not freeing forzen full params * add test * formatting * remove unnecessary env var in test Co-authored-by: Liyang90 --- ...st_torch_distributed_fsdp_frozen_weight.py | 32 +++++++++++++++++++ .../fsdp/xla_fully_sharded_data_parallel.py | 8 ++++- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 test/test_torch_distributed_fsdp_frozen_weight.py diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py new file mode 100644 index 00000000000..79b65a46999 --- /dev/null +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -0,0 +1,32 @@ +import sys +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + + +def _mp_fn(index): + dev = xm.xla_device() + if xm.xla_device_hw(dev) not in ('TPU', 'GPU'): + print( + 'Default device {} is not a TPU or GPU device'.format(dev), + file=sys.stderr) + return + + model = nn.Linear(1024, 1024) + model.weight.requires_grad = False # the weight param is frozen + + model = FSDP(model) # wrapping the linear module with FSDP + + input = torch.rand((2, 1024), device=xm.xla_device()) + + output = model(input) + loss = torch.sum(output) + loss.backward() + assert not any(p._has_full_param for p in model.full_params), \ + 'Expecting all the full params to be freed at this moment.' + + +if __name__ == "__main__": + xmp.spawn(_mp_fn, args=()) diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index e51a31e0b4d..f1b62d1700b 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -1299,13 +1299,19 @@ def _wait_for_post_backward(self) -> None: # A backward pass is done, clean up below. def _finalize_parameters(fsdp_module: XlaFullyShardedDataParallel) -> None: """Helper used below on all fsdp modules.""" + frozen_params = [] for p in fsdp_module.full_params: if not p.requires_grad: - continue + frozen_params.append(p) if hasattr(p, "_shard_bwd_hook"): assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook) p._shard_bwd_hook[1].remove() delattr(p, "_shard_bwd_hook") + # Free the full params with `requires_grad==False` + if frozen_params: + fsdp_module._free_full_params( + frozen_params, + apply_opt_barrier=self.optimization_barrier_in_backward) # Update root and nested FSDP's hooks and flags. for m in self.modules(): # includes self