Skip to content

Commit

Permalink
Merge pull request #31 from ModelTC/qwen2_quarot
Browse files Browse the repository at this point in the history
Support QuaRot for Qwen2
  • Loading branch information
llmc-reviewer authored Aug 19, 2024
2 parents e8a9a7a + 3faa696 commit a449682
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 5 deletions.
5 changes: 2 additions & 3 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
from transformers.models.mixtral.modeling_mixtral import MixtralRMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

# from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm

try:
import fast_hadamard_transform

Expand Down Expand Up @@ -615,7 +614,7 @@ def __repr__(self):
_TRANSFORMERS_LN_TYPES_ = ALL_LAYERNORM_LAYERS + [
MistralRMSNorm,
MixtralRMSNorm,
# Qwen2RMSNorm,
Qwen2RMSNorm,
LlamaRMSNorm,
nn.LayerNorm,
]
Expand Down
10 changes: 9 additions & 1 deletion llmc/compression/quantization/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,15 @@ def get_minmax_range(self, tensor):
return (min_val, max_val)

def get_mse_range(self, tensor, grid=100, norm=2.4, maxshrink=0.8, bs=256):
assert tensor.shape[0] % bs == 0
# assert tensor.shape[0] % bs == 0
if tensor.shape[0] % bs != 0:
logger.warning(
'Batch size is not a multiple of the tensor size,'
'set batch size to {}'.format(
tensor.shape[0]
)
)
bs = tensor.shape[0]
tensor = tensor.float()
min_val, max_val = self.get_minmax_range(tensor)

Expand Down
11 changes: 10 additions & 1 deletion llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@ def __init__(self, model, quant_config, input, config):
self.preprocess()

def preprocess(self):
assert self.config['model']['type'] in ['Opt', 'Llama']
assert self.config['model']['type'] in ['Opt', 'Llama', 'Qwen2']
# if self.config["model"]["type"] in ["Opt"]:
if torch.equal(
self.model.get_head_layers()[0].weight,
self.model.get_embed_layers()[0].weight,
):
logger.info('Tie weight! Skip rotating head layer!')
del self.model.get_head_layers()[0].weight
w = self.model.get_embed_layers()[0].weight.clone()
self.model.get_head_layers()[0].weight = nn.Parameter(w)

self.remove_mean_from_embed()

self.Q = self.get_orthogonal_matrix()
Expand Down
8 changes: 8 additions & 0 deletions llmc/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def find_block_name(self):
def get_embed_layers(self):
return [self.embed_tokens]

def get_head_layers(self):
return [self.model.lm_head]

def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

Expand Down Expand Up @@ -62,12 +68,14 @@ def get_subsets_in_block(self, block):
'input': ['mlp.gate_proj'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.down_proj': block.mlp.down_proj},
'prev_op': [block.mlp.up_proj],
'input': ['mlp.down_proj'],
'inspect': block.mlp.down_proj,
'has_kwargs': False,
'is_mlp': True,
},
]

0 comments on commit a449682

Please sign in to comment.