From a914cb37dc455a3fd0368e3a0898867f25b3a6c9 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 16 Mar 2024 14:20:22 +0900 Subject: [PATCH] fix: make sure to pass kwargs to enable checkpoint --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fce7b20a7a..29461e0e4d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -888,7 +888,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs) if ( cfg.load_in_8bit or cfg.load_in_4bit ) and not skip_prepare_model_for_kbit_training: