-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
631 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from typing import Any, Dict, List | ||
|
||
import torch | ||
from .nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder | ||
from torch import Tensor | ||
from torch.nn import Embedding, ModuleDict | ||
from torch_frame.data.stats import StatType | ||
from torch_geometric.data import HeteroData | ||
from torch_geometric.nn import MLP | ||
from torch_geometric.typing import NodeType | ||
|
||
|
||
class Model(torch.nn.Module): | ||
def __init__( | ||
self, | ||
data: HeteroData, | ||
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], | ||
num_layers: int, | ||
channels: int, | ||
out_channels: int, | ||
num_src_nodes: int, | ||
num_dst_nodes: int, | ||
src_entity_table: str, | ||
dst_entity_table: str, | ||
aggr: str, | ||
norm: str, | ||
# List of node types to add shallow embeddings to input | ||
shallow_list: List[NodeType] = [], | ||
# ID awareness | ||
id_awareness: bool = False, | ||
): | ||
super().__init__() | ||
|
||
self.src_entity_table = src_entity_table | ||
self.dst_entity_table = dst_entity_table | ||
self.encoder = HeteroEncoder( | ||
channels=channels, | ||
src_entity_table=src_entity_table, | ||
dst_entity_table=dst_entity_table, | ||
num_src_nodes=num_src_nodes, | ||
num_dst_nodes=num_dst_nodes, | ||
node_to_col_names_dict={ | ||
node_type: data[node_type].tf.col_names_dict | ||
for node_type in data.node_types | ||
}, | ||
node_to_col_stats=col_stats_dict, | ||
) | ||
self.temporal_encoder = HeteroTemporalEncoder( | ||
node_types=[ | ||
node_type for node_type in data.node_types | ||
if "time" in data[node_type] | ||
], | ||
channels=channels, | ||
) | ||
self.gnn = HeteroGraphSAGE( | ||
node_types=data.node_types, | ||
edge_types=data.edge_types, | ||
channels=channels, | ||
aggr=aggr, | ||
num_layers=num_layers, | ||
) | ||
self.head = MLP( | ||
channels, | ||
out_channels=out_channels, | ||
norm=norm, | ||
num_layers=1, | ||
) | ||
self.embedding_dict = ModuleDict({ | ||
node: | ||
Embedding(data.num_nodes_dict[node], channels) | ||
for node in shallow_list | ||
}) | ||
|
||
self.id_awareness_emb = None | ||
if id_awareness: | ||
self.id_awareness_emb = torch.nn.Embedding(1, channels) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.encoder.reset_parameters() | ||
self.temporal_encoder.reset_parameters() | ||
self.gnn.reset_parameters() | ||
self.head.reset_parameters() | ||
for embedding in self.embedding_dict.values(): | ||
torch.nn.init.normal_(embedding.weight, std=0.1) | ||
if self.id_awareness_emb is not None: | ||
self.id_awareness_emb.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
batch: HeteroData, | ||
entity_table: NodeType, | ||
) -> Tensor: | ||
seed_time = batch[entity_table].seed_time | ||
####################################################################### | ||
x_dict = self.encoder(batch.tf_dict, batch) | ||
####################################################################### | ||
|
||
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, | ||
batch.batch_dict) | ||
|
||
for node_type, rel_time in rel_time_dict.items(): | ||
x_dict[node_type] = x_dict[node_type] + rel_time | ||
|
||
for node_type, embedding in self.embedding_dict.items(): | ||
x_dict[node_type] = x_dict[node_type] + embedding( | ||
batch[node_type].n_id) | ||
|
||
x_dict = self.gnn( | ||
x_dict, | ||
batch.edge_index_dict, | ||
batch.num_sampled_nodes_dict, | ||
batch.num_sampled_edges_dict, | ||
) | ||
|
||
return self.head(x_dict[entity_table][:seed_time.size(0)]) | ||
|
||
def forward_dst_readout( | ||
self, | ||
batch: HeteroData, | ||
entity_table: NodeType, | ||
dst_table: NodeType, | ||
) -> Tensor: | ||
if self.id_awareness_emb is None: | ||
raise RuntimeError( | ||
"id_awareness must be set True to use forward_dst_readout") | ||
seed_time = batch[entity_table].seed_time | ||
x_dict = self.encoder(batch.tf_dict) | ||
# Add ID-awareness to the root node | ||
x_dict[entity_table][:seed_time.size(0 | ||
)] += self.id_awareness_emb.weight | ||
|
||
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, | ||
batch.batch_dict) | ||
|
||
for node_type, rel_time in rel_time_dict.items(): | ||
x_dict[node_type] = x_dict[node_type] + rel_time | ||
|
||
for node_type, embedding in self.embedding_dict.items(): | ||
x_dict[node_type] = x_dict[node_type] + embedding( | ||
batch[node_type].n_id) | ||
|
||
x_dict = self.gnn( | ||
x_dict, | ||
batch.edge_index_dict, | ||
) | ||
|
||
return self.head(x_dict[dst_table]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch_frame | ||
from torch import Tensor | ||
from torch_frame.data.stats import StatType | ||
from torch_frame.nn.models import ResNet | ||
from torch_geometric.data import HeteroData | ||
from torch_geometric.nn import ( | ||
HeteroConv, | ||
LayerNorm, | ||
PositionalEncoding, | ||
SAGEConv, | ||
) | ||
from torch_geometric.typing import EdgeType, NodeType | ||
|
||
|
||
class HeteroEncoder(torch.nn.Module): | ||
r"""HeteroEncoder based on PyTorch Frame. | ||
Args: | ||
channels (int): The output channels for each node type. | ||
src_entity_table (str): Source entity table name. | ||
dst_entity_table (str): Destination entity table name. | ||
num_src_nodes (int): Number of source nodes. | ||
num_dst_nodes (int): Number of destination nodes. | ||
node_to_col_names_dict | ||
(Dict[NodeType, Dict[torch_frame.stype, List[str]]]): | ||
A dictionary mapping from node type to column names dictionary | ||
compatible to PyTorch Frame. | ||
torch_frame_model_cls: Model class for PyTorch Frame. The class object | ||
takes :class:`TensorFrame` object as input and outputs | ||
:obj:`channels`-dimensional embeddings. Default to | ||
:class:`torch_frame.nn.ResNet`. | ||
torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for | ||
:class:`torch_frame_model_cls` class. Default keyword argument is | ||
set specific for :class:`torch_frame.nn.ResNet`. Expect it to | ||
be changed for different :class:`torch_frame_model_cls`. | ||
default_stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]): | ||
A dictionary mapping from :obj:`torch_frame.stype` object into a | ||
tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its | ||
keyword arguments :obj:`kwargs`. | ||
""" | ||
def __init__( | ||
self, | ||
channels: int, | ||
####################################################################### | ||
src_entity_table: str, | ||
dst_entity_table: str, | ||
num_src_nodes: int, | ||
num_dst_nodes: int, | ||
####################################################################### | ||
node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype, | ||
List[str]]], | ||
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]], | ||
torch_frame_model_cls=ResNet, | ||
torch_frame_model_kwargs: Dict[str, Any] = { | ||
"channels": 128, | ||
"num_layers": 4, | ||
}, | ||
default_stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any] = { | ||
torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}), | ||
torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}), | ||
torch_frame.multicategorical: ( | ||
torch_frame.nn.MultiCategoricalEmbeddingEncoder, | ||
{}, | ||
), | ||
torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}), | ||
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), | ||
}, | ||
): | ||
super().__init__() | ||
|
||
####################################################################### | ||
self.src_entity_table = src_entity_table | ||
self.dst_entity_table = dst_entity_table | ||
####################################################################### | ||
|
||
self.encoders = torch.nn.ModuleDict() | ||
|
||
for node_type in node_to_col_names_dict.keys(): | ||
################################################################### | ||
if node_type == self.src_entity_table: | ||
self.encoders[node_type] = nn.Embedding( | ||
num_src_nodes, channels) | ||
elif node_type == self.dst_entity_table: | ||
self.encoders[node_type] = nn.Embedding( | ||
num_dst_nodes, channels) | ||
################################################################### | ||
else: | ||
stype_encoder_dict = { | ||
stype: | ||
default_stype_encoder_cls_kwargs[stype][0]( | ||
**default_stype_encoder_cls_kwargs[stype][1]) | ||
for stype in node_to_col_names_dict[node_type].keys() | ||
} | ||
torch_frame_model = torch_frame_model_cls( | ||
**torch_frame_model_kwargs, | ||
out_channels=channels, | ||
col_stats=node_to_col_stats[node_type], | ||
col_names_dict=node_to_col_names_dict[node_type], | ||
stype_encoder_dict=stype_encoder_dict, | ||
) | ||
self.encoders[node_type] = torch_frame_model | ||
|
||
def reset_parameters(self): | ||
for node_type, encoder in self.encoders.items(): | ||
################################################################### | ||
if node_type in {self.src_entity_table, self.dst_entity_table}: | ||
nn.init.xavier_uniform_(encoder.weight) | ||
################################################################### | ||
else: | ||
encoder.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
tf_dict: Dict[NodeType, torch_frame.TensorFrame], | ||
batch: HeteroData, | ||
) -> Dict[NodeType, Tensor]: | ||
x_dict = {} | ||
for node_type, tf in tf_dict.items(): | ||
if node_type not in {self.src_entity_table, self.dst_entity_table}: | ||
x_dict[node_type] = self.encoders[node_type](tf) | ||
else: | ||
############################################################### | ||
x_dict[node_type] = self.encoders[node_type]( | ||
batch[node_type].n_id) | ||
############################################################### | ||
return x_dict | ||
|
||
|
||
class HeteroTemporalEncoder(torch.nn.Module): | ||
def __init__(self, node_types: List[NodeType], channels: int): | ||
super().__init__() | ||
|
||
self.encoder_dict = torch.nn.ModuleDict({ | ||
node_type: | ||
PositionalEncoding(channels) | ||
for node_type in node_types | ||
}) | ||
self.lin_dict = torch.nn.ModuleDict({ | ||
node_type: | ||
torch.nn.Linear(channels, channels) | ||
for node_type in node_types | ||
}) | ||
|
||
def reset_parameters(self): | ||
for encoder in self.encoder_dict.values(): | ||
encoder.reset_parameters() | ||
for lin in self.lin_dict.values(): | ||
lin.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
seed_time: Tensor, | ||
time_dict: Dict[NodeType, Tensor], | ||
batch_dict: Dict[NodeType, Tensor], | ||
) -> Dict[NodeType, Tensor]: | ||
out_dict: Dict[NodeType, Tensor] = {} | ||
|
||
for node_type, time in time_dict.items(): | ||
rel_time = seed_time[batch_dict[node_type]] - time | ||
rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days. | ||
|
||
x = self.encoder_dict[node_type](rel_time) | ||
x = self.lin_dict[node_type](x) | ||
out_dict[node_type] = x | ||
|
||
return out_dict | ||
|
||
|
||
class HeteroGraphSAGE(torch.nn.Module): | ||
def __init__( | ||
self, | ||
node_types: List[NodeType], | ||
edge_types: List[EdgeType], | ||
channels: int, | ||
aggr: str = "mean", | ||
num_layers: int = 2, | ||
): | ||
super().__init__() | ||
|
||
self.convs = torch.nn.ModuleList() | ||
for _ in range(num_layers): | ||
conv = HeteroConv( | ||
{ | ||
edge_type: SAGEConv( | ||
(channels, channels), channels, aggr=aggr) | ||
for edge_type in edge_types | ||
}, | ||
aggr="sum", | ||
) | ||
self.convs.append(conv) | ||
|
||
self.norms = torch.nn.ModuleList() | ||
for _ in range(num_layers): | ||
norm_dict = torch.nn.ModuleDict() | ||
for node_type in node_types: | ||
norm_dict[node_type] = LayerNorm(channels, mode="node") | ||
self.norms.append(norm_dict) | ||
|
||
def reset_parameters(self): | ||
for conv in self.convs: | ||
conv.reset_parameters() | ||
for norm_dict in self.norms: | ||
for norm in norm_dict.values(): | ||
norm.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
x_dict: Dict[NodeType, Tensor], | ||
edge_index_dict: Dict[NodeType, Tensor], | ||
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None, | ||
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None, | ||
) -> Dict[NodeType, Tensor]: | ||
for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)): | ||
x_dict = conv(x_dict, edge_index_dict) | ||
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} | ||
################################################################### | ||
x_dict = { | ||
key: F.leaky_relu(x, negative_slope=0.2) | ||
for key, x in x_dict.items() | ||
} | ||
################################################################### | ||
|
||
return x_dict |
Oops, something went wrong.