diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 36e598de43..d68ae46b14 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -5,14 +5,16 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import Optional, Tuple, Union import torch import transformers.modelcard from accelerate.logging import get_logger from datasets import Dataset from optimum.bettertransformer import BetterTransformer +from peft import PeftModel from pkg_resources import get_distribution # type: ignore +from transformers import PreTrainedModel, PreTrainedTokenizer from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs @@ -43,7 +45,7 @@ class TrainDatasetMeta: def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta -): +) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: # load the tokenizer first LOG.debug( f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 41c43a0290..66f6e16acf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -590,7 +590,7 @@ def load_model( # make sure these are fp32 per Ramesh et al. (2021) embedding_modules = get_linear_embedding_layers(cfg.model_config_type) for name, module in model.named_modules(): - if "norm" in name: + if any(m in name for m in ["norm", "gate"]): module.to(torch.float32) if model_config.model_type == "btlm": # don't upcast lm_head for btlm diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 896cc74d0f..ee6f06d875 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -7,6 +7,7 @@ import unittest from pathlib import Path +import torch from transformers.utils import is_torch_bf16_gpu_available from axolotl.cli import load_datasets @@ -27,7 +28,7 @@ class TestMixtral(unittest.TestCase): """ @with_temp_dir - def test_qlora(self, temp_dir): + def test_qlora_w_fa2(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { @@ -37,10 +38,18 @@ def test_qlora(self, temp_dir): "sequence_len": 1024, "load_in_4bit": True, "adapter": "qlora", - "lora_r": 16, - "lora_alpha": 32, + "lora_r": 4, + "lora_alpha": 8, "lora_dropout": 0.1, - "lora_target_linear": True, + "lora_target_modules": [ + "o_proj", + "w3", + "k_proj", + "v_proj", + "w1", + "q_proj", + "w2", + ], "val_set_size": 0.1, "special_tokens": {}, "datasets": [ @@ -65,7 +74,179 @@ def test_qlora(self, temp_dir): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype + == torch.uint8 + ) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_qlora_wo_fa2(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": False, + "sequence_len": 1024, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 4, + "lora_alpha": 8, + "lora_dropout": 0.1, + "lora_target_modules": [ + "o_proj", + "w3", + "k_proj", + "v_proj", + "w1", + "q_proj", + "w2", + ], + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype + == torch.uint8 + ) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_16bit_lora_w_fa2(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 1024, + "adapter": "lora", + "lora_r": 4, + "lora_alpha": 8, + "lora_dropout": 0.1, + "lora_target_modules": [ + "o_proj", + "w3", + "k_proj", + "v_proj", + "w1", + "q_proj", + "w2", + ], + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype + == torch.float32 + ) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_16bit_lora_wo_fa2(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": False, + "sequence_len": 1024, + "adapter": "lora", + "lora_r": 4, + "lora_alpha": 8, + "lora_dropout": 0.1, + "lora_target_modules": [ + "o_proj", + "w3", + "k_proj", + "v_proj", + "w1", + "q_proj", + "w2", + ], + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype + == torch.float32 + ) assert (Path(temp_dir) / "adapter_model.bin").exists() @with_temp_dir