Skip to content

Commit

Permalink
Add KV cache quantization.
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Nov 28, 2024
1 parent e9b83ce commit e2ec48c
Show file tree
Hide file tree
Showing 15 changed files with 477 additions and 38 deletions.
37 changes: 37 additions & 0 deletions configs/quantization/methods/RTN/rtn_w_a_kv.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
torch_dtype: auto
eval:
eval_pos: [pretrain, fake_quant]
name: wikitext2
download: False
path: eval data path
seq_len: 2048
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
bs: 1
inference_per_block: False
# Consistency of tokens between original and fake-quantized model output.
eval_token_consist: True
quant:
method: RTN
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
kvcache:
method: Naive
bit: 8
symmetric: True
granularity: per_token
save:
save_fake: False
save_path: /path/to/save/
47 changes: 47 additions & 0 deletions configs/quantization/methods/RTN/rtn_w_a_pertensor_static_kv.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
torch_dtype: auto
calib:
name: pileval
download: False
path: calib data path
n_samples: 128
bs: 1
seq_len: 2048
preproc: general
seed: *seed
eval:
eval_pos: [pretrain, fake_quant]
name: wikitext2
download: False
path: eval data path
seq_len: 2048
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
bs: 1
inference_per_block: False
# Consistency of tokens between original and fake-quantized model output.
eval_token_consist: True
quant:
method: RTN
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_tensor
static: True
kvcache:
method: Naive
bit: 8
symmetric: True
granularity: per_tensor
save:
save_fake: False
save_path: /path/to/save/
9 changes: 9 additions & 0 deletions llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def cache_input_hook(self, m, x, y, name, feat_dict):
else:
feat_dict[name].append(tuple(inputs))

def kv_cache_input_hook(self):
def hook_fn(module, args, kwargs):
kvcache = getattr(module, 'kvcache')
kwargs['past_key_value'] = kvcache
kwargs['use_cache'] = False
return args, kwargs

return hook_fn

@abstractmethod
def block_opt(self, block):
pass
Expand Down
1 change: 1 addition & 0 deletions llmc/compression/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .dgq import DGQ
from .gptq import GPTQ
from .hqq import HQQ
from .kvquant import NaiveQuantKVCache
from .llmint8 import LlmInt8
from .module_utils import FakeQuantLinear
from .ntweak import NormTweaking
Expand Down
50 changes: 44 additions & 6 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from loguru import logger

from llmc.utils import copy_files
from llmc.utils.registry_factory import KV_REGISTRY

from ..blockwise_optimization import BlockwiseOpt
from .auto_clip import AutoClipper
Expand Down Expand Up @@ -176,6 +177,13 @@ def set_quant_config(self):
self.tp = self.quant_config.get('tp', 1)
self.quant_config['weight']['tp'] = self.tp

# set model config
self.hidden_size = self.model.model_config.hidden_size
self.num_heads = self.model.model_config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.intermediate_size = self.model.model_config.intermediate_size
self.num_hidden_layers = self.model.model_config.num_hidden_layers

# select quant module
self.quant_type = self.quant_config.get('quant_type', 'int-quant')
if self.quant_type == 'int-quant':
Expand Down Expand Up @@ -227,6 +235,19 @@ def set_quant_config(self):
f'{json.dumps(self.mix_bits_map, ensure_ascii=False, indent=4)}'
)

# set kv cache quant config
if 'kvcache' in self.quant_config:
self.quant_config['kvcache']['static'] = self.act_static
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
self.num_hidden_layers, self.config.calib.n_samples,
self.config.calib.bs
)
self.quant_kvcache = True
self.model.kvcache_buffer.append(self.kv_module)
else:
self.quant_kvcache = False

# set special quant config
special_config = self.quant_config.get('special', {})
self.true_sequential = special_config.get('true_sequential', False)
Expand Down Expand Up @@ -263,12 +284,6 @@ def set_quant_config(self):
self.online_rotate = special_config.get('online_rotate', False)
if self.online_rotate:
assert self.config['model']['type'] in ['Opt', 'Llama']

self.hidden_size = self.model.model_config.hidden_size
if self.online_rotate:
self.num_heads = self.model.model_config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.intermediate_size = self.model.model_config.intermediate_size
self.fp32_had = special_config.get('fp32_had', False)

self.quant_objects = self.quant_config.get('quant_objects', ['language'])
Expand Down Expand Up @@ -372,6 +387,10 @@ def block_forward(self, block, input_data=None):
return output

def block_opt(self, block):

if self.quant_kvcache:
self.register_kv_cache(block)

block = block.cuda()
named_linears = self.model.get_block_linears(block)
extra_modules = self.model.get_extra_modules(block)
Expand Down Expand Up @@ -512,6 +531,15 @@ def collect_layers_weights(self, layers, tensor_parallelize_style=None):
weights.append(_m.weight)
return weights

@torch.no_grad()
def register_kv_cache(self, block):
attn_layers_dict = self.model.get_attn_in_block(block)
attn_layer = attn_layers_dict[list(attn_layers_dict.keys())[0]]
setattr(attn_layer, 'kvcache', self.kv_module)
attn_layer.register_forward_pre_hook(
self.kv_cache_input_hook(), with_kwargs=True
)

@torch.no_grad()
def register_non_linear_qparams(self, block, input_feat):
layer_types = [
Expand Down Expand Up @@ -780,6 +808,8 @@ def set_non_linear_mode(self, quant_format, module, mode):
if quant_format != 'fake_quant':
return
for name, m in module.named_modules():
if 'kvcache' in name:
continue
if getattr(m, 'calib', None) is not None:
m.calib = mode

Expand Down Expand Up @@ -814,6 +844,14 @@ def deploy(self, quant_format, keep_device=False):
)
self.set_non_linear_mode(quant_format, self.model.model, False)

if self.quant_kvcache:
if quant_format == 'transformed':
self.kv_module.transformed = True
elif quant_format == 'fake_quant':
self.kv_module.transformed = False
if self.act_static:
self.kv_module.calib = False

if self.model.vlm_model is not None:
logger.info(f'Now, the vlm_model is: {self.model.vlm_model}')

Expand Down
Loading

0 comments on commit e2ec48c

Please sign in to comment.