-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
133 lines (112 loc) · 4.36 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import datetime
import matplotlib.pyplot as plt
import os
import random
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.init as init
from pprint import pprint
########################################################################################################################
# configuration #
########################################################################################################################
def mkdir_p(path):
import errno
try:
os.makedirs(path)
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
print('Directory {} already exists.'.format(path))
else:
raise
def date_filename(base_dir='./'):
dt = datetime.datetime.now()
return os.path.join(base_dir, '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second
))
def setup_log_dir(opts):
log_dir = '{}'.format(date_filename(opts['log_dir']))
mkdir_p(log_dir)
return log_dir
def save_arg_dict(opts, filename='settings.txt'):
def _format_value(v):
if isinstance(v, float):
return '{:.4f}'.format(v)
elif isinstance(v, int):
return '{:d}'.format(v)
else:
return '{}'.format(v)
save_path = os.path.join(opts['log_dir'], filename)
with open(save_path, 'w') as f:
for key, value in opts.items():
f.write('{}\t{}\n'.format(key, _format_value(value)))
print('Saved settings to {}'.format(save_path))
def setup(args):
opts = args.__dict__.copy()
cudnn.benchmark = False
cudnn.deterministic = True
# Seed
if opts['seed'] is None:
opts['seed'] = random.randint(1, 10000)
random.seed(opts['seed'])
torch.manual_seed(opts['seed'])
# Dataset
from configure import dataset_based_configure
opts = dataset_based_configure(opts)
assert opts['path_to_dataset'] is not None, 'Expect path to dataset to be set.'
if not os.path.exists(opts['path_to_dataset']):
if opts['dataset'] == 'cycles':
from cycles import generate_dataset
generate_dataset(opts['min_size'], opts['max_size'],
opts['ds_size'], opts['path_to_dataset'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))
# Optimization
if opts['clip_grad']:
assert opts['clip_grad'] is not None, 'Expect the gradient norm constraint to be set.'
# Log
print('Prepare logging directory...')
log_dir = setup_log_dir(opts)
opts['log_dir'] = log_dir
mkdir_p(log_dir + '/samples')
plt.switch_backend('Agg')
save_arg_dict(opts)
pprint(opts)
return opts
########################################################################################################################
# model #
########################################################################################################################
def weights_init(m):
'''
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
def dgmg_message_weight_init(m):
"""
This is similar as the function above where we initialize linear layers from a normal distribution with std
1./10 as suggested by the author. This should only be used for the message passing functions, i.e. fe's in the
paper.
"""
def _weight_init(m):
if isinstance(m, nn.Linear):
init.normal_(m.weight.data, std=1./10)
init.normal_(m.bias.data, std=1./10)
else:
raise ValueError('Expected the input to be of type nn.Linear!')
if isinstance(m, nn.ModuleList):
for layer in m:
layer.apply(_weight_init)
else:
m.apply(_weight_init)