Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Nov 28, 2024
1 parent e2ec48c commit 8702ec2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 16 deletions.
11 changes: 6 additions & 5 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,12 @@ def set_quant_config(self):
self.quant_config['weight']['tp'] = self.tp

# set model config
self.hidden_size = self.model.model_config.hidden_size
self.num_heads = self.model.model_config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.intermediate_size = self.model.model_config.intermediate_size
self.num_hidden_layers = self.model.model_config.num_hidden_layers
self.hidden_size = self.model.model_config.get('hidden_size')
self.num_heads = self.model.model_config.get('num_attention_heads')
self.head_dim = self.hidden_size // self.num_heads \
if self.hidden_size and self.num_heads else None
self.intermediate_size = self.model.model_config.get('intermediate_size')
self.num_hidden_layers = self.model.model_config.get('num_hidden_layers')

# select quant module
self.quant_type = self.quant_config.get('quant_type', 'int-quant')
Expand Down
11 changes: 0 additions & 11 deletions llmc/compression/quantization/kvquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,6 @@ def _calibration(self, layer_idx, key_states, value_states):
self.calib_key_cache[layer_idx].clear()
self.calib_value_cache[layer_idx].clear()

# def _reshape_states(self, tensor):
# batch_size, num_heads, seq_len, head_dim = tensor.shape
# if self.kvquant_cfg.granularity == "per_group":
# group_size = self.kvquant_cfg.group_size
# assert head_dim % group_size == 0
# tensor = tensor.reshape(batch_size, -1, group_size)
# return tensor

# def _restore_states(self, tensor, org_shape):
# return tensor.view(org_shape)

def _quantize(self, tensor, layer_idx, is_key):
org_shape = tensor.shape
tensor = self.kvquantizer.reshape_tensor(tensor)
Expand Down

0 comments on commit 8702ec2

Please sign in to comment.