Skip to content

Commit

Permalink
[Misc] Fused MoE Marlin support for GPTQ (vllm-project#8217)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka authored Sep 10, 2024
1 parent c7cb5c3 commit 6cd5e5b
Show file tree
Hide file tree
Showing 19 changed files with 912 additions and 204 deletions.
13 changes: 12 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,18 @@ steps:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt

- label: Weight Loading Multiple GPU Test - Large Models # optional
working_dir: "/vllm-workspace/tests"
num_gpus: 2
gpu: a100
optional: true
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt


##### multi gpus test #####
Expand Down
2 changes: 1 addition & 1 deletion csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe(
moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, max_par, replicate_input, apply_weights);
return c;
}
}
2 changes: 1 addition & 1 deletion csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
bool replicate_input, bool apply_weights);
1 change: 0 additions & 1 deletion csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor");

m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
221 changes: 217 additions & 4 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
Run `pytest tests/kernels/test_moe.py`.
"""
from typing import List

import pytest
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types


def torch_moe(a, w1, w2, score, topk):
Expand All @@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)


@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
Expand All @@ -43,11 +65,11 @@ def test_fused_moe(
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
Expand Down Expand Up @@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))


@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
torch.manual_seed(7)

if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return

quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []

for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)

w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)

w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []

for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)

w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, False)

triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights,
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2


@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
return

quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

w_ref_l = []
qweights_l = []
scales_l = []
g_idx_l = []
sort_indices_l = []

for i in range(w.shape[0]):
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)

w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
3 changes: 3 additions & 0 deletions tests/weight_loading/models-large.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
2 changes: 0 additions & 2 deletions tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON

__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
]

if HAS_TRITON:

from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)

__all__ += [
"fused_marlin_moe",
"single_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",
Expand Down
Loading

0 comments on commit 6cd5e5b

Please sign in to comment.