From b87ec2d5a072c2bbbee059ccdde9ca9e0fbf85f0 Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 12 Sep 2023 04:56:10 -0400 Subject: [PATCH] update to `prepare_model_for_kbit_training` (#728) * update to `prepare_model_for_kbit_training` from deprecated `prepare_model_for_int8_training` and add `use_gradient_checkpointing=args.gradient_checkpointing` to automatically follow the gradient checkpointing choice is also the workaround for #694 * workaround for gradient checkpointing issue calling model.gradient_checkpointing_enable() twice causes issues this workaround calls it in prepare_model_for_kbit_training and then changes the arg to false to make sure it isn't called again in huggingface trainer inner loop also changes stack_llama_2 sft trainer to use correct device map for ddp training so that you can test this issue --- docs/source/sft_trainer.mdx | 4 ++-- docs/source/using_llama_models.mdx | 4 ++-- .../stack_llama_2/scripts/sft_llama2.py | 3 ++- trl/models/modeling_base.py | 12 ++++++------ trl/trainer/dpo_trainer.py | 4 ++-- trl/trainer/reward_trainer.py | 4 ++-- trl/trainer/sft_trainer.py | 9 +++++++-- 7 files changed, 23 insertions(+), 17 deletions(-) diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index c15e03f2d0..8b53cb7dd1 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -336,7 +336,7 @@ trainer.train() Pay attention to the following best practices when training a model with that trainer: - [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training. -- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. +- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. - For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. - If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. @@ -346,4 +346,4 @@ Pay attention to the following best practices when training a model with that tr ## ConstantLengthDataset -[[autodoc]] trainer.ConstantLengthDataset \ No newline at end of file +[[autodoc]] trainer.ConstantLengthDataset diff --git a/docs/source/using_llama_models.mdx b/docs/source/using_llama_models.mdx index ceb8476f78..cf602d2030 100644 --- a/docs/source/using_llama_models.mdx +++ b/docs/source/using_llama_models.mdx @@ -52,7 +52,7 @@ model = AutoModelForCausalLM.from_pretrained( load_in_8bit=True, device_map={"": Accelerator().local_process_index} ) -model = prepare_model_for_int8_training(model) +model = prepare_model_for_kbit_training(model) # add LoRA to model lora_config = LoraConfig( @@ -157,4 +157,4 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): ppo_trainer.log_stats(stats, batch, rewards) ``` -For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama). \ No newline at end of file +For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama). diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py index d0c3d38340..e771966262 100644 --- a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -4,6 +4,7 @@ from typing import Optional import torch +from accelerate import Accelerator from datasets import load_dataset from peft import AutoPeftModelForCausalLM, LoraConfig from tqdm import tqdm @@ -148,7 +149,7 @@ def create_datasets(tokenizer, args): base_model = AutoModelForCausalLM.from_pretrained( script_args.model_name, quantization_config=bnb_config, - device_map={"": 0}, + device_map={"": Accelerator().local_process_index}, trust_remote_code=True, use_auth_token=True, ) diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index 7cd85678d1..748d0d4adb 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -35,7 +35,7 @@ PeftModelForSeq2SeqLM, PromptLearningConfig, get_peft_model, - prepare_model_for_int8_training, + prepare_model_for_kbit_training, ) from peft.peft_model import set_peft_model_state_dict @@ -108,7 +108,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): `from_pretrained` method. We also pre-process the kwargs to extract the arguments that are specific to the `transformers.PreTrainedModel` class and the arguments that are specific to trl models. The kwargs - also support `prepare_model_for_int8_training` arguments from + also support `prepare_model_for_kbit_training` arguments from `peft` library. """ if kwargs is not None: @@ -203,7 +203,7 @@ class and the arguments that are specific to trl models. The kwargs if peft_config is not None: # Initialize a new peft adapter with the given config if is_loaded_in_8bit or is_loaded_in_4bit: - pretrained_model = prepare_model_for_int8_training( + pretrained_model = prepare_model_for_kbit_training( pretrained_model, **peft_quantization_kwargs, ) @@ -216,7 +216,7 @@ class and the arguments that are specific to trl models. The kwargs if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): # Initialize a new peft adapter with the given config if is_loaded_in_8bit or is_loaded_in_4bit: - pretrained_model = prepare_model_for_int8_training( + pretrained_model = prepare_model_for_kbit_training( pretrained_model, **peft_quantization_kwargs, ) @@ -339,7 +339,7 @@ def _split_kwargs(cls, kwargs): check_peft_kwargs = False if is_peft_available(): - from peft import prepare_model_for_int8_training + from peft import prepare_model_for_kbit_training check_peft_kwargs = True @@ -354,7 +354,7 @@ def _split_kwargs(cls, kwargs): unsupported_kwargs[key] = value if check_peft_kwargs: - if key in prepare_model_for_int8_training.__code__.co_varnames: + if key in prepare_model_for_kbit_training.__code__.co_varnames: peft_kwargs[key] = value if key in unsupported_kwargs: unsupported_kwargs.pop(key) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index f57cce161e..ebbf96aa74 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -29,7 +29,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_int8_training + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training class DPOTrainer(Trainer): @@ -116,7 +116,7 @@ def __init__( ) elif is_peft_available() and peft_config is not None: if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - model = prepare_model_for_int8_training(model) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) model = get_peft_model(model, peft_config) if model is not None: diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index b32466beac..0af7d1efc8 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -29,7 +29,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_int8_training + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training class RewardTrainer(Trainer): @@ -113,7 +113,7 @@ def __init__( ) elif is_peft_available() and peft_config is not None: if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): - model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) model = get_peft_model(model, peft_config) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index ddc7e2b8c3..d8c899d5f7 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import warnings from typing import Callable, Dict, List, Optional, Tuple, Union @@ -35,7 +36,7 @@ if is_peft_available(): - from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_int8_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training class SFTTrainer(Trainer): @@ -147,7 +148,11 @@ def __init__( ) if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - model = prepare_model_for_int8_training(model) + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=args.gradient_checkpointing + ) + + args = dataclasses.replace(args, gradient_checkpointing=False) model = get_peft_model(model, peft_config)