Skip to content

Commit

Permalink
Add FP8 quantization test (#1114)
Browse files Browse the repository at this point in the history
* Add llm fp8 quantization test; rename test variables; add dataset name check

* Create a variable for supported language datasets

* ruff
  • Loading branch information
nikita-savelyevv authored Jan 20, 2025
1 parent 2590794 commit 38b6e54
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 48 deletions.
23 changes: 18 additions & 5 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from optimum.configuration_utils import BaseConfig

from ..utils.import_utils import is_nncf_available
from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_VISUAL_LM_DATASETS
from .utils import (
LANGUAGE_DATASETS,
PREDEFINED_SD_DATASETS,
PREDEFINED_SPEECH_TO_TEXT_DATASETS,
PREDEFINED_VISUAL_LM_DATASETS,
)


if is_nncf_available():
Expand Down Expand Up @@ -467,13 +472,12 @@ def post_init(self):
f"If you wish to provide a custom dataset, please use the `OVQuantizer` instead."
)
if self.dataset is not None and isinstance(self.dataset, str):
lm_datasets = ["wikitext2", "c4", "c4-new", "auto"]
visual_lm_datasets = list(PREDEFINED_VISUAL_LM_DATASETS.keys())
stable_diffusion_datasets = list(PREDEFINED_SD_DATASETS.keys())
if self.dataset not in lm_datasets + visual_lm_datasets + stable_diffusion_datasets:
if self.dataset not in LANGUAGE_DATASETS + visual_lm_datasets + stable_diffusion_datasets:
raise ValueError(
f"""You have entered a string value for dataset. You can only choose between
{lm_datasets} for LLMs, {visual_lm_datasets} for visual LLMs
{LANGUAGE_DATASETS} for LLMs, {visual_lm_datasets} for visual LLMs
or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}"""
)

