Skip to content

Commit

Permalink
Added GAT implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
szaman19 committed Oct 26, 2023
1 parent 02b7f79 commit f5e43aa
Showing 1 changed file with 164 additions and 0 deletions.
164 changes: 164 additions & 0 deletions applications/graph/NodePropPrediction/GAT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import lbann
from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule
import lbann.modules
import math


def ContractHeads(lbann_graph_layer, shape):
"""
A utility function that contracts the rows of a (N, M, H) matrix to an (N, M) matrix using grouped 2D convolution.
The contration computes the average along the first dimension so the output is scaled by 1 / H.
Args:
lbann_graph_layer (layer): Graph layer tensor with shape (N, M, H)
shape ((int, int, int)): Shape of graph layer tensor
Returns:
(Layer): Contracted and rescaled output with shape (N, M)
"""
num_nodes, output_channels, num_heads = shape
weights = lbann.Weights(
initializer=lbann.ConstantInitializer(value=1 / num_heads),
optimizer=lbann.NoOptimizer(),
)
kernel_shape = (1, num_heads)
contraction = lbann.Convolution(
num_dims=2,
output_channels=num_nodes,
kernel_size=kernel_shape,
stride=1,
padding=0,
groups=num_nodes,
has_bias=False,
weights=weights,
)
output = lbann.Reshape(contraction, dims=[num_nodes, output_channels])
return output


class GAT(Module):
"""Graph Attention Network layer. For kernel details, see:
https://arxiv.org/abs/1710.10903
"""

global_count = 0

def __init__(
self,
input_channels,
output_channels,
num_nodes,
num_edges,
num_heads=1,
name=None,
):
"""Initialize GatedGraph layer
Args:
input_channels (int): The size of the input node features
output_channels (int): The output size of the node features
num_nodes (int): Number of vertices in the graph
num_edges (int): Number of edges in the graph
num_heads (int): Number of attention heads (default: 1)
name (str): Name of the layers and prefix to use for the layers.
data_layout (str): Data layout (default: data parallel)
"""
super().__init__()

# Add Name for the components for the layer
GAT.global_count += 1
self.name = name if name else "GAT_{}".format(GAT.global_count)
# Add variables
self.output_channel_size = output_channels
self.input_channel_size = input_channels
self.num_nodes = num_nodes
self.num_edges = num_edges
self.num_heads = num_heads

weights = lbann.Weights(
initializer=lbann.UniformInitializer(
min=-1 / (math.sqrt(output_channels)),
max=1 / (math.sqrt(output_channels)),
)
)
self.W_k = ChannelwiseFullyConnectedModule(
self.output_channel_size * num_heads,
bias=False,
weights=[weights],
name=f"{self.name}_nn_{1}",
)

self.a_vec = ConvolutionModule(
num_dims=1,
out_channels=self.num_nodes,
kernel_size=[2 * self.output_channel_size, 1],
groups=self.num_nodes,
bias=False,
name=f"{self.name}_nn_{2}",
)

def forward(
self, node_feature_mat, source_indices, target_indices, reduction="concat"
):
"""Call GATGraphConv
Args:
node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes, input_channels)
source_indices (Layer): Source node indices of the edges with shape (num_edges)
target_indices (Layer): Target node indices of the edges with shape (num_edges)
reduction (string: [concat| average]): The type of reductions to use for multiple heads
Returns:
(Layer) : The output after kernel ops. The shape of the layer is
(num_nodes, num_heads * num_output_channels) if reduction is "concat"
(num_nodes, num_output_channels) if reduction is "average"
"""
# (N x [self.output_channel * self.num_heads])
transform_node_features = self.W_nn(
node_feature_mat, name=f"{self.name}_transform"
)
# (E x [self.output_channel * self.num_heads])
e_i = lbann.Gather(transform_node_features, source_indices, axis=0)
e_j = lbann.Gather(transform_node_features, target_indices, axis=0)
# (E x self.output_channel x self.num_heads)
e_i = lbann.Reshape(
e_i, dims=[self.num_edges, self.output_channel_size, self.num_heads]
)
e_j = lbann.Reshape(
e_j, dims=[self.num_edges, self.output_channel_size, self.num_heads]
)
# (E x 2 * self.output_channel x self.num_heads)
messages = lbann.Concatenation([e_i, e_j], axis=1)
# (E x self.num_heads)
m_ij = lbann.Reshape(
self.a_vec(messages), dims=[self.num_edges, self.num_heads]
)
m_ij = lbann.ExpOperator(lbann.LeakyRelu(m_ij, negative_slope=0.02))
# (N x self.num_heads)
contraction = lbann.Scatter(m_ij, target_indices, axis=0)
# (N x 1 x self.num_heads)
broadcast = lbann.Reshape(contraction, dims=[self.num_nodes, 1, self.num_heads])
# (E x 1 x self.num_heads)
broadcast = lbann.Gather(broadcast, target_indices, axis=1)
# (E x self.output_channel_size x self.num_heads)
broadcast = lbann.Tessellate(
broadcast, dims=[self.num_edges, self.output_channel_size, self.num_heads]
)
# (E x self.output_channel_size x self.num_heads)
normalize = lbann.Scatter(broadcast, source_indices, axis=0)
alpha_ij = lbann.DivideOperator(m_ij, normalize)

h_ij = lbann.MultiplyOperator(alpha_ij, e_j)

h_i = lbann.Scatter(h_ij, source_indices)

if reduction.lower() == "concat":
node_feature_mat = lbann.Reshape(h_i)
elif reduction.lower() == "average":
node_feature_mat = ContractHeads(
h_i, (self.num_nodes, self.output_channel_size, self.num_heads)
)
else:
raise ValueError("Expected reduction arguments are: concat or average")

return node_feature_mat

0 comments on commit f5e43aa

Please sign in to comment.