From fb10f0952227765b2e39117f75645985ef4ef52a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 19 Nov 2024 12:08:09 +0100 Subject: [PATCH] Add support for Phi (3.5) MoE --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 1 + awq/models/phimoe.py | 175 ++++++++++++++++++++++++++++++++++++++ awq/quantize/quantizer.py | 2 +- 5 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 awq/models/phimoe.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 79ca150e..5c8b21cf 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -19,6 +19,7 @@ from .stablelm import StableLmAWQForCausalLM from .starcoder2 import Starcoder2AWQForCausalLM from .llava_next import LlavaNextAWQForCausalLM +from .phimoe import PhiMoEAWQForCausalLM from .phi3 import Phi3AWQForCausalLM from .phi3_v import Phi3VAWQForCausalLM from .cohere import CohereAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index 5f6378f7..ef4d4287 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -37,6 +37,7 @@ "internlm2": InternLM2AWQForCausalLM, "minicpm3": MiniCPM3AWQForCausalLM, "qwen2_vl": Qwen2VLAWQForCausalLM, + "phimoe": PhiMoEAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index 3a525f82..cf98a1c5 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -78,6 +78,7 @@ "llava_next": "AutoModelForVision2Seq", "phi3": "AutoModelForCausalLM", "phi3_v": "AutoModelForCausalLM", + "phimoe": "AutoModelForCausalLM", "cohere": "AutoModelForCausalLM", "deepseek_v2": "AutoModelForCausalLM", "minicpm": "AutoModelForCausalLM", diff --git a/awq/models/phimoe.py b/awq/models/phimoe.py new file mode 100644 index 00000000..9e712001 --- /dev/null +++ b/awq/models/phimoe.py @@ -0,0 +1,175 @@ +import tqdm +import torch +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.modules.fused.block import MixtralBlock +from awq.modules.fused.model import MixtralModel +from awq.modules.fused.moe import FusedSparseMoeBlock +from awq.utils.fused_utils import fuse_qkv, fuse_linears +from awq.modules.linear import WQLinear_GEMM +from awq.modules.fused.norm import FasterTransformerRMSNorm + + +class PhiMoEAWQForCausalLM(BaseAWQForCausalLM): + layer_type = "PhiMoEDecoderLayer" + max_seq_len_key = "max_position_embeddings" + modules_to_not_convert = ["gate"] + + @staticmethod + def get_model_layers(model): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module, input_feat, module_kwargs): + layers = [] + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + # linear in + layers.append( + dict( + prev_op=module.post_attention_layernorm, + layers=[ + w + for expert in module.block_sparse_moe.experts + for w in [expert.w1, expert.w3] + ], + inp=input_feat["block_sparse_moe"], + module2inspect=module.block_sparse_moe, + ) + ) + + # linear out + for i, expert in enumerate(module.block_sparse_moe.experts): + layers.append( + dict( + prev_op=expert.w3, + layers=[expert.w2], + inp=input_feat[f"block_sparse_moe.experts.{i}.w2"], + ) + ) + + return layers + + +class MixtralFuser: + def __init__(self, model): + self.model = model + + self.mixtral_blocks: List[Tuple[str, object]] = [ + (name, module) + for name, module in self.model.named_modules() + if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower() + ] + + def fuse_transformer(self): + blocks = [] + + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ) + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, module.input_layernorm.variance_epsilon + ) + + norm_2 = FasterTransformerRMSNorm( + module.post_attention_layernorm.weight, + module.post_attention_layernorm.variance_epsilon, + ) + + sparse_moe = module.block_sparse_moe + if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM): + fused_w1w3s = [ + fuse_linears( + [ + sparse_moe.experts[i].w1, + sparse_moe.experts[i].w3, + ], + device, + ) + for i in range(len(sparse_moe.experts)) + ] + + stacked_w1w3s = fuse_linears( + fused_w1w3s, device, dim=0, operation=torch.stack + ) + + stacked_w2s = fuse_linears( + [expert.w2 for expert in sparse_moe.experts], + device, + dim=0, + operation=torch.stack, + ) + + sparse_moe = FusedSparseMoeBlock( + top_k=sparse_moe.top_k, + gate=sparse_moe.gate, + ws=stacked_w1w3s, + w2s=stacked_w2s, + ) + + blocks.append( + MixtralBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + moe=sparse_moe, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_seq_len, + rope_theta=self.model.config.rope_theta, + ) + ) + + model_norm = FasterTransformerRMSNorm( + self.model.model.norm.weight, + self.model.model.norm.variance_epsilon, + ) + + self.model.model = MixtralModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + model_norm, + ) + setattr(self.model.model, "blocks", self.model.model.blocks) diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index cd9fb0dd..3be3e603 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -604,7 +604,7 @@ def cache_input_hook(m, x, y, name, feat_dict): handles = [] # FIXME: Workaround for Mixtral to use block_sparse_moe input features - if self.awq_model.model_type == "mixtral": + if self.awq_model.model_type in ["mixtral", "phimoe"]: named_linears = { **named_linears, "block_sparse_moe": layer.block_sparse_moe,