Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix kvquant decode bugs #237

Merged
merged 2 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def kv_cache_input_hook(self):
def hook_fn(module, args, kwargs):
kvcache = getattr(module, 'kvcache')
kwargs['past_key_value'] = kvcache
kwargs['use_cache'] = False
kwargs['use_cache'] = True
if kwargs['hidden_states'].shape[1] == 1:
kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0)
cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1)
sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1)
kwargs['position_embeddings'] = (cos, sin)

return args, kwargs

return hook_fn
Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/kvquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class NaiveQuantKVCache(DynamicCache):
def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz):
super().__init__()

assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group']
self.num_hidden_layers, self.num_samples, self.bsz = (
num_hidden_layers,
num_samples,
Expand All @@ -23,7 +24,6 @@ def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz)
self.kvquantizer = FloatQuantizer(**kvquant_cfg)

self.kvquant_cfg = kvquant_cfg
assert self.kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group']
self.static = kvquant_cfg.get('static', False)
self._quantized_key_cache = []
self._quantized_value_cache = []
Expand Down
35 changes: 2 additions & 33 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,23 +634,13 @@ def __init__(self, weight, bias, ori_module):
self.rotater = ori_module.rotater

@torch.no_grad()
def forward(self, x, dtype=None):
def forward(self, x):
if hasattr(self, 'buf_rotate') and self.buf_rotate:
x = self.rotater.rotate(x)

org_dtype = self.weight.data.dtype
if dtype is not None:
self.convert_dtype(dtype)

x = torch.functional.F.linear(x, self.weight, self.bias)
self.convert_dtype(org_dtype)
return x

def convert_dtype(self, dtype):
self.weight.data = self.weight.data.to(dtype)
if self.bias is not None:
self.bias.data = self.bias.data.to(dtype)

@classmethod
@torch.no_grad()
def new(cls, module):
Expand Down Expand Up @@ -829,7 +819,7 @@ def __init__(self, weight, bias, ori_module, w_qdq, a_qdq):
self.dynamic_quant_weight = False
self.dynamic_quant_tmp_weight = False

def forward(self, x, dtype=None):
def forward(self, x):
if hasattr(self, 'buf_rotate') and self.buf_rotate:
x = self.rotater.rotate(x)

Expand All @@ -848,20 +838,9 @@ def forward(self, x, dtype=None):
elif self.dynamic_quant_tmp_weight:
self.tmp_weight = self.w_qdq(self)

org_dtype = self.tmp_weight.data.dtype
if dtype is not None:
self.convert_dtype(dtype)

x = torch.functional.F.linear(x, self.tmp_weight, self.tmp_bias)

self.convert_dtype(org_dtype)
return x

def convert_dtype(self, dtype):
self.tmp_weight.data = self.tmp_weight.data.to(dtype)
if self.tmp_bias is not None:
self.tmp_bias.data = self.tmp_bias.data.to(dtype)

@classmethod
@torch.no_grad()
def new(cls, module, w_qdq, a_qdq):
Expand Down Expand Up @@ -924,19 +903,9 @@ def forward(self, x, dtype=None):
if self.a_qdq is not None:
x = self.a_qdq(x, self)

org_dtype = self.weight.data.dtype
if dtype is not None:
self.convert_dtype(dtype)

x = torch.functional.F.linear(x, self.weight, self.bias)
self.convert_dtype(org_dtype)
return x

def convert_dtype(self, dtype):
self.weight.data = self.weight.data.to(dtype)
if self.bias is not None:
self.bias.data = self.bias.data.to(dtype)

@classmethod
@torch.no_grad()
def new(cls, module, w_qdq, a_qdq, debug_print={}):
Expand Down