Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to prepare_model_for_kbit_training #728

Merged
merged 3 commits into from
Sep 12, 2023

Conversation

mnoukhov
Copy link
Contributor

@mnoukhov mnoukhov commented Sep 2, 2023

since peft has deprecated prepare_model_for_int8_training

also add use_gradient_checkpointing=args.gradient_checkpointing to automatically follow the gradient checkpointing choice in training args

For RewardTrainer, this is the workaround to #480 proposed by #694.

Concurrently @lewtun is working on #726 which adds the use_gradient_checkpointing for RewardTrainer. I'm happy to wait until it is merged to merge this.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 2, 2023

The documentation is not available anymore as the PR was closed or merged.

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694
calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
@mnoukhov
Copy link
Contributor Author

mnoukhov commented Sep 3, 2023

I've realized this fix actually causes an issue. Calling model.gradient_checkpointing_enable() twice leads to the same error as #480. And it is called twice:

  1. in peft's prepare_model_for_kbit_training here
  2. in huggingface Trainer's inner training loop here

I've made a workaround here in sft_trainer but it is ugly and maybe there's a better way.

To demonstrate the issue, I fixed stack_llama_2/sft_llama2.py by making the device map use Accelerator so it actually runs on multi-gpu on a single device. Run accelerator launch --multi_gpu sft_llama2.py without fix in sft_trainer and it will error

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Thanks for deepdiving and explaining, I left one question, let me know what do you think

model, use_gradient_checkpointing=args.gradient_checkpointing
)

args = dataclasses.replace(args, gradient_checkpointing=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change here and not above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do want to call gradient_checkpointing_enable once, we just don't want to call it twice. We will call it in 'prepare_for_kbit_trainingbut this change makes sure we don't call it inTrainer`

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect makes sense!

@lewtun
Copy link
Member

lewtun commented Sep 6, 2023

FYI @mnoukhov my PR #726 has now been merged so feel free to wrap this one up - thank you 🚀 !

@younesbelkada
Copy link
Contributor

@mnoukhov thanks again for your work on this ! Would you be happy to fix the merge conflicts? After that we should be good to merge!

@mnoukhov
Copy link
Contributor Author

Pulled and should be ready to merge!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this great effort!

@younesbelkada younesbelkada merged commit b87ec2d into huggingface:main Sep 12, 2023
11 checks passed
kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
* update to `prepare_model_for_kbit_training`

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694

* workaround for gradient checkpointing issue

calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* update to `prepare_model_for_kbit_training`

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694

* workaround for gradient checkpointing issue

calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants