Skip to content

Commit

Permalink
Use SparseMoELayer
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Sep 19, 2024
1 parent d1ac3d0 commit dc2c25b
Showing 1 changed file with 13 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import nn
from transformers.activations import ACT2FN

from text_generation_server.layers.moe import SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
Expand Down Expand Up @@ -251,95 +252,32 @@ def forward(
)


def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")

assert mat in ["w1", "w2", "w3"]

world_size = weights.process_group.size()
rank = weights.process_group.rank()

assert (
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"

block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size

tensor = torch.empty(
(config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device,
)

for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")

if mat == "w2":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor


class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok

act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]

# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)

# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
self.w13 = torch.cat([w1, w3], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts", "w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
self.moe = SparseMoELayer(
prefix=f"{prefix}.experts",
n_experts=config.num_local_experts,
n_expert_group=None,
renormalize=True,
topk=config.num_experts_per_tok,
topk_group=None,
weights=weights,
gate_proj_name="w1",
up_proj_name="w3",
down_proj_name="w2",
)

self.process_group = weights.process_group

def forward(self, x, adapter_data) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = fused_moe(
x,
self.w13,
self.w2,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)
out = self.moe(x, gating_output=router_logits)

# Reduce sum
if self.process_group.size() > 1:
Expand Down

0 comments on commit dc2c25b

Please sign in to comment.