-
Notifications
You must be signed in to change notification settings - Fork 31
/
util.py
117 lines (93 loc) · 3.98 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import itertools, imageio, torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets
from scipy.misc import imresize
def show_result(G, x_, y_, num_epoch, show = False, save = False, path = 'result.png'):
# G.eval()
test_images = G(x_)
size_figure_grid = 3
fig, ax = plt.subplots(x_.size()[0], size_figure_grid, figsize=(5, 5))
for i, j in itertools.product(range(x_.size()[0]), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
for i in range(x_.size()[0]):
ax[i, 0].cla()
ax[i, 0].imshow((x_[i].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
ax[i, 1].cla()
ax[i, 1].imshow((test_images[i].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
ax[i, 2].cla()
ax[i, 2].imshow((y_[i].numpy().transpose(1, 2, 0) + 1) / 2)
label = 'Epoch {0}'.format(num_epoch)
fig.text(0.5, 0.04, label, ha='center')
if save:
plt.savefig(path)
if show:
plt.show()
else:
plt.close()
def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
x = range(len(hist['D_losses']))
y1 = hist['D_losses']
y2 = hist['G_losses']
plt.plot(x, y1, label='D_loss')
plt.plot(x, y2, label='G_loss')
plt.xlabel('Iter')
plt.ylabel('Loss')
plt.legend(loc=4)
plt.grid(True)
plt.tight_layout()
if save:
plt.savefig(path)
if show:
plt.show()
else:
plt.close()
def generate_animation(root, model, opt):
images = []
for e in range(opt.train_epoch):
img_name = root + 'Fixed_results/' + model + str(e + 1) + '.png'
images.append(imageio.imread(img_name))
imageio.mimsave(root + model + 'generate_animation.gif', images, fps=5)
def data_load(path, subfolder, transform, batch_size, shuffle=True):
dset = datasets.ImageFolder(path, transform)
ind = dset.class_to_idx[subfolder]
n = 0
for i in range(dset.__len__()):
if ind != dset.imgs[n][1]:
del dset.imgs[n]
n -= 1
n += 1
return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle)
def imgs_resize(imgs, resize_scale = 286):
outputs = torch.FloatTensor(imgs.size()[0], imgs.size()[1], resize_scale, resize_scale)
for i in range(imgs.size()[0]):
img = imresize(imgs[i].numpy(), [resize_scale, resize_scale])
outputs[i] = torch.FloatTensor((img.transpose(2, 0, 1).astype(np.float32).reshape(-1, imgs.size()[1], resize_scale, resize_scale) - 127.5) / 127.5)
return outputs
def random_crop(imgs1, imgs2, crop_size = 256):
outputs1 = torch.FloatTensor(imgs1.size()[0], imgs1.size()[1], crop_size, crop_size)
outputs2 = torch.FloatTensor(imgs2.size()[0], imgs2.size()[1], crop_size, crop_size)
for i in range(imgs1.size()[0]):
img1 = imgs1[i]
img2 = imgs2[i]
rand1 = np.random.randint(0, imgs1.size()[2] - crop_size)
rand2 = np.random.randint(0, imgs2.size()[2] - crop_size)
outputs1[i] = img1[:, rand1: crop_size + rand1, rand2: crop_size + rand2]
outputs2[i] = img2[:, rand1: crop_size + rand1, rand2: crop_size + rand2]
return outputs1, outputs2
def random_fliplr(imgs1, imgs2):
outputs1 = torch.FloatTensor(imgs1.size())
outputs2 = torch.FloatTensor(imgs2.size())
for i in range(imgs1.size()[0]):
if torch.rand(1)[0] < 0.5:
img1 = torch.FloatTensor(
(np.fliplr(imgs1[i].numpy().transpose(1, 2, 0)).transpose(2, 0, 1).reshape(-1, imgs1.size()[1], imgs1.size()[2], imgs1.size()[3]) + 1) / 2)
outputs1[i] = (img1 - 0.5) / 0.5
img2 = torch.FloatTensor(
(np.fliplr(imgs2[i].numpy().transpose(1, 2, 0)).transpose(2, 0, 1).reshape(-1, imgs2.size()[1], imgs2.size()[2], imgs2.size()[3]) + 1) / 2)
outputs2[i] = (img2 - 0.5) / 0.5
else:
outputs1[i] = imgs1[i]
outputs2[i] = imgs2[i]
return outputs1, outputs2