Skip to content

Commit

Permalink
update to prepare_model_for_kbit_training
Browse files Browse the repository at this point in the history
from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for #694
  • Loading branch information
mnoukhov committed Sep 2, 2023
1 parent 34e6948 commit 16ef09b
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 16 deletions.
4 changes: 2 additions & 2 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ trainer.train()
Pay attention to the following best practices when training a model with that trainer:

- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.

Expand All @@ -346,4 +346,4 @@ Pay attention to the following best practices when training a model with that tr

## ConstantLengthDataset

[[autodoc]] trainer.ConstantLengthDataset
[[autodoc]] trainer.ConstantLengthDataset
4 changes: 2 additions & 2 deletions docs/source/using_llama_models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ model = AutoModelForCausalLM.from_pretrained(
load_in_8bit=True,
device_map={"": Accelerator().local_process_index}
)
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model)

# add LoRA to model
lora_config = LoraConfig(
Expand Down Expand Up @@ -157,4 +157,4 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
ppo_trainer.log_stats(stats, batch, rewards)
```

For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
12 changes: 6 additions & 6 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
PeftModelForSeq2SeqLM,
PromptLearningConfig,
get_peft_model,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
)
from peft.peft_model import set_peft_model_state_dict

Expand Down Expand Up @@ -108,7 +108,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
`from_pretrained` method. We also pre-process the kwargs to extract
the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models. The kwargs
also support `prepare_model_for_int8_training` arguments from
also support `prepare_model_for_kbit_training` arguments from
`peft` library.
"""
if kwargs is not None:
Expand Down Expand Up @@ -203,7 +203,7 @@ class and the arguments that are specific to trl models. The kwargs
if peft_config is not None:
# Initialize a new peft adapter with the given config
if is_loaded_in_8bit or is_loaded_in_4bit:
pretrained_model = prepare_model_for_int8_training(
pretrained_model = prepare_model_for_kbit_training(
pretrained_model,
**peft_quantization_kwargs,
)
Expand All @@ -216,7 +216,7 @@ class and the arguments that are specific to trl models. The kwargs
if peft_config is not None and isinstance(pretrained_model, PreTrainedModel):
# Initialize a new peft adapter with the given config
if is_loaded_in_8bit or is_loaded_in_4bit:
pretrained_model = prepare_model_for_int8_training(
pretrained_model = prepare_model_for_kbit_training(
pretrained_model,
**peft_quantization_kwargs,
)
Expand Down Expand Up @@ -339,7 +339,7 @@ def _split_kwargs(cls, kwargs):
check_peft_kwargs = False

if is_peft_available():
from peft import prepare_model_for_int8_training
from peft import prepare_model_for_kbit_training

check_peft_kwargs = True

Expand All @@ -354,7 +354,7 @@ def _split_kwargs(cls, kwargs):
unsupported_kwargs[key] = value

if check_peft_kwargs:
if key in prepare_model_for_int8_training.__code__.co_varnames:
if key in prepare_model_for_kbit_training.__code__.co_varnames:
peft_kwargs[key] = value
if key in unsupported_kwargs:
unsupported_kwargs.pop(key)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


class DPOTrainer(Trainer):
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
model = get_peft_model(model, peft_config)

self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


class RewardTrainer(Trainer):
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)

model = get_peft_model(model, peft_config)

Expand Down
6 changes: 4 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training


class SFTTrainer(Trainer):
Expand Down Expand Up @@ -145,7 +145,9 @@ def __init__(
)

if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
)

model = get_peft_model(model, peft_config)

Expand Down

0 comments on commit 16ef09b

Please sign in to comment.