Skip to content

Commit

Permalink
use PyG's ParameterDict
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Dec 5, 2023
1 parent e83fe9d commit 2fc5df1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 83 deletions.
53 changes: 26 additions & 27 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(

self.node_types = node_types
self.edge_types = edge_types
self.edge_types_str = ["__".join(etype) for etype in self.edge_types]
self.num_heads = heads
self.concat_heads = concat

Expand All @@ -95,27 +94,28 @@ def __init__(

lin_weights = dict.fromkeys(self.node_types)

attn_weights = dict.fromkeys(self.edge_types_str)
attn_weights = dict.fromkeys(self.edge_types)

biases = dict.fromkeys(self.edge_types_str)
biases = dict.fromkeys(self.edge_types)

ParameterDict = torch_geometric.nn.parameter_dict.ParameterDict

for edge_type in self.edge_types:
src_type, _, dst_type = edge_type
etype_str = "__".join(edge_type)
self.relations_per_ntype[src_type][0].append(etype_str)
self.relations_per_ntype[src_type][0].append(edge_type)
if src_type != dst_type:
self.relations_per_ntype[dst_type][1].append(etype_str)
self.relations_per_ntype[dst_type][1].append(edge_type)

attn_weights[etype_str] = torch.empty(
attn_weights[edge_type] = torch.empty(
2 * self.num_heads * self.out_channels
)

if bias and concat:
biases[etype_str] = torch.empty(self.num_heads * out_channels)
biases[edge_type] = torch.empty(self.num_heads * out_channels)
elif bias:
biases[etype_str] = torch.empty(out_channels)
biases[edge_type] = torch.empty(out_channels)
else:
biases[etype_str] = None
biases[edge_type] = None

for ntype in self.node_types:
n_src_rel = len(self.relations_per_ntype[ntype][0])
Expand All @@ -126,11 +126,11 @@ def __init__(
(n_rel * self.num_heads * self.out_channels, self.in_channels[ntype])
)

self.lin_weights = nn.ParameterDict(lin_weights)
self.attn_weights = nn.ParameterDict(attn_weights)
self.lin_weights = ParameterDict(lin_weights)
self.attn_weights = ParameterDict(attn_weights)

if bias:
self.bias = nn.ParameterDict(biases)
self.bias = ParameterDict(biases)
else:
self.register_parameter("bias", None)

Expand Down Expand Up @@ -159,8 +159,8 @@ def split_tensors(
x_dst_dict : dict[str, torch.Tensor]
A dictionary to hold destination node feature for each relation graph.
"""
x_src_dict = dict.fromkeys(self.edge_types_str)
x_dst_dict = dict.fromkeys(self.edge_types_str)
x_src_dict = dict.fromkeys(self.edge_types)
x_dst_dict = dict.fromkeys(self.edge_types)

for ntype, t in x_fused_dict.items():
n_src_rel = len(self.relations_per_ntype[ntype][0])
Expand All @@ -182,24 +182,24 @@ def reset_parameters(self, seed: Optional[int] = None):

w_src, w_dst = self.split_tensors(self.lin_weights, dim=0)

for _, edge_type in enumerate(self.edge_types):
for edge_type in self.edge_types:
src_type, _, dst_type = edge_type
etype_str = "__".join(edge_type)

# lin_src
torch_geometric.nn.inits.glorot(w_src[etype_str])
torch_geometric.nn.inits.glorot(w_src[edge_type])

# lin_dst
if src_type != dst_type:
torch_geometric.nn.inits.glorot(w_dst[etype_str])
torch_geometric.nn.inits.glorot(w_dst[edge_type])

# attn_weights
torch_geometric.nn.inits.glorot(
self.attn_weights[etype_str].view(-1, self.num_heads, self.out_channels)
self.attn_weights[edge_type].view(-1, self.num_heads, self.out_channels)
)

# bias
if self.bias is not None:
torch_geometric.nn.inits.zeros(self.bias[etype_str])
torch_geometric.nn.inits.zeros(self.bias[edge_type])

def forward(
self,
Expand All @@ -217,7 +217,6 @@ def forward(

for edge_type, edge_index in edge_index_dict.items():
src_type, _, dst_type = edge_type
etype_str = "__".join(edge_type)

csc = BaseConv.to_csc(
edge_index, (x_dict[src_type].size(0), x_dict[dst_type].size(0))
Expand All @@ -229,8 +228,8 @@ def forward(
bipartite=False,
)
out = mha_gat_n2n(
x_src_dict[etype_str],
self.attn_weights[etype_str],
x_src_dict[edge_type],
self.attn_weights[edge_type],
graph,
num_heads=self.num_heads,
activation="LeakyReLU",
Expand All @@ -244,8 +243,8 @@ def forward(
bipartite=True,
)
out = mha_gat_n2n(
(x_src_dict[etype_str], x_dst_dict[etype_str]),
self.attn_weights[etype_str],
(x_src_dict[edge_type], x_dst_dict[edge_type]),
self.attn_weights[edge_type],
graph,
num_heads=self.num_heads,
activation="LeakyReLU",
Expand All @@ -254,7 +253,7 @@ def forward(
)

if self.bias is not None:
out = out + self.bias[etype_str]
out = out + self.bias[edge_type]

out_dict[dst_type].append(out)

Expand Down
35 changes: 15 additions & 20 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from cugraph_pyg.nn import HeteroGATConv as CuGraphHeteroGATConv
from cugraph.utilities.utils import import_optional, MissingModule
from utils import convert_edge_type_key

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")
Expand Down Expand Up @@ -67,18 +66,17 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads):
# copy over linear and attention weights
w_src, w_dst = conv2.split_tensors(conv2.lin_weights, dim=0)
with torch.no_grad():
for etype_str in conv2.edge_types_str:
src_t, _, dst_t = etype_str.split("__")
pyg_etype_str = convert_edge_type_key(etype_str)
w_src[etype_str][:, :] = conv1.convs[pyg_etype_str].lin_src.weight[:, :]
if w_dst[etype_str] is not None:
w_dst[etype_str][:, :] = conv1.convs[pyg_etype_str].lin_dst.weight[:, :]

conv2.attn_weights[etype_str][: heads * out_channels] = conv1.convs[
pyg_etype_str
for edge_type in conv2.edge_types:
src_t, _, dst_t = edge_type
w_src[edge_type][:, :] = conv1.convs[edge_type].lin_src.weight[:, :]
if w_dst[edge_type] is not None:
w_dst[edge_type][:, :] = conv1.convs[edge_type].lin_dst.weight[:, :]

conv2.attn_weights[edge_type][: heads * out_channels] = conv1.convs[
edge_type
].att_src.data.flatten()
conv2.attn_weights[etype_str][heads * out_channels :] = conv1.convs[
pyg_etype_str
conv2.attn_weights[edge_type][heads * out_channels :] = conv1.convs[
edge_type
].att_dst.data.flatten()

out1 = conv1(data.x_dict, data.edge_index_dict)
Expand All @@ -98,16 +96,15 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads):

# check gradient w.r.t attention weights
out_dim = heads * out_channels
for etype_str in conv2.edge_types_str:
pyg_etype_str = convert_edge_type_key(etype_str)
for edge_type in conv2.edge_types:
assert torch.allclose(
conv1.convs[pyg_etype_str].att_src.grad.flatten(),
conv2.attn_weights[etype_str].grad[:out_dim],
conv1.convs[edge_type].att_src.grad.flatten(),
conv2.attn_weights[edge_type].grad[:out_dim],
atol=ATOL,
)
assert torch.allclose(
conv1.convs[pyg_etype_str].att_dst.grad.flatten(),
conv2.attn_weights[etype_str].grad[out_dim:],
conv1.convs[edge_type].att_dst.grad.flatten(),
conv2.attn_weights[edge_type].grad[out_dim:],
atol=ATOL,
)

Expand All @@ -116,10 +113,8 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads):
for node_t, (rels_as_src, rels_as_dst) in conv2.relations_per_ntype.items():
grad_list = []
for rel_t in rels_as_src:
rel_t = convert_edge_type_key(rel_t)
grad_list.append(conv1.convs[rel_t].lin_src.weight.grad.clone())
for rel_t in rels_as_dst:
rel_t = convert_edge_type_key(rel_t)
grad_list.append(conv1.convs[rel_t].lin_dst.weight.grad.clone())
assert len(grad_list) > 0
grad_lin_weights_ref[node_t] = torch.vstack(grad_list)
Expand Down
36 changes: 0 additions & 36 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/utils.py

This file was deleted.

0 comments on commit 2fc5df1

Please sign in to comment.