Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node Property Prediction Models on LBANN #2364

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions applications/graph/NodePropPrediction/GAT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import lbann
from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule
import lbann.modules
import math


def ContractHeads(lbann_graph_layer, shape):
"""
A utility function that contracts the rows of a (N, M, H) matrix to an (N, M) matrix using grouped 2D convolution.
The contration computes the average along the first dimension so the output is scaled by 1 / H.

Args:
lbann_graph_layer (layer): Graph layer tensor with shape (N, M, H)

shape ((int, int, int)): Shape of graph layer tensor

Returns:
(Layer): Contracted and rescaled output with shape (N, M)
"""
num_nodes, output_channels, num_heads = shape
weights = lbann.Weights(
initializer=lbann.ConstantInitializer(value=1 / num_heads),
optimizer=lbann.NoOptimizer(),
)
kernel_shape = (1, num_heads)
contraction = lbann.Convolution(
num_dims=2,
output_channels=num_nodes,
kernel_size=kernel_shape,
stride=1,
padding=0,
groups=num_nodes,
has_bias=False,
weights=weights,
)
output = lbann.Reshape(contraction, dims=[num_nodes, output_channels])
return output


class GAT(Module):
"""Graph Attention Network layer. For kernel details, see:

https://arxiv.org/abs/1710.10903

"""

global_count = 0

def __init__(
self,
input_channels,
output_channels,
num_nodes,
num_edges,
num_heads=1,
name=None,
):
"""Initialize GatedGraph layer
Args:
input_channels (int): The size of the input node features
output_channels (int): The output size of the node features
num_nodes (int): Number of vertices in the graph
num_edges (int): Number of edges in the graph
num_heads (int): Number of attention heads (default: 1)
name (str): Name of the layers and prefix to use for the layers.
data_layout (str): Data layout (default: data parallel)
"""
super().__init__()

# Add Name for the components for the layer
GAT.global_count += 1
self.name = name if name else "GAT_{}".format(GAT.global_count)
# Add variables
self.output_channel_size = output_channels
self.input_channel_size = input_channels
self.num_nodes = num_nodes
self.num_edges = num_edges
self.num_heads = num_heads

weights = lbann.Weights(
initializer=lbann.UniformInitializer(
min=-1 / (math.sqrt(output_channels)),
max=1 / (math.sqrt(output_channels)),
)
)
self.W_k = ChannelwiseFullyConnectedModule(
self.output_channel_size * num_heads,
bias=False,
weights=[weights],
name=f"{self.name}_nn_{1}",
)

self.a_vec = ConvolutionModule(
num_dims=1,
out_channels=self.num_nodes,
kernel_size=[2 * self.output_channel_size, 1],
groups=self.num_nodes,
bias=False,
name=f"{self.name}_nn_{2}",
)

def forward(
self, node_feature_mat, source_indices, target_indices, reduction="concat"
):
"""Call GATGraphConv
Args:
node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes, input_channels)
source_indices (Layer): Source node indices of the edges with shape (num_edges)
target_indices (Layer): Target node indices of the edges with shape (num_edges)
reduction (string: [concat| average]): The type of reductions to use for multiple heads
Returns:
(Layer) : The output after kernel ops. The shape of the layer is
(num_nodes, num_heads * num_output_channels) if reduction is "concat"
(num_nodes, num_output_channels) if reduction is "average"
"""
# (N x [self.output_channel * self.num_heads])
transform_node_features = self.W_nn(
node_feature_mat, name=f"{self.name}_transform"
)
# (E x [self.output_channel * self.num_heads])
e_i = lbann.Gather(transform_node_features, source_indices, axis=0)
e_j = lbann.Gather(transform_node_features, target_indices, axis=0)
# (E x self.output_channel x self.num_heads)
e_i = lbann.Reshape(
e_i, dims=[self.num_edges, self.output_channel_size, self.num_heads]
)
e_j = lbann.Reshape(
e_j, dims=[self.num_edges, self.output_channel_size, self.num_heads]
)
# (E x 2 * self.output_channel x self.num_heads)
messages = lbann.Concatenation([e_i, e_j], axis=1)
# (E x self.num_heads)
m_ij = lbann.Reshape(
self.a_vec(messages), dims=[self.num_edges, self.num_heads]
)
m_ij = lbann.ExpOperator(lbann.LeakyRelu(m_ij, negative_slope=0.02))
# (N x self.num_heads)
contraction = lbann.Scatter(m_ij, target_indices, axis=0)
# (N x 1 x self.num_heads)
broadcast = lbann.Reshape(contraction, dims=[self.num_nodes, 1, self.num_heads])
# (E x 1 x self.num_heads)
broadcast = lbann.Gather(broadcast, target_indices, axis=1)
# (E x self.output_channel_size x self.num_heads)
broadcast = lbann.Tessellate(
broadcast, dims=[self.num_edges, self.output_channel_size, self.num_heads]
)
# (E x self.output_channel_size x self.num_heads)
normalize = lbann.Scatter(broadcast, source_indices, axis=0)
alpha_ij = lbann.DivideOperator(m_ij, normalize)

