forked from VCLMW-Pose/MW-Pose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpicture_demo.py
69 lines (59 loc) · 2.19 KB
/
picture_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import re
import sys
import cv2
import math
import time
import scipy
import argparse
import matplotlib
import numpy as np
import pylab as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from scipy.ndimage.filters import gaussian_filter
from network.rtpose_vgg import get_model
from network.post import decode_pose
from training.datasets.coco_data.preprocessing import (inception_preprocess,
rtpose_preprocess,
ssd_preprocess, vgg_preprocess)
from network import im_transform
from evaluate.coco_eval import get_multiplier, get_outputs, handle_paf_and_heat
from openpose_utils import *
dir = 'D:\\Documents\\Source\\MW-Pose\\'
datadir = 'D:\\Documents\\Source\\MW-Pose\\test\\'
weight_name = 'D:\\Documents\\Source\\MW-Pose\\openpose\\pose_model.pth'
loader = Loader(datadir)
saver = Saver(datadir)
model = get_model('vgg19')
model.load_state_dict(torch.load(weight_name))
model = torch.nn.DataParallel(model).cuda()
model.float()
model.eval()
for i, fname in enumerate(loader):
oriImg = cv2.imread(fname) # B,G,R order
shape_dst = np.min(oriImg.shape[0:2])
# Get results of original image
multiplier = get_multiplier(oriImg)
with torch.no_grad():
orig_paf, orig_heat = get_outputs(
multiplier, oriImg, model, 'rtpose')
# Get results of flipped image
swapped_img = oriImg[:, ::-1, :]
flipped_paf, flipped_heat = get_outputs(multiplier, swapped_img,
model, 'rtpose')
# compute averaged heatmap and paf
paf, heatmap = handle_paf_and_heat(
orig_heat, flipped_heat, orig_paf, flipped_paf)
param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
canvas, to_plot, joint_list, person_to_joint_assoc = decode_pose(
oriImg, param, heatmap, paf)
saver.crawl(fname, joint_list, person_to_joint_assoc)
cv2.imwrite(dir + 'done\\' + str(i)+'.png', to_plot)
print('%d images have been annotated!' % i)
print('Annotation completed!')
saver.distribute()
exit()