forked from AkankshaNarula/Myntra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCPVTON.py
102 lines (86 loc) · 3.55 KB
/
CPVTON.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw
import argparse
import os
import time
import sys
from networks import GMM, UnetGenerator, load_checkpoint
import json
# normalize inputs
transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
class CPVTON(object):
def __init__(self, gmm_path, tom_path, use_cuda=True):
opt = self.get_opt()
self.use_cuda = use_cuda
self.gmm = GMM(opt, use_cuda=use_cuda)
load_checkpoint(self.gmm, gmm_path, use_cuda=use_cuda)
self.gmm.eval()
self.tom = UnetGenerator(26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
load_checkpoint(self.tom, tom_path, use_cuda=use_cuda)
self.tom.eval()
if use_cuda:
self.gmm.cuda()
self.tom.cuda()
print("use_cuda = "+str(self.use_cuda))
def predict(self, parse_array, pose_map, human, c):
im = transformer(human)
c = transformer(c) # [-1,1]
# parse -> shape
parse_shape = (parse_array > 0).astype(np.float32)
# blur, downsample + upsample
parse_shape = Image.fromarray((parse_shape*255).astype(np.uint8))
parse_shape = parse_shape.resize((192//16, 256//16), Image.BILINEAR)
parse_shape = parse_shape.resize((192, 256), Image.BILINEAR)
shape = transformer(parse_shape)
parse_head = (parse_array == 1).astype(np.float32) + \
(parse_array == 2).astype(np.float32) + \
(parse_array == 4).astype(np.float32) + \
(parse_array == 13).astype(np.float32) + \
(parse_array == 9).astype(np.float32)
phead = torch.from_numpy(parse_head) # [0,1]
im_h = im * phead - (1 - phead)
agnostic = torch.cat([shape, im_h, pose_map], 0)
if self.use_cuda:
# batch==1
agnostic = agnostic.unsqueeze(0).cuda()
c = c.unsqueeze(0).cuda()
# warp result
grid, theta = self.gmm(agnostic.cuda(), c.cuda())
c_warp = F.grid_sample(c.cuda(), grid, padding_mode='border')
else:
agnostic = agnostic.unsqueeze(0)
c = c.unsqueeze(0)
grid, theta = self.gmm(agnostic, c)
c_warp = F.grid_sample(c, grid, padding_mode='border')
tensor = (c_warp.detach().clone()+1)*0.5 * 255
tensor = tensor.cpu().clamp(0, 255)
array = tensor.numpy().astype('uint8')
c_warp = transformer(np.transpose(array[0], axes=(1, 2, 0)))
c_warp = c_warp.unsqueeze(0)
if self.use_cuda:
outputs = self.tom(torch.cat([agnostic.cuda(), c_warp.cuda()], 1))
else:
outputs = self.tom(torch.cat([agnostic, c_warp], 1))
p_rendered, m_composite = torch.split(outputs, 3, 1)
p_rendered = torch.tanh(p_rendered)
m_composite = torch.sigmoid(m_composite)
if self.use_cuda:
p_tryon = c_warp.cuda() * m_composite + p_rendered * (1 - m_composite)
else:
p_tryon = c_warp * m_composite + p_rendered * (1 - m_composite)
return (p_tryon, c_warp)
def get_opt(self):
parser = argparse.ArgumentParser()
parser.add_argument("--fine_width", type=int, default=192)
parser.add_argument("--fine_height", type=int, default=256)
parser.add_argument("--grid_size", type=int, default=5)
opt = parser.parse_args()
return opt