diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2fc373a855010..b750fc713b43f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -346,11 +346,15 @@ def fused_topk( topk, dtype=torch.float32, device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) token_expert_indicies = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - moe_kernels.topk_softmax( + ops.topk_softmax( topk_weights, topk_ids, token_expert_indicies,