-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
74 lines (61 loc) · 2.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from itertools import product
import sys
import argparse
from utils import logger
from datasets import get_dataset
from train_eval import cross_validation_with_val_set
from param_parser import parameter_parser
from utils import tab_printer
import torch
import random
import numpy as np
args = parameter_parser()
#tab_printer(args)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if not args.no_cuda:
torch.cuda.manual_seed(args.seed)
def create_n_filter_triples(dataset, feat_str, gfn_add_ak3=True,
gfn_reall=False, reddit_odeg10=True,
dd_odeg10_ak1=True):
triples_filtered = []
if gfn_add_ak3:
feat_str += '+ak3'
if reddit_odeg10 and dataset in [
'REDDIT-BINARY', 'REDDIT-MULTI-5K', 'REDDIT-MULTI-12K']:
feat_str = feat_str.replace('odeg100', 'odeg10')
if dd_odeg10_ak1 and dataset in ['DD']:
feat_str = feat_str.replace('odeg100', 'odeg10')
feat_str = feat_str.replace('ak3', 'ak1')
triples_filtered.append((dataset, feat_str))
return triples_filtered
def run_exp_lib(dataset_feat_net_triples):
results = []
sys.stdout.flush()
for (dataset_name, feat_str) in dataset_feat_net_triples:
sys.stdout.flush()
dataset = get_dataset(
dataset_name, sparse=True, feat_str=feat_str, root=args.data_root)
max_node_num = max(dataset.data.num_nodes)
print('Data: {}, Max Node Num: {}'.format(dataset_name, max_node_num))
train_acc, acc, std, duration = cross_validation_with_val_set(
args,
dataset,
max_node_num=max_node_num,
folds=10,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_step_size=args.lr_decay_step_size,
weight_decay=args.weight_decay,
epoch_select=args.epoch_select,
with_eval_mode=args.with_eval_mode,
logger=logger)
def main():
dataset = args.dataset
feat_str = 'deg+odeg100'
run_exp_lib(create_n_filter_triples(dataset, feat_str))
if __name__ == '__main__':
main()