Expand Down Expand Up @@ -617,7 +621,8 @@ def __init__(
overflow_fix (`str`, default to "disable"):
Parameter for controlling overflow fix setting.
dataset (`str`, *optional*):
The dataset used for quantization. For text-to-speech model quantization the allowed value is 'librispeech'.
The dataset used for quantization. For language models the allowed values are
['auto', 'wikitext2','c4','c4-new']. For text-to-speech model quantization the allowed value is 'librispeech'.
tokenizer (`str`, *optional*):
The tokenizer used to process the dataset. You can pass either:
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
Expand Down Expand Up @@ -673,6 +678,14 @@ def post_init(self):
"""
super().post_init()

if self.dataset is not None:
speech_to_text_datasets = list(PREDEFINED_SPEECH_TO_TEXT_DATASETS.keys())
if self.dataset not in LANGUAGE_DATASETS + speech_to_text_datasets:
raise ValueError(
f"""You can only choose between the following datasets: {LANGUAGE_DATASETS} for LLMs or
{speech_to_text_datasets} for speech-to-text models, but we found {self.dataset}."""
)

if self.bits != 8:
raise ValueError(f"Only support 8-bit for static quantization but found {self.bits}")

Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@
}


LANGUAGE_DATASETS = ["wikitext2", "c4", "c4-new", "auto"]

PREDEFINED_SD_DATASETS = {
"conceptual_captions": {"split": "train", "inputs": {"prompt": "caption"}},
"laion/220k-GPT4Vision-captions-from-LIVIS": {"split": "train", "inputs": {"prompt": "caption"}},
Expand Down
24 changes: 13 additions & 11 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,19 +365,21 @@ def test_exporters_cli_int8(self, task: str, model_type: str):
self.assertEqual(expected_int8[i], num_weight_nodes["int8"])

@parameterized.expand(SUPPORTED_SD_HYBRID_ARCHITECTURES)
def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: int, exp_num_int8: int):
def test_exporters_cli_hybrid_quantization(
self, model_type: str, expected_fake_nodes: int, expected_int8_nodes: int
):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --dataset laion/filtered-wit --weight-format int8 {tmpdir}",
shell=True,
check=True,
)
model = eval(_HEAD_TO_AUTOMODELS[model_type.replace("-refiner", "")]).from_pretrained(tmpdir)
num_fq, num_weight_nodes = get_num_quantized_nodes(
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(
model.unet if model.unet is not None else model.transformer
)
self.assertEqual(exp_num_int8, num_weight_nodes["int8"])
self.assertEqual(exp_num_fq, num_fq)
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
self.assertEqual(expected_fake_nodes, num_fake_nodes)

@parameterized.expand(TEST_4BIT_CONFIGURATIONS)
def test_exporters_cli_4bit(
Expand Down Expand Up @@ -422,8 +424,8 @@ def test_exporters_cli_full_quantization(
model_type: str,
quant_mode: str,
option: str,
expected_num_f_nodes_per_model: Tuple[int],
expected_num_weight_nodes_per_model: Tuple[int],
expected_fake_nodes: Tuple[int],
expected_low_precision_nodes: Tuple[int],
):
with TemporaryDirectory() as tmpdir:
subprocess.run(
Expand All @@ -439,12 +441,12 @@ def test_exporters_cli_full_quantization(
if model.decoder_with_past is not None:
models.append(model.decoder_with_past)
else:
expected_num_f_nodes_per_model = expected_num_f_nodes_per_model[:-1]
self.assertEqual(len(expected_num_f_nodes_per_model), len(models))
expected_fake_nodes = expected_fake_nodes[:-1]
self.assertEqual(len(expected_fake_nodes), len(models))
for i, model in enumerate(models):
actual_num_f_nodes, actual_num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_num_f_nodes_per_model[i], actual_num_f_nodes)
self.assertEqual(expected_num_weight_nodes_per_model[i], actual_num_weight_nodes[quant_mode])
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_nodes[i], num_fake_nodes)
self.assertEqual(expected_low_precision_nodes[i], num_weight_nodes[quant_mode])

def test_exporters_cli_int4_with_local_model_and_default_config(self):
with TemporaryDirectory() as tmpdir:
Expand Down
89 changes: 57 additions & 32 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,23 @@ class OVQuantizerTest(unittest.TestCase):
(14, 22, 21) if is_transformers_version("<=", "4.42.4") else (14, 22, 25),
(14, 21, 17) if is_transformers_version("<=", "4.42.4") else (14, 22, 18),
),
(
OVModelForCausalLM,
"llama",
OVQuantizationConfig(
dataset="wikitext2",
num_samples=1,
weight_only=False,
weight_format="f8e4m3",
activation_format="f8e4m3",
),
(13,),
(16,),
),
]

@parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL)
def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_nodes, expected_int8_nodes):
model_id = MODEL_NAMES[model_name]
task = model_cls.export_feature
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
Expand Down Expand Up @@ -149,9 +162,9 @@ def preprocess_function(examples, tokenizer):
ov_config=ov_config,
)
model = model_cls.from_pretrained(tmp_dir, file_name=file_name)
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_quantize, num_fake_quantize)
self.assertEqual(expected_int8, num_weight_nodes["int8"])
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_nodes, num_fake_nodes)
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
Expand All @@ -162,7 +175,7 @@ def preprocess_function(examples, tokenizer):
self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict())

@parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL)
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_nodes, expected_int8_nodes):
model_id = MODEL_NAMES[model_name]
task = model_cls.export_feature
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
Expand Down Expand Up @@ -190,9 +203,9 @@ def preprocess_function(examples, tokenizer):

model = model_cls.from_pretrained(tmp_dir)

num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_quantize, num_fake_quantize)
self.assertEqual(expected_int8, num_weight_nodes["int8"])
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_nodes, num_fake_nodes)
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
Expand All @@ -204,9 +217,10 @@ def preprocess_function(examples, tokenizer):

@parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL_WITH_AUTO_DATASET)
def test_ov_model_static_quantization_with_auto_dataset(
self, model_cls, model_name, quantization_config, expected_fake_quantize, expected_int8
self, model_cls, model_name, quantization_config, expected_fake_nodes, expected_low_precision_nodes
):
model_id = MODEL_NAMES[model_name]
quant_mode = quantization_config.activation_format

with TemporaryDirectory() as tmp_dir:
ov_model = model_cls.from_pretrained(model_id, quantization_config=quantization_config)
Expand All @@ -217,17 +231,28 @@ def test_ov_model_static_quantization_with_auto_dataset(

if ov_model.decoder_with_past is not None:
models.append(ov_model.decoder_with_past.model)
for model, expected_fq, expected_i8 in zip(
for model, expected_fake_nodes, expected_lp_nodes in zip(
models,
expected_fake_quantize,
expected_int8,
expected_fake_nodes,
expected_low_precision_nodes,
):
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fq, num_fake_quantize)
self.assertEqual(expected_i8, num_weight_nodes["int8"])
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_nodes, num_fake_nodes)
self.assertEqual(expected_lp_nodes, num_weight_nodes[quant_mode])

input_features = torch.randn((1, 128, 3000), dtype=torch.float32)
ov_model.generate(input_features)
elif model_cls == OVModelForCausalLM:
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(ov_model.model)
self.assertEqual(expected_fake_nodes[0], num_fake_nodes)
self.assertEqual(expected_low_precision_nodes[0], num_weight_nodes[quant_mode])

tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = ov_model(**tokens)
self.assertTrue("logits" in outputs)
else:
raise Exception("Unexpected model class.")

Expand Down Expand Up @@ -608,7 +633,7 @@ def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_p
self.assertEqual(OVWeightQuantizationConfig().to_dict(), loaded_config.quantization_config.to_dict())

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS)
def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_int8, expected_int4):
def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_int8_nodes, expected_int4_nodes):
task = model_cls.export_feature
model_id = MODEL_NAMES[model_name]
with TemporaryDirectory() as tmp_dir:
Expand All @@ -623,8 +648,8 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i
model = model_cls.from_pretrained(tmp_dir)

_, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_int8, num_weight_nodes["int8"])
self.assertEqual(expected_int4, num_weight_nodes["int4"])
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
self.assertEqual(expected_int4_nodes, num_weight_nodes["int4"])

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
Expand Down Expand Up @@ -699,17 +724,17 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust
self.assertEqual(expected_ov_int8[i], num_weight_nodes["int8"])

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION)
def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8):
def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_fake_nodes, expected_int8_nodes):
model_id = MODEL_NAMES[model_type]
quantization_config = OVWeightQuantizationConfig(bits=8, dataset="conceptual_captions", num_samples=2)
with TemporaryDirectory() as tmp_dir:
model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config)

num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
num_fake, num_weight_nodes = get_num_quantized_nodes(
model.unet if model.unet is not None else model.transformer
)
self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
self.assertEqual(expected_ov_int8, num_weight_nodes["int8"])
self.assertEqual(expected_fake_nodes, num_fake)
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
self.assertEqual(0, num_weight_nodes["int4"])

model.save_pretrained(tmp_dir)
Expand All @@ -721,16 +746,16 @@ def test_stable_diffusion_with_weight_compression(self):

quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config))

num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(
int8_pipe.unet if int8_pipe.unet is not None else int8_pipe.transformer
)
self.assertEqual(0, num_fake_quantize)
self.assertEqual(0, num_fake_nodes)
self.assertEqual(242, num_weight_nodes["int8"])
self.assertEqual(0, num_weight_nodes["int4"])

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION[-1:])
def test_ovmodel_hybrid_quantization_with_custom_dataset(
self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8
self, model_cls, model_type, expected_fake_nodes, expected_int8_nodes
):
model_id = MODEL_NAMES[model_type]
dataset = [
Expand All @@ -742,11 +767,11 @@ def test_ovmodel_hybrid_quantization_with_custom_dataset(
self.assertEqual(quantization_config.quant_method, OVQuantizationMethod.HYBRID)

quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config), calibration_dataset=dataset)
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(
model.unet if model.unet is not None else model.transformer
)
self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
self.assertEqual(expected_ov_int8, num_weight_nodes["int8"])
self.assertEqual(expected_fake_nodes, num_fake_nodes)
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])
self.assertEqual(0, num_weight_nodes["int4"])

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS)
Expand Down Expand Up @@ -1050,7 +1075,7 @@ class OVTrainerTest(unittest.TestCase):
@unittest.skipIf(
is_transformers_version(">=", "4.46"), reason="OVTrainer is not compatible with transformers>=v4.46"
)
def test_aware_training_quantization(self, model_name, expected_fake_quantize, expected_int8):
def test_aware_training_quantization(self, model_name, expected_fake_nodes, expected_int8_nodes):
model_id = MODEL_NAMES[model_name]
model = AutoModelForSequenceClassification.from_pretrained(model_id, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -1084,9 +1109,9 @@ def compute_metrics(p):
trainer.save_model()

model = OVModelForSequenceClassification.from_pretrained(tmp_dir)
num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_quantize, num_fake_quantize)
self.assertEqual(expected_int8, num_weight_nodes["int8"])
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_fake_nodes, num_fake_nodes)
self.assertEqual(expected_int8_nodes, num_weight_nodes["int8"])

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
Expand Down

0 comments on commit 38b6e54

Please sign in to comment.