Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeCamp2023-555 #2469

Merged
merged 12 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class ResizeInstanceMask : public ResizeBBox {
int resize_width = int(mask_width / scale_factor_[1] + 0.5);
// skip resize if scale_factor is 1.0
if (resize_height != mask_height || resize_width != mask_width) {
cv::resize(mask_mat, mask_mat, cv::Size(resize_height, resize_width), cv::INTER_LINEAR);
cv::resize(mask_mat, mask_mat, cv::Size(resize_width, resize_height), cv::INTER_LINEAR);
}
// crop masks
mask_mat = mask_mat(cv::Range(0, img_h), cv::Range(0, img_w)).clone();
Expand Down
5 changes: 5 additions & 0 deletions mmdeploy/codebase/mmdet/deploy/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
type = 'ResizeInstanceMask' # for instance-seg
# resize and crop mask to origin image
params['is_resize_mask'] = True
if 'mask_thr' in params:
type = 'ResizeInstanceMask' # for instance-seg
# resize and crop mask to origin image
params['mask_thr_binary'] = params['mask_thr']
params['is_resize_mask'] = True

if get_backend(self.deploy_cfg) == Backend.RKNN:
if 'YOLO' in self.model_cfg.model.type or \
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def postprocessing_results(self,
masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
if model_type == 'RTMDet':
if model_type in ['RTMDet', 'CondInst']:
export_postprocess_mask = True
else:
export_postprocess_mask = False
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import base_dense_head # noqa: F401,F403
from . import centernet_head # noqa: F401,F403
from . import condinst_head # noqa: F401,F403
from . import detr_head # noqa: F401,F403
from . import fovea_head # noqa: F401,F403
from . import gfl_head # noqa: F401,F403
Expand Down
202 changes: 202 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch
from mmdet.models.utils import aligned_bilinear
from mmengine.config import ConfigDict
from torch import Tensor

from mmdeploy.codebase.mmdet.deploy import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmcv.ops.nms import multiclass_nms


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstBboxHead.predict_by_feat')
def condinst_bbox_head__predict_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
param_preds: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True,
):
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg

assert len(cls_scores) == len(bbox_preds)
device = bbox_preds[0].device
cfg = self.test_cfg if cfg is None else cfg
batch_size = bbox_preds[0].shape[0]
featmap_sizes = [cls_score.shape[-2:] for cls_score in cls_scores]

all_level_points_strides = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
all_level_points = [i[:, :2] for i in all_level_points_strides]
all_level_strides = [i[:, 2] for i in all_level_points_strides]

flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
for bbox_pred in bbox_preds
]
flatten_score_factors = [
score_factor.permute(0, 2, 3, 1).reshape(batch_size, -1, 1)
for score_factor in score_factors
]
flatten_param_preds = [
param_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_params)
for param_pred in param_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_score_factors = torch.cat(flatten_score_factors, dim=1).sigmoid()
flatten_param_preds = torch.cat(flatten_param_preds, dim=1)

points = torch.cat(all_level_points)
strides = torch.cat(all_level_strides)
tl_x = points[..., 0] - flatten_bbox_preds[..., 0]
tl_y = points[..., 1] - flatten_bbox_preds[..., 1]
br_x = points[..., 0] + flatten_bbox_preds[..., 2]
br_y = points[..., 1] + flatten_bbox_preds[..., 3]

bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
scores = flatten_cls_scores
score_factors = flatten_score_factors
param_preds = flatten_param_preds
scores = scores * score_factors

# get post processing config
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)

dets, labels, inds = multiclass_nms(
bboxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k,
output_index=True,
)

batch_inds = torch.arange(batch_size, device=bboxes.device).view(-1, 1)
points = points.unsqueeze(0).repeat(batch_size, 1, 1)
strides = strides.unsqueeze(0).repeat(batch_size, 1)
param_preds = param_preds[batch_inds, inds, :]
points = points[batch_inds, inds, :]
strides = strides[batch_inds, inds]
results = dict(
dets=dets,
labels=labels,
param_preds=param_preds,
points=points,
strides=strides)
return results


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstMaskHead.forward')
def condinst_mask_head__forward(self, x: tuple,
positive_infos: Dict[str, torch.Tensor]):
mask_feats = self.mask_feature_head(x)

param_preds = positive_infos['param_preds']
points = positive_infos['points']
strides = positive_infos['strides']

batch_size = points.shape[0]
num_insts = points.shape[1]
hw = mask_feats.size()[-2:]
mask_feats = mask_feats.unsqueeze(1).repeat(1, num_insts, 1, 1, 1)

