Skip to content

Commit

Permalink
Add support for dense MoE
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Sep 25, 2024
1 parent 2727221 commit b5fa8bd
Showing 1 changed file with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
# limitations under the License.

from contextlib import contextmanager
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Type

import torch
import torch.distributed

from torch import nn
from transformers.activations import ACT2FN

from text_generation_server.layers.moe import SparseMoELayer
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
Expand Down Expand Up @@ -252,14 +252,16 @@ def forward(
)


class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config, weights):
class Phi3MoE(nn.Module):
def __init__(
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
):
super().__init__()

# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)

self.moe = SparseMoELayer(
self.moe = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.num_local_experts,
n_expert_group=None,
Expand Down Expand Up @@ -395,7 +397,14 @@ def __init__(self, index, prefix, config, weights):
)

if config.model_type == "phimoe":
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.dense = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
# with moe the layernorms are are not rmsnorms and they have bias
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
Expand Down

0 comments on commit b5fa8bd

Please sign in to comment.