diff --git a/README.md b/README.md index 5e0cbf6..a38ea13 100644 --- a/README.md +++ b/README.md @@ -1 +1,62 @@ -# GNNLens2 + + + + +GNNLens2 is an interactive visualization tool for graph neural networks (GNN). It allows seamless integration with [deep graph library (DGL)](https://github.com/dmlc/dgl) and can meet your various visualization requirements for presentation, analysis and model explanation. It is an open source version of [GNNLens](https://arxiv.org/abs/2011.11048) with simplification and extension. + +## Installation + +### Requirements + +- [PyTorch](https://pytorch.org/) +- [DGL](https://www.dgl.ai/pages/start.html) +- Flask-CORS + +You can install Flask-CORS with + +```bash +pip install -U flask-cors +``` + +### Installation for the latest stable version + +```bash +pip install gnnlens +``` + +### Installation from source + +If you want to try experimental features, you can install from source as follows: + +```bash +git clone https://github.com/dmlc/GNNLens2.git +cd GNNLens2/python +python setup.py install +``` + +### Verifying successful installation + +Once you have installed the package, you can verify the success of installation with + +```python +import gnnlens + +print(gnnlens.__version__) +# 0.1.0 +``` + +## Tutorials + +We provide a set of tutorials to get you started with the library: +- [Tutorial 1: Graph structure](resources/tutorials/tutorial_1_graph.md) +- [Tutorial 2: Ground truth and predicted node labels](resources/tutorials/tutorial_2_nlabel.md) +- [Tutorial 3: Edge weights and attention](resources/tutorials/tutorial_3_eweight.md) +- [Tutorial 4: Weighted subgraphs and explanation methods](resources/tutorials/tutorial_4_subgraph.md) + +## Team + +**HKUST VisLab**: [Zhihua Jin](jnzhihuoo1), [Huamin Qu](http://huamin.org/) + +**AWS Shanghai AI Lab**: [Mufei Li](https://github.com/mufeili), [Wanru Zhao](https://github.com/Ryan0v0) (work done during internship), [Jian Zhang](https://github.com/zhjwy9343), [Minjie Wang](https://jermainewang.github.io/) + +**SMU**: [Yong Wang](http://yong-wang.org/) diff --git a/python/gnnlens/__init__.py b/python/gnnlens/__init__.py index 6046be8..76d5c64 100644 --- a/python/gnnlens/__init__.py +++ b/python/gnnlens/__init__.py @@ -15,4 +15,5 @@ # specific language governing permissions and limitations # under the License. +from .libinfo import __version__ from .writer import * diff --git a/python/gnnlens/libinfo.py b/python/gnnlens/libinfo.py new file mode 100644 index 0000000..831c97c --- /dev/null +++ b/python/gnnlens/libinfo.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# current version +__version__ = '0.1.0' diff --git a/python/gnnlens/writer.py b/python/gnnlens/writer.py index 154c29d..a425a55 100644 --- a/python/gnnlens/writer.py +++ b/python/gnnlens/writer.py @@ -32,7 +32,7 @@ class Writer(): be either a relative path or an absolute path. """ def __init__(self, logdir): - os.makedirs(logdir, exist_ok=False) + os.makedirs(logdir) self.logdir = logdir self.graph_names = [] self.graph_data = dict() diff --git a/resources/README.png b/resources/README.png new file mode 100644 index 0000000..e531e9c Binary files /dev/null and b/resources/README.png differ diff --git a/resources/figures/tutorial_1/control_panel.png b/resources/figures/tutorial_1/control_panel.png new file mode 100644 index 0000000..cc83bd8 Binary files /dev/null and b/resources/figures/tutorial_1/control_panel.png differ diff --git a/resources/figures/tutorial_1/drag.png b/resources/figures/tutorial_1/drag.png new file mode 100644 index 0000000..27381ee Binary files /dev/null and b/resources/figures/tutorial_1/drag.png differ diff --git a/resources/figures/tutorial_1/empty_interface.png b/resources/figures/tutorial_1/empty_interface.png new file mode 100644 index 0000000..ac3b1bc Binary files /dev/null and b/resources/figures/tutorial_1/empty_interface.png differ diff --git a/resources/figures/tutorial_1/hover_node.png b/resources/figures/tutorial_1/hover_node.png new file mode 100644 index 0000000..02d8891 Binary files /dev/null and b/resources/figures/tutorial_1/hover_node.png differ diff --git a/resources/figures/tutorial_1/overview.png b/resources/figures/tutorial_1/overview.png new file mode 100644 index 0000000..1fa75b7 Binary files /dev/null and b/resources/figures/tutorial_1/overview.png differ diff --git a/resources/figures/tutorial_1/stop_simulation.png b/resources/figures/tutorial_1/stop_simulation.png new file mode 100644 index 0000000..db31905 Binary files /dev/null and b/resources/figures/tutorial_1/stop_simulation.png differ diff --git a/resources/figures/tutorial_1/subgraph_options.png b/resources/figures/tutorial_1/subgraph_options.png new file mode 100644 index 0000000..b9ec8f9 Binary files /dev/null and b/resources/figures/tutorial_1/subgraph_options.png differ diff --git a/resources/figures/tutorial_1/two_hop.png b/resources/figures/tutorial_1/two_hop.png new file mode 100644 index 0000000..c812e54 Binary files /dev/null and b/resources/figures/tutorial_1/two_hop.png differ diff --git a/resources/figures/tutorial_1/zoom.png b/resources/figures/tutorial_1/zoom.png new file mode 100644 index 0000000..470f011 Binary files /dev/null and b/resources/figures/tutorial_1/zoom.png differ diff --git a/resources/figures/tutorial_2/glyph.png b/resources/figures/tutorial_2/glyph.png new file mode 100644 index 0000000..994386a Binary files /dev/null and b/resources/figures/tutorial_2/glyph.png differ diff --git a/resources/figures/tutorial_2/nlabel_selector.png b/resources/figures/tutorial_2/nlabel_selector.png new file mode 100644 index 0000000..b99b20b Binary files /dev/null and b/resources/figures/tutorial_2/nlabel_selector.png differ diff --git a/resources/figures/tutorial_2/real_color.png b/resources/figures/tutorial_2/real_color.png new file mode 100644 index 0000000..c26d644 Binary files /dev/null and b/resources/figures/tutorial_2/real_color.png differ diff --git a/resources/figures/tutorial_2/subgraph_nlabel.png b/resources/figures/tutorial_2/subgraph_nlabel.png new file mode 100644 index 0000000..1f7c32c Binary files /dev/null and b/resources/figures/tutorial_2/subgraph_nlabel.png differ diff --git a/resources/figures/tutorial_3/eweight_options.png b/resources/figures/tutorial_3/eweight_options.png new file mode 100644 index 0000000..d77a6b1 Binary files /dev/null and b/resources/figures/tutorial_3/eweight_options.png differ diff --git a/resources/figures/tutorial_3/eweight_subgraph.png b/resources/figures/tutorial_3/eweight_subgraph.png new file mode 100644 index 0000000..ff05634 Binary files /dev/null and b/resources/figures/tutorial_3/eweight_subgraph.png differ diff --git a/resources/figures/tutorial_4/1_ig.png b/resources/figures/tutorial_4/1_ig.png new file mode 100644 index 0000000..3eb3a33 Binary files /dev/null and b/resources/figures/tutorial_4/1_ig.png differ diff --git a/resources/figures/tutorial_4/ig_subgraph.png b/resources/figures/tutorial_4/ig_subgraph.png new file mode 100644 index 0000000..f9e9e21 Binary files /dev/null and b/resources/figures/tutorial_4/ig_subgraph.png differ diff --git a/resources/figures/tutorial_4/subgraph.png b/resources/figures/tutorial_4/subgraph.png new file mode 100644 index 0000000..7711c56 Binary files /dev/null and b/resources/figures/tutorial_4/subgraph.png differ diff --git a/resources/logo.png b/resources/logo.png new file mode 100644 index 0000000..3c042e7 Binary files /dev/null and b/resources/logo.png differ diff --git a/resources/tutorials/tutorial_1_graph.md b/resources/tutorials/tutorial_1_graph.md new file mode 100644 index 0000000..8b669ae --- /dev/null +++ b/resources/tutorials/tutorial_1_graph.md @@ -0,0 +1,103 @@ +# Tutorial 1: Graph structure + +Graph structure plays a critical role in developing a GNN. You may want to visualize the whole graph to roughly understand the sparsity of it and if it has subgraphs of a particular pattern. + +GNNLens2 allows you to do it using a simple API with very little effort. + +## Data preparation + +First we load DGL’s built-in Cora and Citeseer dataset and retrieve their graph structures. + +```python +from dgl.data import CoraGraphDataset, CiteseerGraphDataset + +cora_dataset = CoraGraphDataset() +cora_graph = cora_dataset[0] +citeseer_dataset = CiteseerGraphDataset() +citeseer_graph = citeseer_dataset[0] +``` + +Next, we need to dump the graph structures to a local file that GNNLens2 can read. GNNLens2 provides a built-in class `Writer` for this purpose. You can add an arbitrary number of graphs, one at a time. + +Once you finish adding data, you need to call **writer.close()**. + +```python +from gnnlens import Writer + +# Specify the path to create a new directory for dumping data files. +writer = Writer('tutorial_graph') +writer.add_graph(name='Cora', graph=cora_graph) +writer.add_graph(name='Citeseer', graph=citeseer_graph) +# Finish dumping +writer.close() +``` + +## Launch GNNLens2 + +To launch GNNLens2, run the following command line. + +```bash +gnnlens --logdir tutorial_graph +``` + +By entering `localhost:7777` in your web browser address bar, you can see the GNNLens2 interface like below. `7777` is the default port GNNLens2 uses. You can specify an alternative one by adding `--port xxxx` after the command line and change the address in the web browser accordingly. + +## GNNLens2 Interface + +

