Skip to content

Commit

Permalink
add single depth to 3d hand keypoints, add nyu hand dataset and awr n…
Browse files Browse the repository at this point in the history
…etwork
  • Loading branch information
walsvid committed Jul 15, 2022
1 parent afb37d4 commit 4b1771c
Show file tree
Hide file tree
Showing 19 changed files with 2,060 additions and 15 deletions.
92 changes: 92 additions & 0 deletions configs/_base_/datasets/nyu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
dataset_info = dict(
dataset_name='nyu',
paper_info=dict(
author='Jonathan Tompson and Murphy Stein and Yann Lecun and '
'Ken Perlin',
title='Real-Time Continuous Pose Recovery of Human Hands '
'Using Convolutional Networks',
container='ACM Transactions on Graphics',
year='2014',
homepage='https://jonathantompson.github.io/NYU_Hand_Pose_Dataset.htm',
),
keypoint_info={
0: dict(name='F1_KNU3_A', id=0, color=[255, 128, 0], type='', swap=''),
1: dict(name='F1_KNU3_B', id=1, color=[255, 128, 0], type='', swap=''),
2: dict(name='F1_KNU2_A', id=2, color=[255, 128, 0], type='', swap=''),
3: dict(name='F1_KNU2_B', id=3, color=[255, 128, 0], type='', swap=''),
4:
dict(name='F1_KNU1_A', id=4, color=[255, 153, 255], type='', swap=''),
5:
dict(name='F1_KNU1_B', id=5, color=[255, 153, 255], type='', swap=''),
6:
dict(name='F2_KNU3_A', id=6, color=[255, 153, 255], type='', swap=''),
7:
dict(name='F2_KNU3_B', id=7, color=[255, 153, 255], type='', swap=''),
8:
dict(name='F2_KNU2_A', id=8, color=[102, 178, 255], type='', swap=''),
9:
dict(name='F2_KNU2_B', id=9, color=[102, 178, 255], type='', swap=''),
10:
dict(name='F2_KNU1_A', id=10, color=[102, 178, 255], type='', swap=''),
11:
dict(name='F2_KNU1_B', id=11, color=[102, 178, 255], type='', swap=''),
12:
dict(name='F3_KNU3_A', id=12, color=[255, 51, 51], type='', swap=''),
13:
dict(name='F3_KNU3_B', id=13, color=[255, 51, 51], type='', swap=''),
14:
dict(name='F3_KNU2_A', id=14, color=[255, 51, 51], type='', swap=''),
15:
dict(name='F3_KNU2_B', id=15, color=[255, 51, 51], type='', swap=''),
16: dict(name='F3_KNU1_A', id=16, color=[0, 255, 0], type='', swap=''),
17: dict(name='F3_KNU1_B', id=17, color=[0, 255, 0], type='', swap=''),
18: dict(name='F4_KNU3_A', id=18, color=[0, 255, 0], type='', swap=''),
19: dict(name='F4_KNU3_B', id=19, color=[0, 255, 0], type='', swap=''),
20:
dict(name='F4_KNU2_A', id=20, color=[255, 255, 255], type='', swap=''),
21:
dict(name='F4_KNU2_B', id=21, color=[255, 128, 0], type='', swap=''),
22:
dict(name='F4_KNU1_A', id=22, color=[255, 128, 0], type='', swap=''),
23:
dict(name='F4_KNU1_B', id=23, color=[255, 128, 0], type='', swap=''),
24:
dict(name='TH_KNU3_A', id=24, color=[255, 128, 0], type='', swap=''),
25:
dict(name='TH_KNU3_B', id=25, color=[255, 153, 255], type='', swap=''),
26:
dict(name='TH_KNU2_A', id=26, color=[255, 153, 255], type='', swap=''),
27:
dict(name='TH_KNU2_B', id=27, color=[255, 153, 255], type='', swap=''),
28:
dict(name='TH_KNU1_A', id=28, color=[255, 153, 255], type='', swap=''),
29:
dict(name='TH_KNU1_B', id=29, color=[102, 178, 255], type='', swap=''),
30:
dict(name='PALM_1', id=30, color=[102, 178, 255], type='', swap=''),
31:
dict(name='PALM_2', id=31, color=[102, 178, 255], type='', swap=''),
32:
dict(name='PALM_3', id=32, color=[102, 178, 255], type='', swap=''),
33: dict(name='PALM_4', id=33, color=[255, 51, 51], type='', swap=''),
34: dict(name='PALM_5', id=34, color=[255, 51, 51], type='', swap=''),
35: dict(name='PALM_6', id=35, color=[255, 51, 51], type='', swap=''),
},
skeleton_info={
0: dict(link=('PALM_3', 'F1_KNU2_B'), id=0, color=[255, 128, 0]),
1: dict(link=('F1_KNU2_B', 'F1_KNU3_A'), id=1, color=[255, 128, 0]),
2: dict(link=('PALM_3', 'F2_KNU2_B'), id=2, color=[255, 128, 0]),
3: dict(link=('F2_KNU2_B', 'F2_KNU3_A'), id=3, color=[255, 128, 0]),
4: dict(link=('PALM_3', 'F3_KNU2_B'), id=4, color=[255, 153, 255]),
5: dict(link=('F3_KNU2_B', 'F3_KNU3_A'), id=5, color=[255, 153, 255]),
6: dict(link=('PALM_3', 'F4_KNU2_B'), id=6, color=[255, 153, 255]),
7: dict(link=('F4_KNU2_B', 'F4_KNU3_A'), id=7, color=[255, 153, 255]),
8: dict(link=('PALM_3', 'TH_KNU2_B'), id=8, color=[102, 178, 255]),
9: dict(link=('TH_KNU2_B', 'TH_KNU3_B'), id=9, color=[102, 178, 255]),
10:
dict(link=('TH_KNU3_B', 'TH_KNU3_A'), id=10, color=[102, 178, 255]),
11: dict(link=('PALM_3', 'PALM_1'), id=11, color=[102, 178, 255]),
12: dict(link=('PALM_3', 'PALM_2'), id=12, color=[255, 51, 51]),
},
joint_weights=[1.] * 36,
sigmas=[])
177 changes: 177 additions & 0 deletions configs/hand/3d_kpt_sview_depth_img/awr/nyu/res50_nyu_all_128x128.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/nyu.py'
]
checkpoint_config = dict(interval=1)
# TODO: metric
evaluation = dict(
interval=1,
metric=['MRRPE', 'MPJPE', 'Handedness_acc'],
save_best='MPJPE_all')

