Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conversion with rtmdet-inst, vit, conformer #2453

Merged
merged 2 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions mmdeploy/mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,19 +610,19 @@ def multiclass_nms__torchscript(boxes: Tensor,

Use batched_nms from torchvision instead of custom nms.
"""
assert not output_index, 'output_index is not supported on this backend.'
# TODO: simplify inference for non-batch model
from torchvision.ops import batched_nms
batch_size = scores.shape[0]
num_boxes = scores.shape[1]
num_classes = scores.shape[2]
box_per_cls = len(boxes.shape) == 4
scores = torch.where(scores > score_threshold, scores, scores.new_zeros(1))

pre_topk_inds = None
# pre-topk
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
pre_topk_inds = topk_inds
batch_inds = torch.arange(batch_size).view(-1, 1).long()
boxes = boxes[batch_inds, topk_inds, ...]
scores = scores[batch_inds, topk_inds, :]
Expand All @@ -646,10 +646,14 @@ def multiclass_nms__torchscript(boxes: Tensor,

keeps = torch.cat(keeps)
scores = scores.permute(0, 2, 1)
dets, labels = _select_nms_index(
scores, boxes, keeps, batch_size, keep_top_k=keep_top_k)

return dets, labels
return _select_nms_index(
scores,
boxes,
keeps,
batch_size,
keep_top_k=keep_top_k,
pre_inds=pre_topk_inds,
output_index=output_index)


class AscendBatchNMSOp(torch.autograd.Function):
Expand Down
26 changes: 26 additions & 0 deletions mmdeploy/pytorch/functions/multi_head_attention_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,29 @@
**kwargs) -> Tuple[Tensor, Tensor]:
"""Rewrite for custom ops."""
return ScaledDotProductAttentionTRT.apply(q, k, v, attn_mask)


@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.scaled_dot_product_attention',
backend=Backend.DEFAULT.value)
def scaled_dot_product_attention__default(query,
key,
value,
attn_mask=None,
dropout_p=0.,
scale=None,
is_causal=False):
"""Rewrite to export to onnx on torch>=2.0.0."""
scale = scale or query.size(-1)**0.5

Check warning on line 69 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L69

Added line #L69 was not covered by tests
if is_causal and attn_mask is not None:
attn_mask = torch.ones(

Check warning on line 71 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L71

Added line #L71 was not covered by tests
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))

Check warning on line 74 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L74

Added line #L74 was not covered by tests

attn_weight = query @ key.transpose(-2, -1) / scale

Check warning on line 76 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L76

Added line #L76 was not covered by tests
if attn_mask is not None:
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
return attn_weight @ value

Check warning on line 81 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L78-L81

Added lines #L78 - L81 were not covered by tests
Loading