+ +

+ +The interface is empty as no graph is selected. The control panel on the left has multiple selectors for users to make selections. The first selector is the graph selector. You can click it and select a graph to visualize from the drop-down list. The options in the drop-down list are the names you passed to `add_graph`. + +

+ +

+ +After you select a graph, GNNLens2 will plot the corresponding graph as below. GNNLens2 determines the graph layout (node positions) on the fly using a force-directed graph drawing algorithm. The algorithm simulates the physical forces on nodes. The simulation stops when you click the “Stop Simulation” button and starts when you click the same button again. + +

+ +

+ +For a large graph, you can view different parts of it by clicking on the overview box at the lower-right corner. + +

+ +

+ +You can drag the graph by pressing and holding the mouse button. The figure below is the result of dragging the graph to the right. + +

+ +

+ +You can also zoom in or out on the graph. The figure below is the result of zooming in on the graph. + +

+ +

+ +As you move the cursor to a particular node, GNNLens2 will display its node ID and highlight its one-hop neighborhood. + +

+ +

+ +If you want to examine a subgraph centered at a particular node, you can simply click on it. GNNLens2 will then display its two-hop subgraph by default and the node you clicked on will be highlighted. You can click the overview box to put the subgraph in the center. + +

+ +

+ +You can switch the subgraph option in the “Subgraph” drop-down list. + +

