From 2ecf7110ce0d8dabf44adab2ef805b57a9a3ed0c Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 20 Aug 2024 14:05:50 -0400 Subject: [PATCH] comments - suppress warnings on state dict load, tests, fixes --- src/transformers/modeling_utils.py | 1 + src/transformers/quantizers/auto.py | 4 +- src/transformers/quantizers/base.py | 10 ++++ .../quantizer_compressed_tensors.py | 18 ++++++ src/transformers/utils/quantization_config.py | 27 +++++++++ .../test_compressed_tensors.py | 56 +++++++++++++------ 6 files changed, 97 insertions(+), 19 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9cef7ff13edb88..125dd629f4cc5e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4182,6 +4182,7 @@ def _fix_key(key): for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if hf_quantizer is not None: + unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix) missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) # retrieve weights on meta device and put them back on CPU. diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index d5c40436b00d1c..1dcd87c993a2e6 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -51,7 +51,7 @@ "quanto": QuantoHfQuantizer, "eetq": EetqHfQuantizer, "hqq": HqqHfQuantizer, - "compressed_tensors": CompressedTensorsHfQuantizer, + "compressed-tensors": CompressedTensorsHfQuantizer, "fbgemm_fp8": FbgemmFp8HfQuantizer, "torchao": TorchAoHfQuantizer, } @@ -65,7 +65,7 @@ "aqlm": AqlmConfig, "quanto": QuantoConfig, "hqq": HqqConfig, - "compressed_tensors": CompressedTensorsConfig, + "compressed-tensors": CompressedTensorsConfig, "fbgemm_fp8": FbgemmFp8Config, "torchao": TorchAoConfig, } diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 3ee28ada1bb25d..81eb8ac6956227 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -99,6 +99,16 @@ def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": """ return torch_dtype + def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]: + """ + Override this method if you want to adjust the `unexpected_keys`. + + Args: + unexpected_keys (`List[str]`, *optional*): + The list of unexpected keys in the state dict of the model compared to the checkpoint + """ + return unexpected_keys + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: """ Override this method if you want to adjust the `missing_keys`. diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index b4ca3f3d567cec..3b37d9c529899e 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + from ..utils import is_compressed_tensors_available, is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin from .base import HfQuantizer @@ -58,6 +60,22 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": ) return torch_dtype + def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]: + def _is_compressed_key(key: str) -> bool: + # key names in compressed state dict that will not be present in + # a decompressed state dict + return key.endswith("weight_shape") or key.endswith("weight_packed") + + return [key for key in unexpected_keys if not _is_compressed_key(key)] + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + def _is_decompressed_key(key: str) -> bool: + # key names in decompressed state dict that will not be present in + # a compressed state dict + return key.endswith("weight") or "scale" in key or "zero_point" in key + + return [key for key in missing_keys if not _is_decompressed_key(key)] + def _process_model_before_weight_loading(self, model, **kwargs): if self.quantization_config.quantization_config is not None: from compressed_tensors.quantization import apply_quantization_config diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index a6227386cb94f8..6d3f953f65dfbd 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1100,6 +1100,7 @@ def __init__( "quantization_status": quantization_status, "global_compression_ratio": global_compression_ratio, "ignore": ignore, + **kwargs, } ) @@ -1110,6 +1111,32 @@ def __init__( super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS) + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """ + Instantiates a [`CompressedTensorsConfig`] from a Python dictionary of parameters. + Optionally unwraps any args from the nested quantization_config + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. + return_unused_kwargs (`bool`,*optional*, defaults to `False`): + Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in + `PreTrainedModel`. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. + """ + if "quantization_config" in config_dict: + config_dict = dict( + sparsity_config=config_dict.get("sparsity_config"), + **config_dict["quantization_config"], + ) + + return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs) + def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: diff --git a/tests/quantization/compressed_tensor/test_compressed_tensors.py b/tests/quantization/compressed_tensor/test_compressed_tensors.py index 46d0221ed6e724..e7710b0b594e48 100644 --- a/tests/quantization/compressed_tensor/test_compressed_tensors.py +++ b/tests/quantization/compressed_tensor/test_compressed_tensors.py @@ -1,7 +1,7 @@ import gc import unittest -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig from transformers.testing_utils import require_compressed_tensors, require_torch from transformers.utils import is_torch_available @@ -13,7 +13,8 @@ @require_compressed_tensors @require_torch class CompressedTensorsTest(unittest.TestCase): - quantized_model_name = "nm-testing/tinyllama-oneshot-w8a8-test-static-shape-change-v3" + tinyllama_w8a8 = "nm-testing/tinyllama-oneshot-w8a8-test-static-shape-change-v3" + llama3_8b_fp8 = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat" prompt = "Paris is the capital of which country?" @@ -22,31 +23,52 @@ def tearDown(self): torch.cuda.empty_cache() gc.collect() - @classmethod - def setUpClass(self): - """ - Setup quantized model - """ - self.tokenizer = AutoTokenizer.from_pretrained(self.quantized_model_name) - self.quantized_model = AutoModelForCausalLM.from_pretrained(self.quantized_model_name) - self.device = self.quantized_model.device + def test_config_args(self): + with self.assertRaises(ValueError): + # passing quant scheme directly is not allowed + CompressedTensorsConfig(config_groups={"weights": {"num_bits": 8}}) + CompressedTensorsConfig( + config_groups={"FP8": ["Linear"]}, + ignore=["lm_head"], + quantization_status="frozen", + sparsity_config={"format": "dense"}, + ) + + def test_config_to_from_dict(self): + config = CompressedTensorsConfig(config_groups={"FP8": ["Linear"]}, sparsity_config={"format": "dense"}) + config_dict = config.to_dict() + config_from_dict = CompressedTensorsConfig.from_dict(config_dict) + + from compressed_tensors import QuantizationConfig, SparsityCompressionConfig + + self.assertIsInstance(config_from_dict.quantization_config, QuantizationConfig) + self.assertIsInstance(config_from_dict.sparsity_config, SparsityCompressionConfig) + + def test_tinyllama_w8a8(self): + self._test_quantized_model(self.tinyllama_w8a8) + + def test_llama_8b_fp8(self): + self._test_quantized_model(self.llama3_8b_fp8) - def test_quantized_model(self): + def _test_quantized_model(self, model_name: str): """Carry out generation""" + quantized_model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + device = quantized_model.device self.assertIsNotNone( - self.quantized_model.config.quantization_config, + quantized_model.config.quantization_config, "quantization_config should not be None", ) self.assertTrue( any( key - for key, tensor in self.quantized_model.state_dict().items() + for key, tensor in quantized_model.state_dict().items() if "scale" in key and not torch.all(tensor == 1.0) ), - "quantized model should load a non-trivail scale into the state dict", + "quantized model should load a non-trivial scale into the state dict", ) - inputs = self.tokenizer(self.prompt, return_tensors="pt").to(self.device) - generated_ids = self.quantized_model.generate(**inputs, max_length=50) - outputs = self.tokenizer.batch_decode(generated_ids) + inputs = tokenizer(self.prompt, return_tensors="pt").to(device) + generated_ids = quantized_model.generate(**inputs, max_length=50) + outputs = tokenizer.batch_decode(generated_ids) self.assertIsNotNone(outputs)