-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdemo_video.py
106 lines (82 loc) · 2.96 KB
/
demo_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torchvision.transforms as transforms
import numpy as np
import cv2
import copy
import sys
from utils.imutils import *
from utils.transforms import *
from datasets import W300LP, VW300, AFLW2000, LS3DW
import models
from models.fan_model import FAN
from utils.evaluation import get_preds, final_preds
from faceboxes import face_detector_init, detect
CHECKPOINT_PATH = "./checkpoint/fan3d_wo_norm_att/model_best.pth.tar"
# flag of saving pics to gen gif
SAVE = False
SAVE_DIR = "./save_pics"
if len(sys.argv) < 2:
print("please specify run model...")
exit(0)
model_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__") and callable(models.__dict__[name]))
print(model_names)
model = FAN(2)
if sys.argv[1] == "cpu":
model_dict = model.state_dict()
checkpoint = torch.load(CHECKPOINT_PATH, map_location=lambda storage, loc: storage)['state_dict']
for k in checkpoint.keys():
model_dict[k.replace('module.', '')] = checkpoint[k]
model.load_state_dict(model_dict)
else:
model = torch.nn.DataParallel(model).cuda()
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
proto = "faceboxes_deploy.prototxt"
mdl = "faceboxes_iter_120000.caffemodel"
face_detector = face_detector_init(proto, mdl)
if SAVE == True:
if not os.path.exists(SAVE_DIR):
os.mkdir(SAVE_DIR)
count = 0
reference_scale = 200
cap = cv2.VideoCapture(0)
while True:
_, img_ori = cap.read()
# rects = face_detector(img_ori, 1)
rects = detect(img_ori, face_detector)
if len(rects) == 0:
continue
print(rects)
for rect in rects:
d = [rect.left() - 10, rect.top() - 10, rect.right() + 10, rect.bottom() + 10]
# d = [rect.left() , rect.top() , rect.right() , rect.bottom()]
center = torch.FloatTensor([d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0])
# center[1] = center[1] + (d[3] - d[1]) * 0.12
hw = max(d[2] - d[0], d[3] - d[1])
scale = float(hw / reference_scale)
# print(scale)
img_chn = copy.deepcopy(img_ori[:,:,::-1])
img_trans = np.transpose(img_chn, (2,0,1))
inp = crop(img_trans, center, scale)
inp.unsqueeze_(0)
output = model(inp)
if sys.argv[1] == "cpu":
score_map = output[-1].data
else:
score_map = output[-1].data.cpu()
pts_img = final_preds(score_map, [center], [scale], [64, 64])
# print(pts_img)
pts_img = np.squeeze(pts_img.numpy())
# print(pts_img)
for i in range(pts_img.shape[0]):
pts = pts_img[i]
cv2.circle(img_ori, (pts[0], pts[1]), 2, (0, 255, 0), -1, 2)
cv2.rectangle(img_ori, (d[0], d[1]), (d[2], d[3]), (255, 255, 255))
cv2.imshow("landmark", img_ori)
if SAVE == True:
cv2.imwrite(os.path.join(SAVE_DIR, "image_{}.jpg".format(count)), img_ori)
cv2.waitKey(1)
count += 1