Skip to content

Commit

Permalink
[Fix] add kernel api to ms_deform_attn_mlucpp
Browse files Browse the repository at this point in the history
  • Loading branch information
DanieeelLiu committed Jul 26, 2024
1 parent f3d33e2 commit 0cca35b
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,18 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value,
(char*)output_ptr);
break;
}
case MS_DEFORM_ATTN_FORWARD_FAST: {
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardFast<<<"
<< k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForwardFast(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
break;
}

}

output = output.view({batch_size, num_queries, num_heads * channels});
Expand Down

0 comments on commit 0cca35b

Please sign in to comment.