Skip to content

Commit

Permalink
Fix FSDP for Models with Frozen Weights (#5484) (#5539)
Browse files Browse the repository at this point in the history
* Fix fsdp not freeing forzen full params

* add test

* formatting

* remove unnecessary env var in test

Co-authored-by: Liyang90 <[email protected]>
  • Loading branch information
wonjoolee95 and Liyang90 authored Sep 6, 2023
1 parent a1d3651 commit 89574d2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
32 changes: 32 additions & 0 deletions test/test_torch_distributed_fsdp_frozen_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP


def _mp_fn(index):
dev = xm.xla_device()
if xm.xla_device_hw(dev) not in ('TPU', 'GPU'):
print(
'Default device {} is not a TPU or GPU device'.format(dev),
file=sys.stderr)
return

model = nn.Linear(1024, 1024)
model.weight.requires_grad = False # the weight param is frozen

model = FSDP(model) # wrapping the linear module with FSDP

input = torch.rand((2, 1024), device=xm.xla_device())

output = model(input)
loss = torch.sum(output)
loss.backward()
assert not any(p._has_full_param for p in model.full_params), \
'Expecting all the full params to be freed at this moment.'


if __name__ == "__main__":
xmp.spawn(_mp_fn, args=())
Original file line number Diff line number Diff line change
Expand Up @@ -1299,13 +1299,19 @@ def _wait_for_post_backward(self) -> None:
# A backward pass is done, clean up below.
def _finalize_parameters(fsdp_module: XlaFullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
frozen_params = []
for p in fsdp_module.full_params:
if not p.requires_grad:
continue
frozen_params.append(p)
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
# Free the full params with `requires_grad==False`
if frozen_params:
fsdp_module._free_full_params(
frozen_params,
apply_opt_barrier=self.optimization_barrier_in_backward)

# Update root and nested FSDP's hooks and flags.
for m in self.modules(): # includes self
Expand Down

0 comments on commit 89574d2

Please sign in to comment.