diff --git a/README.md b/README.md index 98b8a78239..4dd80339a4 100644 --- a/README.md +++ b/README.md @@ -741,6 +741,9 @@ group_by_length: false # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing gradient_checkpointing: false +# additional kwargs to pass to the trainer for gradient checkpointing +# gradient_checkpointing_kwargs: +# use_reentrant: false # Stop training after this many evaluation losses have increased in a row # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fed26de464..4ca2877d19 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -566,6 +566,14 @@ def build(self, total_num_steps): training_arguments_kwargs[ "gradient_checkpointing" ] = self.cfg.gradient_checkpointing + if self.cfg.gradient_checkpointing_kwargs: + training_arguments_kwargs[ + "gradient_checkpointing_kwargs" + ] = self.cfg.gradient_checkpointing_kwargs + else: + training_arguments_kwargs["gradient_checkpointing_kwargs"] = { + "use_reentrant": False + } if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: