-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
executable file
·118 lines (97 loc) · 4.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# ARTGAN neural network implementation
# Training and evaluating
import torch
import random
import torchvision
import os
import argparse
import utils as ut
from pathlib import Path
from torchvision import transforms, utils
from nn.ArtGAN import ArtGAN
from nn.Generator import Generator, Dec, zNet
from nn.Discriminator import Discriminator, Enc, clsNet
from WikiartDataset import WikiartDataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def main():
# We set random seed so it's reproducible
seed = 1000
random.seed(seed)
torch.manual_seed(seed)
# Parser
# We say with which class dataset we are working with
# We also say which version to save with and if
# we retraining
parser = argparse.ArgumentParser()
parser.add_argument("class_dataset", type=str)
parser.add_argument("version", type=int)
parser.add_argument("--retrain", type=str, default=None)
args = parser.parse_args()
# Training using wikiart dataset
class_dataset = args.class_dataset # style - artist - genre
version = args.version # number of the version
num_folder = "../save/" + class_dataset + "_v" + str(version)
# Check if you are in the folder Deep_Learning_Dataset
if not os.path.exists(num_folder):
os.makedirs(num_folder)
transform = transforms.Compose(
[
transforms.Resize(64),
transforms.RandomCrop(64),
transforms.ToTensor(),
])
# Choice of kind of dataset
print("Creating dataset with wikiart")
trainset_wikiart = WikiartDataset(0, class_dataset + "_train.csv", "../wikiart/", 'Train', transform)
testset_wikiart = WikiartDataset(0, class_dataset + "_val.csv", "../wikiart/", 'Test', transform)
with open('../wikiart/' + class_dataset + '_class.txt', 'r') as f:
cl = [line.strip() for line in f]
cl.append("fake")
classes = tuple(cl)
n_classes = len(classes) - 1
# We determine the batch size and call the loader
batch_size = 128
print("Calling loader")
trainloader_wikiart = torch.utils.data.DataLoader(trainset_wikiart, batch_size=batch_size, shuffle=True)
testloader_wikiart = torch.utils.data.DataLoader(testset_wikiart, batch_size=batch_size, shuffle=True)
use_cuda = True
# If we are retraining, we have to load the checkpoints
if args.retrain:
checkpoint = torch.load(args.retrain)
gen = Generator(zNet(input_size=100 + n_classes), Dec())
dis = Discriminator(clsNet(num_classes=n_classes), Enc())
if use_cuda and torch.cuda.is_available():
gen.cuda()
dis.cuda()
g_op = torch.optim.RMSprop(gen.parameters(), lr=0.001, alpha=0.9)
d_op = torch.optim.RMSprop(dis.parameters(), lr=0.001, alpha=0.9)
epo = checkpoint['epoch']
gen.load_state_dict(checkpoint["G"])
dis.load_state_dict(checkpoint["D"])
d_op.load_state_dict(checkpoint["opt_D"])
g_op.load_state_dict(checkpoint["opt_G"])
net = ArtGAN(img_size=64, input_dim_enc=3,
z_dim=100, num_classes=n_classes,
out_dim_zNet=1024, G=gen, D=dis, retrain=True)
# If we are not loading the neural network, we create one
else:
net = ArtGAN(img_size=64, input_dim_enc=3,
z_dim=100, num_classes=n_classes,
out_dim_zNet=1024)
# We say if we are using cuda
if use_cuda and torch.cuda.is_available():
print("using cuda")
net.cuda()
# We begin the training
print("Beginning training . . .")
if args.retrain:
d_loss_l, g_loss_l = net.train(trainloader_wikiart, None, classes, epochs=100,
cuda=use_cuda and torch.cuda.is_available(), path=num_folder, g_op=g_op,
d_op=d_op, init_epoch=epo + 1)
else:
d_loss_l, g_loss_l = net.train(trainloader_wikiart, None, classes, epochs=100,
cuda=use_cuda and torch.cuda.is_available(), path=num_folder)
print("Ended!")
if __name__ == '__main__':
main()