Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marlin MoE integration #2

Closed
wants to merge 111 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
5a2ab25
Moving branch to a different repo
ElizaWszola Aug 2, 2024
b39dba4
clean up the CPU code
ElizaWszola Aug 2, 2024
b0c4671
Fix build issues
ElizaWszola Aug 2, 2024
e5c1a81
Refactoring for maintainability
DhruvaBansal00 Aug 7, 2024
7da678e
Fixing tests
DhruvaBansal00 Aug 7, 2024
641696b
Addressing repacking comment
DhruvaBansal00 Aug 8, 2024
3cef667
gptq -> marlin renaming
DhruvaBansal00 Aug 8, 2024
a6710af
Undo formatting changes
DhruvaBansal00 Aug 8, 2024
e29107f
Final formatting change
DhruvaBansal00 Aug 8, 2024
099d61e
Switching to mixtral file for quantized mixtral
DhruvaBansal00 Aug 12, 2024
bdf6bdc
Bug fixes
DhruvaBansal00 Aug 12, 2024
19c5c59
is quantized change
DhruvaBansal00 Aug 12, 2024
3b7cc60
debug stat
DhruvaBansal00 Aug 12, 2024
d2c4754
replace wiehgt name with param name
DhruvaBansal00 Aug 12, 2024
f579cb2
typo
DhruvaBansal00 Aug 12, 2024
79394eb
debug
DhruvaBansal00 Aug 12, 2024
ec75f4e
more debug
DhruvaBansal00 Aug 12, 2024
91ca970
only relevant logging
DhruvaBansal00 Aug 12, 2024
1b9d5bb
log
DhruvaBansal00 Aug 12, 2024
ec06719
log
DhruvaBansal00 Aug 12, 2024
71d82e1
removing qzero weights
DhruvaBansal00 Aug 12, 2024
d3465d0
Qzeors in expert mapping
DhruvaBansal00 Aug 12, 2024
226ee26
Debug
DhruvaBansal00 Aug 12, 2024
21d7d27
Load qzero
DhruvaBansal00 Aug 12, 2024
2dabb4b
rm 2x
DhruvaBansal00 Aug 12, 2024
6366976
Mapping for scales
DhruvaBansal00 Aug 12, 2024
d63c096
rm logging
DhruvaBansal00 Aug 12, 2024
360fef4
Adding lyaer wise logging
DhruvaBansal00 Aug 12, 2024
c23d616
shard ids
DhruvaBansal00 Aug 12, 2024
8d81d14
Loading qzero correctly
DhruvaBansal00 Aug 12, 2024
22e1aa7
List operand
DhruvaBansal00 Aug 12, 2024
81e01f3
If clause
DhruvaBansal00 Aug 12, 2024
dcfd32d
Able to load layers
DhruvaBansal00 Aug 12, 2024
f04cbea
Setting load quant to false
DhruvaBansal00 Aug 12, 2024
a56821d
Disabling logging
DhruvaBansal00 Aug 12, 2024
7f961c6
Removing *2 in marlin moe repack
DhruvaBansal00 Aug 13, 2024
4a6c7ff
*4 in marlin moe repack
DhruvaBansal00 Aug 13, 2024
e6cd286
bits
DhruvaBansal00 Aug 13, 2024
90241c4
*4
DhruvaBansal00 Aug 13, 2024
67409e9
intermediate size
DhruvaBansal00 Aug 13, 2024
539032e
repeat keyword
DhruvaBansal00 Aug 13, 2024
57b1cbe
hidden size
DhruvaBansal00 Aug 13, 2024
87f1dd4
intermediate size back
DhruvaBansal00 Aug 13, 2024
4c073c2
permute scales w3
DhruvaBansal00 Aug 13, 2024
d732493
*2
DhruvaBansal00 Aug 13, 2024
fdc22c4
log
DhruvaBansal00 Aug 13, 2024
272822e
shape as 2
DhruvaBansal00 Aug 13, 2024
3ce045e
test
DhruvaBansal00 Aug 13, 2024
c4ba477
Increasing to 4 and changing assert
DhruvaBansal00 Aug 13, 2024
2ea8370
logging
DhruvaBansal00 Aug 13, 2024
8287025
marlin moe repack change
DhruvaBansal00 Aug 13, 2024
53b23b9
mult qweight shape by pack factor
DhruvaBansal00 Aug 13, 2024
bc40786
Potential support for 8 bit
DhruvaBansal00 Aug 13, 2024
bea13de
undo change
DhruvaBansal00 Aug 13, 2024
a3a9114
qzeros
DhruvaBansal00 Aug 13, 2024
eb916f9
switching traffic to mixtral quant
DhruvaBansal00 Aug 13, 2024
017d6f8
compat
DhruvaBansal00 Aug 13, 2024
eb9c087
Passing intermediate tensor into mixtral in quant file
DhruvaBansal00 Aug 13, 2024
ea3cf18
Removing intemediate tensors from forward
DhruvaBansal00 Aug 13, 2024
4f6b4ca
load weights from quant
DhruvaBansal00 Aug 13, 2024
7ec27d9
Mixtral load weights change:
DhruvaBansal00 Aug 13, 2024
aa1fe77
none shard id change
DhruvaBansal00 Aug 13, 2024
ae8fb15
Use class from mixtral_quant
DhruvaBansal00 Aug 15, 2024
b863981
Removing lora from mixtral model init
DhruvaBansal00 Aug 15, 2024
5556d28
Adding empty intermediate tensors
DhruvaBansal00 Aug 15, 2024
c484a37
Building quantMixtralModel
DhruvaBansal00 Aug 15, 2024
0344e72
fused moe test
DhruvaBansal00 Aug 15, 2024
8c8b3fa
Lora enabled mixtral
DhruvaBansal00 Aug 15, 2024
dff59cd
LoRAMixtralModel compat
DhruvaBansal00 Aug 15, 2024
33f7e51
remove prefix
DhruvaBansal00 Aug 15, 2024
fdba917
use fused moe
DhruvaBansal00 Aug 15, 2024
780471e
remove org num embeddings
DhruvaBansal00 Aug 15, 2024
c0970f1
pass use fused moe into decoder
DhruvaBansal00 Aug 15, 2024
6a1a838
Mixtral for causal lm load func
DhruvaBansal00 Aug 15, 2024
5c3e857
Copying over quant mixtral
DhruvaBansal00 Aug 15, 2024
8d327de
Passing prefix
DhruvaBansal00 Aug 15, 2024
d337aea
Weight load
DhruvaBansal00 Aug 15, 2024
379f3e8
Weight load back
DhruvaBansal00 Aug 15, 2024
a5d356e
Load with name not weight name
DhruvaBansal00 Aug 15, 2024
62c0135
params dict should load from old name
DhruvaBansal00 Aug 15, 2024
d23c00c
logging name and parmas
DhruvaBansal00 Aug 15, 2024
6dda447
log expert parmas map
DhruvaBansal00 Aug 15, 2024
67ce7b6
parity with prev commits
DhruvaBansal00 Aug 15, 2024
bd933c9
Adding qzeros to mapping
DhruvaBansal00 Aug 15, 2024
77cd095
Remove log
DhruvaBansal00 Aug 15, 2024
529191e
Remove is quantized
DhruvaBansal00 Aug 15, 2024
2450543
Assume fused true
DhruvaBansal00 Aug 15, 2024
8cba45e
rm fused true
DhruvaBansal00 Aug 15, 2024
10940a5
Switching to mixtral moe
DhruvaBansal00 Aug 15, 2024
895ffbe
Precision changes
DhruvaBansal00 Aug 15, 2024
e54b2e4
Cleanup
DhruvaBansal00 Aug 15, 2024
b4f23dc
Mixtral quant parity:
DhruvaBansal00 Aug 15, 2024
d59fe3b
fixing tests
DhruvaBansal00 Aug 15, 2024
0d9cbdc
Tests working and correctness verified
DhruvaBansal00 Aug 15, 2024
112aa40
Formating
DhruvaBansal00 Aug 15, 2024
1ca9098
Moving single marlin alongside fused marlin
DhruvaBansal00 Aug 19, 2024
4d41425
Removing unused imports
DhruvaBansal00 Aug 19, 2024
4907f43
single marlin moe import
DhruvaBansal00 Aug 19, 2024
8f4648c
Merge branch 'main' into marlin-moe-integration
ElizaWszola Aug 20, 2024
8225037
Merge branch 'marlin-moe-integration' into gptq-marlin-refactor
ElizaWszola Aug 20, 2024
315e3b6
Unify shard_id to be of str w[1-3] format
ElizaWszola Aug 21, 2024
34bb5b0
Merge pull request #4 from DhruvaBansal00/gptq-marlin-refactor
ElizaWszola Aug 22, 2024
fd4bb21
Merge branch 'main' into marlin-moe-integration
ElizaWszola Aug 22, 2024
7956a69
Unfused codepath for non-supported quant_types
ElizaWszola Aug 26, 2024
2511f78
uint8b128 support
ElizaWszola Aug 28, 2024
f875842
Merge branch 'main' into marlin-moe-integration
ElizaWszola Aug 29, 2024
d8feb8d
Cleanup, compressed tensors compatibility
ElizaWszola Aug 29, 2024
3676621
update todo
ElizaWszola Aug 29, 2024
75e3dd5
Fix merge
ElizaWszola Aug 30, 2024
a5f5a74
bad paste
ElizaWszola Aug 30, 2024
e305306
GPTQFusedMoE layer
ElizaWszola Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 211 additions & 90 deletions csrc/moe/marlin_moe_ops.cu

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

