Skip to content

Commit

Permalink
make sure to set gradient checkpointing too
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 25, 2024
1 parent 239000f commit 44595f6
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,14 +1015,18 @@ def build_training_arguments(self, total_num_steps):
training_args_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
if self.cfg.gradient_checkpointing_kwargs is not None:
if self.cfg.gradient_checkpointing:
training_args_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}

training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
Expand Down

0 comments on commit 44595f6

Please sign in to comment.