Skip to content

Commit

Permalink
use recommended setting for use_reentrant w gradient checkpointing (#…
Browse files Browse the repository at this point in the history
…1021)

* use recommended setting for use_reentrant w gradient checkpointing

* add doc for gradient_checkpointing_kwargs
  • Loading branch information
winglian authored Jan 2, 2024
1 parent 3678a6c commit 4d2e842
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,9 @@ group_by_length: false

# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
# use_reentrant: false

# Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,14 @@ def build(self, total_num_steps):
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs:
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
Expand Down

0 comments on commit 4d2e842

Please sign in to comment.