Skip to content

Commit

Permalink
fix(config): passing gradient_checkpoint_kwargs (#1412)
Browse files Browse the repository at this point in the history
* fix(config): change default use_reentrant to true

* Update trainer_builder.py

* fix: make sure to pass kwargs to enable checkpoint

* chore: lint
  • Loading branch information
NanoCode012 authored Mar 19, 2024
1 parent 35202dd commit e4b51b1
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ group_by_length: false
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
# use_reentrant: false
# use_reentrant: true

# Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
Expand Down
4 changes: 0 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,10 +970,6 @@ def build(self, total_num_steps):
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,9 @@ def load_and_quantize_parallel(name_param, model, **kwargs):

if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
)
if (
cfg.load_in_8bit or cfg.load_in_4bit
) and not skip_prepare_model_for_kbit_training:
Expand Down

0 comments on commit e4b51b1

Please sign in to comment.