Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core / DDP] Fix RM trainer + DDP + quantization + propagate gradient_checkpointing_kwargs in SFT & DPO #912

Merged
merged 14 commits into from
Oct 31, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 24, 2023

Needs : huggingface/peft#1036 / huggingface/transformers#27020

Fixes: #891
Fixes: #835

To avoid issues with PEFT + DDP, we need to call the gradient checkpointing method with use_reentrant=False that you can pass over the argument gradient_checkpointing_kwargs directly in trainer with huggingface/transformers#27020 .

For users that do not have the correct transformers version they need to update transformers with the correct version to get that feature, otherwise gradient_checkpointing_kwargs will get ignored.

cc @lvwerra

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 24, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! One minor comment.

@@ -103,7 +103,7 @@ class ScriptArguments:

# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")
train_dataset = load_dataset(args.dataset_name, split="train[:50]")
Copy link
Contributor

@vwxyzjn vwxyzjn Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be not hardcoded? Maybe like split=args.split

@younesbelkada younesbelkada reopened this Oct 30, 2023
@younesbelkada younesbelkada changed the title [core / DDP] Fix RM trainer + DDP + quantization [core / DDP] Fix RM trainer + DDP + quantization + propagate gradient_checkpointing_kwargs Oct 30, 2023
@younesbelkada younesbelkada changed the title [core / DDP] Fix RM trainer + DDP + quantization + propagate gradient_checkpointing_kwargs [core / DDP] Fix RM trainer + DDP + quantization + propagate gradient_checkpointing_kwargs in SFT & DPO Oct 30, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 30, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada marked this pull request as ready for review October 31, 2023 15:55
@younesbelkada younesbelkada requested a review from lvwerra October 31, 2023 16:09
Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@younesbelkada younesbelkada merged commit cbc6c9b into main Oct 31, 2023
8 checks passed
@younesbelkada younesbelkada deleted the fix-rm-ddp branch October 31, 2023 17:50
hijkzzz pushed a commit to OpenRLHF/OpenRLHF that referenced this pull request Feb 5, 2024
* fix bug: generate_args-do_sample

* fix gradient_checkpointing_kwargs bug

see: huggingface/trl#912 and huggingface/transformers#26969

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
…dient_checkpointing_kwargs` in SFT & DPO (huggingface#912)

* make use of forward hooks

* correctly delete attributes

* fix RM DPP issues

* revert unneeded changes

* more fixes

* fix diff

* fix

* propagate to SFT

* Update examples/scripts/reward_modeling.py

* propagate the fix on DPO trainer

* add to example scripts

* trigger CI
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants