Skip to content

Commit

Permalink
GNN link
Browse files Browse the repository at this point in the history
  • Loading branch information
zechengz committed Sep 30, 2024
1 parent beed5d8 commit b8e6265
Show file tree
Hide file tree
Showing 4 changed files with 631 additions and 0 deletions.
Empty file added examples/baseline/__init__.py
Empty file.
148 changes: 148 additions & 0 deletions examples/baseline/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Any, Dict, List

import torch
from .nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType


class Model(torch.nn.Module):
def __init__(
self,
data: HeteroData,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_layers: int,
channels: int,
out_channels: int,
num_src_nodes: int,
num_dst_nodes: int,
src_entity_table: str,
dst_entity_table: str,
aggr: str,
norm: str,
# List of node types to add shallow embeddings to input
shallow_list: List[NodeType] = [],
# ID awareness
id_awareness: bool = False,
):
super().__init__()

self.src_entity_table = src_entity_table
self.dst_entity_table = dst_entity_table
self.encoder = HeteroEncoder(
channels=channels,
src_entity_table=src_entity_table,
dst_entity_table=dst_entity_table,
num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes,
node_to_col_names_dict={
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
},
node_to_col_stats=col_stats_dict,
)
self.temporal_encoder = HeteroTemporalEncoder(
node_types=[
node_type for node_type in data.node_types
if "time" in data[node_type]
],
channels=channels,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
edge_types=data.edge_types,
channels=channels,
aggr=aggr,
num_layers=num_layers,
)
self.head = MLP(
channels,
out_channels=out_channels,
norm=norm,
num_layers=1,
)
self.embedding_dict = ModuleDict({
node:
Embedding(data.num_nodes_dict[node], channels)
for node in shallow_list
})

self.id_awareness_emb = None
if id_awareness:
self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.reset_parameters()

def reset_parameters(self):
self.encoder.reset_parameters()
self.temporal_encoder.reset_parameters()
self.gnn.reset_parameters()
self.head.reset_parameters()
for embedding in self.embedding_dict.values():
torch.nn.init.normal_(embedding.weight, std=0.1)
if self.id_awareness_emb is not None:
self.id_awareness_emb.reset_parameters()

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
) -> Tensor:
seed_time = batch[entity_table].seed_time
#######################################################################
x_dict = self.encoder(batch.tf_dict, batch)
#######################################################################

rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

for node_type, embedding in self.embedding_dict.items():
x_dict[node_type] = x_dict[node_type] + embedding(
batch[node_type].n_id)

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
batch.num_sampled_nodes_dict,
batch.num_sampled_edges_dict,
)

return self.head(x_dict[entity_table][:seed_time.size(0)])

def forward_dst_readout(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
if self.id_awareness_emb is None:
raise RuntimeError(
"id_awareness must be set True to use forward_dst_readout")
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight

rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

for node_type, embedding in self.embedding_dict.items():
x_dict[node_type] = x_dict[node_type] + embedding(
batch[node_type].n_id)

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)

return self.head(x_dict[dst_table])
228 changes: 228 additions & 0 deletions examples/baseline/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_frame
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_frame.nn.models import ResNet
from torch_geometric.data import HeteroData
from torch_geometric.nn import (
HeteroConv,
LayerNorm,
PositionalEncoding,
SAGEConv,
)
from torch_geometric.typing import EdgeType, NodeType


class HeteroEncoder(torch.nn.Module):
r"""HeteroEncoder based on PyTorch Frame.
Args:
channels (int): The output channels for each node type.
src_entity_table (str): Source entity table name.
dst_entity_table (str): Destination entity table name.
num_src_nodes (int): Number of source nodes.
num_dst_nodes (int): Number of destination nodes.
node_to_col_names_dict
(Dict[NodeType, Dict[torch_frame.stype, List[str]]]):
A dictionary mapping from node type to column names dictionary
compatible to PyTorch Frame.
torch_frame_model_cls: Model class for PyTorch Frame. The class object
takes :class:`TensorFrame` object as input and outputs
:obj:`channels`-dimensional embeddings. Default to
:class:`torch_frame.nn.ResNet`.
torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for
:class:`torch_frame_model_cls` class. Default keyword argument is
set specific for :class:`torch_frame.nn.ResNet`. Expect it to
be changed for different :class:`torch_frame_model_cls`.
default_stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]):
A dictionary mapping from :obj:`torch_frame.stype` object into a
tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its
keyword arguments :obj:`kwargs`.
"""
def __init__(
self,
channels: int,
#######################################################################
src_entity_table: str,
dst_entity_table: str,
num_src_nodes: int,
num_dst_nodes: int,
#######################################################################
node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype,
List[str]]],
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]],
torch_frame_model_cls=ResNet,
torch_frame_model_kwargs: Dict[str, Any] = {
"channels": 128,
"num_layers": 4,
},
default_stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any] = {
torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}),
torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}),
torch_frame.multicategorical: (
torch_frame.nn.MultiCategoricalEmbeddingEncoder,
{},
),
torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}),
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}),
},
):
super().__init__()