+ +

+ +To terminate GNNLens2, use `ctrl + c`. + +## Next + +So far, we've seen how to visualize a graph structure. Now let us look at how to [use node labels to color nodes in visualization](./tutorial_2_nlabel.md). diff --git a/resources/tutorials/tutorial_2_nlabel.md b/resources/tutorials/tutorial_2_nlabel.md new file mode 100644 index 0000000..04ce61d --- /dev/null +++ b/resources/tutorials/tutorial_2_nlabel.md @@ -0,0 +1,133 @@ +# Tutorial 2: Ground truth and predicted node labels + +The nodes in a graph can be associated with a label like node type or node class. For the task of multiclass node classification, you can have ground truth node labels and node labels predicted from different models. GNNLens2 allows coloring nodes based on node labels in graph visualization and comparing node labels from different sources. + +## Data preparation + +First, we load DGL’s built-in Cora dataset and retrieve its graph structure, node labels (classes) and number of node classes. + +```python +from dgl.data import CoraGraphDataset + +dataset = CoraGraphDataset() +graph = dataset[0] +nlabels = graph.ndata['label'] +num_classes = dataset.num_classes +``` + +We dump them to a local file that GNNLens2 can read. Compared with [the previous section](./tutorial_1_graph.md), we additionally dump the node classes and the number of node classes. + +```python +from gnnlens import Writer + +# Specify the path to create a new directory for dumping data files. +writer = Writer('tutorial_nlabel') +writer.add_graph(name='Cora', graph=graph, + nlabels=nlabels, num_nlabel_types=num_classes) +``` + +Next, we train two graph convolutional networks (GCN) for node classification, `GCN_L1` (GCN with one layer) and `GCN_L2` (GCN with two layers). Once trained, we retrieve the predicted node classes and dump them to local files + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F +from dgl.nn.pytorch import GraphConv + +# Define a class for GCN +class GCN(nn.Module): + def __init__(self, + in_feats, + num_classes, + num_layers): + super(GCN, self).__init__() + self.layers = nn.ModuleList() + self.layers.append(GraphConv(in_feats, num_classes)) + for _ in range(num_layers - 1): + self.layers.append(GraphConv(num_classes, num_classes)) + + def forward(self, g, h): + for layer in self.layers: + h = layer(g, h) + return h + +# Define a function to train a GCN with the specified number of layers +# and return the predictions +def train_gcn(g, num_layers, num_classes): + features = g.ndata['feat'] + labels = g.ndata['label'] + train_mask = g.ndata['train_mask'] + model = GCN(in_feats=features.shape[1], + num_classes=num_classes, + num_layers=num_layers) + loss_func = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + + num_epochs = 200 + model.train() + for _ in range(num_epochs): + logits = model(g, features) + loss = loss_func(logits[train_mask], labels[train_mask]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + model.eval() + predictions = model(g, features) + _, predicted_classes = torch.max(predictions, dim=1) + return predicted_classes + +print("Training GCN with one layer...") +predictions_one_layer = train_gcn(graph, num_layers=1, num_classes=num_classes) +print("Training GCN with two layers...") +predictions_two_layers = train_gcn(graph, num_layers=2, num_classes=num_classes) +# Dump the predictions to local files +writer.add_model(graph_name='Cora', model_name='GCN_L1', + nlabels=predictions_one_layer) +writer.add_model(graph_name='Cora', model_name='GCN_L2', + nlabels=predictions_two_layers) +# Finish dumping +writer.close() +``` + +## Launch GNNLens2 + +To launch GNNLens2, run the following command line. + +```bash +gnnlens --logdir tutorial_nlabel +``` + +By entering `localhost:7777` in your web browser address bar, you can see the GNNLens2 interface. `7777` is the default port GNNLens2 uses. You can specify an alternative one by adding `--port xxxx` after the command line and change the address in the web browser accordingly. + +## GNNLens2 Interface + +The second selector in the control panel on the left is the nlabel selector. After you select a graph and click the nlabel selector, it will display the available node labels from different sources. The options include `ground_truth` for the ground truth node labels and the model names passed to `add_model` for the model predictions. + +