optimizer = dict(
type='Adam',
lr=2e-4,
)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[15, 17])
total_epochs = 20
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])

load_from = '/root/mmpose/data/ckpt/new_res50.pth'
used_keypoints_index = [0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 27, 30, 31, 32]

channel_cfg = dict(
num_output_channels=14,
dataset_joints=36,
dataset_channel=used_keypoints_index,
inference_channel=used_keypoints_index)

# model settings
model = dict(
type='Depthhand3D', # pretrained=None
backbone=dict(
type='AWRResNet',
depth=50,
frozen_stages=-1,
zero_init_residual=False,
in_channels=1),
keypoint_head=dict(
type='AdaptiveWeightingRegression3DHead',
offset_head_cfg=dict(
in_channels=256,
out_channels_vector=42,
out_channels_scalar=14,
heatmap_kernel_size=1.0,
),
deconv_head_cfg=dict(
in_channels=2048,
out_channels=256,
depth_size=64,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
extra=dict(final_conv_kernel=0, )),
loss_offset=dict(type='AWRSmoothL1Loss', use_target_weight=False),
loss_keypoint=dict(type='AWRSmoothL1Loss', use_target_weight=True),
),
train_cfg=dict(use_img_for_head=True),
test_cfg=dict(use_img_for_head=True, flip_test=False))

data_cfg = dict(
image_size=[128, 128],
heatmap_size=[64, 64, 56],
cube_size=[300, 300, 300],
heatmap_size_root=64,
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'])

train_pipeline = [
dict(type='LoadImageFromFile', color_type='unchanged'),
dict(type='TopDownGetBboxCenterScale', padding=1.0),
dict(type='TopDownAffine'),
dict(type='DepthToTensor'),
dict(
type='MultitaskGatherTarget',
pipeline_list=[
[
dict(
type='TopDownGenerateTargetRegression',
use_zero_mean=True,
joint_indices=used_keypoints_index,
is_3d=True,
normalize_depth=True,
),
dict(
type='HandGenerateJointToOffset',
heatmap_kernel_size=1.0,
)
],
[
dict(
type='TopDownGenerateTargetRegression',
use_zero_mean=True,
joint_indices=used_keypoints_index,
is_3d=True,
normalize_depth=True,
)
],
],
pipeline_indices=[0, 1],
),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'flip_pairs', 'cube_size', 'center_depth', 'focal',
'princpt', 'image_size', 'joints_cam', 'dataset_channel',
'joints_uvd'
]),
]

