From 5979473a69281d7ec79a2e40817d13377487873f Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:50:52 +0800 Subject: [PATCH] Fix optimum compat (#3) * add meta info * cleanup * cleanup * The value of quantizer should be an array * Update quantizer.py * If is_auto_gptq_available() also writes "auto_gptq:version" to "quantizer" * If is_auto_gptq_available() also writes "auto_gptq:version" to "quantizer" * Update quantizer.py * cleanup * comment on meta * hf_select_quant_linear pass checkpoint_format * add todo fix * move convert code to quantizer.save() * Update quantizer.py * Optimize hf_convert_gptq_v2_to_v1_format() * Optimize hf_convert_gptq_v1_to_v2_format() * fix GPTQTestCUDA * hf_select_quant_linear() always set pack=True * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * GPTQQuantizer add backend * lower checkpoint_format and backend * cleanup * move backend to bottom * no need to check gptqmodel version for ipex support * Update import_utils.py * Update quantizer.py * fix UnboundLocalError: cannot access local variable 'version' where it is not associated with a value * make version var short * Update import_utils.py * fix unittest * use assertLessEqual --------- Co-authored-by: Qubitium-ModelCloud Co-authored-by: LRL --- optimum/gptq/quantizer.py | 71 +++++++++++++++++++++++---------- optimum/utils/import_utils.py | 16 ++++++-- tests/gptq/test_quantization.py | 43 ++++++++++++-------- 3 files changed, 89 insertions(+), 41 deletions(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 5d020c5ef12..d92d1b850ac 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -32,6 +32,7 @@ from .constants import GPTQ_CONFIG from .data import get_dataset, prepare_dataset from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen +from ..version import __version__ as optimum_version if is_accelerate_available(): @@ -46,6 +47,7 @@ from auto_gptq.modeling._utils import autogptq_post_init as gptq_post_init from auto_gptq.quantization import GPTQ from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear + from auto_gptq import __version__ as autogptq_version if is_gptqmodel_available(): from gptqmodel import exllama_set_max_input_length @@ -53,6 +55,7 @@ from gptqmodel.utils.importer import hf_select_quant_linear from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init + from gptqmodel.version import __version__ as gptqmodel_version logger = getLogger(__name__) @@ -80,15 +83,17 @@ def __init__( desc_act: bool = False, sym: bool = True, true_sequential: bool = True, - use_cuda_fp16: bool = False, checkpoint_format: str = "gptq", + meta: Optional[Dict[str, any]] = None, + backend: Optional[str] = None, + use_cuda_fp16: bool = False, model_seqlen: Optional[int] = None, block_name_to_quantize: Optional[str] = None, module_name_preceding_first_block: Optional[List[str]] = None, batch_size: int = 1, pad_token_id: Optional[int] = None, disable_exllama: bool = False, - exllama_config: Dict[str, Any] = None, + exllama_config: Optional[Dict[str, Any]] = None, max_input_length: Optional[int] = None, cache_block_outputs: Optional[bool] = True, modules_in_block_to_quantize: Optional[List[List[str]]] = None, @@ -117,6 +122,14 @@ def __init__( Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes quantization using inputs that have passed through the previously quantized layers. + checkpoint_format (`str`, *optional*, defaults to `gptq`): + GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only. + meta (`Dict[str, any]`, *optional*): + Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta. + i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"] + backend (`str`, *optional*): + Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only + valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py use_cuda_fp16 (`bool`, defaults to `False`): Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. model_seqlen (`Optional[int]`, defaults to `None`): @@ -152,6 +165,9 @@ def __init__( self.desc_act = desc_act self.sym = sym self.true_sequential = true_sequential + self.checkpoint_format = checkpoint_format.lower() + self.meta = meta + self.backend = backend.lower() if backend is not None else None self.use_cuda_fp16 = use_cuda_fp16 self.model_seqlen = model_seqlen self.block_name_to_quantize = block_name_to_quantize @@ -164,7 +180,6 @@ def __init__( self.quant_method = QuantizationMethod.GPTQ self.cache_block_outputs = cache_block_outputs self.modules_in_block_to_quantize = modules_in_block_to_quantize - self.checkpoint_format = checkpoint_format self.serialization_keys = [ "bits", @@ -177,6 +192,7 @@ def __init__( "quant_method", "modules_in_block_to_quantize", "checkpoint_format", + "meta", ] if self.bits not in [2, 3, 4, 8]: @@ -198,15 +214,17 @@ def __init__( ) self.exllama_version = self.exllama_config["version"] - def select_quant_linear(self, pack: bool, device_map: Union[str, dict]): + def select_quant_linear(self, device_map: Union[str, dict]): if is_gptqmodel_available(): self.quant_linear = hf_select_quant_linear( bits=self.bits, group_size=self.group_size, desc_act=self.desc_act, sym=self.sym, + checkpoint_format=self.checkpoint_format, + meta=self.meta, device_map=device_map, - pack=pack, + backend=self.backend, ) else: self.quant_linear = hf_select_quant_linear( @@ -225,6 +243,20 @@ def to_dict(self): gptq_dict = {} for key in self.serialization_keys: gptq_dict[key] = getattr(self, key) + + if gptq_dict.get("meta") is None: + gptq_dict["meta"] = {} + + meta = gptq_dict["meta"] + # store both optimum:version and gptq_lib:version into quantize_config.meta.quantizer + if meta.get("quantizer") is None: + meta["quantizer"] = [f"optimum:{optimum_version}"] + + if is_gptqmodel_available(): + meta["quantizer"].append(f"gptqmodel:{gptqmodel_version}") + elif is_auto_gptq_available(): + meta["quantizer"].append(f"auto_gptq:{autogptq_version}") + return gptq_dict @classmethod @@ -263,7 +295,7 @@ def convert_model(self, model: nn.Module, **kwargs): ) del layers_to_be_replaced[name] - self.select_quant_linear(pack=False, device_map=kwargs.get("device_map", None)) + self.select_quant_linear(device_map=kwargs.get("device_map", None)) self._replace_by_quant_layers(model, layers_to_be_replaced) @@ -379,10 +411,7 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): gptq_supports_cpu = ( is_auto_gptq_available() and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") - ) or ( - is_gptqmodel_available() - and version.parse(importlib.metadata.version("gptqmodel")) > version.parse("1.3.1") - ) + ) or is_gptqmodel_available() if not gptq_supports_cpu and not torch.cuda.is_available(): raise RuntimeError( @@ -663,18 +692,12 @@ def tmp(_, input, output): # Step 5: Any post-initialization that require device information, for example buffers initialization on device. model = self.post_init_model(model) - # convert gptqmodel internal gptq_v2 format to v1 for saving/compat - # sym=False is valid for gptq_v2 format only. for sym=True, need to convert to v1 before save. - if self.sym and self.checkpoint_format == "gptq_v2": - model = hf_convert_gptq_v2_to_v1_format(model, self.bits, self.quant_linear) - self.checkpoint_format = "gptq" - torch.cuda.empty_cache() if hasattr(torch, "xpu"): torch.xpu.empty_cache() return model - def post_init_model(self, model, **kwargs): + def post_init_model(self, model): """ Post-initialization that require device information, for example buffers initialization on device. @@ -695,8 +718,8 @@ def post_init_model(self, model, **kwargs): class StoreAttr(object): pass - if is_gptqmodel_available() and self.checkpoint_format == "gptq": - model = hf_convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear) + if is_gptqmodel_available(): + model, _ = hf_convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear, self.checkpoint_format, self.meta) model.quantize_config = StoreAttr() model.quantize_config.desc_act = self.desc_act @@ -727,7 +750,7 @@ def pack_model( layers = get_layers(model) layers = {n: layers[n] for n in quantizers} - self.select_quant_linear(pack=True, device_map=model.hf_device_map) + self.select_quant_linear(device_map=model.hf_device_map) self._replace_by_quant_layers(model, quantizers) qlayers = get_layers(model, [self.quant_linear]) @@ -765,6 +788,12 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ + + # convert gptqmodel internal gptq_v2 format to v1 for max compatibility + model, converted = hf_convert_gptq_v2_to_v1_format(model, self.sym, self.bits, self.quant_linear, self.checkpoint_format, self.meta) + if converted: + self.checkpoint_format = "gptq" + os.makedirs(save_dir, exist_ok=True) model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f: @@ -871,7 +900,7 @@ def load_quantized_model( quantizer.exllama_version = quantizer.exllama_config["version"] quantizer.max_input_length = max_input_length - model = quantizer.convert_model(model) + model = quantizer.convert_model(model, device_map=device_map) if no_split_module_classes is None: no_split_module_classes = quantizer.get_no_split_module_classes(model) diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index bad6910653d..8c2de4cae59 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -52,6 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0") DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0") AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0 +GPTQMODEL_MINIMUM_VERSION = version.parse("1.3.99") # Allows 1.4.0.dev0 # This is the minimal required version to support some ONNX Runtime features @@ -139,17 +140,24 @@ def is_datasets_available(): def is_auto_gptq_available(): if _auto_gptq_available: - version_autogptq = version.parse(importlib_metadata.version("auto_gptq")) - if AUTOGPTQ_MINIMUM_VERSION < version_autogptq: + v = version.parse(importlib_metadata.version("auto_gptq")) + if v >= AUTOGPTQ_MINIMUM_VERSION: return True else: raise ImportError( - f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported" + f"Found an incompatible version of auto-gptq. Found version {v}, but only version >= {AUTOGPTQ_MINIMUM_VERSION} are supported" ) def is_gptqmodel_available(): - return _gptqmodel_available + if _gptqmodel_available: + v = version.parse(importlib_metadata.version("gptqmodel")) + if v >= GPTQMODEL_MINIMUM_VERSION: + return True + else: + raise ImportError( + f"Found an incompatible version of gptqmodel. Found version {v}, but only version >= {GPTQMODEL_MINIMUM_VERSION} are supported" + ) @contextmanager diff --git a/tests/gptq/test_quantization.py b/tests/gptq/test_quantization.py index bbb0db4aec0..b16e77fcc52 100644 --- a/tests/gptq/test_quantization.py +++ b/tests/gptq/test_quantization.py @@ -55,6 +55,7 @@ class GPTQTest(unittest.TestCase): bits = 4 group_size = 128 desc_act = False + sym = True disable_exllama = True exllama_config = None cache_block_outputs = True @@ -73,6 +74,7 @@ def setUpClass(cls): """ cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.config = AutoConfig.from_pretrained(cls.model_name) cls.model_fp16 = AutoModelForCausalLM.from_pretrained( cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map_for_quantization @@ -87,6 +89,7 @@ def setUpClass(cls): dataset=cls.dataset, group_size=cls.group_size, desc_act=cls.desc_act, + sym=cls.sym, disable_exllama=cls.disable_exllama, exllama_config=cls.exllama_config, cache_block_outputs=cls.cache_block_outputs, @@ -116,13 +119,20 @@ def test_quantized_layers_class(self): """ if is_gptqmodel_available(): + if hasattr(self.config, "quantization_config"): + checkpoint_format = self.config.quantization_config.get("checkpoint_format") + meta = self.config.quantization_config.get("meta") + else: + checkpoint_format = "gptq" + meta = None QuantLinear = hf_select_quant_linear( bits=self.bits, group_size=self.group_size, desc_act=self.desc_act, - sym=True, + sym=self.sym, device_map=self.device_map_for_quantization, - pack=False, + checkpoint_format=checkpoint_format, + meta=meta, ) else: QuantLinear = hf_select_quant_linear( @@ -133,10 +143,10 @@ def test_quantized_layers_class(self): disable_exllama=self.disable_exllama or self.exllama_config["version"] != 1, disable_exllamav2=self.disable_exllama or self.exllama_config["version"] != 2, ) - self.assertTrue(self.quantized_model.model.layers[0].mlp.gate_proj.__class__ == QuantLinear) + self.assertEqual(self.quantized_model.model.layers[0].mlp.gate_proj.__class__, QuantLinear) def check_quantized_layers_type(self, model, value): - self.assertTrue(model.model.layers[0].mlp.gate_proj.QUANT_TYPE == value) + self.assertEqual(model.model.layers[0].mlp.gate_proj.QUANT_TYPE, value) def test_serialization(self): """ @@ -161,7 +171,7 @@ def test_serialization(self): if is_auto_gptq_available() and not is_gptqmodel_available(): quant_type = "cuda-old" if self.disable_exllama else "exllama" else: - quant_type = "ipex" if self.device_map_for_quantization == "cpu" else "cuda" + quant_type = "ipex" if self.device_map_for_quantization == "cpu" else "exllama" self.check_quantized_layers_type(quantized_model_from_saved, quant_type) @@ -179,7 +189,10 @@ def test_serialization(self): class GPTQTestCUDA(GPTQTest): device_map_for_quantization = "cuda" device_for_inference = 0 - expected_compression_ratio = 1.66 + expected_compression_ratio = 1.2577 + expected_fp16_perplexity = 38 + expected_quantized_perplexity = 45 + def test_perplexity(self): """ @@ -187,8 +200,8 @@ def test_perplexity(self): the perplexity of the converted models """ - self.assertEqual(int(self.fp16_ppl), self.expected_fp16_perplexity) - self.assertEqual(int(self.quantized_ppl), self.expected_quantized_perplexity) + self.assertLessEqual(int(self.fp16_ppl), self.expected_fp16_perplexity) + self.assertLessEqual(int(self.quantized_ppl), self.expected_quantized_perplexity) class GPTQTestExllama(GPTQTestCUDA): @@ -199,6 +212,7 @@ class GPTQTestExllama(GPTQTestCUDA): class GPTQTestActOrder(GPTQTestCUDA): disable_exllama = True desc_act = True + expected_quantized_perplexity = 46 def test_serialization(self): # act_order don't work with qlinear_cuda kernel @@ -282,7 +296,6 @@ def test_exllama_serialization(self): """ Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel """ - with tempfile.TemporaryDirectory() as tmpdirname: self.quantizer.save(self.quantized_model, tmpdirname) self.quantized_model.config.save_pretrained(tmpdirname) @@ -296,16 +309,13 @@ def test_exllama_serialization(self): save_folder=tmpdirname, device_map={"": self.device_for_inference}, ) - self.check_quantized_layers_type(quantized_model_from_saved, "exllamav2") + self.check_quantized_layers_type(quantized_model_from_saved, "exllama" if is_gptqmodel_available else "exllamav2") # transformers and auto-gptq compatibility # quantized models are more compatible with device map than # device context managers (they're never used in transformers testing suite) _ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference}) - if is_gptqmodel_available(): - _ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference}) - else: - _ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference}) + _ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference}) class GPTQTestNoBlockCaching(GPTQTestCUDA): @@ -318,11 +328,12 @@ class GPTQTestModuleQuant(GPTQTestCUDA): ["self_attn.q_proj"], ["mlp.gate_proj"], ] - expected_compression_ratio = 1.577 + expected_compression_ratio = 1.068 + expected_quantized_perplexity = 39 def test_not_converted_layers(self): # self_attention.dense should not be converted - self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.__class__.__name__ == "Linear") + self.assertEqual(self.quantized_model.model.layers[0].self_attn.k_proj.__class__.__name__, "Linear") class GPTQUtilsTest(unittest.TestCase):