diff --git a/applications/graph/NodePropPrediction/GCN.py b/applications/graph/NodePropPrediction/GCN.py new file mode 100644 index 00000000000..f77649e252e --- /dev/null +++ b/applications/graph/NodePropPrediction/GCN.py @@ -0,0 +1,91 @@ +import lbann +from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule +import lbann.modules + + +class GCN(Module): + """ + Graph convolutional kernel + """ + + def __init__( + self, + num_nodes, + num_edges, + input_features, + output_features, + activation=lbann.Relu, + distconv_enabled=True, + num_groups=4, + ): + super().__init__() + self._input_dims = input_features + self._output_dims = output_features + self._num_nodes = num_nodes + self._num_edges = num_edges + + def forward(self, node_features, source_indices, target_indices): + x = lbann.Gather(node_features, target_indices, axis=0) + x = lbann.ChannelwiseFullyConnected(x, output_channel_dims=self._output_dims) + x = self._activation(x) + x = lbann.Scatter(x, source_indices, dims=self._ft_dims) + return x + + +def create_model(num_nodes, num_edges, input_features, output_features, num_layers=3): + """ + Create a GCN model + """ + # Layer graph + input_ = lbann.Input() + split_indices = [0, num_nodes * input_features] + split_indices += [split_indices[-1] + num_edges] + split_indices += [split_indices[-1] + num_edges] + split_indices += [split_indices[-1] + num_nodes] + + node_features = lbann.Reshape( + lbann.Identity(input_), dims=[num_nodes, input_features] + ) + + source_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges]) + target_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges]) + label = lbann.Reshape(lbann.Identity(input_), dims=[num_nodes]) + + x = GCN( + num_nodes, + num_edges, + input_features, + output_features, + activation=lbann.Relu, + distconv_enabled=False, + num_groups=4, + )(node_features, source_indices, target_indices) + + for _ in range(num_layers - 1): + x = GCN( + num_nodes, + num_edges, + input_features, + output_features, + activation=lbann.Relu, + distconv_enabled=False, + num_groups=4, + )(x, source_indices, target_indices) + + # Loss function + loss = lbann.CrossEntropy([x, label]) + + # Metrics + acc = lbann.CategoricalAccuracy([x, label]) + + # Callbacks + callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] + + # Construct model + return lbann.Model( + num_epochs=1, + layers=lbann.traverse_layer_graph(input_), + objective_function=loss, + metrics=[acc], + callbacks=callbacks, + ) diff --git a/applications/graph/NodePropPrediction/dataset_wrapper.py b/applications/graph/NodePropPrediction/dataset_wrapper.py new file mode 100644 index 00000000000..06f3c7ee349 --- /dev/null +++ b/applications/graph/NodePropPrediction/dataset_wrapper.py @@ -0,0 +1,27 @@ +import lbann +import os.path as osp + + +current_dir = osp.dirname(osp.realpath(__file__)) + +DATASET_CONFIG = { + "ARXIV": { + "num_nodes": 169343, + "num_edges": 1166243, + "input_features": 128, + } +} + + +def make_data_reader(dataset): + reader = lbann.reader_pb2.DataReader() + reader.name = "python" + reader.role = "train" + reader.shuffle = True + reader.percent_of_data_to_use = 1.0 + reader.python.module = f"{dataset}_dataset" + reader.python.module_dir = osp.join(current_dir, "datasets") + reader.python.sample_function = "get_train_sample" + reader.python.num_samples_function = "num_train_samples" + reader.python.sample_dims_function = "sample_dims" + return reader diff --git a/applications/graph/NodePropPrediction/datasets/arxiv_dataset.py b/applications/graph/NodePropPrediction/datasets/arxiv_dataset.py new file mode 100644 index 00000000000..e2c2d06e7fe --- /dev/null +++ b/applications/graph/NodePropPrediction/datasets/arxiv_dataset.py @@ -0,0 +1,42 @@ +import numpy as np +import os + + +# load the dataset + +data_dir = "/p/vast1/lbann/datasets/OpenGraphBenchmarks/dataset/ogbn_arxiv" + +connectivity_data = np.load(data_dir + "/edges.npy") +node_data = ( + np.load(data_dir + "/node_feats.npy") + if os.path.exists(data_dir + "/node_feats.npy") + else np.random.rand(169343, 128) # random node features +) + +labels_data = ( + np.load(data_dir + "/labels.npy") + if os.path.exists(data_dir + "/labels.npy") + else np.random.randint(0, 40, 169343) # random labels +) + +num_edges = 1166243 +num_nodes = 169343 + +assert connectivity_data.shape == (num_edges, 2) +assert node_data.shape == (num_nodes, 128) + + +def get_train_sample(index): + # Return the complete node data + return node_data.flatten() + connectivity_data.flatten() + labels_data.flatten() + + +def sample_dims(): + return ( + np.reduce(node_data.shape, lambda x, y: x * y) + + np.reduce(connectivity_data.shape, lambda x, y: x * y), + ) + + +def num_train_samples(): + return 1 diff --git a/applications/graph/NodePropPrediction/main.py b/applications/graph/NodePropPrediction/main.py new file mode 100644 index 00000000000..605614ccd1b --- /dev/null +++ b/applications/graph/NodePropPrediction/main.py @@ -0,0 +1,91 @@ +from dataset_wrapper import DATASET_CONFIG +import lbann +import lbann.contrib.launcher +import lbann.contrib.args + +import argparse + +desc = " Training a Graph Convolutional Model using LBANN" +parser = argparse.ArgumentParser(description=desc) + +lbann.contrib.args.add_scheduler_arguments(parser, "GNN") +lbann.contrib.args.add_optimizer_arguments(parser) + +parser.add_argument( + "--num-epochs", + action="store", + default=100, + type=int, + help="number of epochs (deafult: 100)", + metavar="NUM", +) + +parser.add_argument( + "--model", + action="store", + default="GCN", + type=str, + help="The type of model to use", + metavar="NAME", +) + +parser.add_argument( + "--dataset", + action="store", + default="ARXIV", + type=str, + help="The dataset to use", + metavar="NAME", +) + +parser.add_argument( + "--latent-dim", + action="store", + default=16, + type=int, + help="The latent dimension of the model", + metavar="NUM", +) + +parser.add_argument( + "--num-layers", + action="store", + default=3, + type=int, + help="The number of layers in the model", + metavar="NUM", +) + + +SUPPORTED_MODELS = ["GCN", "GAT"] +SUPPORTED_DATASETS = ["ARXIV", "PRODUCTS", "MAG240M"] + + +def main(): + args = parser.parse_args() + + kwargs = lbann.contrib.args.get_scheduler_kwargs(args) + + num_epochs = args.num_epochs + mini_batch_size = 1 + job_name = args.job_name + model_arch = args.model + dataset = args.dataset + + if model_arch not in SUPPORTED_MODELS: + raise ValueError( + f"Model {model_arch} not supported. Supported models are {SUPPORTED_MODELS}" + ) + + if dataset not in SUPPORTED_DATASETS: + raise ValueError( + f"Dataset {dataset} not supported. Supported datasets are {SUPPORTED_DATASETS}" + ) + dataset_config = DATASET_CONFIG[dataset] + num_nodes = dataset_config["num_nodes"] + num_edges = dataset_config["num_edges"] + input_features = dataset_config["input_features"] + + + optimizer = lbann.SGD(learn_rate=0.01, momentum=0.0, eps=1e-8) + lbann.contrib.launcher.run(trainer, model, data_reader, opt, **kwargs)