-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Add GCN implementation with loops - Add dataset wrapper for arxiv - To do: Add MAG dataset - rebased
- Loading branch information
Showing
4 changed files
with
251 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import lbann | ||
from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule | ||
import lbann.modules | ||
|
||
|
||
class GCN(Module): | ||
""" | ||
Graph convolutional kernel | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_nodes, | ||
num_edges, | ||
input_features, | ||
output_features, | ||
activation=lbann.Relu, | ||
distconv_enabled=True, | ||
num_groups=4, | ||
): | ||
super().__init__() | ||
self._input_dims = input_features | ||
self._output_dims = output_features | ||
self._num_nodes = num_nodes | ||
self._num_edges = num_edges | ||
|
||
def forward(self, node_features, source_indices, target_indices): | ||
x = lbann.Gather(node_features, target_indices, axis=0) | ||
x = lbann.ChannelwiseFullyConnected(x, output_channel_dims=self._output_dims) | ||
x = self._activation(x) | ||
x = lbann.Scatter(x, source_indices, dims=self._ft_dims) | ||
return x | ||
|
||
|
||
def create_model(num_nodes, num_edges, input_features, output_features, num_layers=3): | ||
""" | ||
Create a GCN model | ||
""" | ||
# Layer graph | ||
input_ = lbann.Input() | ||
split_indices = [0, num_nodes * input_features] | ||
split_indices += [split_indices[-1] + num_edges] | ||
split_indices += [split_indices[-1] + num_edges] | ||
split_indices += [split_indices[-1] + num_nodes] | ||
|
||
node_features = lbann.Reshape( | ||
lbann.Identity(input_), dims=[num_nodes, input_features] | ||
) | ||
|
||
source_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges]) | ||
target_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges]) | ||
label = lbann.Reshape(lbann.Identity(input_), dims=[num_nodes]) | ||
|
||
x = GCN( | ||
num_nodes, | ||
num_edges, | ||
input_features, | ||
output_features, | ||
activation=lbann.Relu, | ||
distconv_enabled=False, | ||
num_groups=4, | ||
)(node_features, source_indices, target_indices) | ||
|
||
for _ in range(num_layers - 1): | ||
x = GCN( | ||
num_nodes, | ||
num_edges, | ||
input_features, | ||
output_features, | ||
activation=lbann.Relu, | ||
distconv_enabled=False, | ||
num_groups=4, | ||
)(x, source_indices, target_indices) | ||
|
||
# Loss function | ||
loss = lbann.CrossEntropy([x, label]) | ||
|
||
# Metrics | ||
acc = lbann.CategoricalAccuracy([x, label]) | ||
|
||
# Callbacks | ||
callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] | ||
|
||
# Construct model | ||
return lbann.Model( | ||
num_epochs=1, | ||
layers=lbann.traverse_layer_graph(input_), | ||
objective_function=loss, | ||
metrics=[acc], | ||
callbacks=callbacks, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import lbann | ||
import os.path as osp | ||
|
||
|
||
current_dir = osp.dirname(osp.realpath(__file__)) | ||
|
||
DATASET_CONFIG = { | ||
"ARXIV": { | ||
"num_nodes": 169343, | ||
"num_edges": 1166243, | ||
"input_features": 128, | ||
} | ||
} | ||
|
||
|
||
def make_data_reader(dataset): | ||
reader = lbann.reader_pb2.DataReader() | ||
reader.name = "python" | ||
reader.role = "train" | ||
reader.shuffle = True | ||
reader.percent_of_data_to_use = 1.0 | ||
reader.python.module = f"{dataset}_dataset" | ||
reader.python.module_dir = osp.join(current_dir, "datasets") | ||
reader.python.sample_function = "get_train_sample" | ||
reader.python.num_samples_function = "num_train_samples" | ||
reader.python.sample_dims_function = "sample_dims" | ||
return reader |
42 changes: 42 additions & 0 deletions
42
applications/graph/NodePropPrediction/datasets/arxiv_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import numpy as np | ||
import os | ||
|
||
|
||
# load the dataset | ||
|
||
data_dir = "/p/vast1/lbann/datasets/OpenGraphBenchmarks/dataset/ogbn_arxiv" | ||
|
||
connectivity_data = np.load(data_dir + "/edges.npy") | ||
node_data = ( | ||
np.load(data_dir + "/node_feats.npy") | ||
if os.path.exists(data_dir + "/node_feats.npy") | ||
else np.random.rand(169343, 128) # random node features | ||
) | ||
|
||
labels_data = ( | ||
np.load(data_dir + "/labels.npy") | ||
if os.path.exists(data_dir + "/labels.npy") | ||
else np.random.randint(0, 40, 169343) # random labels | ||
) | ||
|
||
num_edges = 1166243 | ||
num_nodes = 169343 | ||
|
||
assert connectivity_data.shape == (num_edges, 2) | ||
assert node_data.shape == (num_nodes, 128) | ||
|
||
|
||
def get_train_sample(index): | ||
# Return the complete node data | ||
return node_data.flatten() + connectivity_data.flatten() + labels_data.flatten() | ||
|
||
|
||
def sample_dims(): | ||
return ( | ||
np.reduce(node_data.shape, lambda x, y: x * y) | ||
+ np.reduce(connectivity_data.shape, lambda x, y: x * y), | ||
) | ||
|
||
|
||
def num_train_samples(): | ||
return 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from dataset_wrapper import DATASET_CONFIG | ||
import lbann | ||
import lbann.contrib.launcher | ||
import lbann.contrib.args | ||
|
||
import argparse | ||
|
||
desc = " Training a Graph Convolutional Model using LBANN" | ||
parser = argparse.ArgumentParser(description=desc) | ||
|
||
lbann.contrib.args.add_scheduler_arguments(parser, "GNN") | ||
lbann.contrib.args.add_optimizer_arguments(parser) | ||
|
||
parser.add_argument( | ||
"--num-epochs", | ||
action="store", | ||
default=100, | ||
type=int, | ||
help="number of epochs (deafult: 100)", | ||
metavar="NUM", | ||
) | ||
|
||
parser.add_argument( | ||
"--model", | ||
action="store", | ||
default="GCN", | ||
type=str, | ||
help="The type of model to use", | ||
metavar="NAME", | ||
) | ||
|
||
parser.add_argument( | ||
"--dataset", | ||
action="store", | ||
default="ARXIV", | ||
type=str, | ||
help="The dataset to use", | ||
metavar="NAME", | ||
) | ||
|
||
parser.add_argument( | ||
"--latent-dim", | ||
action="store", | ||
default=16, | ||
type=int, | ||
help="The latent dimension of the model", | ||
metavar="NUM", | ||
) | ||
|
||
parser.add_argument( | ||
"--num-layers", | ||
action="store", | ||
default=3, | ||
type=int, | ||
help="The number of layers in the model", | ||
metavar="NUM", | ||
) | ||
|
||
|
||
SUPPORTED_MODELS = ["GCN", "GAT"] | ||
SUPPORTED_DATASETS = ["ARXIV", "PRODUCTS", "MAG240M"] | ||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
|
||
kwargs = lbann.contrib.args.get_scheduler_kwargs(args) | ||
|
||
num_epochs = args.num_epochs | ||
mini_batch_size = 1 | ||
job_name = args.job_name | ||
model_arch = args.model | ||
dataset = args.dataset | ||
|
||
if model_arch not in SUPPORTED_MODELS: | ||
raise ValueError( | ||
f"Model {model_arch} not supported. Supported models are {SUPPORTED_MODELS}" | ||
) | ||
|
||
if dataset not in SUPPORTED_DATASETS: | ||
raise ValueError( | ||
f"Dataset {dataset} not supported. Supported datasets are {SUPPORTED_DATASETS}" | ||
) | ||
dataset_config = DATASET_CONFIG[dataset] | ||
num_nodes = dataset_config["num_nodes"] | ||
num_edges = dataset_config["num_edges"] | ||
input_features = dataset_config["input_features"] | ||
|
||
|
||
optimizer = lbann.SGD(learn_rate=0.01, momentum=0.0, eps=1e-8) | ||
lbann.contrib.launcher.run(trainer, model, data_reader, opt, **kwargs) |