Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core / DDP] Fix RM trainer + DDP + quantization + propagate gradient_checkpointing_kwargs in SFT & DPO #912

Merged
merged 14 commits into from
Oct 31, 2023
4 changes: 2 additions & 2 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ScriptArguments:
num_train_epochs=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": True},
learning_rate=1.41e-5,
report_to="tensorboard",
remove_unused_columns=False,
Expand All @@ -67,7 +68,6 @@ class ScriptArguments:
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM",
task_type="SEQ_CLS",
modules_to_save=["scores"],
),
Expand Down Expand Up @@ -103,7 +103,7 @@ class ScriptArguments:

# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")
train_dataset = load_dataset(args.dataset_name, split="train[:50]")
Copy link
Contributor

@vwxyzjn vwxyzjn Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be not hardcoded? Maybe like split=args.split

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved


# Tokenize chosen/rejected pairs of inputs
Expand Down
52 changes: 22 additions & 30 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
import tempfile
import unittest
Expand Down Expand Up @@ -581,25 +580,23 @@ def test_sft_trainer_with_model_neftune(self):
packing=True,
)

# inspect input embeddings forward code source
input_embeddings_forward_code_source = inspect.getsource(trainer.model.get_input_embeddings().forward)
device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()

self.assertTrue(
"mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)" in input_embeddings_forward_code_source
)
torch.random.manual_seed(42)
embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))

# training should work fine
trainer.train()
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))

# inspect input embeddings forward code source - this time it should not contain any code from NEFTune.
input_embeddings_forward_code_source = inspect.getsource(trainer.model.get_input_embeddings().forward)
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0)

self.assertFalse(
"mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)" in input_embeddings_forward_code_source
)
trainer.train()

# Make sure forward pass works fine
_ = trainer.model(torch.LongTensor([[1, 0, 1]]))
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)

@require_peft
def test_peft_sft_trainer(self):
Expand Down Expand Up @@ -675,25 +672,19 @@ def test_peft_sft_trainer_neftune(self):

self.assertTrue(isinstance(trainer.model, PeftModel))

# inspect input embeddings forward code source
input_embeddings_forward_code_source = inspect.getsource(
trainer.model.base_model.get_input_embeddings().forward
)
device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()

self.assertTrue(
"mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)" in input_embeddings_forward_code_source
)
torch.random.manual_seed(42)
embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))

trainer.train()
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))

# inspect input embeddings forward code source - this time it should not contain any code from NEFTune.
input_embeddings_forward_code_source = inspect.getsource(
trainer.model.base_model.get_input_embeddings().forward
)
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0)

self.assertFalse(
"mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)" in input_embeddings_forward_code_source
)
trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
Expand All @@ -703,4 +694,5 @@ def test_peft_sft_trainer_neftune(self):
self.assertTrue("pytorch_model.bin" not in os.listdir(tmp_dir + "/checkpoint-2"))

# Make sure forward pass works fine to check if embeddings forward is not broken.
_ = trainer.model(torch.LongTensor([[1, 0, 1]]))
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
17 changes: 15 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -128,10 +129,22 @@ def __init__(
elif is_peft_available() and peft_config is not None:
if not isinstance(model, PeftModel):
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
warnings.warn(
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
)
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

model = get_peft_model(model, peft_config)

if is_peft_available() and callbacks is None and isinstance(model, PeftModel):
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[Dict] = None,
peft_config: Optional[PeftConfig] = None,
dataset_text_field: Optional[str] = None,
packing: Optional[bool] = False,
formatting_func: Optional[Callable] = None,
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/training_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ class RewardConfig(TrainingArguments):
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
gradient_checkpointing: Optional[bool] = True
"""If True, use gradient checkpointing to save memory at the expense of slower backward pass."""
gradient_checkpointing_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the gradient checkpointing function."""