Skip to content

Commit

Permalink
[core / DDP] Fix RM trainer + DDP + quantization + propagate `gra…
Browse files Browse the repository at this point in the history
…dient_checkpointing_kwargs` in SFT & DPO (huggingface#912)

* make use of forward hooks

* correctly delete attributes

* fix RM DPP issues

* revert unneeded changes

* more fixes

* fix diff

* fix

* propagate to SFT

* Update examples/scripts/reward_modeling.py

* propagate the fix on DPO trainer

* add to example scripts

* trigger CI
  • Loading branch information
younesbelkada authored and Andrew Lapp committed May 10, 2024
1 parent 506ad02 commit 6767715
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 5 deletions.
12 changes: 12 additions & 0 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ class ScriptArguments:
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)


def extract_anthropic_prompt(prompt_and_response):
Expand Down Expand Up @@ -149,6 +158,9 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
warmup_steps=150,
report_to=script_args.report_to,
bf16=True,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next transformers release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
)

# 5. initialize the DPO trainer
Expand Down
1 change: 1 addition & 0 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ScriptArguments:
num_train_epochs=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
learning_rate=1.41e-5,
report_to="tensorboard",
remove_unused_columns=False,
Expand Down
12 changes: 12 additions & 0 deletions examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ class ScriptArguments:
)
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})


Expand Down Expand Up @@ -115,6 +124,9 @@ class ScriptArguments:
save_total_limit=script_args.save_total_limit,
push_to_hub=script_args.push_to_hub,
hub_model_id=script_args.hub_model_id,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
)

# Step 4: Define the LoraConfig
Expand Down
14 changes: 13 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import random
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -137,7 +138,18 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
Expand Down
17 changes: 15 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -128,10 +129,22 @@ def __init__(
elif is_peft_available() and peft_config is not None:
if not isinstance(model, PeftModel):
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
warnings.warn(
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
)
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

model = get_peft_model(model, peft_config)

if is_peft_available() and isinstance(model, PeftModel):
Expand Down
14 changes: 12 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import inspect
import warnings
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -159,10 +160,19 @@ def __init__(
)

if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

args = dataclasses.replace(args, gradient_checkpointing=False)

model = get_peft_model(model, peft_config)
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/training_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ class RewardConfig(TrainingArguments):
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
gradient_checkpointing: Optional[bool] = True
"""If True, use gradient checkpointing to save memory at the expense of slower backward pass."""
gradient_checkpointing_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the gradient checkpointing function."""

0 comments on commit 6767715

Please sign in to comment.