-
Notifications
You must be signed in to change notification settings - Fork 0
/
explain_main.py
92 lines (72 loc) · 3.56 KB
/
explain_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# The major idea of the overall GNN model explanation
import argparse
import os
import dgl
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import load_graphs
from models import dummy_gnn_model
from NodeExplainerModule import NodeExplainerModule
from utils_graph import extract_subgraph, visualize_sub_graph
def main(args):
# load an exisitng model or ask for training a model
model_path = os.path.join('./', 'dummy_model_{}.pth'.format(args.dataset))
if os.path.exists(model_path):
model_stat_dict = th.load(model_path)
else:
raise FileExistsError('No Saved Model file. Please train a GNN model first...')
# load graph, feat, and label
g_list, label_dict = load_graphs('./'+args.dataset+'.bin')
graph = g_list[0]
labels = graph.ndata['label']
feats = graph.ndata['feat']
num_classes = max(labels).item() + 1
feat_dim = feats.shape[1]
hid_dim = label_dict['hid_dim'].item()
# create a model and load from state_dict
dummy_model = dummy_gnn_model(feat_dim, hid_dim, num_classes)
dummy_model.load_state_dict(model_stat_dict)
# Choose a node of the target class to be explained and extract its subgraph.
# Here just pick the first one of the target class.
target_list = [i for i, e in enumerate(labels) if e==args.target_class]
n_idx = th.tensor([target_list[0]])
# Extract the computation graph within k-hop of target node and use it for explainability
sub_graph, ori_n_idxes, new_n_idx = extract_subgraph(graph, n_idx, hops=args.hop)
#Sub-graph features.
sub_feats = feats[ori_n_idxes,:]
# create an explainer
explainer = NodeExplainerModule(model=dummy_model,
num_edges=sub_graph.number_of_edges(),
node_feat_dim=feat_dim)
# define optimizer
optim = th.optim.Adam([explainer.edge_mask, explainer.node_feat_mask], lr=args.lr, weight_decay=args.wd)
# train the explainer for the given node
dummy_model.eval()
model_logits = dummy_model(sub_graph, sub_feats)
model_predict = F.one_hot(th.argmax(model_logits, dim=-1), num_classes)
for epoch in range(args.epochs):
explainer.train()
exp_logits = explainer(sub_graph, sub_feats)
loss = explainer._loss(exp_logits[new_n_idx], model_predict[new_n_idx])
optim.zero_grad()
loss.backward()
optim.step()
# visualize the importance of edges
edge_weights = explainer.edge_mask.sigmoid().detach()
visualize_sub_graph(sub_graph, edge_weights.numpy(), ori_n_idxes, n_idx)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Demo of GNN explainer in DGL')
parser.add_argument('--dataset', type=str, default='syn1',
help='The dataset to be explained.')
parser.add_argument('--target_class', type=int, default='1',
help='The class to be explained. In the synthetic 1 dataset, Valid option is from 0 to 4'
'Will choose the first node in this class to explain')
parser.add_argument('--hop', type=int, default='2',
help='The hop number of the computation sub-graph. For syn1 and syn2, k=2. For syn3, syn4, and syn5, k=4.')
parser.add_argument('--epochs', type=int, default=200, help='The number of epochs.')
parser.add_argument('--lr', type=float, default=0.01, help='The learning rate.')
parser.add_argument('--wd', type=float, default=0.0, help='Weight decay.')
args = parser.parse_args()
print(args)
main(args)