Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto_wrap_policy for PEFT with FSDP #2253

Closed
wants to merge 4 commits into from

Conversation

fs4r
Copy link

@fs4r fs4r commented Dec 14, 2023

What does this PR do?

While doing finetuning with a PEFT model and FSDP I noticed excessive memory usage caused by the parameter use_orig_params which leads to additional memory usage for frozen parameters. When trying to disable use_orig_params the following error was thrown by FSDP

ValueError: Must flatten tensors with uniform requires_grad when use_orig_params=False

The reason for this is documented in the torch library.

FSDP has some constraints on freezing parameters (i.e. setting param.requires_grad=False). For use_orig_params=False, each FSDP instance must manage parameters that are all frozen or all non-frozen. For use_orig_params=True, FSDP supports mixing frozen and non-frozen, but we recommend not doing so since then the gradient memory usage will be higher than expected (namely, equivalent to not freezing those parameters). This means that ideally, frozen parameters should be isolated into their own nn.Module s and wrapped separately with FSDP.

I implemented a check in the auto_wrap_policy function of the FullyShardedDataParallelPlugin and wrapped the unfrozen parameters individually if necessary.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?
    I wrote an additional test called test_auto_wrap_policy_peftwhich tests if FSDP fails when using PEFT and use_orig_params=False.

Who can review?

@pacman100
@muellerzr

@fs4r
Copy link
Author

fs4r commented Dec 14, 2023

I also added a test called test_auto_wrap_policy_peftin the fsdp tests

@muellerzr muellerzr requested a review from pacman100 December 14, 2023 12:13
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for improving the interop of accelerate with PEFT in an FSDP setting. I have a few small comments, please check.

Regarding the change in general, I have a question. I don't have experience with FSDP, so maybe it's obvious, but what is the special thing about PEFT that requires this extra policy? Is it just that PEFT results in a mixture of params with and without requires_grad or is there something else that's special for PEFT?

src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
@fs4r
Copy link
Author

fs4r commented Dec 14, 2023

Thanks a lot for improving the interop of accelerate with PEFT in an FSDP setting. I have a few small comments, please check.

Regarding the change in general, I have a question. I don't have experience with FSDP, so maybe it's obvious, but what is the special thing about PEFT that requires this extra policy? Is it just that PEFT results in a mixture of params with and without requires_grad or is there something else that's special for PEFT?

Yes the reason is a mixture of params with and without grad. And PEFT does not make sense with use_orig_params=True configuration. That's why this is necessary.

@BenjaminBossan
Copy link
Member

Yes the reason is a mixture of params with and without grad. And PEFT does not make sense with use_orig_params=True configuration. That's why this is necessary.

Thanks for explaining. I wonder if this could be detached from PEFT, WDYT? This would also allow to avoid the PeftModel import, which needs to be guarded otherwise, as PEFT is not a strict dependency of accelerate.

@fs4r
Copy link
Author

fs4r commented Dec 14, 2023

Yes the reason is a mixture of params with and without grad. And PEFT does not make sense with use_orig_params=True configuration. That's why this is necessary.

Thanks for explaining. I wonder if this could be detached from PEFT, WDYT? This would also allow to avoid the PeftModel import, which needs to be guarded otherwise, as PEFT is not a strict dependency of accelerate.

What we could do is check if there exist parameters with requires_grad=False in the model and if so run the wrapping. I don't know if there exist cases where you do not want the extra wrapping (using use_orig_params=True). @pacman100 what do you think?

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, use_orig_params=True is required if one wants to prepare model and optimizer together in a single accelerator.prepare call else without it one needs to prepare/wrap model in FSDp before creating the optimizer. See the PR #2177 for more information.

With respect to auto wrap policy for PEFT when using FSDP, it is already present in PEFT here: https://github.com/huggingface/peft/blob/482a2a6d9aaa01d534b1240e8c1ab6d346eb278f/src/peft/utils/other.py#L354 and used in the example here: https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq_accelerate_fsdp.py

Can you please provide a minimal reproducer for us to deep dive into the memory increase?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jan 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants