Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core / DDP] Fix RM trainer + DDP + quantization + propagate gradient_checkpointing_kwargs in SFT & DPO #912

Merged
merged 14 commits into from
Oct 31, 2023
12 changes: 12 additions & 0 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ class ScriptArguments:
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)


def extract_anthropic_prompt(prompt_and_response):
Expand Down Expand Up @@ -149,6 +158,9 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
warmup_steps=150,
report_to=script_args.report_to,
bf16=True,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next transformers release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
)

# 5. initialize the DPO trainer
Expand Down
1 change: 1 addition & 0 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ScriptArguments:
num_train_epochs=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
learning_rate=1.41e-5,
report_to="tensorboard",
remove_unused_columns=False,
Expand Down
12 changes: 12 additions & 0 deletions examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ class ScriptArguments:
)
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})


Expand Down Expand Up @@ -114,6 +123,9 @@ class ScriptArguments:
save_total_limit=script_args.save_total_limit,
push_to_hub=script_args.push_to_hub,
hub_model_id=script_args.hub_model_id,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
)

# Step 4: Define the LoraConfig
Expand Down
14 changes: 13 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import random
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -137,7 +138,18 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
Expand Down
17 changes: 15 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -128,10 +129,22 @@ def __init__(
elif is_peft_available() and peft_config is not None:
if not isinstance(model, PeftModel):
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
warnings.warn(
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
)
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

model = get_peft_model(model, peft_config)

if is_peft_available() and isinstance(model, PeftModel):
Expand Down
14 changes: 12 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import inspect
import warnings
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -159,10 +160,19 @@ def __init__(
)

if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

args = dataclasses.replace(args, gradient_checkpointing=False)

model = get_peft_model(model, peft_config)
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/training_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ class RewardConfig(TrainingArguments):
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
gradient_checkpointing: Optional[bool] = True
"""If True, use gradient checkpointing to save memory at the expense of slower backward pass."""
gradient_checkpointing_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the gradient checkpointing function."""
Loading