Skip to content

Commit

Permalink
Refactoring for maintainability
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvaBansal00 committed Aug 7, 2024
1 parent b0c4671 commit e5c1a81
Show file tree
Hide file tree
Showing 6 changed files with 665 additions and 575 deletions.
18 changes: 10 additions & 8 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_marlin_moe,
single_marlin_moe)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq
from vllm.model_executor.layers.fused_moe.fused_moe import single_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.triton_utils import HAS_TRITON

__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"fused_marlin_moe",
"fused_moe_gptq",
"single_marlin_moe",
]

if HAS_TRITON:

from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
fused_experts,
fused_moe,
fused_topk,
get_config_file_name,
grouped_topk,
)

__all__ += [
"fused_moe",
Expand Down
102 changes: 1 addition & 101 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,104 +704,4 @@ def single_marlin_moe(
g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m,
True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)


def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)

get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)

block_size_m = config['BLOCK_SIZE_M']

sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)

max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)

intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)

intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
block_size_m, True, False)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))

intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
block_size_m, False, True)

return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
138 changes: 138 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Fused MoE utilities for GPTQ."""
import functools
import torch

from typing import Any, Dict, Optional
from vllm import _custom_ops as ops

from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config


def fused_moe_gptq(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize)

get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
is_marlin=True,
)
config = get_config_func(M)

block_size_m = config["BLOCK_SIZE_M"]

sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)

max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(
max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False
)

intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states,
w1,
sorted_token_ids,
topk_weights,
topk_ids,
w1_scale,
g_idx1,
rand_perm1,
workspace,
M,
2 * N,
K,
True,
E,
topk,
block_size_m,
True,
False,
)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))

intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2,
w2,
sorted_token_ids,
topk_weights,
topk_ids,
w2_scale,
g_idx2,
rand_perm2,
workspace,
M,
K,
N,
True,
E,
topk,
block_size_m,
False,
True,
)

return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
Loading

0 comments on commit e5c1a81

Please sign in to comment.