From c7a9e76f8c00c63258321c62e03a82b5fe5de07d Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 16 Mar 2024 14:11:38 +0900 Subject: [PATCH 1/4] fix(config): change default use_reentrant to true --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 826c16045c..4d17f2efda 100644 --- a/README.md +++ b/README.md @@ -843,7 +843,7 @@ group_by_length: false gradient_checkpointing: false # additional kwargs to pass to the trainer for gradient checkpointing # gradient_checkpointing_kwargs: -# use_reentrant: false +# use_reentrant: true # Stop training after this many evaluation losses have increased in a row # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback From 34211ff6a5be1b1e0f58774f403eefe4ef8983ef Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 16 Mar 2024 14:18:34 +0900 Subject: [PATCH 2/4] Update trainer_builder.py --- src/axolotl/core/trainer_builder.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d11f0c6532..7c5aa00af6 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -837,10 +837,6 @@ def build(self, total_num_steps): training_arguments_kwargs[ "gradient_checkpointing_kwargs" ] = self.cfg.gradient_checkpointing_kwargs - else: - training_arguments_kwargs["gradient_checkpointing_kwargs"] = { - "use_reentrant": False - } if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: From a914cb37dc455a3fd0368e3a0898867f25b3a6c9 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 16 Mar 2024 14:20:22 +0900 Subject: [PATCH 3/4] fix: make sure to pass kwargs to enable checkpoint --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fce7b20a7a..29461e0e4d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -888,7 +888,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs) if ( cfg.load_in_8bit or cfg.load_in_4bit ) and not skip_prepare_model_for_kbit_training: From c9fe32761ec8d02eb80743a850000efca07fa4f3 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 16 Mar 2024 14:26:28 +0900 Subject: [PATCH 4/4] chore: lint --- src/axolotl/utils/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 29461e0e4d..40090a07c0 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -888,7 +888,9 @@ def load_and_quantize_parallel(name_param, model, **kwargs): if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs) + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs + ) if ( cfg.load_in_8bit or cfg.load_in_4bit ) and not skip_prepare_model_for_kbit_training: