Skip to content

Commit

Permalink
[Fix] fix the onnx exportation for yoloxpose in mmpose (#2466)
Browse files Browse the repository at this point in the history
* fix the onnx exportation for yoloxpose

* remove deprecated func

* refine code

* fix the rescaling process of top-down models

* fix ut

* add yoloxpose in regression test

* fix comment

* rebase & fix conflict
  • Loading branch information
Ben-Louis authored Oct 23, 2023
1 parent 1090fb6 commit 6edd802
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 131 deletions.
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmpose/deploy/pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def process_model_config(
type='Normalize',
mean=data_preprocessor.mean,
std=data_preprocessor.std,
to_rgb=data_preprocessor.bgr_to_rgb))
to_rgb=data_preprocessor.get('bgr_to_rgb', False)))
test_pipeline.append(dict(type='ImageToTensor', keys=['img']))
test_pipeline.append(
dict(
Expand Down
36 changes: 23 additions & 13 deletions mmdeploy/codebase/mmpose/deploy/pose_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def forward(self,
inputs = inputs.contiguous().to(self.device)
batch_outputs = self.wrapper({self.input_name: inputs})
batch_outputs = self.wrapper.output_to_list(batch_outputs)
if self.model_cfg.model.type == 'YOLODetector':
return self.pack_yolox_pose_result(batch_outputs, data_samples)

codebase_cfg = get_codebase_config(self.deploy_cfg)
codec = self.model_cfg.codec
if isinstance(codec, (list, tuple)):
codec = codec[-1]
if codec.type == 'SimCCLabel':

if codec.type == 'YOLOXPoseAnnotationProcessor':
return self.pack_yolox_pose_result(batch_outputs, data_samples)
elif codec.type == 'SimCCLabel':
export_postprocess = codebase_cfg.get('export_postprocess', False)
if export_postprocess:
keypoints, scores = [_.cpu().numpy() for _ in batch_outputs]
Expand Down Expand Up @@ -134,7 +135,7 @@ def pack_result(self,
convert_coordinate (bool): Whether to convert keypoints
coordinates to original image space. Default is True.
Returns:
data_samples (List[BaseDataElement])
data_samples (List[BaseDataElement]):
updated data_samples with predictions.
"""
if isinstance(preds, tuple):
Expand All @@ -153,11 +154,11 @@ def pack_result(self,
# convert keypoint coordinates from input space to image space
if convert_coordinate:
input_size = data_sample.metainfo['input_size']
bbox_centers = gt_instances.bbox_centers
bbox_scales = gt_instances.bbox_scales
input_center = data_sample.metainfo['input_center']
input_scale = data_sample.metainfo['input_scale']
keypoints = pred_instances.keypoints
keypoints = keypoints / input_size * bbox_scales
keypoints += bbox_centers - 0.5 * bbox_scales
keypoints = keypoints / input_size * input_scale
keypoints += input_center - 0.5 * input_scale
pred_instances.keypoints = keypoints

pred_instances.bboxes = gt_instances.bboxes
Expand All @@ -178,7 +179,7 @@ def pack_yolox_pose_result(self, preds: List[torch.Tensor],
data_samples (List[BaseDataElement]): A list of meta info for
image(s).
Returns:
data_samples (List[BaseDataElement])
data_samples (List[BaseDataElement]):
updated data_samples with predictions.
"""
assert preds[0].shape[0] == len(data_samples)
Expand All @@ -197,11 +198,20 @@ def pack_yolox_pose_result(self, preds: List[torch.Tensor],
keypoint_scores = keypoint_scores[inds]

pred_instances = InstanceData()

# rescale
scale_factor = data_sample.scale_factor
scale_factor = keypoints.new_tensor(scale_factor)
keypoints /= keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
bboxes /= keypoints.new_tensor(scale_factor).repeat(1, 2)
input_size = data_sample.metainfo['input_size']
input_center = data_sample.metainfo['input_center']
input_scale = data_sample.metainfo['input_scale']

rescale = keypoints.new_tensor(input_scale) / keypoints.new_tensor(
input_size)
translation = keypoints.new_tensor(
input_center) - 0.5 * keypoints.new_tensor(input_scale)

keypoints = keypoints * rescale.reshape(
1, 1, 2) + translation.reshape(1, 1, 2)
bboxes = bboxes * rescale.repeat(1, 2) + translation.repeat(1, 2)
pred_instances.bboxes = bboxes.cpu().numpy()
pred_instances.bbox_scores = bbox_scores
# the precision test requires keypoints to be np.ndarray
Expand Down
145 changes: 28 additions & 117 deletions mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Optional, Tuple

import torch
from mmengine.config import ConfigDict
from torch import Tensor

from mmdeploy.codebase.mmdet import get_post_processing_params
Expand All @@ -11,18 +10,18 @@
from mmdeploy.utils import Backend, get_backend


@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
'YOLOXPoseHead.predict')
@FUNCTION_REWRITER.register_rewriter(
func_name='mmpose.models.heads.hybrid_heads.'
'yoloxpose_head.YOLOXPoseHead.forward')
def predict(self,
x: Tuple[Tensor],
batch_data_samples=None,
rescale: bool = True):
batch_data_samples: List = [],
test_cfg: Optional[dict] = None):
"""Get predictions and transform to bbox and keypoints results.
Args:
x (Tuple[Tensor]): The input tensor from upstream network.
batch_data_samples: Batch image meta info. Defaults to None.
rescale: If True, return boxes in original image space.
Defaults to False.
test_cfg: The runtime config for testing process.
Returns:
Tuple[Tensor]: Predict bbox and keypoint results.
Expand All @@ -33,73 +32,17 @@ def predict(self,
Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
the last dimension 3 arrange as (x, y, score).
"""
outs = self(x)
predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_data_samples, rescale=rescale)
return predictions


@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
'YOLOXPoseHead.predict_by_feat')
def yolox_pose_head__predict_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
kpt_preds: Optional[List[Tensor]] = None,
vis_preds: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> Tuple[Tensor]:
"""Transform a batch of output features extracted by the head into bbox and
keypoint results.
In addition to the base class method, keypoint predictions are also
calculated in this method.

Args:
cls_scores (List[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (List[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (Optional[List[Tensor]]): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
kpt_preds (Optional[List[Tensor]]): Keypoints for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_keypoints * 2, H, W)
vis_preds (Optional[List[Tensor]]): Keypoints scores for
all scale levels, each is a 4D-tensor, has shape
(batch_size, num_keypoints, H, W)
batch_img_metas (Optional[List[dict]]): Batch image meta
info. Defaults to None.
cfg (Optional[ConfigDict]): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
Tuple[Tensor]: Predict bbox and keypoint results.
- dets (Tensor): Predict bboxes and scores, which is a 3D Tensor,
has shape (batch_size, num_instances, 5), the last dimension 5
arrange as (x1, y1, x2, y2, score).
- pred_kpts (Tensor): Predict keypoints and scores, which is a 4D
Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
the last dimension 3 arrange as (x, y, score).
"""
cls_scores, objectnesses, bbox_preds, kpt_offsets, \
kpt_vis = self.head_module(x)[:5]

ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
dtype = cls_scores[0].dtype
device = cls_scores[0].device
bbox_decoder = self.bbox_coder.decode

assert len(cls_scores) == len(bbox_preds)
cfg = self.test_cfg if cfg is None else cfg
cfg = self.test_cfg if test_cfg is None else test_cfg

num_imgs = cls_scores[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
Expand All @@ -110,60 +53,27 @@ def yolox_pose_head__predict_by_feat(
flatten_priors = torch.cat(self.mlvl_priors)

mlvl_strides = [
flatten_priors.new_full(
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
stride)
flatten_priors.new_full((featmap_size.numel(), ), stride)
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)

# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
for cls_score in cls_scores
]
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()

flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)

if objectnesses is not None:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))

