-
Notifications
You must be signed in to change notification settings - Fork 5
/
demo_all_cls.py
118 lines (91 loc) · 3.51 KB
/
demo_all_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# --------------------------------------------------------
# Written by Yufei Ye (https://github.com/JudyYe)
# --------------------------------------------------------
from __future__ import print_function
import os
import numpy as np
import torch
from PIL import Image
import glob
from models.generator import ReconstructModel
from models.evaluator import Evaluator
from nnutils import mesh_utils
from nnutils.utils import load_my_state_dict
from absl import app
from config.config_flag import *
flags.DEFINE_string("demo_image", "examples/demo_images/", "path to input")
flags.DEFINE_string("demo_out", "outputs/demo_out", "dir of output")
flags.DEFINE_string("ckpt_dir", "weights", "dir of output")
FLAGS = flags.FLAGS
# optimization lambda
FLAGS.lap_loss = 100
FLAGS.lap_norm_loss = .5
FLAGS.cyc_mask_loss = 10
FLAGS.cyc_perc_loss = 0
def demo_all_cls(_):
for i, image_file in enumerate(glob.iglob(os.path.join(FLAGS.demo_image, '*_m.png'))):
data = load_image(image_file)
cls = data['cls']
# load pretrianed model
try:
model, cfg = load_model(os.path.join(FLAGS.ckpt_dir, cls, 'model.pth'))
except FileNotFoundError:
print(os.path.join(FLAGS.ckpt_dir, cls, 'model.pth'))
continue
# load demo data and preprocess.
# for visualization utils
evaluator = Evaluator(cfg)
# step1: infer coarse shape and camera pose
vox_world, camera_param = model.forward_image(data['image'])
# init meshes
vox_mesh = mesh_utils.cubify(vox_world).clone()
# step2: optimize meshes
mesh_inputs = {'mesh': vox_mesh, 'view': camera_param}
with torch.enable_grad():
mesh_outputs, record = evaluator.opt_mask(model, mesh_inputs, data, True, 300)
# visualize mesh.
vis_mesh(mesh_outputs, camera_param, evaluator.snapshot_mesh, os.path.basename(image_file).split('_m.')[0])
def load_image(mask_file):
image_file = mask_file.replace('_m.', '.')
image = np.asarray(Image.open(image_file))
image = image[:,:, 0: 3] # in case of RGBA
image = image[:, :, :3] / 127.5 - 1 # [-1, 1]
mask_path = mask_file
mask = np.asarray(Image.open(mask_path))
if mask.ndim < 3:
mask = mask[..., np.newaxis]
mask = (mask > 0).astype(np.float)
fg = image * mask + (1 - mask) # white bg
fg = to_tensor(fg)
image = to_tensor(image)
mask = to_tensor(mask)
cls = os.path.basename(mask_file).split('_')[0]
return {'bg': image, 'image': fg, 'mask': mask, 'cls': cls}
def to_tensor(image):
image = np.transpose(image, [2, 0, 1])
image = image[np.newaxis]
return torch.FloatTensor(image).cuda()
def load_model(ckpt_file):
print('Init...', ckpt_file)
pretrained_dict = torch.load(ckpt_file)
cfg = pretrained_dict['cfg']
model = ReconstructModel()
load_my_state_dict(model, pretrained_dict['G'])
model.eval()
model.cuda()
return model, cfg
def vis_mesh(cano_mesh, pred_view, snapshot_func, prefix, f=375):
"""
:param cano_mesh:
:param pred_view:
:param renderer:
:param snapshot_func: snapshot given pose_list, and generate gif.
:return:
"""
# a novel view
snapshot_func(cano_mesh['mesh'][-1], [], None,
FLAGS.demo_out, prefix, 'mesh', pred_view=pred_view)
snapshot_func(cano_mesh['mesh'][-1], [], cano_mesh['mesh'].textures,
FLAGS.demo_out, prefix, 'meshTexture', pred_view=pred_view)
if __name__ == '__main__':
app.run(demo_all_cls)