Skip to content

Commit

Permalink
docstring, allow different in_channels
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Oct 5, 2023
1 parent 5538418 commit 32f5b12
Showing 1 changed file with 59 additions and 18 deletions.
77 changes: 59 additions & 18 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,47 @@


class HeteroGATConv(BaseConv):
r"""Heterogeneous graph."""
r"""The graph attentional operator on heterogeneous graphs, where a separate
`GATConv` is applied on the homogeneous graph for each edge type. Compared
with directly wrapping `GATConv`s with `HeteroConv`, `HeteroGATConv` fuses
all the linear transformation associated with each node type together into 1
GEMM call, to improve the performance on GPUs.
Parameters
----------
in_channels : int or Dict[str, int])
Size of each input sample of every node type.
out_channels : int
Size of each output sample.
node_types : List[str]
List of Node types.
edge_types : List[Tuple[str, str, str]]
List of Edge types.
heads : int, optional (default=1)
Number of multi-head-attentions.
concat : bool, optional (default=True):
If set to :obj:`False`, the multi-head attentions are averaged instead
of concatenated.
negative_slope : float, optional (default=0.2)
LeakyReLU angle of the negative slope.
bias : bool, optional (default=True)
If set to :obj:`False`, the layer will not learn an additive bias.
aggr : str, optional (default="sum")
The aggregation scheme to use for grouping node embeddings generated by
different relations. Choose from "sum", "mean", "min", "max".
"""

def __init__(
self,
in_channels: int,
in_channels: int | dict[str, int],
out_channels: int,
node_types: list[str],
edge_types: list[tuple[str, str, str]],
Expand All @@ -41,25 +77,27 @@ def __init__(
):
super().__init__()

if isinstance(in_channels, int):
in_channels = dict.fromkeys(node_types, in_channels)
self.in_channels = in_channels
self.out_channels = out_channels

self.node_types = node_types
self.edge_types = edge_types

self.edge_types_str = ["__".join(etype) for etype in self.edge_types]
self.num_heads = heads
self.concat_heads = concat
self.in_channels = in_channels
self.out_channels = out_channels

self.negative_slope = negative_slope
self.aggr = aggr

edge_types_str = ["__".join(etype) for etype in self.edge_types]

self.relations_per_ntype = defaultdict(lambda: ([], []))

lin_weights = dict.fromkeys(self.node_types)

attn_weights = dict.fromkeys(edge_types_str)
attn_weights = dict.fromkeys(self.edge_types_str)

biases = dict.fromkeys(edge_types_str)
biases = dict.fromkeys(self.edge_types_str)

for edge_type in self.edge_types:
src_type, _, dst_type = edge_type
Expand All @@ -85,7 +123,7 @@ def __init__(
n_rel = n_src_rel + n_dst_rel

lin_weights[ntype] = torch.empty(
(n_rel * self.num_heads * self.out_channels, self.in_channels)
(n_rel * self.num_heads * self.out_channels, self.in_channels[ntype])
)

self.lin_weights = nn.ParameterDict(lin_weights)
Expand All @@ -98,9 +136,12 @@ def __init__(

self.reset_parameters()

def split_tensors(self, x_dict: torch.Tensor, dim: int):
x_src_dict = {"__".join(etype): None for etype in self.edge_types}
x_dst_dict = {"__".join(etype): None for etype in self.edge_types}
def split_tensors(
self, x_dict: torch.Tensor, dim: int
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
"""Split fused tensors into chunks based on edge types."""
x_src_dict = dict.fromkeys(self.edge_types_str)
x_dst_dict = dict.fromkeys(self.edge_types_str)

for ntype, t in x_dict.items():
n_src_rel = len(self.relations_per_ntype[ntype][0])
Expand All @@ -118,14 +159,14 @@ def split_tensors(self, x_dict: torch.Tensor, dim: int):

return x_src_dict, x_dst_dict

def reset_parameters(self, seed: Optional[int] = None):
def reset_parameters(self, seed: Optional[int] = None) -> None:
if seed is not None:
torch.manual_seed(seed)

w_src, w_dst = self.split_tensors(self.lin_weights, dim=0)

for i, edge_type in enumerate(self.edge_types):
src_type, etype, dst_type = edge_type
src_type, _, dst_type = edge_type
etype_str = "__".join(edge_type)
# lin_src
torch_geometric.nn.inits.glorot(w_src[etype_str])
Expand All @@ -143,8 +184,8 @@ def reset_parameters(self, seed: Optional[int] = None):
if self.bias is not None:
torch_geometric.nn.inits.zeros(self.bias[etype_str])

def forward(self, x_dict: dict, edge_index_dict: dict):
feat_dict = {ntype: None for ntype in x_dict.keys()}
def forward(self, x_dict: dict, edge_index_dict: dict) -> dict[str, torch.Tensor]:
feat_dict = dict.fromkeys(x_dict.keys())

for ntype, x in x_dict.items():
feat_dict[ntype] = x @ self.lin_weights[ntype].T
Expand All @@ -154,7 +195,7 @@ def forward(self, x_dict: dict, edge_index_dict: dict):
out_dict = defaultdict(list)

for edge_type, edge_index in edge_index_dict.items():
src_type, etype, dst_type = edge_type
src_type, _, dst_type = edge_type
etype_str = "__".join(edge_type)

csc = BaseConv.to_csc(
Expand Down

0 comments on commit 32f5b12

Please sign in to comment.