val_pipeline = [
dict(type='LoadImageFromFile', color_type='unchanged'),
dict(type='TopDownGetBboxCenterScale', padding=1.0),
dict(type='TopDownAffine'),
dict(type='DepthToTensor'),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'flip_pairs', 'cube_size', 'center_depth', 'focal',
'princpt', 'image_size', 'joints_cam', 'dataset_channel',
'joints_uvd'
])
]

test_pipeline = val_pipeline

data_root = 'data/nyu'
data = dict(
samples_per_gpu=4,
workers_per_gpu=0,
shuffle=False,
train=dict(
type='NYUHandDataset',
ann_file=f'{data_root}/annotations/nyu_test_data.json',
camera_file=f'{data_root}/annotations/nyu_test_camera.json',
joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json',
img_prefix=f'{data_root}/images/test/',
data_cfg=data_cfg,
use_refined_center=False,
align_uvd_xyz_direction=True,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='NYUHandDataset',
ann_file=f'{data_root}/annotations/nyu_test_data.json',
camera_file=f'{data_root}/annotations/nyu_test_camera.json',
joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json',
img_prefix=f'{data_root}/images/test/',
data_cfg=data_cfg,
use_refined_center=False,
align_uvd_xyz_direction=True,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='NYUHandDataset',
ann_file=f'{data_root}/annotations/nyu_test_data.json',
camera_file=f'{data_root}/annotations/nyu_test_camera.json',
joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json',
img_prefix=f'{data_root}/images/test/',
data_cfg=data_cfg,
use_refined_center=False,
align_uvd_xyz_direction=True,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
35 changes: 35 additions & 0 deletions mmpose/core/evaluation/top_down_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,41 @@ def keypoints_from_heatmaps3d(heatmaps, center, scale):
return preds, maxvals


def keypoints_from_joint_uvd(joint_uvd, center, scale, image_size):
"""Get final keypoint predictions from 3d heatmaps and transform them back
to the image.
Note:
- batch size: N
- num keypoints: K
- heatmap depth size: D
- heatmap height: H
- heatmap width: W
Args:
heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps.
center (np.ndarray[N, 2]): Center of the bounding box (x, y).
scale (np.ndarray[N, 2]): Scale of the bounding box
wrt height/width.
Returns:
tuple: A tuple containing keypoint predictions and scores.
- preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \
in images.
- maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints.
"""
N, K, D = joint_uvd.shape
preds = joint_uvd
maxvals = np.ones((N, K, 1), dtype=np.float32)
# Transform back to the image
for i in range(N):
preds[i, :, :2] = transform_preds(
(preds[i, :, :2] + 1) * image_size[i] / 2, center[i], scale[i],
[image_size[i, 1], image_size[i, 0]])
return preds, maxvals


def multilabel_classification_accuracy(pred, gt, mask, thr=0.5):
"""Get multi-label classification accuracy.
Expand Down
12 changes: 9 additions & 3 deletions mmpose/datasets/datasets/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
from .kpt_2d_sview_rgb_vid_top_down_dataset import \
Kpt2dSviewRgbVidTopDownDataset
from .kpt_3d_mview_rgb_img_direct_dataset import Kpt3dMviewRgbImgDirectDataset
from .kpt_3d_sview_depth_img_top_down_dataset import \
Kpt3dSviewDepthImgTopDownDataset
from .kpt_3d_sview_kpt_2d_dataset import Kpt3dSviewKpt2dDataset
from .kpt_3d_sview_rgb_img_top_down_dataset import \
Kpt3dSviewRgbImgTopDownDataset

__all__ = [
'Kpt3dMviewRgbImgDirectDataset', 'Kpt2dSviewRgbImgTopDownDataset',
'Kpt3dSviewRgbImgTopDownDataset', 'Kpt2dSviewRgbImgBottomUpDataset',
'Kpt3dSviewKpt2dDataset', 'Kpt2dSviewRgbVidTopDownDataset'
'Kpt3dMviewRgbImgDirectDataset',
'Kpt2dSviewRgbImgTopDownDataset',
'Kpt3dSviewRgbImgTopDownDataset',
'Kpt2dSviewRgbImgBottomUpDataset',
'Kpt3dSviewKpt2dDataset',
'Kpt2dSviewRgbVidTopDownDataset',
'Kpt3dSviewDepthImgTopDownDataset',
]
Loading

0 comments on commit 4b1771c

Please sign in to comment.