Skip to content

Commit

Permalink
fix for accelerate env var for auto bf16, add new base image and expa…
Browse files Browse the repository at this point in the history
…nd torch_cuda_arch_list support (axolotl-ai-cloud#1413)
  • Loading branch information
winglian authored Mar 26, 2024
1 parent e07347b commit da265dd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@ jobs:
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from accelerate.logging import get_logger
from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
Expand Down Expand Up @@ -324,6 +325,11 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed

if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"


def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
Expand Down

0 comments on commit da265dd

Please sign in to comment.