From da97285e63811c17ce1e92b2c32c26c9ed8e2d5d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Jan 2024 14:58:21 -0500 Subject: [PATCH] keep gate in fp32 for 16 bit loras (#1105) * keep gate in fp32 for loras * add e2e check for lora w/o flash attention for mixtral to check gate * add checks for gate in fp32 for mixtral, add typehints to train outputs * mixtral doesn't support basic lora :facepalm: add lora tests @ 16bit and fix gate layer check fix the parameter name, was using the old disco name don't lora over the gate so we can check that is in fp32 fix dtype check * ensure we're using fp16/bf16 for 16bit and qlora is always going to be in uint8 --- src/axolotl/train.py | 6 +- src/axolotl/utils/models.py | 2 +- tests/e2e/test_mixtral.py | 191 +++++++++++++++++++++++++++++++++++- 3 files changed, 191 insertions(+), 8 deletions(-) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 36e598de43..d68ae46b14 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -5,14 +5,16 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import Optional, Tuple, Union import torch import transformers.modelcard from accelerate.logging import get_logger from datasets import Dataset from optimum.bettertransformer import BetterTransformer +from peft import PeftModel from pkg_resources import get_distribution # type: ignore +from transformers import PreTrainedModel, PreTrainedTokenizer from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs @@ -43,7 +45,7 @@ class TrainDatasetMeta: def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta -): +) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: # load the tokenizer first LOG.debug( f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 41c43a0290..66f6e16acf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -590,7 +590,7 @@ def load_model( # make sure these are fp32 per Ramesh et al. (2021) embedding_modules = get_linear_embedding_layers(cfg.model_config_type) for name, module in model.named_modules(): - if "norm" in name: + if any(m in name for m in ["norm", "gate"]): module.to(torch.float32) if model_config.model_type == "btlm": # don't upcast lm_head for btlm diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 896cc74d0f..ee6f06d875 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -7,6 +7,7 @@ import unittest from pathlib import Path +import torch from transformers.utils import is_torch_bf16_gpu_available from axolotl.cli import load_datasets @@ -27,7 +28,7 @@ class TestMixtral(unittest.TestCase): """ @with_temp_dir - def test_qlora(self, temp_dir): + def test_qlora_w_fa2(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( { @@ -37,10 +38,18 @@ def test_qlora(self, temp_dir): "sequence_len": 1024, "load_in_4bit": True, "adapter": "qlora", - "lora_r": 16, - "lora_alpha": 32, + "lora_r": 4, + "lora_alpha": 8, "lora_dropout": 0.1, - "lora_target_linear": True, + "lora_target_modules": [ + "o_proj", + "w3", + "k_proj", + "v_proj", + "w1", + "q_proj", + "w2", + ], "val_set_size": 0.1, "special_tokens": {}, "datasets": [ @@ -65,7 +74,179 @@ def test_qlora(self, temp_dir): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype + == torch.uint8 + ) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_qlora_wo_fa2(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": False, + "sequence_len": 1024, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 4, + "lora_alpha": 8, + "lora_dropout": 0.1, + "lora_target_modules": [ + "o_proj", + "w3", + "k_proj", + "v_proj", + "w1", + "q_proj", + "w2", + ], + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype + == torch.uint8 + ) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_16bit_lora_w_fa2(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": True, + "sequence_len": 1024, + "adapter": "lora", + "lora_r": 4, + "lora_alpha": 8, + "lora_dropout": 0.1, + "lora_target_modules": [ + "o_proj", + "w3", + "k_proj", + "v_proj", + "w1", + "q_proj", + "w2", + ], + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype + == torch.float32 + ) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_16bit_lora_wo_fa2(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "hf-internal-testing/Mixtral-tiny", + "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1", + "flash_attention": False, + "sequence_len": 1024, + "adapter": "lora", + "lora_r": 4, + "lora_alpha": 8, + "lora_dropout": 0.1, + "lora_target_modules": [ + "o_proj", + "w3", + "k_proj", + "v_proj", + "w1", + "q_proj", + "w2", + ], + "val_set_size": 0.1, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert ( + model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype + == torch.float32 + ) assert (Path(temp_dir) / "adapter_model.bin").exists() @with_temp_dir