From 03e59077a008715099ea35ebf3fb0a9358f951e8 Mon Sep 17 00:00:00 2001
From: Wing Lian <wing.lian@gmail.com>
Date: Thu, 21 Sep 2023 21:52:12 -0400
Subject: [PATCH] misc fixes to add gptq tests (#621)

* misc fixes to add gptq tests

* set bf16 needed for fa2
---
 src/axolotl/utils/bench.py   |  6 +++-
 src/axolotl/utils/models.py  | 35 ++++++++++++++--------
 src/axolotl/utils/trainer.py |  1 +
 tests/e2e/test_lora_llama.py | 58 ++++++++++++++++++++++++++++++++++--
 tests/e2e/test_phi.py        | 14 +++++----
 5 files changed, 93 insertions(+), 21 deletions(-)

diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py
index 685be526f0..40be0d9ac8 100644
--- a/src/axolotl/utils/bench.py
+++ b/src/axolotl/utils/bench.py
@@ -19,7 +19,11 @@ def deco(func):
         def wrapper(*args, **kwargs):
             device = kwargs.get("device", args[0] if args else None)
 
-            if not torch.cuda.is_available() or device == "auto" or device == "cpu":
+            if (
+                not torch.cuda.is_available()
+                or device == "auto"
+                or torch.device(device).type == "cpu"
+            ):
                 return default_value
 
             return func(*args, **kwargs)
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index a349776d77..543a0e1a13 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -10,6 +10,7 @@
 import transformers
 from optimum.bettertransformer import BetterTransformer
 from peft import PeftConfig, prepare_model_for_kbit_training
+from peft.tuners.lora import QuantLinear
 from transformers import (  # noqa: F401
     AutoConfig,
     AutoModelForCausalLM,
@@ -309,16 +310,26 @@ def load_model(
             ):
                 config.max_sequence_length = cfg.sequence_len
                 LOG.warning(f"increasing context length to {cfg.sequence_len}")
-            model = AutoModelForCausalLM.from_pretrained(
-                base_model,
-                config=config,
-                device_map=cfg.device_map,
-                load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
-                load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
-                torch_dtype=cfg.torch_dtype,
-                trust_remote_code=cfg.trust_remote_code or False,
-                **model_kwargs,
-            )
+            if cfg.gptq:
+                model = AutoModelForCausalLM.from_pretrained(
+                    base_model,
+                    config=config,
+                    device_map=cfg.device_map,
+                    torch_dtype=cfg.torch_dtype,
+                    trust_remote_code=cfg.trust_remote_code or False,
+                    **model_kwargs,
+                )
+            else:
+                model = AutoModelForCausalLM.from_pretrained(
+                    base_model,
+                    config=config,
+                    device_map=cfg.device_map,
+                    load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
+                    load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
+                    torch_dtype=cfg.torch_dtype,
+                    trust_remote_code=cfg.trust_remote_code or False,
+                    **model_kwargs,
+                )
     except Exception as err:  # pylint: disable=broad-exception-caught
         LOG.error(
             "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -466,10 +477,10 @@ def load_llama_adapter(model, cfg):
 
 
 def find_all_linear_names(model):
-    cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
+    cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
     lora_module_names = set()
     for name, module in model.named_modules():
-        if isinstance(module, cls):
+        if isinstance(module, cls) or "Linear" in module.__class__.__name__:
             names = name.split(".")
             lora_module_names.add(names[0] if len(names) == 1 else names[-1])
 
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 944ac5f511..a4ec1553ef 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -676,6 +676,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
             (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
             and cfg.val_set_size > 0
             and cfg.save_steps
+            and cfg.eval_steps
             and cfg.save_steps % cfg.eval_steps == 0
         )
         or False,
diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py
index fbca33633e..7d4b75cceb 100644
--- a/tests/e2e/test_lora_llama.py
+++ b/tests/e2e/test_lora_llama.py
@@ -6,6 +6,7 @@
 import os
 import tempfile
 import unittest
+from pathlib import Path
 
 from axolotl.cli import load_datasets
 from axolotl.common.cli import TrainerCliArgs
@@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):
 
     def test_lora(self):
         # pylint: disable=duplicate-code
+        output_dir = tempfile.mkdtemp()
         cfg = DictDefault(
             {
                 "base_model": "JackFram/llama-68m",
@@ -51,7 +53,7 @@ def test_lora(self):
                 "num_epochs": 2,
                 "micro_batch_size": 8,
                 "gradient_accumulation_steps": 1,
-                "output_dir": tempfile.mkdtemp(),
+                "output_dir": output_dir,
                 "learning_rate": 0.00001,
                 "optimizer": "adamw_torch",
                 "lr_scheduler": "cosine",
@@ -62,9 +64,11 @@ def test_lora(self):
         dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
         train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+        assert (Path(output_dir) / "adapter_model.bin").exists()
 
     def test_lora_packing(self):
         # pylint: disable=duplicate-code
+        output_dir = tempfile.mkdtemp()
         cfg = DictDefault(
             {
                 "base_model": "JackFram/llama-68m",
@@ -94,7 +98,7 @@ def test_lora_packing(self):
                 "num_epochs": 2,
                 "micro_batch_size": 8,
                 "gradient_accumulation_steps": 1,
-                "output_dir": tempfile.mkdtemp(),
+                "output_dir": output_dir,
                 "learning_rate": 0.00001,
                 "optimizer": "adamw_torch",
                 "lr_scheduler": "cosine",
@@ -105,3 +109,53 @@ def test_lora_packing(self):
         dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
         train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+        assert (Path(output_dir) / "adapter_model.bin").exists()
+
+    def test_lora_gptq(self):
+        # pylint: disable=duplicate-code
+        output_dir = tempfile.mkdtemp()
+        cfg = DictDefault(
+            {
+                "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
+                "base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
+                "model_type": "AutoModelForCausalLM",
+                "tokenizer_type": "LlamaTokenizer",
+                "sequence_len": 1024,
+                "sample_packing": True,
+                "flash_attention": True,
+                "load_in_8bit": True,
+                "adapter": "lora",
+                "gptq": True,
+                "gptq_disable_exllama": True,
+                "lora_r": 32,
+                "lora_alpha": 64,
+                "lora_dropout": 0.05,
+                "lora_target_linear": True,
+                "val_set_size": 0.1,
+                "special_tokens": {
+                    "unk_token": "<unk>",
+                    "bos_token": "<s>",
+                    "eos_token": "</s>",
+                },
+                "datasets": [
+                    {
+                        "path": "mhenrichsen/alpaca_2k_test",
+                        "type": "alpaca",
+                    },
+                ],
+                "num_epochs": 2,
+                "save_steps": 0.5,
+                "micro_batch_size": 8,
+                "gradient_accumulation_steps": 1,
+                "output_dir": output_dir,
+                "learning_rate": 0.00001,
+                "optimizer": "adamw_torch",
+                "lr_scheduler": "cosine",
+            }
+        )
+        normalize_config(cfg)
+        cli_args = TrainerCliArgs()
+        dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+        train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+        assert (Path(output_dir) / "adapter_model.bin").exists()
diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py
index fb8aa5d875..a84ef0778c 100644
--- a/tests/e2e/test_phi.py
+++ b/tests/e2e/test_phi.py
@@ -31,9 +31,9 @@ def test_ft(self):
                 "trust_remote_code": True,
                 "model_type": "MixFormerSequentialForCausalLM",
                 "tokenizer_type": "AutoTokenizer",
-                "sequence_len": 2048,
+                "sequence_len": 512,
                 "sample_packing": False,
-                "load_in_8bit": True,
+                "load_in_8bit": False,
                 "adapter": None,
                 "val_set_size": 0.1,
                 "special_tokens": {
@@ -55,8 +55,9 @@ def test_ft(self):
                 "gradient_accumulation_steps": 1,
                 "output_dir": tempfile.mkdtemp(),
                 "learning_rate": 0.00001,
-                "optimizer": "adamw_torch",
+                "optimizer": "adamw_bnb_8bit",
                 "lr_scheduler": "cosine",
+                "bf16": True,
             }
         )
         normalize_config(cfg)
@@ -74,9 +75,9 @@ def test_ft_packed(self):
                 "trust_remote_code": True,
                 "model_type": "MixFormerSequentialForCausalLM",
                 "tokenizer_type": "AutoTokenizer",
-                "sequence_len": 2048,
+                "sequence_len": 512,
                 "sample_packing": True,
-                "load_in_8bit": True,
+                "load_in_8bit": False,
                 "adapter": None,
                 "val_set_size": 0.1,
                 "special_tokens": {
@@ -98,8 +99,9 @@ def test_ft_packed(self):
                 "gradient_accumulation_steps": 1,
                 "output_dir": tempfile.mkdtemp(),
                 "learning_rate": 0.00001,
-                "optimizer": "adamw_torch",
+                "optimizer": "adamw_bnb_8bit",
                 "lr_scheduler": "cosine",
+                "bf16": True,
             }
         )
         normalize_config(cfg)