diff --git a/applications/graph/NodePropPrediction/GAT.py b/applications/graph/NodePropPrediction/GAT.py new file mode 100644 index 00000000000..2faf4a94b58 --- /dev/null +++ b/applications/graph/NodePropPrediction/GAT.py @@ -0,0 +1,164 @@ +import lbann +from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule +import lbann.modules +import math + + +def ContractHeads(lbann_graph_layer, shape): + """ + A utility function that contracts the rows of a (N, M, H) matrix to an (N, M) matrix using grouped 2D convolution. + The contration computes the average along the first dimension so the output is scaled by 1 / H. + + Args: + lbann_graph_layer (layer): Graph layer tensor with shape (N, M, H) + + shape ((int, int, int)): Shape of graph layer tensor + + Returns: + (Layer): Contracted and rescaled output with shape (N, M) + """ + num_nodes, output_channels, num_heads = shape + weights = lbann.Weights( + initializer=lbann.ConstantInitializer(value=1 / num_heads), + optimizer=lbann.NoOptimizer(), + ) + kernel_shape = (1, num_heads) + contraction = lbann.Convolution( + num_dims=2, + output_channels=num_nodes, + kernel_size=kernel_shape, + stride=1, + padding=0, + groups=num_nodes, + has_bias=False, + weights=weights, + ) + output = lbann.Reshape(contraction, dims=[num_nodes, output_channels]) + return output + + +class GAT(Module): + """Graph Attention Network layer. For kernel details, see: + + https://arxiv.org/abs/1710.10903 + + """ + + global_count = 0 + + def __init__( + self, + input_channels, + output_channels, + num_nodes, + num_edges, + num_heads=1, + name=None, + ): + """Initialize GatedGraph layer + Args: + input_channels (int): The size of the input node features + output_channels (int): The output size of the node features + num_nodes (int): Number of vertices in the graph + num_edges (int): Number of edges in the graph + num_heads (int): Number of attention heads (default: 1) + name (str): Name of the layers and prefix to use for the layers. + data_layout (str): Data layout (default: data parallel) + """ + super().__init__() + + # Add Name for the components for the layer + GAT.global_count += 1 + self.name = name if name else "GAT_{}".format(GAT.global_count) + # Add variables + self.output_channel_size = output_channels + self.input_channel_size = input_channels + self.num_nodes = num_nodes + self.num_edges = num_edges + self.num_heads = num_heads + + weights = lbann.Weights( + initializer=lbann.UniformInitializer( + min=-1 / (math.sqrt(output_channels)), + max=1 / (math.sqrt(output_channels)), + ) + ) + self.W_k = ChannelwiseFullyConnectedModule( + self.output_channel_size * num_heads, + bias=False, + weights=[weights], + name=f"{self.name}_nn_{1}", + ) + + self.a_vec = ConvolutionModule( + num_dims=1, + out_channels=self.num_nodes, + kernel_size=[2 * self.output_channel_size, 1], + groups=self.num_nodes, + bias=False, + name=f"{self.name}_nn_{2}", + ) + + def forward( + self, node_feature_mat, source_indices, target_indices, reduction="concat" + ): + """Call GATGraphConv + Args: + node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes, input_channels) + source_indices (Layer): Source node indices of the edges with shape (num_edges) + target_indices (Layer): Target node indices of the edges with shape (num_edges) + reduction (string: [concat| average]): The type of reductions to use for multiple heads + Returns: + (Layer) : The output after kernel ops. The shape of the layer is + (num_nodes, num_heads * num_output_channels) if reduction is "concat" + (num_nodes, num_output_channels) if reduction is "average" + """ + # (N x [self.output_channel * self.num_heads]) + transform_node_features = self.W_nn( + node_feature_mat, name=f"{self.name}_transform" + ) + # (E x [self.output_channel * self.num_heads]) + e_i = lbann.Gather(transform_node_features, source_indices, axis=0) + e_j = lbann.Gather(transform_node_features, target_indices, axis=0) + # (E x self.output_channel x self.num_heads) + e_i = lbann.Reshape( + e_i, dims=[self.num_edges, self.output_channel_size, self.num_heads] + ) + e_j = lbann.Reshape( + e_j, dims=[self.num_edges, self.output_channel_size, self.num_heads] + ) + # (E x 2 * self.output_channel x self.num_heads) + messages = lbann.Concatenation([e_i, e_j], axis=1) + # (E x self.num_heads) + m_ij = lbann.Reshape( + self.a_vec(messages), dims=[self.num_edges, self.num_heads] + ) + m_ij = lbann.ExpOperator(lbann.LeakyRelu(m_ij, negative_slope=0.02)) + # (N x self.num_heads) + contraction = lbann.Scatter(m_ij, target_indices, axis=0) + # (N x 1 x self.num_heads) + broadcast = lbann.Reshape(contraction, dims=[self.num_nodes, 1, self.num_heads]) + # (E x 1 x self.num_heads) + broadcast = lbann.Gather(broadcast, target_indices, axis=1) + # (E x self.output_channel_size x self.num_heads) + broadcast = lbann.Tessellate( + broadcast, dims=[self.num_edges, self.output_channel_size, self.num_heads] + ) + # (E x self.output_channel_size x self.num_heads) + normalize = lbann.Scatter(broadcast, source_indices, axis=0) + alpha_ij = lbann.DivideOperator(m_ij, normalize) + + h_ij = lbann.MultiplyOperator(alpha_ij, e_j) + + h_i = lbann.Scatter(h_ij, source_indices) + + if reduction.lower() == "concat": + node_feature_mat = lbann.Reshape(h_i) + elif reduction.lower() == "average": + node_feature_mat = ContractHeads( + h_i, (self.num_nodes, self.output_channel_size, self.num_heads) + ) + else: + raise ValueError("Expected reduction arguments are: concat or average") + + return node_feature_mat diff --git a/applications/graph/NodePropPrediction/GCN.py b/applications/graph/NodePropPrediction/GCN.py new file mode 100644 index 00000000000..f77649e252e --- /dev/null +++ b/applications/graph/NodePropPrediction/GCN.py @@ -0,0 +1,91 @@ +import lbann +from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule +import lbann.modules + + +class GCN(Module): + """ + Graph convolutional kernel + """ + + def __init__( + self, + num_nodes, + num_edges, + input_features, + output_features, + activation=lbann.Relu, + distconv_enabled=True, + num_groups=4, + ): + super().__init__() + self._input_dims = input_features + self._output_dims = output_features + self._num_nodes = num_nodes + self._num_edges = num_edges + + def forward(self, node_features, source_indices, target_indices): + x = lbann.Gather(node_features, target_indices, axis=0) + x = lbann.ChannelwiseFullyConnected(x, output_channel_dims=self._output_dims) + x = self._activation(x) + x = lbann.Scatter(x, source_indices, dims=self._ft_dims) + return x + + +def create_model(num_nodes, num_edges, input_features, output_features, num_layers=3): + """ + Create a GCN model + """ + # Layer graph + input_ = lbann.Input() + split_indices = [0, num_nodes * input_features] + split_indices += [split_indices[-1] + num_edges] + split_indices += [split_indices[-1] + num_edges] + split_indices += [split_indices[-1] + num_nodes] + + node_features = lbann.Reshape( + lbann.Identity(input_), dims=[num_nodes, input_features] + ) + + source_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges]) + target_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges]) + label = lbann.Reshape(lbann.Identity(input_), dims=[num_nodes]) + + x = GCN( + num_nodes, + num_edges, + input_features, + output_features, + activation=lbann.Relu, + distconv_enabled=False, + num_groups=4, + )(node_features, source_indices, target_indices) + + for _ in range(num_layers - 1): + x = GCN( + num_nodes, + num_edges, + input_features, + output_features, + activation=lbann.Relu, + distconv_enabled=False, + num_groups=4, + )(x, source_indices, target_indices) + + # Loss function + loss = lbann.CrossEntropy([x, label]) + + # Metrics + acc = lbann.CategoricalAccuracy([x, label]) + + # Callbacks + callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] + + # Construct model + return lbann.Model( + num_epochs=1, + layers=lbann.traverse_layer_graph(input_), + objective_function=loss, + metrics=[acc], + callbacks=callbacks, + ) diff --git a/applications/graph/NodePropPrediction/dataset_wrapper.py b/applications/graph/NodePropPrediction/dataset_wrapper.py new file mode 100644 index 00000000000..06f3c7ee349 --- /dev/null +++ b/applications/graph/NodePropPrediction/dataset_wrapper.py @@ -0,0 +1,27 @@ +import lbann +import os.path as osp + + +current_dir = osp.dirname(osp.realpath(__file__)) + +DATASET_CONFIG = { + "ARXIV": { + "num_nodes": 169343, + "num_edges": 1166243, + "input_features": 128, + } +} + + +def make_data_reader(dataset): + reader = lbann.reader_pb2.DataReader() + reader.name = "python" + reader.role = "train" + reader.shuffle = True + reader.percent_of_data_to_use = 1.0 + reader.python.module = f"{dataset}_dataset" + reader.python.module_dir = osp.join(current_dir, "datasets") + reader.python.sample_function = "get_train_sample" + reader.python.num_samples_function = "num_train_samples" + reader.python.sample_dims_function = "sample_dims" + return reader diff --git a/applications/graph/NodePropPrediction/datasets/arxiv_dataset.py b/applications/graph/NodePropPrediction/datasets/arxiv_dataset.py new file mode 100644 index 00000000000..e2c2d06e7fe --- /dev/null +++ b/applications/graph/NodePropPrediction/datasets/arxiv_dataset.py @@ -0,0 +1,42 @@ +import numpy as np +import os + + +# load the dataset + +data_dir = "/p/vast1/lbann/datasets/OpenGraphBenchmarks/dataset/ogbn_arxiv" + +connectivity_data = np.load(data_dir + "/edges.npy") +node_data = ( + np.load(data_dir + "/node_feats.npy") + if os.path.exists(data_dir + "/node_feats.npy") + else np.random.rand(169343, 128) # random node features +) + +labels_data = ( + np.load(data_dir + "/labels.npy") + if os.path.exists(data_dir + "/labels.npy") + else np.random.randint(0, 40, 169343) # random labels +) + +num_edges = 1166243 +num_nodes = 169343 + +assert connectivity_data.shape == (num_edges, 2) +assert node_data.shape == (num_nodes, 128) + + +def get_train_sample(index): + # Return the complete node data + return node_data.flatten() + connectivity_data.flatten() + labels_data.flatten() + + +def sample_dims(): + return ( + np.reduce(node_data.shape, lambda x, y: x * y) + + np.reduce(connectivity_data.shape, lambda x, y: x * y), + ) + + +def num_train_samples(): + return 1 diff --git a/applications/graph/NodePropPrediction/main.py b/applications/graph/NodePropPrediction/main.py new file mode 100644 index 00000000000..605614ccd1b --- /dev/null +++ b/applications/graph/NodePropPrediction/main.py @@ -0,0 +1,91 @@ +from dataset_wrapper import DATASET_CONFIG +import lbann +import lbann.contrib.launcher +import lbann.contrib.args + +import argparse + +desc = " Training a Graph Convolutional Model using LBANN" +parser = argparse.ArgumentParser(description=desc) + +lbann.contrib.args.add_scheduler_arguments(parser, "GNN") +lbann.contrib.args.add_optimizer_arguments(parser) + +parser.add_argument( + "--num-epochs", + action="store", + default=100, + type=int, + help="number of epochs (deafult: 100)", + metavar="NUM", +) + +parser.add_argument( + "--model", + action="store", + default="GCN", + type=str, + help="The type of model to use", + metavar="NAME", +) + +parser.add_argument( + "--dataset", + action="store", + default="ARXIV", + type=str, + help="The dataset to use", + metavar="NAME", +) + +parser.add_argument( + "--latent-dim", + action="store", + default=16, + type=int, + help="The latent dimension of the model", + metavar="NUM", +) + +parser.add_argument( + "--num-layers", + action="store", + default=3, + type=int, + help="The number of layers in the model", + metavar="NUM", +) + + +SUPPORTED_MODELS = ["GCN", "GAT"] +SUPPORTED_DATASETS = ["ARXIV", "PRODUCTS", "MAG240M"] + + +def main(): + args = parser.parse_args() + + kwargs = lbann.contrib.args.get_scheduler_kwargs(args) + + num_epochs = args.num_epochs + mini_batch_size = 1 + job_name = args.job_name + model_arch = args.model + dataset = args.dataset + + if model_arch not in SUPPORTED_MODELS: + raise ValueError( + f"Model {model_arch} not supported. Supported models are {SUPPORTED_MODELS}" + ) + + if dataset not in SUPPORTED_DATASETS: + raise ValueError( + f"Dataset {dataset} not supported. Supported datasets are {SUPPORTED_DATASETS}" + ) + dataset_config = DATASET_CONFIG[dataset] + num_nodes = dataset_config["num_nodes"] + num_edges = dataset_config["num_edges"] + input_features = dataset_config["input_features"] + + + optimizer = lbann.SGD(learn_rate=0.01, momentum=0.0, eps=1e-8) + lbann.contrib.launcher.run(trainer, model, data_reader, opt, **kwargs)