Skip to content

Commit

Permalink
fix 8-bit multi-gpu training bug (huggingface#1353)
Browse files Browse the repository at this point in the history
* fix 8-bit multi-gpu training bug see huggingface#1348

* Update dpo_llama2.py

make gradient_checkpointing_kwargs configurable.

* Update dpo_llama2.py

remote unnecessary config of device_map

* format with make precommit

---------

Co-authored-by: ubuntu <[email protected]>
  • Loading branch information
2 people authored and Andrew Lapp committed May 10, 2024
1 parent 7e203c2 commit 52ec199
Showing 1 changed file with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, Optional

import torch
from accelerate import Accelerator
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
Expand Down Expand Up @@ -41,6 +42,10 @@ class ScriptArguments:
default=True, metadata={"help": "whether to use gradient checkpointing"}
)

gradient_checkpointing_use_reentrant: Optional[bool] = field(
default=True, metadata={"help": "whether to use reentrant for gradient checkpointing"}
)

lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
Expand Down Expand Up @@ -129,6 +134,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
device_map={"": Accelerator().local_process_index},
)
model.config.use_cache = False

Expand Down Expand Up @@ -175,6 +181,7 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
bf16=True,
remove_unused_columns=False,
run_name="dpo_llama2",
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
)

peft_config = LoraConfig(
Expand Down

0 comments on commit 52ec199

Please sign in to comment.