-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtrain.py
147 lines (129 loc) · 6.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import time
import os
from time import gmtime, strftime
from datetime import datetime
import torch
from torch import nn
from torch.autograd import Variable
import torch.utils.data
from torchfoldext import FoldExt
import util
from dynamicplot import DynamicPlot
from grassdata import GRASSDataset
from grassmodel import GRASSEncoder
from grassmodel import GRASSDecoder
import grassmodel
config = util.get_args()
config.cuda = not config.no_cuda
if config.gpu<0 and config.cuda:
config.gpu = 0
torch.cuda.set_device(config.gpu)
if config.cuda and torch.cuda.is_available():
print("Using CUDA on GPU ", config.gpu)
else:
print("Not using CUDA.")
encoder = GRASSEncoder(config)
decoder = GRASSDecoder(config)
if config.cuda:
encoder.cuda()
decoder.cuda()
print("Loading data ...... ", end='', flush=True)
grass_data = GRASSDataset(config.data_path)
def my_collate(batch):
return batch
train_iter = torch.utils.data.DataLoader(grass_data, batch_size=config.batch_size, shuffle=True, collate_fn=my_collate)
print("DONE")
encoder_opt = torch.optim.Adam(encoder.parameters(), lr=1e-3)
decoder_opt = torch.optim.Adam(decoder.parameters(), lr=1e-3)
print("Start training ...... ")
start = time.time()
if config.save_snapshot:
if not os.path.exists(config.save_path):
os.makedirs(config.save_path)
snapshot_folder = os.path.join(config.save_path, 'snapshots_'+strftime("%Y-%m-%d_%H-%M-%S",gmtime()))
if not os.path.exists(snapshot_folder):
os.makedirs(snapshot_folder)
if config.save_log:
fd_log = open('training_log.log', mode='a')
fd_log.write('\n\nTraining log at '+datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
fd_log.write('\n#epoch: {}'.format(config.epochs))
fd_log.write('\nbatch_size: {}'.format(config.batch_size))
fd_log.write('\ncuda: {}'.format(config.cuda))
fd_log.flush()
header = ' Time Epoch Iteration Progress(%) ReconLoss KLDivLoss TotalLoss'
log_template = ' '.join('{:>9s},{:>5.0f}/{:<5.0f},{:>5.0f}/{:<5.0f},{:>9.1f}%,{:>11.2f},{:>10.2f},{:>10.2f}'.split(','))
total_iter = config.epochs * len(train_iter)
if not config.no_plot:
plot_x = [x for x in range(total_iter)]
plot_total_loss = [None for x in range(total_iter)]
plot_recon_loss = [None for x in range(total_iter)]
plot_kldiv_loss = [None for x in range(total_iter)]
dyn_plot = DynamicPlot(title='Training loss over epochs (GRASS)', xdata=plot_x, ydata={'Total_loss':plot_total_loss, 'Reconstruction_loss':plot_recon_loss, 'KL_divergence_loss':plot_kldiv_loss})
iter_id = 0
max_loss = 0
for epoch in range(config.epochs):
print(header)
for batch_idx, batch in enumerate(train_iter):
# Initialize torchfold for *encoding*
enc_fold = FoldExt(cuda=config.cuda)
enc_fold_nodes = [] # list of fold nodes for encoding
# Collect computation nodes recursively from encoding process
for example in batch:
enc_fold_nodes.append(grassmodel.encode_structure_fold(enc_fold, example))
# Apply the computations on the encoder model
enc_fold_nodes = enc_fold.apply(encoder, [enc_fold_nodes])
# Split into a list of fold nodes per example
enc_fold_nodes = torch.split(enc_fold_nodes[0], 1, 0)
# Initialize torchfold for *decoding*
dec_fold = FoldExt(cuda=config.cuda)
# Collect computation nodes recursively from decoding process
dec_fold_nodes = []
kld_fold_nodes = []
for example, fnode in zip(batch, enc_fold_nodes):
root_code, kl_div = torch.chunk(fnode, 2, 1)
dec_fold_nodes.append(grassmodel.decode_structure_fold(dec_fold, root_code, example))
kld_fold_nodes.append(kl_div)
# Apply the computations on the decoder model
total_loss = dec_fold.apply(decoder, [dec_fold_nodes, kld_fold_nodes])
# the first dim of total_loss is for reconstruction and the second for KL divergence
recon_loss = total_loss[0].sum() / len(batch) # avg. reconstruction loss per example
kldiv_loss = total_loss[1].sum().mul(-0.05) / len(batch) # avg. KL divergence loss per example
total_loss = recon_loss + kldiv_loss
# Do parameter optimization
encoder_opt.zero_grad()
decoder_opt.zero_grad()
total_loss.backward()
encoder_opt.step()
decoder_opt.step()
# Report statistics
if batch_idx % config.show_log_every == 0:
print(log_template.format(strftime("%H:%M:%S",time.gmtime(time.time()-start)),
epoch, config.epochs, 1+batch_idx, len(train_iter),
100. * (1+batch_idx+len(train_iter)*epoch) / (len(train_iter)*config.epochs),
recon_loss.data[0], kldiv_loss.data[0], total_loss.data[0]))
# Plot losses
if not config.no_plot:
plot_total_loss[iter_id] = total_loss.data[0]
plot_recon_loss[iter_id] = recon_loss.data[0]
plot_kldiv_loss[iter_id] = kldiv_loss.data[0]
max_loss = max(max_loss, total_loss.data[0], recon_loss.data[0], kldiv_loss.data[0])
dyn_plot.setxlim(0., (iter_id+1)*1.05)
dyn_plot.setylim(0., max_loss*1.05)
dyn_plot.update_plots(ydata={'Total_loss':plot_total_loss, 'Reconstruction_loss':plot_recon_loss, 'KL_divergence_loss':plot_kldiv_loss})
iter_id += 1
# Save snapshots of the models being trained
if config.save_snapshot and (epoch+1) % config.save_snapshot_every == 0 :
print("Saving snapshots of the models ...... ", end='', flush=True)
torch.save(encoder, snapshot_folder+'//vae_encoder_model_epoch_{}_loss_{:.2f}.pkl'.format(epoch+1, total_loss.data[0]))
torch.save(decoder, snapshot_folder+'//vae_decoder_model_epoch_{}_loss_{:.2f}.pkl'.format(epoch+1, total_loss.data[0]))
print("DONE")
# Save training log
if config.save_log and (epoch+1) % config.save_log_every == 0 :
fd_log = open('training_log.log', mode='a')
fd_log.write('\nepoch:{} recon_loss:{:.2f} kld_loss:{:.2f} total_loss:{:.2f}'.format(epoch+1, recon_loss.data[0], kldiv_loss.data[0], total_loss.data[0]))
fd_log.close()
# Save the final models
print("Saving final models ...... ", end='', flush=True)
torch.save(encoder, config.save_path+'//vae_encoder_model.pkl')
torch.save(decoder, config.save_path+'//vae_decoder_model.pkl')
print("DONE")