From f629593d2c320f6f10bb18b2283806089a88121c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 4 Jul 2024 12:04:45 +0000 Subject: [PATCH] format --- tests/kernels/test_moe.py | 2 ++ vllm/model_executor/layers/fused_moe/__init__.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index d3ce70eda5725..7a3f2558f4699 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -38,6 +38,7 @@ def torch_moe(a, w1, w2, score, topk): return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + def torch_moe_single(a, w, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -51,6 +52,7 @@ def torch_moe_single(a, w, score, topk): out[mask] = a[mask] @ w[i].transpose(0, 1) return (out.view(B, -1, w.shape[1])).sum(dim=1) + @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 6c8af100884f9..6e3fde339bca1 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,6 +1,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_marlin_moe, fused_moe, fused_topk, - get_config_file_name, single_marlin_moe, grouped_topk) + get_config_file_name, grouped_topk, single_marlin_moe) __all__ = [ "fused_moe", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ce71713986ba5..cf316c0f9afa4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -367,6 +367,7 @@ def fused_topk( topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + # This is used by the Deepseek-V2 model def grouped_topk( hidden_states: torch.Tensor, @@ -412,6 +413,7 @@ def get_expert_config(w1: torch.Tensor, w2: torch.Tensor, topk: int, M: int, return get_default_config(M, E, N, w1.shape[2], topk, "float8" if use_fp8 else None) + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,