From 086561326f1a3bd1f2a718726ae274ab46ec55f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20H=C3=A4llqvist?= <35776028+simhallq@users.noreply.github.com> Date: Sun, 14 Jan 2024 18:06:56 +0100 Subject: [PATCH] Enable or disable bf16 support based on availability (#1116) --- src/axolotl/utils/config.py | 8 ++++++++ tests/test_normalize_config.py | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index f884903837..b7372c6fb9 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -61,6 +61,14 @@ def normalize_config(cfg): cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} cfg.batch_size = cfg.batch_size * cfg.world_size + if cfg.bf16 == "auto": + if is_torch_bf16_gpu_available(): + LOG.debug("bf16 support detected, enabling for this configuration.") + cfg.bf16 = True + else: + LOG.debug("bf16 support not detected, disabling for this configuration.") + cfg.bf16 = False + if cfg.device == "mps": cfg.load_in_8bit = False cfg.tf32 = False diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index 004d0068ef..da039f6cd3 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -2,6 +2,7 @@ Test classes for checking functionality of the cfg normalization """ import unittest +from unittest.mock import patch from axolotl.utils.config import normalize_cfg_datasets, normalize_config from axolotl.utils.dict import DictDefault @@ -67,3 +68,23 @@ def test_chat_template_chatml(self): assert cfg.datasets[0].conversation == "vicuna_v1.1" assert cfg.datasets[1].conversation == "chatml" + + @patch("axolotl.utils.config.is_torch_bf16_gpu_available") + def test_bf16_auto_setter_available(self, mock_bf16_avail): + cfg = self._get_base_cfg() + cfg.bf16 = "auto" + mock_bf16_avail.return_value = True + + normalize_config(cfg) + + self.assertTrue(cfg.bf16) + + @patch("axolotl.utils.config.is_torch_bf16_gpu_available") + def test_bf16_auto_setter_not_available(self, mock_bf16_avail): + cfg = self._get_base_cfg() + cfg.bf16 = "auto" + mock_bf16_avail.return_value = False + + normalize_config(cfg) + + self.assertFalse(cfg.bf16)