diff --git a/csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp b/csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp index c24be17f2c..270f4c6641 100644 --- a/csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp +++ b/csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp @@ -149,7 +149,7 @@ class ResizeInstanceMask : public ResizeBBox { int resize_width = int(mask_width / scale_factor_[1] + 0.5); // skip resize if scale_factor is 1.0 if (resize_height != mask_height || resize_width != mask_width) { - cv::resize(mask_mat, mask_mat, cv::Size(resize_height, resize_width), cv::INTER_LINEAR); + cv::resize(mask_mat, mask_mat, cv::Size(resize_width, resize_height), cv::INTER_LINEAR); } // crop masks mask_mat = mask_mat(cv::Range(0, img_h), cv::Range(0, img_w)).clone(); diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index 40c5a5b0cf..889af233eb 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -320,6 +320,11 @@ def get_postprocess(self, *args, **kwargs) -> Dict: type = 'ResizeInstanceMask' # for instance-seg # resize and crop mask to origin image params['is_resize_mask'] = True + if 'mask_thr' in params: + type = 'ResizeInstanceMask' # for instance-seg + # resize and crop mask to origin image + params['mask_thr_binary'] = params['mask_thr'] + params['is_resize_mask'] = True if get_backend(self.deploy_cfg) == Backend.RKNN: if 'YOLO' in self.model_cfg.model.type or \ diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index a1e52bffd3..c6a958e5eb 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -241,7 +241,7 @@ def postprocessing_results(self, masks = batch_masks[i] img_h, img_w = img_metas[i]['img_shape'][:2] ori_h, ori_w = img_metas[i]['ori_shape'][:2] - if model_type == 'RTMDet': + if model_type in ['RTMDet', 'CondInst']: export_postprocess_mask = True else: export_postprocess_mask = False diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index ee87f41715..062bc7de52 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from . import base_dense_head # noqa: F401,F403 from . import centernet_head # noqa: F401,F403 +from . import condinst_head # noqa: F401,F403 from . import detr_head # noqa: F401,F403 from . import fovea_head # noqa: F401,F403 from . import gfl_head # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py new file mode 100644 index 0000000000..abfb6da56f --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import torch +from mmdet.models.utils import aligned_bilinear +from mmengine.config import ConfigDict +from torch import Tensor + +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.mmcv.ops.nms import multiclass_nms + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.dense_heads.CondInstBboxHead.predict_by_feat') +def condinst_bbox_head__predict_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + score_factors: Optional[List[Tensor]] = None, + param_preds: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True, +): + ctx = FUNCTION_REWRITER.get_context() + deploy_cfg = ctx.cfg + + assert len(cls_scores) == len(bbox_preds) + device = bbox_preds[0].device + cfg = self.test_cfg if cfg is None else cfg + batch_size = bbox_preds[0].shape[0] + featmap_sizes = [cls_score.shape[-2:] for cls_score in cls_scores] + + all_level_points_strides = self.prior_generator.grid_priors( + featmap_sizes, device=device, with_stride=True) + all_level_points = [i[:, :2] for i in all_level_points_strides] + all_level_strides = [i[:, 2] for i in all_level_points_strides] + + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, + self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) + for bbox_pred in bbox_preds + ] + flatten_score_factors = [ + score_factor.permute(0, 2, 3, 1).reshape(batch_size, -1, 1) + for score_factor in score_factors + ] + flatten_param_preds = [ + param_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_params) + for param_pred in param_preds + ] + flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() + flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + flatten_score_factors = torch.cat(flatten_score_factors, dim=1).sigmoid() + flatten_param_preds = torch.cat(flatten_param_preds, dim=1) + + points = torch.cat(all_level_points) + strides = torch.cat(all_level_strides) + tl_x = points[..., 0] - flatten_bbox_preds[..., 0] + tl_y = points[..., 1] - flatten_bbox_preds[..., 1] + br_x = points[..., 0] + flatten_bbox_preds[..., 2] + br_y = points[..., 1] + flatten_bbox_preds[..., 3] + + bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1) + scores = flatten_cls_scores + score_factors = flatten_score_factors + param_preds = flatten_param_preds + scores = scores * score_factors + + # get post processing config + post_params = get_post_processing_params(deploy_cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + + dets, labels, inds = multiclass_nms( + bboxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + output_index=True, + ) + + batch_inds = torch.arange(batch_size, device=bboxes.device).view(-1, 1) + points = points.unsqueeze(0).repeat(batch_size, 1, 1) + strides = strides.unsqueeze(0).repeat(batch_size, 1) + param_preds = param_preds[batch_inds, inds, :] + points = points[batch_inds, inds, :] + strides = strides[batch_inds, inds] + results = dict( + dets=dets, + labels=labels, + param_preds=param_preds, + points=points, + strides=strides) + return results + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.dense_heads.CondInstMaskHead.forward') +def condinst_mask_head__forward(self, x: tuple, + positive_infos: Dict[str, torch.Tensor]): + mask_feats = self.mask_feature_head(x) + + param_preds = positive_infos['param_preds'] + points = positive_infos['points'] + strides = positive_infos['strides'] + + batch_size = points.shape[0] + num_insts = points.shape[1] + hw = mask_feats.size()[-2:] + mask_feats = mask_feats.unsqueeze(1).repeat(1, num_insts, 1, 1, 1) + + points = points.reshape(-1, 1, 2).unsqueeze(0) + locations = self.prior_generator.single_level_grid_priors( + hw, level_idx=0, device=mask_feats.device) + locations = locations.unsqueeze(0).repeat(batch_size, 1, + 1).reshape(batch_size, 1, -1, 2) + centers = points.reshape(batch_size, -1, 1, 2) + rel_coordinates = (centers - locations).permute(0, 1, 3, 2).float() + rel_coordinates /= (strides[:, :, None, None] * self.size_of_interest) + rel_coords = rel_coordinates.reshape(batch_size, -1, 2, hw[0], hw[1]) + mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=2) + + weights, biases = _parse_dynamic_params(self, param_preds) + mask_preds = _dynamic_conv_forward(mask_head_inputs, weights, biases) + mask_preds = mask_preds.reshape(batch_size, num_insts, hw[0], hw[1]) + mask_preds = aligned_bilinear( + mask_preds, int(self.mask_feat_stride / self.mask_out_stride)) + return (mask_preds, ) + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.dense_heads.CondInstMaskHead.predict_by_feat') +def condinst_mask_head__predict_by_feat(self, + mask_preds: Tensor, + results_list: Dict[str, torch.Tensor], + batch_img_metas: List[dict], + rescale: bool = True, + **kwargs): + cfg = self.test_cfg + + dets = results_list['dets'] + labels = results_list['labels'] + img_hw = batch_img_metas[0]['img_shape'][:2] + + mask_preds = mask_preds.sigmoid() + mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride) + mask_preds = mask_preds[:, :, :img_hw[0], :img_hw[1]] + masks = (mask_preds > cfg.mask_thr).float() + + return dets, labels, masks + + +def _parse_dynamic_params(self, params: Tensor): + """parse the dynamic params for dynamic conv.""" + batch_size = params.shape[0] + num_insts = params.shape[1] + params = params.permute(1, 0, 2) + params_splits = list( + torch.split_with_sizes( + params, self.weight_nums + self.bias_nums, dim=2)) + + weight_splits = params_splits[:self.num_layers] + bias_splits = params_splits[self.num_layers:] + + for idx in range(self.num_layers): + if idx < self.num_layers - 1: + weight_splits[idx] = weight_splits[idx].reshape( + batch_size, num_insts, self.in_channels, -1) + else: + weight_splits[idx] = weight_splits[idx].reshape( + batch_size, num_insts, 1, -1) + + return weight_splits, bias_splits + + +def _dynamic_conv_forward(features: Tensor, weights: List[Tensor], + biases: List[Tensor]): + """dynamic forward, each layer follow a relu.""" + n_layers = len(weights) + x = features.flatten(0, 1).flatten(2) + for i, (w, b) in enumerate(zip(weights, biases)): + # replace dynamic conv with bmm + w = w.flatten(0, 1) + b = b.flatten(0, 1).unsqueeze(2) + x = torch.bmm(w, x) + x = x + b + if i < n_layers - 1: + x = x.clamp_(min=0) + return x diff --git a/mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py b/mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py index bdff6e6369..f173f20ea7 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py +++ b/mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py @@ -12,7 +12,11 @@ 'instance_segmentor_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) -def __forward_impl_instance_seg(self, batch_inputs, data_samples, **kwargs): +def __forward_impl_instance_seg(self, + batch_inputs, + data_samples, + rescale=True, + **kwargs): """Rewrite and adding mark for `forward`. Encapsulate this function for rewriting `forward` of BaseDetector. @@ -20,7 +24,18 @@ def __forward_impl_instance_seg(self, batch_inputs, data_samples, **kwargs): 2. Support both dynamic and static export to onnx. """ x = self.extract_feat(batch_inputs) - mask_outs = self.mask_head.predict(x, data_samples, rescale=False) + if self.with_bbox: + # the bbox branch does not need to be scaled to the original + # image scale, because the mask branch will scale both bbox + # and mask at the same time. + bbox_rescale = rescale if not self.with_mask else False + results_list = self.bbox_head.predict( + x, data_samples, rescale=bbox_rescale) + else: + results_list = None + + mask_outs = self.mask_head.predict( + x, data_samples, rescale=rescale, results_list=results_list) return mask_outs diff --git a/mmdeploy/pytorch/functions/repeat.py b/mmdeploy/pytorch/functions/repeat.py index edb6efc3a5..b3a5e09b68 100644 --- a/mmdeploy/pytorch/functions/repeat.py +++ b/mmdeploy/pytorch/functions/repeat.py @@ -19,6 +19,9 @@ def tensor__repeat__tensorrt(input: torch.Tensor, *size: Union[torch.Size, origin_func = ctx.origin_func if input.dim() == 1 and len(size) == 1: + if isinstance(*size, tuple): + return origin_func(input.unsqueeze(0), + *([1] + list(*size))).squeeze(0) return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0) else: return origin_func(input, *size) diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index ec81d7b6c1..679715afa2 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -446,3 +446,13 @@ models: pipelines: - *pipeline_ort_dynamic_fp32 - *pipeline_trt_dynamic_fp32 + + - name: CondInst + metafile: configs/condinst/metafile.yml + model_configs: + - configs/condinst/condinst_r50_fpn_ms-poly-90k_coco_instance.py + pipelines: + - deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime_dynamic.py + backend_test: *default_backend_test + - deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py + backend_test: *default_backend_test