Skip to content

Commit

Permalink
Fix kvquant decode bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Nov 29, 2024
1 parent 3cb7b47 commit 68d9dc4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 36 deletions.
10 changes: 8 additions & 2 deletions llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,17 @@ def cache_input_hook(self, m, x, y, name, feat_dict):
else:
feat_dict[name].append(tuple(inputs))

def kv_cache_input_hook(self):
def kv_cache_input_hook():
def hook_fn(module, args, kwargs):
kvcache = getattr(module, 'kvcache')
kwargs['past_key_value'] = kvcache
kwargs['use_cache'] = False
kwargs['use_cache'] = True
if kwargs["hidden_states"].shape[1] == 1:
kwargs["position_ids"] = kwargs["position_ids"][:, -1].unsqueeze(0).unsqueeze(0)
cos = kwargs["position_embeddings"][0][:, -1, :].unsqueeze(1)
sin = kwargs["position_embeddings"][1][:, -1, :].unsqueeze(1)
kwargs["position_embeddings"] = (cos, sin)

return args, kwargs

return hook_fn
Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/kvquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class NaiveQuantKVCache(DynamicCache):
def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz):
super().__init__()

assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group']
self.num_hidden_layers, self.num_samples, self.bsz = (
num_hidden_layers,
num_samples,
Expand All @@ -23,7 +24,6 @@ def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz)
self.kvquantizer = FloatQuantizer(**kvquant_cfg)

self.kvquant_cfg = kvquant_cfg
assert self.kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group']
self.static = kvquant_cfg.get('static', False)
self._quantized_key_cache = []
self._quantized_value_cache = []
Expand Down
35 changes: 2 additions & 33 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,23 +634,13 @@ def __init__(self, weight, bias, ori_module):
self.rotater = ori_module.rotater

@torch.no_grad()
def forward(self, x, dtype=None):
def forward(self, x):
if hasattr(self, 'buf_rotate') and self.buf_rotate:
x = self.rotater.rotate(x)

org_dtype = self.weight.data.dtype
if dtype is not None:
self.convert_dtype(dtype)

x = torch.functional.F.linear(x, self.weight, self.bias)
self.convert_dtype(org_dtype)
return x

def convert_dtype(self, dtype):
self.weight.data = self.weight.data.to(dtype)
if self.bias is not None:
self.bias.data = self.bias.data.to(dtype)

@classmethod
@torch.no_grad()
def new(cls, module):
Expand Down Expand Up @@ -829,7 +819,7 @@ def __init__(self, weight, bias, ori_module, w_qdq, a_qdq):
self.dynamic_quant_weight = False
self.dynamic_quant_tmp_weight = False

def forward(self, x, dtype=None):
def forward(self, x):
if hasattr(self, 'buf_rotate') and self.buf_rotate:
x = self.rotater.rotate(x)

Expand All @@ -848,20 +838,9 @@ def forward(self, x, dtype=None):
elif self.dynamic_quant_tmp_weight:
self.tmp_weight = self.w_qdq(self)

org_dtype = self.tmp_weight.data.dtype
if dtype is not None:
self.convert_dtype(dtype)

x = torch.functional.F.linear(x, self.tmp_weight, self.tmp_bias)

self.convert_dtype(org_dtype)
return x

def convert_dtype(self, dtype):
self.tmp_weight.data = self.tmp_weight.data.to(dtype)
if self.tmp_bias is not None:
self.tmp_bias.data = self.tmp_bias.data.to(dtype)

@classmethod
@torch.no_grad()
def new(cls, module, w_qdq, a_qdq):
Expand Down Expand Up @@ -924,19 +903,9 @@ def forward(self, x, dtype=None):
if self.a_qdq is not None:
x = self.a_qdq(x, self)

org_dtype = self.weight.data.dtype
if dtype is not None:
self.convert_dtype(dtype)

x = torch.functional.F.linear(x, self.weight, self.bias)
self.convert_dtype(org_dtype)
return x

def convert_dtype(self, dtype):
self.weight.data = self.weight.data.to(dtype)
if self.bias is not None:
self.bias.data = self.bias.data.to(dtype)

@classmethod
@torch.no_grad()
def new(cls, module, w_qdq, a_qdq, debug_print={}):
Expand Down

0 comments on commit 68d9dc4

Please sign in to comment.