Skip to content

Commit

Permalink
Added trainer files
Browse files Browse the repository at this point in the history
- Add GCN implementation with loops
- Add dataset wrapper for arxiv
  - To do: Add MAG dataset
- rebased
  • Loading branch information
szaman19 committed Dec 20, 2023
1 parent a879802 commit f6d61c7
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 0 deletions.
91 changes: 91 additions & 0 deletions applications/graph/NodePropPrediction/GCN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import lbann
from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule
import lbann.modules


class GCN(Module):
"""
Graph convolutional kernel
"""

def __init__(
self,
num_nodes,
num_edges,
input_features,
output_features,
activation=lbann.Relu,
distconv_enabled=True,
num_groups=4,
):
super().__init__()
self._input_dims = input_features
self._output_dims = output_features
self._num_nodes = num_nodes
self._num_edges = num_edges

def forward(self, node_features, source_indices, target_indices):
x = lbann.Gather(node_features, target_indices, axis=0)
x = lbann.ChannelwiseFullyConnected(x, output_channel_dims=self._output_dims)
x = self._activation(x)
x = lbann.Scatter(x, source_indices, dims=self._ft_dims)
return x


def create_model(num_nodes, num_edges, input_features, output_features, num_layers=3):
"""
Create a GCN model
"""
# Layer graph
input_ = lbann.Input()
split_indices = [0, num_nodes * input_features]
split_indices += [split_indices[-1] + num_edges]
split_indices += [split_indices[-1] + num_edges]
split_indices += [split_indices[-1] + num_nodes]

node_features = lbann.Reshape(
lbann.Identity(input_), dims=[num_nodes, input_features]
)

source_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges])
target_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges])
label = lbann.Reshape(lbann.Identity(input_), dims=[num_nodes])

x = GCN(
num_nodes,
num_edges,
input_features,
output_features,
activation=lbann.Relu,
distconv_enabled=False,
num_groups=4,
)(node_features, source_indices, target_indices)

for _ in range(num_layers - 1):
x = GCN(
num_nodes,
num_edges,
input_features,
output_features,
activation=lbann.Relu,
distconv_enabled=False,
num_groups=4,
)(x, source_indices, target_indices)

# Loss function
loss = lbann.CrossEntropy([x, label])

# Metrics
acc = lbann.CategoricalAccuracy([x, label])

# Callbacks
callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

# Construct model
return lbann.Model(
num_epochs=1,
layers=lbann.traverse_layer_graph(input_),
objective_function=loss,
metrics=[acc],
callbacks=callbacks,
)
27 changes: 27 additions & 0 deletions applications/graph/NodePropPrediction/dataset_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import lbann
import os.path as osp


current_dir = osp.dirname(osp.realpath(__file__))

DATASET_CONFIG = {
"ARXIV": {
"num_nodes": 169343,
"num_edges": 1166243,
"input_features": 128,
}
}


def make_data_reader(dataset):
reader = lbann.reader_pb2.DataReader()
reader.name = "python"
reader.role = "train"
reader.shuffle = True
reader.percent_of_data_to_use = 1.0
reader.python.module = f"{dataset}_dataset"
reader.python.module_dir = osp.join(current_dir, "datasets")
reader.python.sample_function = "get_train_sample"
reader.python.num_samples_function = "num_train_samples"
reader.python.sample_dims_function = "sample_dims"
return reader
42 changes: 42 additions & 0 deletions applications/graph/NodePropPrediction/datasets/arxiv_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import os


# load the dataset

data_dir = "/p/vast1/lbann/datasets/OpenGraphBenchmarks/dataset/ogbn_arxiv"

connectivity_data = np.load(data_dir + "/edges.npy")
node_data = (
np.load(data_dir + "/node_feats.npy")
if os.path.exists(data_dir + "/node_feats.npy")
else np.random.rand(169343, 128) # random node features
)

labels_data = (
np.load(data_dir + "/labels.npy")
if os.path.exists(data_dir + "/labels.npy")
else np.random.randint(0, 40, 169343) # random labels
)

num_edges = 1166243
num_nodes = 169343

assert connectivity_data.shape == (num_edges, 2)
assert node_data.shape == (num_nodes, 128)


def get_train_sample(index):
# Return the complete node data
return node_data.flatten() + connectivity_data.flatten() + labels_data.flatten()


def sample_dims():
return (
np.reduce(node_data.shape, lambda x, y: x * y)
+ np.reduce(connectivity_data.shape, lambda x, y: x * y),
)


def num_train_samples():
return 1
91 changes: 91 additions & 0 deletions applications/graph/NodePropPrediction/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from dataset_wrapper import DATASET_CONFIG
import lbann
import lbann.contrib.launcher
import lbann.contrib.args

import argparse

desc = " Training a Graph Convolutional Model using LBANN"
parser = argparse.ArgumentParser(description=desc)

lbann.contrib.args.add_scheduler_arguments(parser, "GNN")
lbann.contrib.args.add_optimizer_arguments(parser)

parser.add_argument(
"--num-epochs",
action="store",
default=100,
type=int,
help="number of epochs (deafult: 100)",
metavar="NUM",
)

parser.add_argument(
"--model",
action="store",
default="GCN",
type=str,
help="The type of model to use",
metavar="NAME",
)

parser.add_argument(
"--dataset",
action="store",
default="ARXIV",
type=str,
help="The dataset to use",
metavar="NAME",
)

parser.add_argument(
"--latent-dim",
action="store",
default=16,
type=int,
help="The latent dimension of the model",
metavar="NUM",
)

parser.add_argument(
"--num-layers",
action="store",
default=3,
type=int,
help="The number of layers in the model",
metavar="NUM",
)


SUPPORTED_MODELS = ["GCN", "GAT"]
SUPPORTED_DATASETS = ["ARXIV", "PRODUCTS", "MAG240M"]


def main():
args = parser.parse_args()

kwargs = lbann.contrib.args.get_scheduler_kwargs(args)

num_epochs = args.num_epochs
mini_batch_size = 1
job_name = args.job_name
model_arch = args.model
dataset = args.dataset

if model_arch not in SUPPORTED_MODELS:
raise ValueError(
f"Model {model_arch} not supported. Supported models are {SUPPORTED_MODELS}"
)

if dataset not in SUPPORTED_DATASETS:
raise ValueError(
f"Dataset {dataset} not supported. Supported datasets are {SUPPORTED_DATASETS}"
)
dataset_config = DATASET_CONFIG[dataset]
num_nodes = dataset_config["num_nodes"]
num_edges = dataset_config["num_edges"]
input_features = dataset_config["input_features"]


optimizer = lbann.SGD(learn_rate=0.01, momentum=0.0, eps=1e-8)
lbann.contrib.launcher.run(trainer, model, data_reader, opt, **kwargs)

0 comments on commit f6d61c7

Please sign in to comment.