Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring for maintainability #4

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
e5c1a81
Refactoring for maintainability
DhruvaBansal00 Aug 7, 2024
7da678e
Fixing tests
DhruvaBansal00 Aug 7, 2024
641696b
Addressing repacking comment
DhruvaBansal00 Aug 8, 2024
3cef667
gptq -> marlin renaming
DhruvaBansal00 Aug 8, 2024
a6710af
Undo formatting changes
DhruvaBansal00 Aug 8, 2024
e29107f
Final formatting change
DhruvaBansal00 Aug 8, 2024
099d61e
Switching to mixtral file for quantized mixtral
DhruvaBansal00 Aug 12, 2024
bdf6bdc
Bug fixes
DhruvaBansal00 Aug 12, 2024
19c5c59
is quantized change
DhruvaBansal00 Aug 12, 2024
3b7cc60
debug stat
DhruvaBansal00 Aug 12, 2024
d2c4754
replace wiehgt name with param name
DhruvaBansal00 Aug 12, 2024
f579cb2
typo
DhruvaBansal00 Aug 12, 2024
79394eb
debug
DhruvaBansal00 Aug 12, 2024
ec75f4e
more debug
DhruvaBansal00 Aug 12, 2024
91ca970
only relevant logging
DhruvaBansal00 Aug 12, 2024
1b9d5bb
log
DhruvaBansal00 Aug 12, 2024
ec06719
log
DhruvaBansal00 Aug 12, 2024
71d82e1
removing qzero weights
DhruvaBansal00 Aug 12, 2024
d3465d0
Qzeors in expert mapping
DhruvaBansal00 Aug 12, 2024
226ee26
Debug
DhruvaBansal00 Aug 12, 2024
21d7d27
Load qzero
DhruvaBansal00 Aug 12, 2024
2dabb4b
rm 2x
DhruvaBansal00 Aug 12, 2024
6366976
Mapping for scales
DhruvaBansal00 Aug 12, 2024
d63c096
rm logging
DhruvaBansal00 Aug 12, 2024
360fef4
Adding lyaer wise logging
DhruvaBansal00 Aug 12, 2024
c23d616
shard ids
DhruvaBansal00 Aug 12, 2024
8d81d14
Loading qzero correctly
DhruvaBansal00 Aug 12, 2024
22e1aa7
List operand
DhruvaBansal00 Aug 12, 2024
81e01f3
If clause
DhruvaBansal00 Aug 12, 2024
dcfd32d
Able to load layers
DhruvaBansal00 Aug 12, 2024
f04cbea
Setting load quant to false
DhruvaBansal00 Aug 12, 2024
a56821d
Disabling logging
DhruvaBansal00 Aug 12, 2024
7f961c6
Removing *2 in marlin moe repack
DhruvaBansal00 Aug 13, 2024
4a6c7ff
*4 in marlin moe repack
DhruvaBansal00 Aug 13, 2024
e6cd286
bits
DhruvaBansal00 Aug 13, 2024
90241c4
*4
DhruvaBansal00 Aug 13, 2024
67409e9
intermediate size
DhruvaBansal00 Aug 13, 2024
539032e
repeat keyword
DhruvaBansal00 Aug 13, 2024
57b1cbe
hidden size
DhruvaBansal00 Aug 13, 2024
87f1dd4
intermediate size back
DhruvaBansal00 Aug 13, 2024
4c073c2
permute scales w3
DhruvaBansal00 Aug 13, 2024
d732493
*2
DhruvaBansal00 Aug 13, 2024
fdc22c4
log
DhruvaBansal00 Aug 13, 2024
272822e
shape as 2
DhruvaBansal00 Aug 13, 2024
3ce045e
test
DhruvaBansal00 Aug 13, 2024
c4ba477
Increasing to 4 and changing assert
DhruvaBansal00 Aug 13, 2024
2ea8370
logging
DhruvaBansal00 Aug 13, 2024
8287025
marlin moe repack change
DhruvaBansal00 Aug 13, 2024
53b23b9
mult qweight shape by pack factor
DhruvaBansal00 Aug 13, 2024
bc40786
Potential support for 8 bit
DhruvaBansal00 Aug 13, 2024
bea13de
undo change
DhruvaBansal00 Aug 13, 2024
a3a9114
qzeros
DhruvaBansal00 Aug 13, 2024
eb916f9
switching traffic to mixtral quant
DhruvaBansal00 Aug 13, 2024
017d6f8
compat
DhruvaBansal00 Aug 13, 2024
eb9c087
Passing intermediate tensor into mixtral in quant file
DhruvaBansal00 Aug 13, 2024
ea3cf18
Removing intemediate tensors from forward
DhruvaBansal00 Aug 13, 2024
4f6b4ca
load weights from quant
DhruvaBansal00 Aug 13, 2024
7ec27d9
Mixtral load weights change:
DhruvaBansal00 Aug 13, 2024
aa1fe77
none shard id change
DhruvaBansal00 Aug 13, 2024
ae8fb15
Use class from mixtral_quant
DhruvaBansal00 Aug 15, 2024
b863981
Removing lora from mixtral model init
DhruvaBansal00 Aug 15, 2024
5556d28
Adding empty intermediate tensors
DhruvaBansal00 Aug 15, 2024
c484a37
Building quantMixtralModel
DhruvaBansal00 Aug 15, 2024
0344e72
fused moe test
DhruvaBansal00 Aug 15, 2024
8c8b3fa
Lora enabled mixtral
DhruvaBansal00 Aug 15, 2024
dff59cd
LoRAMixtralModel compat
DhruvaBansal00 Aug 15, 2024
33f7e51
remove prefix
DhruvaBansal00 Aug 15, 2024
fdba917
use fused moe
DhruvaBansal00 Aug 15, 2024
780471e
remove org num embeddings
DhruvaBansal00 Aug 15, 2024
c0970f1
pass use fused moe into decoder
DhruvaBansal00 Aug 15, 2024
6a1a838
Mixtral for causal lm load func
DhruvaBansal00 Aug 15, 2024
5c3e857
Copying over quant mixtral
DhruvaBansal00 Aug 15, 2024
8d327de
Passing prefix
DhruvaBansal00 Aug 15, 2024
d337aea
Weight load
DhruvaBansal00 Aug 15, 2024
379f3e8
Weight load back
DhruvaBansal00 Aug 15, 2024
a5d356e
Load with name not weight name
DhruvaBansal00 Aug 15, 2024
62c0135
params dict should load from old name
DhruvaBansal00 Aug 15, 2024
d23c00c
logging name and parmas
DhruvaBansal00 Aug 15, 2024
6dda447
log expert parmas map
DhruvaBansal00 Aug 15, 2024
67ce7b6
parity with prev commits
DhruvaBansal00 Aug 15, 2024
bd933c9
Adding qzeros to mapping
DhruvaBansal00 Aug 15, 2024
77cd095
Remove log
DhruvaBansal00 Aug 15, 2024
529191e
Remove is quantized
DhruvaBansal00 Aug 15, 2024
2450543
Assume fused true
DhruvaBansal00 Aug 15, 2024
8cba45e
rm fused true
DhruvaBansal00 Aug 15, 2024
10940a5
Switching to mixtral moe
DhruvaBansal00 Aug 15, 2024
895ffbe
Precision changes
DhruvaBansal00 Aug 15, 2024
e54b2e4
Cleanup
DhruvaBansal00 Aug 15, 2024
b4f23dc
Mixtral quant parity:
DhruvaBansal00 Aug 15, 2024
d59fe3b
fixing tests
DhruvaBansal00 Aug 15, 2024
0d9cbdc
Tests working and correctness verified
DhruvaBansal00 Aug 15, 2024
112aa40
Formating
DhruvaBansal00 Aug 15, 2024
1ca9098
Moving single marlin alongside fused marlin
DhruvaBansal00 Aug 19, 2024
4d41425
Removing unused imports
DhruvaBansal00 Aug 19, 2024
4907f43
single marlin moe import
DhruvaBansal00 Aug 19, 2024
8225037
Merge branch 'marlin-moe-integration' into gptq-marlin-refactor
ElizaWszola Aug 20, 2024
315e3b6
Unify shard_id to be of str w[1-3] format
ElizaWszola Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 46 additions & 40 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import (fused_marlin_moe, fused_moe,
single_marlin_moe)
from vllm.model_executor.layers.fused_moe import fused_moe, single_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
marlin_quantize, )
from vllm.model_executor.models.mixtral import MixtralMoE


