Skip to content

Commit

Permalink
comments - suppress warnings on state dict load, tests, fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin committed Aug 20, 2024
1 parent ab74d26 commit 2ecf711
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4182,6 +4182,7 @@ def _fix_key(key):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if hf_quantizer is not None:
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix)
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)

# retrieve weights on meta device and put them back on CPU.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
"compressed_tensors": CompressedTensorsHfQuantizer,
"compressed-tensors": CompressedTensorsHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
}
Expand All @@ -65,7 +65,7 @@
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"hqq": HqqConfig,
"compressed_tensors": CompressedTensorsConfig,
"compressed-tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
}
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
"""
return torch_dtype

def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]:
"""
Override this method if you want to adjust the `unexpected_keys`.
Args:
unexpected_keys (`List[str]`, *optional*):
The list of unexpected keys in the state dict of the model compared to the checkpoint
"""
return unexpected_keys

def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
"""
Override this method if you want to adjust the `missing_keys`.
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

from ..utils import is_compressed_tensors_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
from .base import HfQuantizer
Expand Down Expand Up @@ -58,6 +60,22 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
)
return torch_dtype

def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]:
def _is_compressed_key(key: str) -> bool:
# key names in compressed state dict that will not be present in
# a decompressed state dict
return key.endswith("weight_shape") or key.endswith("weight_packed")

return [key for key in unexpected_keys if not _is_compressed_key(key)]

def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
def _is_decompressed_key(key: str) -> bool:
# key names in decompressed state dict that will not be present in
# a compressed state dict
return key.endswith("weight") or "scale" in key or "zero_point" in key

return [key for key in missing_keys if not _is_decompressed_key(key)]

def _process_model_before_weight_loading(self, model, **kwargs):
if self.quantization_config.quantization_config is not None:
from compressed_tensors.quantization import apply_quantization_config
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,7 @@ def __init__(
"quantization_status": quantization_status,
"global_compression_ratio": global_compression_ratio,
"ignore": ignore,
**kwargs,
}
)

Expand All @@ -1110,6 +1111,32 @@ def __init__(

super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS)

@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""
Instantiates a [`CompressedTensorsConfig`] from a Python dictionary of parameters.
Optionally unwraps any args from the nested quantization_config
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
`PreTrainedModel`.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
Returns:
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
"""
if "quantization_config" in config_dict:
config_dict = dict(
sparsity_config=config_dict.get("sparsity_config"),
**config_dict["quantization_config"],
)

return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
Expand Down
56 changes: 39 additions & 17 deletions tests/quantization/compressed_tensor/test_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gc
import unittest

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig
from transformers.testing_utils import require_compressed_tensors, require_torch
from transformers.utils import is_torch_available

Expand All @@ -13,7 +13,8 @@
@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
quantized_model_name = "nm-testing/tinyllama-oneshot-w8a8-test-static-shape-change-v3"
tinyllama_w8a8 = "nm-testing/tinyllama-oneshot-w8a8-test-static-shape-change-v3"
llama3_8b_fp8 = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat"

prompt = "Paris is the capital of which country?"

Expand All @@ -22,31 +23,52 @@ def tearDown(self):
torch.cuda.empty_cache()
gc.collect()

@classmethod
def setUpClass(self):
"""
Setup quantized model
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.quantized_model_name)
self.quantized_model = AutoModelForCausalLM.from_pretrained(self.quantized_model_name)
self.device = self.quantized_model.device
def test_config_args(self):
with self.assertRaises(ValueError):
# passing quant scheme directly is not allowed
CompressedTensorsConfig(config_groups={"weights": {"num_bits": 8}})
CompressedTensorsConfig(
config_groups={"FP8": ["Linear"]},
ignore=["lm_head"],
quantization_status="frozen",
sparsity_config={"format": "dense"},
)

def test_config_to_from_dict(self):
config = CompressedTensorsConfig(config_groups={"FP8": ["Linear"]}, sparsity_config={"format": "dense"})
config_dict = config.to_dict()
config_from_dict = CompressedTensorsConfig.from_dict(config_dict)

from compressed_tensors import QuantizationConfig, SparsityCompressionConfig

self.assertIsInstance(config_from_dict.quantization_config, QuantizationConfig)
self.assertIsInstance(config_from_dict.sparsity_config, SparsityCompressionConfig)

def test_tinyllama_w8a8(self):
self._test_quantized_model(self.tinyllama_w8a8)

def test_llama_8b_fp8(self):
self._test_quantized_model(self.llama3_8b_fp8)

def test_quantized_model(self):
def _test_quantized_model(self, model_name: str):
"""Carry out generation"""
quantized_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = quantized_model.device
self.assertIsNotNone(
self.quantized_model.config.quantization_config,
quantized_model.config.quantization_config,
"quantization_config should not be None",
)
self.assertTrue(
any(
key
for key, tensor in self.quantized_model.state_dict().items()
for key, tensor in quantized_model.state_dict().items()
if "scale" in key and not torch.all(tensor == 1.0)
),
"quantized model should load a non-trivail scale into the state dict",
"quantized model should load a non-trivial scale into the state dict",
)
inputs = self.tokenizer(self.prompt, return_tensors="pt").to(self.device)
generated_ids = self.quantized_model.generate(**inputs, max_length=50)
outputs = self.tokenizer.batch_decode(generated_ids)
inputs = tokenizer(self.prompt, return_tensors="pt").to(device)
generated_ids = quantized_model.generate(**inputs, max_length=50)
outputs = tokenizer.batch_decode(generated_ids)

self.assertIsNotNone(outputs)

0 comments on commit 2ecf711

Please sign in to comment.