diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 83af2cd94b2e6..4d3568e09e40f 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -29,19 +29,6 @@ TOP_KS = [2, 6] -def permute_weight(x: torch.Tensor) -> torch.Tensor: - ## Hardcode BLOCK_K and BLOCK_N - BK = 128 - BN = 128 - x_ = x.clone() - x_ = x_.view(x.shape[0], x.shape[1] // BN, BN // 16, 16, x.shape[2] // BK, - BK // 32, 4, 8) - x_ = x_.permute(0, 1, 5, 2, 6, 4, 3, 7) - x_ = x_.contiguous() - x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]) - return x_ - - @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -65,9 +52,9 @@ def test_fused_moe( # Pad the input if use padding if envs.VLLM_MOE_PADDING: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., :-128] + w1 = F.pad(w1, (0, 128), "constant", 0) torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., :-128] + w2 = F.pad(w2, (0, 128), "constant", 0) torch.cuda.empty_cache() triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) @@ -79,74 +66,6 @@ def test_fused_moe( rtol=0) -@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) -@pytest.mark.parametrize("n", [14336]) -@pytest.mark.parametrize("k", [4096]) -@pytest.mark.parametrize("e", [8]) -@pytest.mark.parametrize("topk", [2]) -@pytest.mark.parametrize("dtype", [torch.float16]) -def test_amd_moe_1( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, -): - if n == k: - pytest.skip() - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 - if envs.VLLM_MOE_SHUFFLE: - w1_shuffled = permute_weight(w1.data) - w2_shuffled = permute_weight(w2.data) - - score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, - w1_shuffled, - w2_shuffled, - score, - topk, - renormalize=False) - torch_output = torch_moe(a, w1, w2, score, topk) - assert torch.allclose(triton_output, torch_output, atol=2e-2, rtol=0) - - -@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) -@pytest.mark.parametrize("n", [4096]) -@pytest.mark.parametrize("k", [14336]) -@pytest.mark.parametrize("e", [8]) -@pytest.mark.parametrize("topk", [2]) -@pytest.mark.parametrize("dtype", [torch.float16]) -def test_amd_moe_2( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, -): - if n == k: - pytest.skip() - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 - if envs.VLLM_MOE_SHUFFLE: - w1_shuffled = permute_weight(w1.data) - w2_shuffled = permute_weight(w2.data) - - score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, - w1_shuffled, - w2_shuffled, - score, - topk, - renormalize=False) - torch_output = torch_moe(a, w1, w2, score, topk) - assert torch.allclose(triton_output, torch_output, atol=2e-1, rtol=0) - - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @torch.inference_mode() @@ -181,13 +100,13 @@ def test_mixtral_moe(dtype: torch.dtype): # pad the weight if using padding if envs.VLLM_MOE_PADDING: - vllm_moe.experts.w13_weight = Parameter( - F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0), - requires_grad=False)[..., :-128] + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0), + requires_grad=False) torch.cuda.empty_cache() vllm_moe.experts.w2_weight = Parameter(F.pad( vllm_moe.experts.w2_weight, (0, 128), "constant", 0), - requires_grad=False)[..., :-128] + requires_grad=False) torch.cuda.empty_cache() # Run forward passes for both MoE blocks diff --git a/vllm/envs.py b/vllm/envs.py index 28935a1fb7b76..e53e7108f953c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -82,7 +82,6 @@ VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 VLLM_MOE_PADDING: bool = True VLLM_FP8_PADDING: bool = True - VLLM_MOE_SHUFFLE: bool = False FUSED_MOE_PERSISTENT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1