Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to prepare_model_for_kbit_training #728

Merged
merged 3 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ trainer.train()
Pay attention to the following best practices when training a model with that trainer:

- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.

Expand All @@ -346,4 +346,4 @@ Pay attention to the following best practices when training a model with that tr

## ConstantLengthDataset

[[autodoc]] trainer.ConstantLengthDataset
[[autodoc]] trainer.ConstantLengthDataset
4 changes: 2 additions & 2 deletions docs/source/using_llama_models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ model = AutoModelForCausalLM.from_pretrained(
load_in_8bit=True,
device_map={"": Accelerator().local_process_index}
)
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model)

# add LoRA to model
lora_config = LoraConfig(
Expand Down Expand Up @@ -157,4 +157,4 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
ppo_trainer.log_stats(stats, batch, rewards)
```

For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -148,7 +149,7 @@ def create_datasets(tokenizer, args):
base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
quantization_config=bnb_config,
device_map={"": 0},
device_map={"": Accelerator().local_process_index},
trust_remote_code=True,
use_auth_token=True,
)
Expand Down
12 changes: 6 additions & 6 deletions trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
PeftModelForSeq2SeqLM,
PromptLearningConfig,
get_peft_model,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
)
from peft.peft_model import set_peft_model_state_dict

Expand Down Expand Up @@ -108,7 +108,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
`from_pretrained` method. We also pre-process the kwargs to extract
the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models. The kwargs
also support `prepare_model_for_int8_training` arguments from
also support `prepare_model_for_kbit_training` arguments from
`peft` library.
"""
if kwargs is not None:
Expand Down Expand Up @@ -203,7 +203,7 @@ class and the arguments that are specific to trl models. The kwargs
if peft_config is not None:
# Initialize a new peft adapter with the given config
if is_loaded_in_8bit or is_loaded_in_4bit:
pretrained_model = prepare_model_for_int8_training(
pretrained_model = prepare_model_for_kbit_training(
pretrained_model,
**peft_quantization_kwargs,
)
Expand All @@ -216,7 +216,7 @@ class and the arguments that are specific to trl models. The kwargs
if peft_config is not None and isinstance(pretrained_model, PreTrainedModel):
# Initialize a new peft adapter with the given config
if is_loaded_in_8bit or is_loaded_in_4bit:
pretrained_model = prepare_model_for_int8_training(
pretrained_model = prepare_model_for_kbit_training(
pretrained_model,
**peft_quantization_kwargs,
)
Expand Down Expand Up @@ -339,7 +339,7 @@ def _split_kwargs(cls, kwargs):
check_peft_kwargs = False

if is_peft_available():
from peft import prepare_model_for_int8_training
from peft import prepare_model_for_kbit_training

check_peft_kwargs = True

Expand All @@ -354,7 +354,7 @@ def _split_kwargs(cls, kwargs):
unsupported_kwargs[key] = value

if check_peft_kwargs:
if key in prepare_model_for_int8_training.__code__.co_varnames:
if key in prepare_model_for_kbit_training.__code__.co_varnames:
peft_kwargs[key] = value
if key in unsupported_kwargs:
unsupported_kwargs.pop(key)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


class DPOTrainer(Trainer):
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
model = get_peft_model(model, peft_config)

if model is not None:
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


class RewardTrainer(Trainer):
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)

model = get_peft_model(model, peft_config)

Expand Down
9 changes: 7 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -35,7 +36,7 @@


if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_int8_training
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training


class SFTTrainer(Trainer):
Expand Down Expand Up @@ -147,7 +148,11 @@ def __init__(
)

if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_int8_training(model)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
)

args = dataclasses.replace(args, gradient_checkpointing=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change here and not above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do want to call gradient_checkpointing_enable once, we just don't want to call it twice. We will call it in 'prepare_for_kbit_trainingbut this change makes sure we don't call it inTrainer`

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect makes sense!


model = get_peft_model(model, peft_config)

Expand Down