diff --git a/src/accelerate/utils/fsdp_utils.py b/src/accelerate/utils/fsdp_utils.py index fb8dc5f2cab..c92dfafd0d0 100644 --- a/src/accelerate/utils/fsdp_utils.py +++ b/src/accelerate/utils/fsdp_utils.py @@ -328,7 +328,7 @@ def merge_fsdp_weights( def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device: torch.cuda.device): - _tied_names = model._tied_weights_keys + _tied_names = getattr(model, "_tied_weights_keys", None) if not _tied_names: # if no tied names just passthrough return param_init_fn