From da265dd79633dcacf9c010590c75b47d3cf2ac0d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 Mar 2024 13:46:19 -0700 Subject: [PATCH] fix for accelerate env var for auto bf16, add new base image and expand torch_cuda_arch_list support (#1413) --- .github/workflows/base.yml | 11 ++++++++--- src/axolotl/utils/trainer.py | 6 ++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 381cf21ac1..ea1da66840 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -16,17 +16,22 @@ jobs: cuda_version: 11.8.0 python_version: "3.10" pytorch: 2.1.2 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "121" cuda_version: 12.1.0 python_version: "3.10" pytorch: 2.1.2 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "121" cuda_version: 12.1.0 python_version: "3.11" pytorch: 2.1.2 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "121" + cuda_version: 12.1.0 + python_version: "3.11" + pytorch: 2.2.1 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" steps: - name: Checkout uses: actions/checkout@v3 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d68681afe3..da9f071c08 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -11,6 +11,7 @@ from accelerate.logging import get_logger from datasets import set_caching_enabled from torch.utils.data import DataLoader, RandomSampler +from transformers.utils import is_torch_bf16_gpu_available from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first @@ -324,6 +325,11 @@ def prepare_optim_env(cfg): os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed + if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: + os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" + elif cfg.fp16: + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): if cfg.rl in ["dpo", "ipo", "kto_pair"]: