Skip to content

Commit

Permalink
support for fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 10, 2023
1 parent 105d0b3 commit ac49990
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,10 @@ def build(self, total_num_steps):
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
if self.cfg.fp8:
training_arguments_kwargs["fp16"] = False
training_arguments_kwargs["bf16"] = False

training_arguments_kwargs["tf32"] = self.cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def normalize_config(cfg):
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False

if cfg.bf16 or cfg.bfloat16:
if cfg.fp8:
cfg.torch_dtype = torch.bfloat16
elif cfg.bf16 or cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
cfg.torch_dtype = torch.float16
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
if cfg.fp8:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"

trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset
Expand Down

0 comments on commit ac49990

Please sign in to comment.