+ +

+ +You can select an option to color nodes using a source of node labels. The color legend is in the lower left corner. + +

+ +

+ +The node coloring also applies to subgraphs if you click on a node. + +

+ +

+ +You can even select multiple options and simultaneously color nodes using multiple sources of node labels. In this case, the circles representing the nodes will be replaced by glyphs. The center of the glyph is colored based on the first selected nlabel source. The outer pie chart will be colored based on the rest nlabel sources in a clockwise direction from the top. This allows a direct comparison among the ground truth node labels and the predicted node labels from various models. + +

+ +

+ +To terminate GNNLens2, use `ctrl + c`. + +## Next + +So far, we've seen how to visualize node labels. Now let us look at how to [use edge weights in visualization](./tutorial_3_eweight.md). diff --git a/resources/tutorials/tutorial_3_eweight.md b/resources/tutorials/tutorial_3_eweight.md new file mode 100644 index 0000000..dae1901 --- /dev/null +++ b/resources/tutorials/tutorial_3_eweight.md @@ -0,0 +1,160 @@ +# Tutorial 3: Edge weights and attention + +The edges in a graph can be associated with a weight, e.g., connectivity strength. Attention-based graph neural networks like graph attention networks (GATs) have been widely used and we can also view learned attention weights as edge weights. GNNLens2 allows visualizing the edge weights like attention weights in a GAT in graph visualization. + +## Data preparation + +First, we load DGL’s built-in Cora dataset and retrieve its graph structure, node labels (classes) and number of node classes. + +```python +from dgl.data import CoraGraphDataset + +dataset = CoraGraphDataset() +graph = dataset[0] +nlabels = graph.ndata['label'] +num_classes = dataset.num_classes +``` + +As the Cora graph is not weighted, we generate two types of random edge weights for demonstration purposes. The edge weights are expected to be in range `[0, 1]`. + +```python +import torch + +confidence = torch.rand(graph.num_edges(),) +strength = torch.rand(graph.num_edges(),) +``` + +We dump them to a local file that GNNLens2 can read. + +```python +from gnnlens import Writer + +# Specify the path to create a new directory for dumping data files. +writer = Writer('tutorial_eweight') +writer.add_graph(name='Cora', graph=graph, + nlabels=nlabels, num_nlabel_types=num_classes, + eweights={'confidence': confidence, 'strength': strength}) +``` + +Next, we train two graph attention networks (GAT) for node classification, `GAT_L2` (GAT with two layers) and `GAT_L3` (GAT with three layers). Once trained, we retrieve the attention weights and dump them to local files. + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F + +import dgl.function as fn +from dgl.nn import GATConv + +class GAT(nn.Module): + def __init__(self, + num_layers, + in_dim, + num_hidden, + num_classes, + heads): + super(GAT, self).__init__() + self.num_layers = num_layers + self.gat_layers = nn.ModuleList() + # input projection (no residual) + self.gat_layers.append(GATConv(in_dim, num_hidden, heads[0])) + # hidden layers + for l in range(1, num_layers - 1): + # due to multi-head, in_dim = num_hidden * number of heads in the previous layer + self.gat_layers.append(GATConv(num_hidden * heads[l-1], num_hidden, heads[l])) + # output projection + self.gat_layers.append(GATConv(num_hidden * heads[-2], num_classes, heads[-1])) + + def forward(self, g, h): + attns = [] + for l in range(self.num_layers - 1): + h, attn = self.gat_layers[l](g, h, get_attention=True) + h = h.flatten(1) + attns.append(attn) + # output projection + logits, attn = self.gat_layers[-1](g, h, get_attention=True) + logits = logits.mean(1) + attns.append(attn) + return logits, attns + +def convert_attns_to_dict(attns): + attn_dict = {} + for layer, attn_list in enumerate(attns): + attn_list = attn_list.squeeze(2).transpose(0, 1) + for head, attn in enumerate(attn_list): + head_name = "L{}_H{}".format(layer, head) + attn_dict[head_name] = attn + return attn_dict + +def train_gat(g, num_layers, heads, num_classes): + features = g.ndata['feat'] + labels = g.ndata['label'] + train_mask = g.ndata['train_mask'] + model = GAT(num_layers=num_layers, + in_dim=features.shape[1], + num_hidden=8, + num_classes=num_classes, + heads=heads) + loss_func = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + + num_epochs = 35 + model.train() + for epochs in range(num_epochs): + logits, _ = model(g, features) + loss = loss_func(logits[train_mask], labels[train_mask]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + model.eval() + predictions, attns = model(g, features) + _, predicted_classes = torch.max(predictions, dim=1) + attn_dict = convert_attns_to_dict(attns) + + return predicted_classes, attn_dict + +print("Training GAT with two layers...") +predictions_gat_two_layers, attn_dict_two_layers = train_gat( + graph, num_layers=2, heads=[2,1], num_classes=num_classes) +writer.add_model(graph_name='Cora', model_name='GAT_L2', + nlabels=predictions_gat_two_layers, eweights=attn_dict_two_layers) + +print("Training GAT with three layers...") +predictions_gat_three_layers, attn_dict_three_layers = train_gat( + graph, num_layers=3, heads=[4,2,1], num_classes=num_classes) +writer.add_model(graph_name='Cora', model_name='GAT_L3', + nlabels=predictions_gat_three_layers, eweights=attn_dict_three_layers) +# Finish dumping +writer.close() +``` + +## Launch GNNLens2 + +To launch GNNLens2, run the following command line. + +```bash +gnnlens --logdir tutorial_eweight +``` + +By entering `localhost:7777` in your web browser address bar, you can see the GNNLens2 interface. `7777` is the default port GNNLens2 uses. You can specify an alternative one by adding `--port xxxx` after the command line and change the address in the web browser accordingly. + +## GNNLens2 Interface + +The third selector in the control panel on the left is the eweight selector. After you select a graph and click the eweight selector, it will display the available edge weights from different sources. The options take the form `A/B`, where `A` is the eweight source and `B` is the eweight name. + +

