From 6029ced54c9dbc346d47821fad035683f8593d2c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 17:00:08 -0500 Subject: [PATCH] improve vram use w gradient checkpointing --- src/axolotl/utils/config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index ca7d037ddc..c88d952c8c 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -160,6 +160,13 @@ def normalize_config(cfg): if isinstance(cfg.pretraining_dataset, dict): cfg.pretraining_dataset = [cfg.pretraining_dataset] + if ( + cfg.gradient_checkpointing + and cfg.unfrozen_parameters is None + and cfg.gradient_checkpointing_kwargs is None + ): + cfg.gradient_checkpointing_kwargs = {"use_reentrant": True} + log_gpu_memory_usage(LOG, "baseline", cfg.device)