diff --git a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp index e54e8669f2..ca1527a045 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp @@ -324,6 +324,18 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value, (char*)output_ptr); break; } + case MS_DEFORM_ATTN_FORWARD_FAST: { + CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardFast<<<" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelMsDeformAttnForwardFast( + k_dim, k_type, queue, data_type, (char*)value_ptr, + (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, + (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, + (char*)output_ptr); + break; + } + } output = output.view({batch_size, num_queries, num_heads * channels});