diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index 19d275f1..c2210722 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -59,11 +59,17 @@ def cache_input_hook(self, m, x, y, name, feat_dict): else: feat_dict[name].append(tuple(inputs)) - def kv_cache_input_hook(self): + def kv_cache_input_hook(): def hook_fn(module, args, kwargs): kvcache = getattr(module, 'kvcache') kwargs['past_key_value'] = kvcache - kwargs['use_cache'] = False + kwargs['use_cache'] = True + if kwargs["hidden_states"].shape[1] == 1: + kwargs["position_ids"] = kwargs["position_ids"][:, -1].unsqueeze(0).unsqueeze(0) + cos = kwargs["position_embeddings"][0][:, -1, :].unsqueeze(1) + sin = kwargs["position_embeddings"][1][:, -1, :].unsqueeze(1) + kwargs["position_embeddings"] = (cos, sin) + return args, kwargs return hook_fn diff --git a/llmc/compression/quantization/kvquant.py b/llmc/compression/quantization/kvquant.py index 5a56d173..0a247c73 100644 --- a/llmc/compression/quantization/kvquant.py +++ b/llmc/compression/quantization/kvquant.py @@ -12,6 +12,7 @@ class NaiveQuantKVCache(DynamicCache): def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz): super().__init__() + assert kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group'] self.num_hidden_layers, self.num_samples, self.bsz = ( num_hidden_layers, num_samples, @@ -23,7 +24,6 @@ def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz) self.kvquantizer = FloatQuantizer(**kvquant_cfg) self.kvquant_cfg = kvquant_cfg - assert self.kvquant_cfg.granularity in ['per_token', 'per_tensor', 'per_group'] self.static = kvquant_cfg.get('static', False) self._quantized_key_cache = [] self._quantized_value_cache = [] diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 4fc74bda..5a51dc5c 100644 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -634,23 +634,13 @@ def __init__(self, weight, bias, ori_module): self.rotater = ori_module.rotater @torch.no_grad() - def forward(self, x, dtype=None): + def forward(self, x): if hasattr(self, 'buf_rotate') and self.buf_rotate: x = self.rotater.rotate(x) - org_dtype = self.weight.data.dtype - if dtype is not None: - self.convert_dtype(dtype) - x = torch.functional.F.linear(x, self.weight, self.bias) - self.convert_dtype(org_dtype) return x - def convert_dtype(self, dtype): - self.weight.data = self.weight.data.to(dtype) - if self.bias is not None: - self.bias.data = self.bias.data.to(dtype) - @classmethod @torch.no_grad() def new(cls, module): @@ -829,7 +819,7 @@ def __init__(self, weight, bias, ori_module, w_qdq, a_qdq): self.dynamic_quant_weight = False self.dynamic_quant_tmp_weight = False - def forward(self, x, dtype=None): + def forward(self, x): if hasattr(self, 'buf_rotate') and self.buf_rotate: x = self.rotater.rotate(x) @@ -848,20 +838,9 @@ def forward(self, x, dtype=None): elif self.dynamic_quant_tmp_weight: self.tmp_weight = self.w_qdq(self) - org_dtype = self.tmp_weight.data.dtype - if dtype is not None: - self.convert_dtype(dtype) - x = torch.functional.F.linear(x, self.tmp_weight, self.tmp_bias) - - self.convert_dtype(org_dtype) return x - def convert_dtype(self, dtype): - self.tmp_weight.data = self.tmp_weight.data.to(dtype) - if self.tmp_bias is not None: - self.tmp_bias.data = self.tmp_bias.data.to(dtype) - @classmethod @torch.no_grad() def new(cls, module, w_qdq, a_qdq): @@ -924,19 +903,9 @@ def forward(self, x, dtype=None): if self.a_qdq is not None: x = self.a_qdq(x, self) - org_dtype = self.weight.data.dtype - if dtype is not None: - self.convert_dtype(dtype) - x = torch.functional.F.linear(x, self.weight, self.bias) - self.convert_dtype(org_dtype) return x - def convert_dtype(self, dtype): - self.weight.data = self.weight.data.to(dtype) - if self.bias is not None: - self.bias.data = self.bias.data.to(dtype) - @classmethod @torch.no_grad() def new(cls, module, w_qdq, a_qdq, debug_print={}):