-
Notifications
You must be signed in to change notification settings - Fork 41
/
demo.py
64 lines (54 loc) · 1.86 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from functools import partial
from collections import namedtuple
import os
from fire import Fire
from pytorch_toolbelt.utils import read_rgb_image
from predictor import FaceMeshPredictor
from demo_utils import (
draw_landmarks,
draw_3d_landmarks,
draw_mesh,
draw_pose,
get_uv_texture,
get_pncc,
get_mesh,
get_flame_params,
get_output_path,
MeshSaver,
ImageSaver,
JsonSaver,
)
DemoFuncs = namedtuple(
"DemoFuncs",
["processor", "saver"],
)
demo_funcs = {
"68_landmarks": DemoFuncs(draw_landmarks, ImageSaver),
"191_landmarks": DemoFuncs(partial(draw_3d_landmarks, subset="191"), ImageSaver),
"445_landmarks": DemoFuncs(partial(draw_3d_landmarks, subset="445"), ImageSaver),
"head_mesh": DemoFuncs(partial(draw_mesh, subset="head"), ImageSaver),
"face_mesh": DemoFuncs(partial(draw_mesh, subset="face"), ImageSaver),
"pose": DemoFuncs(draw_pose, ImageSaver),
"uv_texture": DemoFuncs(get_uv_texture, ImageSaver),
"pncc": DemoFuncs(get_pncc, ImageSaver),
"3d_mesh": DemoFuncs(get_mesh, MeshSaver),
"flame_params": DemoFuncs(get_flame_params, JsonSaver)
}
def demo(
input_image_path: str = 'images/demo_heads/1.jpeg',
outputs_folder: str = "outputs",
type_of_output: str = "68_landmarks",
) -> None:
os.makedirs(outputs_folder, exist_ok=True)
# Preprocess and get predictions.
image = read_rgb_image(input_image_path)
predictor = FaceMeshPredictor.dad_3dnet()
predictions = predictor(image)
# Get the resulting output.
result = demo_funcs[type_of_output].processor(predictions, image)
# Save the demo output.
saver = demo_funcs[type_of_output].saver() # instantiate the Saver
output_path = get_output_path(input_image_path, outputs_folder, type_of_output, saver.extension)
saver(result, output_path)
if __name__ == "__main__":
Fire(demo)