forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ammo.py
135 lines (114 loc) · 4.97 KB
/
ammo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict, Literal, Optional, Union
import torch
from torch.utils.data import DataLoader
try:
import ammo.torch.quantization as atq
from ammo.torch.export import export_model_config
except ImportError:
raise ImportError("AMMO toolkit is not installed. Please install it first.")
from ...logger import logger
def _register_falcon_linears(model):
"""Register Falcon linear modules as Quantiation.
As falcon models could use remote code, which will be loaded dynamically,
to build their model. Therefore, we need to register the linear on the fly
before quantization.
"""
if type(model).__name__ in ["RWForCausalLM", "FalconForCausalLM"]:
from ammo.torch.quantization import tensor_quant
from ammo.torch.quantization.nn.modules.quant_module import \
QuantLinearConvBase
linear_type = type(model.transformer.h[0].self_attention.dense)
class QuantFalconLinearRW1B(linear_type,
QuantLinearConvBase): # type: ignore
default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
atq.module_mapping.QUANT_MODULE_MAPPING[
linear_type] = QuantFalconLinearRW1B.convert
def _quantize_model(model: torch.nn.Module,
qformat: Literal['fp8', 'int8_sq', 'int4_awq'],
calib_dataloader: DataLoader,
quant_cfg_dict: Optional[Dict] = None) -> torch.nn.Module:
assert qformat in ['fp8', 'int8_sq', 'int4_awq'], \
f'Got unsupported AMMO quantization format, {qformat} '
if qformat == "fp8":
quant_cfg = atq.FP8_DEFAULT_CFG
elif qformat == "int8_sq":
quant_cfg = atq.INT8_SMOOTHQUANT_CFG
elif qformat == "int4_awq":
quant_cfg = atq.INT4_AWQ_CFG
# AMMO 0.5.0 disables lm_head quantization by default, remove the filter
if "*lm_head*" in quant_cfg["quant_cfg"]:
del quant_cfg["quant_cfg"]["*lm_head*"]
else:
raise ValueError(f"Unsupported quantization format: {qformat}")
if quant_cfg_dict:
for name, cfg in quant_cfg_dict.items():
quant_cfg['quant_cfg'][name] = cfg
def calibrate_loop():
"""Adjusts weights and scaling factors based on selected algorithms."""
for idx, data in enumerate(calib_dataloader):
logger.debug(f"Calibrating batch {idx}")
model(data)
_register_falcon_linears(model)
logger.debug("Starting quantization...")
atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
logger.debug("Quantization done")
return model
def quantize_and_export(
model: torch.nn.Module,
qformat: Literal['fp8', 'int8_sq', 'int4_awq'],
calib_dataloader: DataLoader,
export_path: Optional[Union[str, Path]] = None,
tensor_parallel_size: int = 1,
quant_cfg_dict: Optional[Dict] = None) -> torch.nn.Module:
model_cls_name = type(model).__name__
model_lookup = {
("llama", "mistral"): "llama",
("gptj", ): "gptj",
("falcon", "rw"): "falcon",
("baichuan", ): "baichuan",
("mpt", ): "mpt",
("gpt2", ): "gpt2",
("chatglm", ): "chatglm",
("qwen", ): "qwen",
}
for templates, model_type_target in model_lookup.items():
if any(t in model_cls_name.lower() for t in templates):
model_type = model_type_target
break
else:
raise NotImplementedError(
f"Deploying quantized model {model_cls_name} is not supported")
model = _quantize_model(model,
qformat=qformat,
calib_dataloader=calib_dataloader,
quant_cfg_dict=quant_cfg_dict)
if export_path:
with torch.inference_mode():
if qformat == "int4_awq" and model_type == "qwen" or \
model_type == "chatglm":
torch.save(model.state_dict(), export_path)
else:
export_model_config(
model,
model_type,
torch.float16,
export_dir=export_path,
inference_tensor_parallel=tensor_parallel_size,
)
logger.info(f"Quantized model exported to :{export_path}")
return model