Skip to content

Commit

Permalink
enable gptqmodel tests
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 4, 2024
1 parent 1bad53e commit ea29c3c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 49 deletions.
9 changes: 6 additions & 3 deletions optimum/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
is_auto_gptq_available,
is_datasets_available,
is_diffusers_available,
is_gptqmodel_available,
is_sentence_transformers_available,
is_timm_available,
)
Expand Down Expand Up @@ -60,11 +61,13 @@ def require_accelerate(test_case):
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)


def require_auto_gptq(test_case):
def require_gptq(test_case):
"""
Decorator marking a test that requires auto-gptq. These tests are skipped when auto-gptq isn't installed.
Decorator marking a test that requires gptqmodel or auto-gptq. These tests are skipped when gptqmodel and auto-gptq are not installed.
"""
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
return unittest.skipUnless(is_auto_gptq_available() or is_gptqmodel_available(), "test requires auto-gptq")(
test_case
)


def require_torch_gpu(test_case):
Expand Down
117 changes: 71 additions & 46 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,31 @@
from optimum.gptq.eval import evaluate_perplexity
from optimum.gptq.utils import get_block_name_with_pattern, get_preceding_modules, get_seqlen
from optimum.utils import recurse_getattr
from optimum.utils.import_utils import is_accelerate_available, is_auto_gptq_available
from optimum.utils.testing_utils import require_auto_gptq, require_torch_gpu
from optimum.utils.import_utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available
from optimum.utils.testing_utils import require_gptq, require_torch_gpu


if is_auto_gptq_available():
from auto_gptq import AutoGPTQForCausalLM
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear

if is_gptqmodel_available():
from gptqmodel import GPTQModel
from gptqmodel.utils.importer import hf_select_quant_linear

if is_accelerate_available():
from accelerate import init_empty_weights


@slow
@require_auto_gptq
@require_torch_gpu
@require_gptq
class GPTQTest(unittest.TestCase):
model_name = "bigscience/bloom-560m"
model_name = "Felladrin/Llama-68M-Chat-v1"

expected_fp16_perplexity = 30
expected_quantized_perplexity = 34

expected_compression_ratio = 1.66
expected_compression_ratio = 1.2577

bits = 4
group_size = 128
Expand All @@ -56,8 +59,8 @@ class GPTQTest(unittest.TestCase):
exllama_config = None
cache_block_outputs = True
modules_in_block_to_quantize = None
device_map_for_quantization = "cuda"
device_for_inference = 0
device_map_for_quantization = "cpu"
device_for_inference = "cpu"
dataset = [
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
]
Expand Down Expand Up @@ -104,33 +107,36 @@ def test_memory_footprint(self):
self.assertAlmostEqual(self.fp16_mem / self.quantized_mem, self.expected_compression_ratio, places=2)

def test_perplexity(self):
"""
A simple test to check if the model conversion has been done correctly by checking on the
the perplexity of the converted models
"""

self.assertEqual(int(self.fp16_ppl), self.expected_fp16_perplexity)
self.assertEqual(int(self.quantized_ppl), self.expected_quantized_perplexity)
pass

def test_quantized_layers_class(self):
"""
A simple test to check if the model conversion has been done correctly by checking on the
the class type of the linear layers of the converted models
"""

QuantLinear = dynamically_import_QuantLinear(
use_triton=False,
use_qigen=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama or self.exllama_config["version"] != 1,
disable_exllamav2=self.disable_exllama or self.exllama_config["version"] != 2,
)
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)
if is_gptqmodel_available():
QuantLinear = hf_select_quant_linear(
bits=self.bits,
group_size=self.group_size,
desc_act=self.desc_act,
sym=True,
device_map=self.device_map_for_quantization,
pack=False,
)
else:
QuantLinear = hf_select_quant_linear(
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama or self.exllama_config["version"] != 1,
disable_exllamav2=self.disable_exllama or self.exllama_config["version"] != 2,
)
self.assertTrue(self.quantized_model.model.layers[0].mlp.gate_proj.__class__ == QuantLinear)

def check_quantized_layers_type(self, model, value):
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.QUANT_TYPE == value)
self.assertTrue(model.model.layers[0].mlp.gate_proj.QUANT_TYPE == value)

def test_serialization(self):
"""
Expand All @@ -152,31 +158,45 @@ def test_serialization(self):
disable_exllama=self.disable_exllama,
exllama_config=self.exllama_config,
)
if self.disable_exllama:
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")
else:
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")
# Only auto-gptq need to check the quant type
if is_auto_gptq_available() and not is_gptqmodel_available():
if self.disable_exllama:
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")
else:
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")

# transformers and auto-gptq compatibility
# quantized models are more compatible with device map than
# device context managers (they're never used in transformers testing suite)
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
if is_gptqmodel_available():
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
else:
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})


class GPTQTestCPUInit(GPTQTest):
device_map_for_quantization = "cpu"
@require_torch_gpu
class GPTQTestCUDA(GPTQTest):
device_map_for_quantization = "cuda"
device_for_inference = 0
expected_compression_ratio = 1.66

def test_perplexity(self):
pass
"""
A simple test to check if the model conversion has been done correctly by checking on the
the perplexity of the converted models
"""

self.assertEqual(int(self.fp16_ppl), self.expected_fp16_perplexity)
self.assertEqual(int(self.quantized_ppl), self.expected_quantized_perplexity)


class GPTQTestExllama(GPTQTest):
class GPTQTestExllama(GPTQTestCUDA):
disable_exllama = False
exllama_config = {"version": 1}


class GPTQTestActOrder(GPTQTest):
class GPTQTestActOrder(GPTQTestCUDA):
disable_exllama = True
desc_act = True

Expand Down Expand Up @@ -209,7 +229,10 @@ def test_exllama_serialization(self):
# quantized models are more compatible with device map than
# device context managers (they're never used in transformers testing suite)
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
if is_gptqmodel_available():
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
else:
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})

def test_exllama_max_input_length(self):
"""
Expand Down Expand Up @@ -246,7 +269,7 @@ def test_exllama_max_input_length(self):
quantized_model_from_saved.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)


class GPTQTestExllamav2(GPTQTest):
class GPTQTestExllamav2(GPTQTestCUDA):
desc_act = False
disable_exllama = True
exllama_config = {"version": 2}
Expand Down Expand Up @@ -279,25 +302,27 @@ def test_exllama_serialization(self):
# quantized models are more compatible with device map than
# device context managers (they're never used in transformers testing suite)
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
if is_gptqmodel_available():
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
else:
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})


class GPTQTestNoBlockCaching(GPTQTest):
class GPTQTestNoBlockCaching(GPTQTestCUDA):
cache_block_outputs = False


class GPTQTestModuleQuant(GPTQTest):
class GPTQTestModuleQuant(GPTQTestCUDA):
# all layers are quantized apart from self_attention.dense
modules_in_block_to_quantize = [
["self_attention.query_key_value"],
["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"],
["self_attn.q_proj"],
["mlp.gate_proj"],
]
expected_compression_ratio = 1.577

def test_not_converted_layers(self):
# self_attention.dense should not be converted
self.assertTrue(self.quantized_model.transformer.h[0].self_attention.dense.__class__.__name__ == "Linear")
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.__class__.__name__ == "Linear")


class GPTQUtilsTest(unittest.TestCase):
Expand Down

0 comments on commit ea29c3c

Please sign in to comment.