#include <torch/all.h>

#include "core/scalar_type.hpp"

torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
9 changes: 5 additions & 4 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor");

"g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
225 changes: 221 additions & 4 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

Run `pytest tests/kernels/test_moe.py`.
"""
from typing import List

import pytest
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe_marlin import (
fused_moe_marlin, single_moe_marlin)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types


def torch_moe(a, w1, w2, score, topk):
Expand All @@ -29,6 +36,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)


@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
Expand All @@ -43,11 +64,11 @@ def test_fused_moe(
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
Expand Down Expand Up @@ -99,3 +120,199 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))


@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
torch.manual_seed(7)

if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []

for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)

w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)

w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []

for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)

w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_moe_marlin(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk,
renormalize=False,
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2


@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_marlin_moe_mmm(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

w_ref_l = []
qweights_l = []
scales_l = []
g_idx_l = []
sort_indices_l = []

for i in range(w.shape[0]):
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)

w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_moe_marlin(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * 2),
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
Expand Down
19 changes: 13 additions & 6 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from vllm.model_executor.layers.fused_moe.fused_moe_marlin import (
fused_moe_marlin, single_moe_marlin)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, GPTQFusedMoE)
from vllm.triton_utils import HAS_TRITON

__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"GPTQFusedMoE",
"fused_moe_marlin",
"single_moe_marlin",
]

if HAS_TRITON:

from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)

__all__ += [
"fused_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",
Expand Down
Loading
Loading