points = points.reshape(-1, 1, 2).unsqueeze(0)
locations = self.prior_generator.single_level_grid_priors(
hw, level_idx=0, device=mask_feats.device)
locations = locations.unsqueeze(0).repeat(batch_size, 1,
1).reshape(batch_size, 1, -1, 2)
centers = points.reshape(batch_size, -1, 1, 2)
rel_coordinates = (centers - locations).permute(0, 1, 3, 2).float()
rel_coordinates /= (strides[:, :, None, None] * self.size_of_interest)
rel_coords = rel_coordinates.reshape(batch_size, -1, 2, hw[0], hw[1])
mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=2)

weights, biases = _parse_dynamic_params(self, param_preds)
mask_preds = _dynamic_conv_forward(mask_head_inputs, weights, biases)
mask_preds = mask_preds.reshape(batch_size, num_insts, hw[0], hw[1])
mask_preds = aligned_bilinear(
mask_preds, int(self.mask_feat_stride / self.mask_out_stride))
return (mask_preds, )


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstMaskHead.predict_by_feat')
def condinst_mask_head__predict_by_feat(self,
mask_preds: Tensor,
results_list: Dict[str, torch.Tensor],
batch_img_metas: List[dict],
rescale: bool = True,
**kwargs):
cfg = self.test_cfg

dets = results_list['dets']
labels = results_list['labels']
img_hw = batch_img_metas[0]['img_shape'][:2]

mask_preds = mask_preds.sigmoid()
mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride)
mask_preds = mask_preds[:, :, :img_hw[0], :img_hw[1]]
masks = (mask_preds > cfg.mask_thr).float()

return dets, labels, masks


def _parse_dynamic_params(self, params: Tensor):
"""parse the dynamic params for dynamic conv."""
batch_size = params.shape[0]
num_insts = params.shape[1]
params = params.permute(1, 0, 2)
params_splits = list(
torch.split_with_sizes(
params, self.weight_nums + self.bias_nums, dim=2))

weight_splits = params_splits[:self.num_layers]
bias_splits = params_splits[self.num_layers:]

for idx in range(self.num_layers):
if idx < self.num_layers - 1:
weight_splits[idx] = weight_splits[idx].reshape(
batch_size, num_insts, self.in_channels, -1)
else:
weight_splits[idx] = weight_splits[idx].reshape(
batch_size, num_insts, 1, -1)

return weight_splits, bias_splits


def _dynamic_conv_forward(features: Tensor, weights: List[Tensor],
biases: List[Tensor]):
"""dynamic forward, each layer follow a relu."""
n_layers = len(weights)
x = features.flatten(0, 1).flatten(2)
for i, (w, b) in enumerate(zip(weights, biases)):
# replace dynamic conv with bmm
w = w.flatten(0, 1)
b = b.flatten(0, 1).unsqueeze(2)
x = torch.bmm(w, x)
x = x + b
if i < n_layers - 1:
x = x.clamp_(min=0)
return x
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,30 @@
'instance_segmentor_forward',
inputs=['input'],
outputs=['dets', 'labels', 'masks'])
def __forward_impl_instance_seg(self, batch_inputs, data_samples, **kwargs):
def __forward_impl_instance_seg(self,
batch_inputs,
data_samples,
rescale=True,
**kwargs):
"""Rewrite and adding mark for `forward`.

Encapsulate this function for rewriting `forward` of BaseDetector.
1. Add mark for BaseDetector.
2. Support both dynamic and static export to onnx.
"""
x = self.extract_feat(batch_inputs)
mask_outs = self.mask_head.predict(x, data_samples, rescale=False)
if self.with_bbox:
# the bbox branch does not need to be scaled to the original
# image scale, because the mask branch will scale both bbox
# and mask at the same time.
bbox_rescale = rescale if not self.with_mask else False
results_list = self.bbox_head.predict(
x, data_samples, rescale=bbox_rescale)
else:
results_list = None

mask_outs = self.mask_head.predict(
x, data_samples, rescale=rescale, results_list=results_list)
return mask_outs


Expand Down
3 changes: 3 additions & 0 deletions mmdeploy/pytorch/functions/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def tensor__repeat__tensorrt(input: torch.Tensor, *size: Union[torch.Size,

origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
if isinstance(*size, tuple):
return origin_func(input.unsqueeze(0),
*([1] + list(*size))).squeeze(0)
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
else:
return origin_func(input, *size)
10 changes: 10 additions & 0 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,13 @@ models:
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32

- name: CondInst
metafile: configs/condinst/metafile.yml
model_configs:
- configs/condinst/condinst_r50_fpn_ms-poly-90k_coco_instance.py
pipelines:
- deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime_dynamic.py
backend_test: *default_backend_test
- deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py
backend_test: *default_backend_test
Loading