h_ij = lbann.MultiplyOperator(alpha_ij, e_j)

h_i = lbann.Scatter(h_ij, source_indices)

if reduction.lower() == "concat":
node_feature_mat = lbann.Reshape(h_i)
elif reduction.lower() == "average":
node_feature_mat = ContractHeads(
h_i, (self.num_nodes, self.output_channel_size, self.num_heads)
)
else:
raise ValueError("Expected reduction arguments are: concat or average")

return node_feature_mat
91 changes: 91 additions & 0 deletions applications/graph/NodePropPrediction/GCN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import lbann
from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule
import lbann.modules


class GCN(Module):
"""
Graph convolutional kernel
"""

def __init__(
self,
num_nodes,
num_edges,
input_features,
output_features,
activation=lbann.Relu,
distconv_enabled=True,
num_groups=4,
):
super().__init__()
self._input_dims = input_features
self._output_dims = output_features
self._num_nodes = num_nodes
self._num_edges = num_edges

def forward(self, node_features, source_indices, target_indices):
x = lbann.Gather(node_features, target_indices, axis=0)
x = lbann.ChannelwiseFullyConnected(x, output_channel_dims=self._output_dims)
x = self._activation(x)
x = lbann.Scatter(x, source_indices, dims=self._ft_dims)
return x


def create_model(num_nodes, num_edges, input_features, output_features, num_layers=3):
"""
Create a GCN model
"""
# Layer graph
input_ = lbann.Input()
split_indices = [0, num_nodes * input_features]
split_indices += [split_indices[-1] + num_edges]
split_indices += [split_indices[-1] + num_edges]
split_indices += [split_indices[-1] + num_nodes]

node_features = lbann.Reshape(
lbann.Identity(input_), dims=[num_nodes, input_features]
)

source_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges])
target_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges])
label = lbann.Reshape(lbann.Identity(input_), dims=[num_nodes])

x = GCN(
num_nodes,
num_edges,
input_features,
output_features,
activation=lbann.Relu,
distconv_enabled=False,
num_groups=4,
)(node_features, source_indices, target_indices)

for _ in range(num_layers - 1):
x = GCN(
num_nodes,
num_edges,
input_features,
output_features,
activation=lbann.Relu,
distconv_enabled=False,
num_groups=4,
)(x, source_indices, target_indices)

# Loss function
loss = lbann.CrossEntropy([x, label])

# Metrics
acc = lbann.CategoricalAccuracy([x, label])

# Callbacks
callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

# Construct model
return lbann.Model(
num_epochs=1,
layers=lbann.traverse_layer_graph(input_),
objective_function=loss,
metrics=[acc],
callbacks=callbacks,
)
27 changes: 27 additions & 0 deletions applications/graph/NodePropPrediction/dataset_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import lbann
import os.path as osp


current_dir = osp.dirname(osp.realpath(__file__))

DATASET_CONFIG = {
"ARXIV": {
"num_nodes": 169343,
"num_edges": 1166243,
"input_features": 128,
}
}


def make_data_reader(dataset):
reader = lbann.reader_pb2.DataReader()
reader.name = "python"
reader.role = "train"
reader.shuffle = True
reader.percent_of_data_to_use = 1.0
reader.python.module = f"{dataset}_dataset"
reader.python.module_dir = osp.join(current_dir, "datasets")
reader.python.sample_function = "get_train_sample"
reader.python.num_samples_function = "num_train_samples"
reader.python.sample_dims_function = "sample_dims"
return reader
42 changes: 42 additions & 0 deletions applications/graph/NodePropPrediction/datasets/arxiv_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import os


# load the dataset

data_dir = "/p/vast1/lbann/datasets/OpenGraphBenchmarks/dataset/ogbn_arxiv"

connectivity_data = np.load(data_dir + "/edges.npy")
node_data = (
np.load(data_dir + "/node_feats.npy")
if os.path.exists(data_dir + "/node_feats.npy")
else np.random.rand(169343, 128) # random node features
)

labels_data = (
np.load(data_dir + "/labels.npy")
if os.path.exists(data_dir + "/labels.npy")
else np.random.randint(0, 40, 169343) # random labels
)

num_edges = 1166243
num_nodes = 169343

assert connectivity_data.shape == (num_edges, 2)
assert node_data.shape == (num_nodes, 128)


def get_train_sample(index):
# Return the complete node data
return node_data.flatten() + connectivity_data.flatten() + labels_data.flatten()


def sample_dims():
return (
np.reduce(node_data.shape, lambda x, y: x * y)
+ np.reduce(connectivity_data.shape, lambda x, y: x * y),
)


def num_train_samples():
return 1
Loading
Loading