Expand Down Expand Up @@ -62,11 +62,11 @@ def test_fused_moe(
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
Expand Down Expand Up @@ -114,10 +114,12 @@ def test_mixtral_moe(dtype: torch.dtype):
torch.bfloat16: 1e-2,
}

assert torch.allclose(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
assert torch.allclose(
hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype],
)


def stack_and_dev(tensors: List[torch.Tensor]):
Expand Down Expand Up @@ -165,11 +167,11 @@ def test_fused_marlin_moe(

num_bits = 4
dtype = torch.float16
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device='cuda', dtype=dtype)
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
Expand Down Expand Up @@ -215,27 +217,31 @@ def test_fused_marlin_moe(
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)

score = torch.randn((m, e), device='cuda', dtype=dtype)
triton_output = fused_moe(a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False)
marlin_output = fused_marlin_moe(a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk,
renormalize=False,
w1_scale=scales1,
w2_scale=scales2)

assert (compute_max_diff(marlin_output, triton_output) < 4e-2)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_moe_gptq(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk,
renormalize=False,
w1_scale=scales1,
w2_scale=scales2,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2


# TODO: make sure this test works
Expand Down Expand Up @@ -272,8 +278,8 @@ def test_single_marlin_moe(

num_bits = 4
dtype = torch.float16
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w = torch.randn((e, n, k), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

w_ref_l = []
qweights_l = []
Expand All @@ -297,7 +303,7 @@ def test_single_marlin_moe(
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
Expand All @@ -308,4 +314,4 @@ def test_single_marlin_moe(
renormalize=False)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert (compute_max_diff(marlin_output, torch_output) < 1e-2)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
18 changes: 10 additions & 8 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_marlin_moe,
single_marlin_moe)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq
from vllm.model_executor.layers.fused_moe.fused_moe import single_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.triton_utils import HAS_TRITON

__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"fused_marlin_moe",
"fused_moe_gptq",
"single_marlin_moe",
]

if HAS_TRITON:

from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
fused_experts,
fused_moe,
fused_topk,
get_config_file_name,
grouped_topk,
)

__all__ += [
"fused_moe",
Expand Down
102 changes: 1 addition & 101 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,104 +704,4 @@ def single_marlin_moe(
g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m,
True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)


def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)

get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)

block_size_m = config['BLOCK_SIZE_M']

sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)

max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)

intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)

intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
block_size_m, True, False)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))

intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
block_size_m, False, True)

return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
Loading