From fade1bbcfab6e43aaa0ae86b8e30992619d668dc Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Fri, 8 Sep 2023 14:05:20 +0800 Subject: [PATCH 1/4] update --- ...ose-detection_simcc_onnxruntime_dynamic.py | 4 +++ mmdeploy/codebase/mmpose/codecs/__init__.py | 5 +++ .../codebase/mmpose/codecs/post_processing.py | 33 +++++++++++++++++++ .../codebase/mmpose/deploy/pose_detection.py | 6 +++- .../mmpose/deploy/pose_detection_model.py | 12 +++++-- .../codebase/mmpose/models/heads/__init__.py | 4 +-- .../mmpose/models/heads/simcc_head.py | 28 ++++++++++++++++ 7 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 mmdeploy/codebase/mmpose/codecs/__init__.py create mode 100644 mmdeploy/codebase/mmpose/codecs/post_processing.py create mode 100644 mmdeploy/codebase/mmpose/models/heads/simcc_head.py diff --git a/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py b/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py index 1ec49b90f9..6b9c3c891e 100644 --- a/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py +++ b/configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py @@ -14,3 +14,7 @@ 0: 'batch' } }) + +codebase_config = dict( + export_postprocess=False # do not export get_simcc_maximum +) diff --git a/mmdeploy/codebase/mmpose/codecs/__init__.py b/mmdeploy/codebase/mmpose/codecs/__init__.py new file mode 100644 index 0000000000..ad861c8c27 --- /dev/null +++ b/mmdeploy/codebase/mmpose/codecs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .post_processing import get_simcc_maximum + +__all__ = ['get_simcc_maximum'] diff --git a/mmdeploy/codebase/mmpose/codecs/post_processing.py b/mmdeploy/codebase/mmpose/codecs/post_processing.py new file mode 100644 index 0000000000..aae5bc94be --- /dev/null +++ b/mmdeploy/codebase/mmpose/codecs/post_processing.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def get_simcc_maximum(simcc_x: torch.Tensor, + simcc_y: torch.Tensor) -> torch.Tensor: + """Get maximum response location and value from simcc representations. + + rewrite to support `torch.Tensor` input type. + + Args: + simcc_x (torch.Tensor): x-axis SimCC in shape (N, K, Wx) + simcc_y (torch.Tensor): y-axis SimCC in shape (N, K, Wy) + + Returns: + tuple: + - locs (torch.Tensor): locations of maximum heatmap responses in shape + (N, K, 2) + - vals (torch.Tensor): values of maximum heatmap responses in shape + (N, K) + """ + N, K, _ = simcc_x.shape + simcc_x = simcc_x.flatten(0, 1) + simcc_y = simcc_y.flatten(0, 1) + x_locs = simcc_x.argmax(dim=1, keepdim=True) + y_locs = simcc_y.argmax(dim=1, keepdim=True) + locs = torch.cat((x_locs, y_locs), dim=1).to(torch.float32) + max_val_x, _ = simcc_x.max(dim=1, keepdim=True) + max_val_y, _ = simcc_y.max(dim=1, keepdim=True) + vals, _ = torch.cat([max_val_x, max_val_y], dim=1).min(dim=1) + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + return locs, vals diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 6584d995fb..86f2e4d09a 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -13,7 +13,8 @@ from mmengine.registry import Registry from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase -from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger +from mmdeploy.utils import (Codebase, Task, get_codebase_config, + get_input_shape, get_root_logger) def process_model_config( @@ -362,6 +363,9 @@ def get_postprocess(self, *args, **kwargs) -> Dict: params['post_process'] = 'megvii' params['modulate_kernel'] = self.model_cfg.kernel_sizes[-1] elif codec.type == 'SimCCLabel': + export_postprocess = get_codebase_config(self.deploy_cfg).get( + 'export_postprocess', False) + params['export_postprocess'] = export_postprocess component = 'SimCCLabelDecode' elif codec.type == 'RegressionLabel': component = 'DeepposeRegressionHeadDecode' diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py index 1686a089fe..be10a80e32 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py @@ -101,12 +101,20 @@ def forward(self, if self.model_cfg.model.type == 'YOLODetector': return self.pack_yolox_pose_result(batch_outputs, data_samples) + codebase_cfg = get_codebase_config(self.deploy_cfg) codec = self.model_cfg.codec if isinstance(codec, (list, tuple)): codec = codec[-1] if codec.type == 'SimCCLabel': - batch_pred_x, batch_pred_y = batch_outputs - preds = self.head.decode((batch_pred_x, batch_pred_y)) + export_postprocess = codebase_cfg.get('export_postprocess', False) + if export_postprocess: + keypoints, scores = [_.cpu().numpy() for _ in batch_outputs] + preds = [ + InstanceData(keypoints=keypoints, keypoint_scores=scores) + ] + else: + batch_pred_x, batch_pred_y = batch_outputs + preds = self.head.decode((batch_pred_x, batch_pred_y)) elif codec.type in ['RegressionLabel', 'IntegralRegressionLabel']: preds = self.head.decode(batch_outputs) else: diff --git a/mmdeploy/codebase/mmpose/models/heads/__init__.py b/mmdeploy/codebase/mmpose/models/heads/__init__.py index 9fb6239cdb..10bd18a0d9 100644 --- a/mmdeploy/codebase/mmpose/models/heads/__init__.py +++ b/mmdeploy/codebase/mmpose/models/heads/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import mspn_head, yolox_pose_head # noqa: F401,F403 +from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403 -__all__ = ['mspn_head', 'yolox_pose_head'] +__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head'] diff --git a/mmdeploy/codebase/mmpose/models/heads/simcc_head.py b/mmdeploy/codebase/mmpose/models/heads/simcc_head.py new file mode 100644 index 0000000000..d568e42d62 --- /dev/null +++ b/mmdeploy/codebase/mmpose/models/heads/simcc_head.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.codebase.mmpose.codecs import get_simcc_maximum +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import get_codebase_config + + +@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.RTMCCHead.forward') +@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.SimCCHead.forward') +def simcc_head__forward(self, feats): + """Rewrite `forward` of SimCCHead for default backend. + + Args: + feats (tuple[Tensor]): Input features. + Returns: + key-points (torch.Tensor): Output keypoints in + shape of (N, K, 3) + """ + ctx = FUNCTION_REWRITER.get_context() + simcc_x, simcc_y = ctx.origin_func(self, feats) + codebase_cfg = get_codebase_config(ctx.cfg) + export_postprocess = codebase_cfg.get('export_postprocess', False) + if not export_postprocess: + return simcc_x, simcc_y + assert self.decoder.use_dark is False, \ + 'Do not support SimCCLabel with use_dark=True' + pts, scores = get_simcc_maximum(simcc_x, simcc_y) + pts /= self.decoder.simcc_split_ratio + return pts, scores From de4ca00a4f9a880ff8dc0471b74cab2ad58453e1 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Mon, 25 Sep 2023 10:28:18 +0800 Subject: [PATCH 2/4] update for simcc csrc --- csrc/mmdeploy/codebase/mmpose/simcc_label.cpp | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp b/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp index ffa0eebf25..fb4af47126 100644 --- a/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp +++ b/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp @@ -26,6 +26,7 @@ class SimCCLabelDecode : public MMPose { auto& params = config["params"]; flip_test_ = params.value("flip_test", flip_test_); simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_); + export_postprocess_ = params.value("export_postprocess", export_postprocess_); if (params.contains("input_size")) { from_value(params["input_size"], input_size_); } @@ -52,7 +53,9 @@ class SimCCLabelDecode : public MMPose { Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}}); Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}}); - get_simcc_maximum(simcc_x, simcc_y, keypoints, scores); + if (!export_postprocess_) { + get_simcc_maximum(simcc_x, simcc_y, keypoints, scores); + } std::vector center; std::vector scale; @@ -61,17 +64,25 @@ class SimCCLabelDecode : public MMPose { PoseDetectorOutput output; float* keypoints_data = keypoints.data(); + float* simcc_x_data = simcc_x.data(); + float* simcc_y_data = simcc_y.data(); + float* scores_data = scores.data(); float scale_value = 200, x = -1, y = -1, s = 0; for (int i = 0; i < simcc_x.shape(1); i++) { - x = *(keypoints_data + 0) / simcc_split_ratio_; - y = *(keypoints_data + 1) / simcc_split_ratio_; + if (export_postprocess_) { + x = *(simcc_x_data++); + y = *(simcc_x_data++); + s = *(scores_data++); + } else { + x = *(keypoints_data++) / simcc_split_ratio_; + y = *(keypoints_data++) / simcc_split_ratio_; + s = *(scores_data++); + } + x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5; y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5; - s = *(scores_data + 0); output.key_points.push_back({{x, y}, s}); - keypoints_data += 2; - scores_data += 1; } return to_value(output); } @@ -104,6 +115,7 @@ class SimCCLabelDecode : public MMPose { private: bool flip_test_{false}; + bool export_postprocess_{false}; bool shift_heatmap_{false}; float simcc_split_ratio_{2.0}; std::vector input_size_{192, 256}; From fa4674abdb172fb0077085bbce64c8b2b289f007 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Fri, 22 Sep 2023 13:03:53 +0800 Subject: [PATCH 3/4] fix docker ci --- .github/workflows/docker.yml | 2 +- .github/workflows/publish.yml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index e49170cce7..8ef4e85932 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -53,7 +53,7 @@ jobs: export TAG=$TAG_PREFIX echo "TAG=${TAG}" >> $GITHUB_ENV echo $TAG - docker ./docker/Release/ -t ${TAG} --no-cache + docker build ./docker/Release/ -t ${TAG} --no-cache docker push $TAG - name: Push docker image with released tag if: startsWith(github.ref, 'refs/tags/') == true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 871c2a025f..3f4585175a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -29,19 +29,19 @@ jobs: echo $MMDEPLOY_VERSION echo "MMDEPLOY_VERSION=$MMDEPLOY_VERSION" >> $GITHUB_ENV echo "OUTPUT_DIR=$PREBUILD_DIR/$MMDEPLOY_VERSION" >> $GITHUB_ENV - pip install twine + python3 -m pip install twine --user - name: Upload mmdeploy continue-on-error: true run: | cd $OUTPUT_DIR/mmdeploy ls -sha *.whl - twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }} + python3 -m twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }} - name: Upload mmdeploy_runtime continue-on-error: true run: | cd $OUTPUT_DIR/mmdeploy_runtime ls -sha *.whl - twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }} + python3 -m twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }} - name: Check assets run: | ls -sha $OUTPUT_DIR/sdk From 848f7862a083ab452dd1265fe6bb0c2b199bba55 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Tue, 26 Sep 2023 14:48:35 +0800 Subject: [PATCH 4/4] update simcc_label --- csrc/mmdeploy/codebase/mmpose/simcc_label.cpp | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp b/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp index fb4af47126..6ad142f6fa 100644 --- a/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp +++ b/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp @@ -27,6 +27,9 @@ class SimCCLabelDecode : public MMPose { flip_test_ = params.value("flip_test", flip_test_); simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_); export_postprocess_ = params.value("export_postprocess", export_postprocess_); + if (export_postprocess_) { + simcc_split_ratio_ = 1.0; + } if (params.contains("input_size")) { from_value(params["input_size"], input_size_); } @@ -53,8 +56,14 @@ class SimCCLabelDecode : public MMPose { Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}}); Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}}); + float *keypoints_data = nullptr, *scores_data = nullptr; if (!export_postprocess_) { get_simcc_maximum(simcc_x, simcc_y, keypoints, scores); + keypoints_data = keypoints.data(); + scores_data = scores.data(); + } else { + keypoints_data = simcc_x.data(); + scores_data = simcc_y.data(); } std::vector center; @@ -63,22 +72,11 @@ class SimCCLabelDecode : public MMPose { from_value(img_metas["scale"], scale); PoseDetectorOutput output; - float* keypoints_data = keypoints.data(); - float* simcc_x_data = simcc_x.data(); - float* simcc_y_data = simcc_y.data(); - - float* scores_data = scores.data(); float scale_value = 200, x = -1, y = -1, s = 0; for (int i = 0; i < simcc_x.shape(1); i++) { - if (export_postprocess_) { - x = *(simcc_x_data++); - y = *(simcc_x_data++); - s = *(scores_data++); - } else { - x = *(keypoints_data++) / simcc_split_ratio_; - y = *(keypoints_data++) / simcc_split_ratio_; - s = *(scores_data++); - } + x = *(keypoints_data++) / simcc_split_ratio_; + y = *(keypoints_data++) / simcc_split_ratio_; + s = *(scores_data++); x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5; y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5;