diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5b7aca531f5..97e2705d448 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,6 +27,7 @@ from torch import nn from transformers.activations import ACT2FN +from text_generation_server.layers.moe import SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, @@ -251,79 +252,24 @@ def forward( ) -def _load_experts(config, prefix: str, mat, weights): - if config.quantize is not None: - raise NotImplementedError("Mixtral does not support weight quantization yet.") - - assert mat in ["w1", "w2", "w3"] - - world_size = weights.process_group.size() - rank = weights.process_group.rank() - - assert ( - config.intermediate_size % world_size == 0 - ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" - - block_size = config.intermediate_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - - tensor = torch.empty( - (config.num_local_experts * block_size, config.hidden_size), - dtype=weights.dtype, - device=weights.device, - ) - - for i in range(config.num_local_experts): - slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") - - if mat == "w2": - expert_slice = slice_[:, start:stop].t().contiguous() - else: - expert_slice = slice_[start:stop] - tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( - dtype=weights.dtype - ).to(device=weights.device) - return tensor - - class BlockSparseMoE(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size // weights.process_group.size() - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - - act = config.hidden_act - if "gelu" in act: - self.act = lambda x: torch.nn.functional.gelu( - x, - approximate=( - "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" - ), - ) - elif "silu" in act: - self.act = torch.nn.functional.silu - else: - self.act = ACT2FN[act] # gating self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) - # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view( - self.num_experts, self.ffn_dim, self.hidden_dim - ) - w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view( - self.num_experts, self.ffn_dim, self.hidden_dim - ) - self.w13 = torch.cat([w1, w3], dim=1) - self.w2 = ( - _load_experts(config, f"{prefix}.experts", "w2", weights) - .view(self.num_experts, self.ffn_dim, self.hidden_dim) - .transpose(1, 2) - .contiguous() + self.moe = SparseMoELayer( + prefix=f"{prefix}.experts", + n_experts=config.num_local_experts, + n_expert_group=None, + renormalize=True, + topk=config.num_experts_per_tok, + topk_group=None, + weights=weights, + gate_proj_name="w1", + up_proj_name="w3", + down_proj_name="w2", ) self.process_group = weights.process_group @@ -331,15 +277,7 @@ def __init__(self, prefix, config, weights): def forward(self, x, adapter_data) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - out = fused_moe( - x, - self.w13, - self.w2, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - ) + out = self.moe(x, gating_output=router_logits) # Reduce sum if self.process_group.size() > 1: