-
Notifications
You must be signed in to change notification settings - Fork 116
/
main.py
executable file
·141 lines (119 loc) · 6.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from train import *
if __name__ == '__main__':
# All necessary arguments are defined in args.py
args = Args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
print('CUDA', args.cuda)
print('File name prefix',args.fname)
# check if necessary directories exist
if not os.path.isdir(args.model_save_path):
os.makedirs(args.model_save_path)
if not os.path.isdir(args.graph_save_path):
os.makedirs(args.graph_save_path)
if not os.path.isdir(args.figure_save_path):
os.makedirs(args.figure_save_path)
if not os.path.isdir(args.timing_save_path):
os.makedirs(args.timing_save_path)
if not os.path.isdir(args.figure_prediction_save_path):
os.makedirs(args.figure_prediction_save_path)
if not os.path.isdir(args.nll_save_path):
os.makedirs(args.nll_save_path)
time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
# logging.basicConfig(filename='logs/train' + time + '.log', level=logging.DEBUG)
if args.clean_tensorboard:
if os.path.isdir("tensorboard"):
shutil.rmtree("tensorboard")
configure("tensorboard/run"+time, flush_secs=5)
graphs = create_graphs.create(args)
# split datasets
random.seed(123)
shuffle(graphs)
graphs_len = len(graphs)
graphs_test = graphs[int(0.8 * graphs_len):]
graphs_train = graphs[0:int(0.8*graphs_len)]
graphs_validate = graphs[0:int(0.2*graphs_len)]
# if use pre-saved graphs
# dir_input = "/dfs/scratch0/jiaxuany0/graphs/"
# fname_test = dir_input + args.note + '_' + args.graph_type + '_' + str(args.num_layers) + '_' + str(
# args.hidden_size_rnn) + '_test_' + str(0) + '.dat'
# graphs = load_graph_list(fname_test, is_real=True)
# graphs_test = graphs[int(0.8 * graphs_len):]
# graphs_train = graphs[0:int(0.8 * graphs_len)]
# graphs_validate = graphs[int(0.2 * graphs_len):int(0.4 * graphs_len)]
graph_validate_len = 0
for graph in graphs_validate:
graph_validate_len += graph.number_of_nodes()
graph_validate_len /= len(graphs_validate)
print('graph_validate_len', graph_validate_len)
graph_test_len = 0
for graph in graphs_test:
graph_test_len += graph.number_of_nodes()
graph_test_len /= len(graphs_test)
print('graph_test_len', graph_test_len)
args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))])
min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))])
# args.max_num_node = 2000
# show graphs statistics
print('total graph num: {}, training set: {}'.format(len(graphs),len(graphs_train)))
print('max number node: {}'.format(args.max_num_node))
print('max/min number edge: {}; {}'.format(max_num_edge,min_num_edge))
print('max previous node: {}'.format(args.max_prev_node))
# save ground truth graphs
## To get train and test set, after loading you need to manually slice
save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat')
save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat')
print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat')
### comment when normal training, for graph completion only
# p = 0.5
# for graph in graphs_train:
# for node in list(graph.nodes()):
# # print('node',node)
# if np.random.rand()>p:
# graph.remove_node(node)
# for edge in list(graph.edges()):
# # print('edge',edge)
# if np.random.rand()>p:
# graph.remove_edge(edge[0],edge[1])
### dataset initialization
if 'nobfs' in args.note:
print('nobfs')
dataset = Graph_sequence_sampler_pytorch_nobfs(graphs_train, max_num_node=args.max_num_node)
args.max_prev_node = args.max_num_node-1
if 'barabasi_noise' in args.graph_type:
print('barabasi_noise')
dataset = Graph_sequence_sampler_pytorch_canonical(graphs_train,max_prev_node=args.max_prev_node)
args.max_prev_node = args.max_num_node - 1
else:
dataset = Graph_sequence_sampler_pytorch(graphs_train,max_prev_node=args.max_prev_node,max_num_node=args.max_num_node)
sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
num_samples=args.batch_size*args.batch_ratio, replacement=True)
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers,
sampler=sample_strategy)
### model initialization
## Graph RNN VAE model
# lstm = LSTM_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_lstm,
# hidden_size=args.hidden_size, num_layers=args.num_layers).cuda()
if 'GraphRNN_VAE_conditional' in args.note:
rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
has_output=False).cuda()
output = MLP_VAE_conditional_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output, y_size=args.max_prev_node).cuda()
elif 'GraphRNN_MLP' in args.note:
rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
has_output=False).cuda()
output = MLP_plain(h_size=args.hidden_size_rnn, embedding_size=args.embedding_size_output, y_size=args.max_prev_node).cuda()
elif 'GraphRNN_RNN' in args.note:
rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
has_output=True, output_size=args.hidden_size_rnn_output).cuda()
output = GRU_plain(input_size=1, embedding_size=args.embedding_size_rnn_output,
hidden_size=args.hidden_size_rnn_output, num_layers=args.num_layers, has_input=True,
has_output=True, output_size=1).cuda()
### start training
train(args, dataset_loader, rnn, output)
### graph completion
# train_graph_completion(args,dataset_loader,rnn,output)
### nll evaluation
# train_nll(args, dataset_loader, dataset_loader, rnn, output, max_iter = 200, graph_validate_len=graph_validate_len,graph_test_len=graph_test_len)