scores = cls_scores
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
flatten_stride)

# deal with key-poinsts
priors = torch.cat(self.mlvl_priors)
strides = [
priors.new_full((featmap_size.numel() * self.num_base_priors, ),
stride)
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
strides = torch.cat(strides)
kpt_preds = torch.cat([
kpt_pred.permute(0, 2, 3, 1).reshape(
num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds
],
dim=1)
flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides)

vis_preds = torch.cat([
vis_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_keypoints,
1) for vis_pred in vis_preds
],
dim=1).sigmoid()

pred_kpts = torch.cat([flatten_decoded_kpts, vis_preds], dim=3)
flatten_cls_scores = self._flatten_predictions(cls_scores).sigmoid()
flatten_bbox_preds = self._flatten_predictions(bbox_preds)
flatten_objectness = self._flatten_predictions(objectnesses).sigmoid()
flatten_kpt_offsets = self._flatten_predictions(kpt_offsets)
flatten_kpt_vis = self._flatten_predictions(kpt_vis).sigmoid()
bboxes = self.decode_bbox(flatten_bbox_preds, flatten_priors,
flatten_stride)
flatten_decoded_kpts = self.decode_kpt_reg(flatten_kpt_offsets,
flatten_priors, flatten_stride)

scores = flatten_cls_scores * flatten_objectness

pred_kpts = torch.cat([flatten_decoded_kpts,
flatten_kpt_vis.unsqueeze(3)],
dim=3)

backend = get_backend(deploy_cfg)
if backend == Backend.TENSORRT:
Expand All @@ -184,10 +94,11 @@ def yolox_pose_head__predict_by_feat(
# nms
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
iou_threshold = cfg.get('nms_thr', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.get('pre_top_k', -1)
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)

# do nms
_, _, nms_indices = multiclass_nms(
bboxes,
Expand Down
10 changes: 10 additions & 0 deletions tests/regression/mmpose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,13 @@ models:
sdk_config: configs/mmpose/pose-detection_simcc_sdk_static-256x192.py
- convert_image: *convert_image
deploy_config: configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py

- name: YOLOX-Pose
metafile: configs/body_2d_keypoint/yoloxpose/coco/yoloxpose_coco.yml
model_configs:
- configs/body_2d_keypoint/yoloxpose/coco/yoloxpose_s_8xb32-300e_coco-640.py
pipelines:
- convert_image:
input_img: *img_human_pose
test_img: *img_human_pose
deploy_config: configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py
2 changes: 2 additions & 0 deletions tests/test_codebase/test_mmpose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def generate_datasample(img_size, heatmap_size=(64, 48)):
img_shape=(h, w, 3),
crop_size=(h, w),
input_size=(h, w),
input_center=numpy.asarray((h / 2, w / 2)),
input_scale=numpy.asarray((h, w)),
heatmap_size=heatmap_size)
pred_instances = InstanceData()
pred_instances.bboxes = numpy.array([[0.0, 0.0, 1.0, 1.0]])
Expand Down

0 comments on commit 6edd802

Please sign in to comment.