+ +

+ +GNNLens2 visualizes edge weights with edge thickness. The thicker an edge is, the higher weight it has. The following figure visualizes the attention weight of the first head in the first GAT layer of GAT_L2 (GAT with two layers). + +

+ +

+ +To terminate GNNLens2, use `ctrl + c`. + +## Next + +So far, we've seen how to visualize edge weights. Now let us look at how to [visualize subgraphs and model explanations](./tutorial_4_subgraph.md). diff --git a/resources/tutorials/tutorial_4_subgraph.md b/resources/tutorials/tutorial_4_subgraph.md new file mode 100644 index 0000000..9c35f31 --- /dev/null +++ b/resources/tutorials/tutorial_4_subgraph.md @@ -0,0 +1,140 @@ +# Tutorial 4: Weighted subgraphs and explanation methods + +Node-centered subgraphs play a critical role in analyzing GNNs. The k-hop subgraph of a node fully determines the information a k-layer GNN exploits to generate its final node representation. Many GNN explanation methods provide explanations by extracting a subgraph and assigning importance weights to the nodes and edges of it. GNNLens2 allows visualizing node-centered weighted subgraphs. This is beneficial for debugging and understanding GNNs and GNN explanation methods. + +For this demonstration, we will use IntegratedGradients from [Captum](https://github.com/pytorch/captum) to explain the predictions of a graph convolutional network (GCN). Captum is a model interpretability and understanding library for PyTorch. You can install it with + +```bash +pip install captum +``` + +## Data preparation + +First, we load DGL’s built-in Cora dataset and retrieve its graph structure, node labels (classes) and number of node classes. + +```python +import dgl +from dgl.data import CoraGraphDataset + +dataset = CoraGraphDataset() +graph = dataset[0] +nlabels = graph.ndata['label'] +num_classes = dataset.num_classes +``` + +We dump them to a local file that GNNLens2 can read. + +```python +from gnnlens import Writer + +# Specify the path to create a new directory for dumping data files. +writer = Writer('tutorial_subgraph') +writer.add_graph(name='Cora', graph=graph, + nlabels=nlabels, num_nlabel_types=num_classes) +``` + +We attribute the model predictions to the input node features with IntegratedGradients. + +```python +import torch.nn as nn + +from captum.attr import IntegratedGradients +from dgl.nn import GraphConv +from functools import partial + +# Define a class for GCN +class GCN(nn.Module): + def __init__(self, + in_feats, + num_classes): + super(GCN, self).__init__() + self.conv = GraphConv(in_feats, num_classes) + + def forward(self, h, g): + # Interchange the order of g and h due to the behavior of partial + return self.conv(g, h) + +# Required by IntegratedGradients +h = graph.ndata['feat'].clone().requires_grad_(True) +model = GCN(h.shape[1], num_classes) +ig = IntegratedGradients(partial(model.forward, g=graph)) +# Attribute the predictions for node class 0 to the input features +feat_attr = ig.attribute(h, target=0, internal_batch_size=graph.num_nodes(), n_steps=50) +``` + +We compute the node importance weights from the input feature weights and normalize them. + +```python +import torch.nn.functional as F + +node_weights = feat_attr.abs().sum(dim=1) +node_weights = (node_weights - node_weights.min()) / node_weights.max() +``` + +Extract 2-hop subgraphs of node 0 and 1 and dump them to a local file that GNNLens2 can read. The subgraph name corresponds to a group of subgraphs. In a subgraph group, each node can be associated with at most one subgraph. For each subgraph, we dump its node and edge IDs in the original graph and optionally subgraph node and edge weights. + +```python +import dgl +import torch + +def extract_subgraph(g, node): + seed_nodes = [node] + sg = dgl.in_subgraph(g, seed_nodes) + src, dst = sg.edges() + seed_nodes = torch.cat([src, dst]).unique() + sg = dgl.in_subgraph(g, seed_nodes, relabel_nodes=True) + return sg + +graph.ndata['weight'] = node_weights +graph.edata['weight'] = torch.randn(graph.num_edges(),) +first_subgraph = extract_subgraph(graph, 0) +writer.add_subgraph(graph_name='Cora', subgraph_name='IntegratedGradients', node_id=0, + subgraph_nids=first_subgraph.ndata[dgl.NID], + subgraph_eids=first_subgraph.edata[dgl.EID], + subgraph_nweights=first_subgraph.ndata['weight'], + subgraph_eweights=first_subgraph.edata['weight']) + +second_subgraph = extract_subgraph(graph, 1) +writer.add_subgraph(graph_name='Cora', subgraph_name='IntegratedGradients', node_id=1, + subgraph_nids=second_subgraph.ndata[dgl.NID], + subgraph_eids=second_subgraph.edata[dgl.EID], + subgraph_nweights=second_subgraph.ndata['weight'], + subgraph_eweights=second_subgraph.edata['weight']) + +# Finish dumping +writer.close() +``` + +## Launch GNNLens2 + +To launch GNNLens2, run the following command line. + +```bash +gnnlens --logdir tutorial_subgraph +``` + +By entering `localhost:7777` in your web browser address bar, you can see the GNNLens2 interface. `7777` is the default port GNNLens2 uses. You can specify an alternative one by adding `--port xxxx` after the command line and change the address in the web browser accordingly. + +## GNNLens2 Interface + +After you select a graph and a node label option, you can click on an arbitrary node. + +

+ +

+ +You can enter node ID 1 in the Id box at the top and click on the subgraph drop-down list. The subgraph options now include `IntegratedGradients`. + +

+ +

+ +After you select `IntegratedGradients`, GNNLens2 will display the subgraph you dumped earlier associated with node 1. GNNLens2 visualizes the importance weights of the nodes by opacity for node color. + +

+ +

+ +If you enter a node ID for which you did not dump a subgraph, it won’t display anything. + +To terminate GNNLens2, use `ctrl + c`.