forked from CCIIPLab/GCE-GNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
110 lines (95 loc) · 4.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import time
import argparse
import pickle
from model import *
from utils import *
def init_seed(seed=None):
if seed is None:
seed = int(time.time() * 1000 // 1000)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='diginetica', help='diginetica/Nowplaying/Tmall')
parser.add_argument('--hiddenSize', type=int, default=100)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--activate', type=str, default='relu')
parser.add_argument('--n_sample_all', type=int, default=12)
parser.add_argument('--n_sample', type=int, default=12)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.001, help='learning rate.')
parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay.')
parser.add_argument('--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay.')
parser.add_argument('--l2', type=float, default=1e-5, help='l2 penalty ')
parser.add_argument('--n_iter', type=int, default=1) # [1, 2]
parser.add_argument('--dropout_gcn', type=float, default=0, help='Dropout rate.') # [0, 0.2, 0.4, 0.6, 0.8]
parser.add_argument('--dropout_local', type=float, default=0, help='Dropout rate.') # [0, 0.5]
parser.add_argument('--dropout_global', type=float, default=0.5, help='Dropout rate.')
parser.add_argument('--validation', action='store_true', help='validation')
parser.add_argument('--valid_portion', type=float, default=0.1, help='split the portion')
parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.')
parser.add_argument('--patience', type=int, default=3)
opt = parser.parse_args()
def main():
init_seed(2020)
if opt.dataset == 'diginetica':
num_node = 43098
opt.n_iter = 2
opt.dropout_gcn = 0.2
opt.dropout_local = 0.0
elif opt.dataset == 'Nowplaying':
num_node = 60417
opt.n_iter = 1
opt.dropout_gcn = 0.0
opt.dropout_local = 0.0
elif opt.dataset == 'Tmall':
num_node = 40728
opt.n_iter = 1
opt.dropout_gcn = 0.6
opt.dropout_local = 0.5
else:
num_node = 310
train_data = pickle.load(open('datasets/' + opt.dataset + '/train.txt', 'rb'))
if opt.validation:
train_data, valid_data = split_validation(train_data, opt.valid_portion)
test_data = valid_data
else:
test_data = pickle.load(open('datasets/' + opt.dataset + '/test.txt', 'rb'))
adj = pickle.load(open('datasets/' + opt.dataset + '/adj_' + str(opt.n_sample_all) + '.pkl', 'rb'))
num = pickle.load(open('datasets/' + opt.dataset + '/num_' + str(opt.n_sample_all) + '.pkl', 'rb'))
train_data = Data(train_data)
test_data = Data(test_data)
adj, num = handle_adj(adj, num_node, opt.n_sample_all, num)
model = trans_to_cuda(CombineGraph(opt, num_node, adj, num))
print(opt)
start = time.time()
best_result = [0, 0]
best_epoch = [0, 0]
bad_counter = 0
for epoch in range(opt.epoch):
print('-------------------------------------------------------')
print('epoch: ', epoch)
hit, mrr = train_test(model, train_data, test_data)
flag = 0
if hit >= best_result[0]:
best_result[0] = hit
best_epoch[0] = epoch
flag = 1
if mrr >= best_result[1]:
best_result[1] = mrr
best_epoch[1] = epoch
flag = 1
print('Current Result:')
print('\tRecall@20:\t%.4f\tMMR@20:\t%.4f' % (hit, mrr))
print('Best Result:')
print('\tRecall@20:\t%.4f\tMMR@20:\t%.4f\tEpoch:\t%d,\t%d' % (
best_result[0], best_result[1], best_epoch[0], best_epoch[1]))
bad_counter += 1 - flag
if bad_counter >= opt.patience:
break
print('-------------------------------------------------------')
end = time.time()
print("Run time: %f s" % (end - start))
if __name__ == '__main__':
main()