Skip to content

Commit

Permalink
improve vram use w gradient checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 22, 2024
1 parent fccb542 commit 6029ced
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ def normalize_config(cfg):
if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset]

if (
cfg.gradient_checkpointing
and cfg.unfrozen_parameters is None
and cfg.gradient_checkpointing_kwargs is None
):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}

log_gpu_memory_usage(LOG, "baseline", cfg.device)


Expand Down

0 comments on commit 6029ced

Please sign in to comment.