#######################################################################
self.src_entity_table = src_entity_table
self.dst_entity_table = dst_entity_table
#######################################################################

self.encoders = torch.nn.ModuleDict()

for node_type in node_to_col_names_dict.keys():
###################################################################
if node_type == self.src_entity_table:
self.encoders[node_type] = nn.Embedding(
num_src_nodes, channels)
elif node_type == self.dst_entity_table:
self.encoders[node_type] = nn.Embedding(
num_dst_nodes, channels)
###################################################################
else:
stype_encoder_dict = {
stype:
default_stype_encoder_cls_kwargs[stype][0](
**default_stype_encoder_cls_kwargs[stype][1])
for stype in node_to_col_names_dict[node_type].keys()
}
torch_frame_model = torch_frame_model_cls(
**torch_frame_model_kwargs,
out_channels=channels,
col_stats=node_to_col_stats[node_type],
col_names_dict=node_to_col_names_dict[node_type],
stype_encoder_dict=stype_encoder_dict,
)
self.encoders[node_type] = torch_frame_model

def reset_parameters(self):
for node_type, encoder in self.encoders.items():
###################################################################
if node_type in {self.src_entity_table, self.dst_entity_table}:
nn.init.xavier_uniform_(encoder.weight)
###################################################################
else:
encoder.reset_parameters()

def forward(
self,
tf_dict: Dict[NodeType, torch_frame.TensorFrame],
batch: HeteroData,
) -> Dict[NodeType, Tensor]:
x_dict = {}
for node_type, tf in tf_dict.items():
if node_type not in {self.src_entity_table, self.dst_entity_table}:
x_dict[node_type] = self.encoders[node_type](tf)
else:
###############################################################
x_dict[node_type] = self.encoders[node_type](
batch[node_type].n_id)
###############################################################
return x_dict


class HeteroTemporalEncoder(torch.nn.Module):
def __init__(self, node_types: List[NodeType], channels: int):
super().__init__()

self.encoder_dict = torch.nn.ModuleDict({
node_type:
PositionalEncoding(channels)
for node_type in node_types
})
self.lin_dict = torch.nn.ModuleDict({
node_type:
torch.nn.Linear(channels, channels)
for node_type in node_types
})

def reset_parameters(self):
for encoder in self.encoder_dict.values():
encoder.reset_parameters()
for lin in self.lin_dict.values():
lin.reset_parameters()

def forward(
self,
seed_time: Tensor,
time_dict: Dict[NodeType, Tensor],
batch_dict: Dict[NodeType, Tensor],
) -> Dict[NodeType, Tensor]:
out_dict: Dict[NodeType, Tensor] = {}

for node_type, time in time_dict.items():
rel_time = seed_time[batch_dict[node_type]] - time
rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days.

x = self.encoder_dict[node_type](rel_time)
x = self.lin_dict[node_type](x)
out_dict[node_type] = x

return out_dict


class HeteroGraphSAGE(torch.nn.Module):
def __init__(
self,
node_types: List[NodeType],
edge_types: List[EdgeType],
channels: int,
aggr: str = "mean",
num_layers: int = 2,
):
super().__init__()

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
edge_type: SAGEConv(
(channels, channels), channels, aggr=aggr)
for edge_type in edge_types
},
aggr="sum",
)
self.convs.append(conv)

self.norms = torch.nn.ModuleList()
for _ in range(num_layers):
norm_dict = torch.nn.ModuleDict()
for node_type in node_types:
norm_dict[node_type] = LayerNorm(channels, mode="node")
self.norms.append(norm_dict)

def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for norm_dict in self.norms:
for norm in norm_dict.values():
norm.reset_parameters()

def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[NodeType, Tensor],
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
) -> Dict[NodeType, Tensor]:
for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
###################################################################
x_dict = {
key: F.leaky_relu(x, negative_slope=0.2)
for key, x in x_dict.items()
}
###################################################################

return x_dict
Loading

0 comments on commit b8e6265

Please sign in to comment.