-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add IDGNN #2
Merged
Merged
add IDGNN #2
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .encoder import HeteroEncoder | ||
|
||
__all__ = classes = [ | ||
'HeteroEncoder', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from typing import Any, Dict, List | ||
|
||
import torch | ||
import torch_frame | ||
from torch import Tensor | ||
from torch_frame.data.stats import StatType | ||
from torch_frame.nn.models import ResNet | ||
from torch_geometric.nn import PositionalEncoding | ||
from torch_geometric.typing import NodeType | ||
|
||
DEFAULT_STYPE_ENCODER_DICT: Dict[torch_frame.stype, Any] = { | ||
torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}), | ||
torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}), | ||
torch_frame.multicategorical: ( | ||
torch_frame.nn.MultiCategoricalEmbeddingEncoder, | ||
{}, | ||
), | ||
torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}), | ||
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), | ||
} | ||
SECONDS_IN_A_DAY = 60 * 60 * 24 | ||
|
||
|
||
class HeteroEncoder(torch.nn.Module): | ||
r"""HeteroStypeWiseEncoder is a simple encoder to encode multi-modal | ||
data from different node types. | ||
|
||
Args: | ||
channels (int): The output channels for each node type. | ||
node_to_col_names_dict (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): # noqa: E501 | ||
A dictionary mapping from node type to column names dictionary | ||
compatible to PyTorch Frame. | ||
node_to_col_stats (Dict[NodeType, Dict[str, Dict[StatType, Any]]]): | ||
A dictionary mapping from node type to column statistics dictionary | ||
compatible to PyTorch Frame. | ||
stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]): | ||
A dictionary mapping from :obj:`torch_frame.stype` object into a | ||
tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its | ||
keyword arguments :obj:`kwargs`. | ||
torch_frame_model_cls: Model class for PyTorch Frame. The class object | ||
takes :class:`TensorFrame` object as input and outputs | ||
:obj:`channels`-dimensional embeddings. Default to | ||
:class:`torch_frame.nn.ResNet`. | ||
torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for | ||
:class:`torch_frame_model_cls` class. Default keyword argument is | ||
set specific for :class:`torch_frame.nn.ResNet`. Expect it to | ||
be changed for different :class:`torch_frame_model_cls`. | ||
""" | ||
def __init__( | ||
self, | ||
channels: int, | ||
node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype, | ||
List[str]]], | ||
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]], | ||
stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any], | ||
torch_frame_model_cls=ResNet, | ||
torch_frame_model_kwargs: Dict[str, Any] = { | ||
"channels": 128, | ||
"num_layers": 4, | ||
}, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.encoders = torch.nn.ModuleDict() | ||
|
||
for node_type in node_to_col_names_dict.keys(): | ||
stype_encoder_dict = { | ||
stype: | ||
stype_encoder_cls_kwargs[stype][0]( | ||
**stype_encoder_cls_kwargs[stype][1]) | ||
for stype in node_to_col_names_dict[node_type].keys() | ||
} | ||
|
||
self.encoders[node_type] = torch_frame_model_cls( | ||
**torch_frame_model_kwargs, | ||
out_channels=channels, | ||
col_stats=node_to_col_stats[node_type], | ||
col_names_dict=node_to_col_names_dict[node_type], | ||
stype_encoder_dict=stype_encoder_dict, | ||
) | ||
|
||
def reset_parameters(self) -> None: | ||
for encoder in self.encoders.values(): | ||
encoder.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
tf_dict: Dict[NodeType, torch_frame.TensorFrame], | ||
) -> Dict[NodeType, Tensor]: | ||
x_dict = { | ||
node_type: self.encoders[node_type](tf) | ||
for node_type, tf in tf_dict.items() | ||
} | ||
return x_dict | ||
|
||
|
||
class HeteroTemporalEncoder(torch.nn.Module): | ||
def __init__(self, node_types: List[NodeType], channels: int) -> None: | ||
super().__init__() | ||
|
||
self.encoder_dict = torch.nn.ModuleDict({ | ||
node_type: | ||
PositionalEncoding(channels) | ||
for node_type in node_types | ||
}) | ||
self.lin_dict = torch.nn.ModuleDict({ | ||
node_type: | ||
torch.nn.Linear(channels, channels) | ||
for node_type in node_types | ||
}) | ||
|
||
def reset_parameters(self) -> None: | ||
for encoder in self.encoder_dict.values(): | ||
encoder.reset_parameters() | ||
for lin in self.lin_dict.values(): | ||
lin.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
seed_time: Tensor, | ||
time_dict: Dict[NodeType, Tensor], | ||
batch_dict: Dict[NodeType, Tensor], | ||
) -> Dict[NodeType, Tensor]: | ||
out_dict: Dict[NodeType, Tensor] = {} | ||
|
||
for node_type, time in time_dict.items(): | ||
rel_time = seed_time[batch_dict[node_type]] - time | ||
rel_time = rel_time / SECONDS_IN_A_DAY | ||
|
||
x = self.encoder_dict[node_type](rel_time) | ||
x = self.lin_dict[node_type](x) | ||
out_dict[node_type] = x | ||
|
||
return out_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .graphsage import HeteroGraphSAGE | ||
from .idgnn import IDGNN | ||
|
||
__all__ = classes = [ | ||
'HeteroGraphSAGE', | ||
'IDGNN', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Dict, List, Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch_geometric.nn import HeteroConv, LayerNorm, SAGEConv | ||
from torch_geometric.typing import EdgeType, NodeType | ||
|
||
|
||
class HeteroGraphSAGE(torch.nn.Module): | ||
def __init__( | ||
self, | ||
node_types: List[NodeType], | ||
edge_types: List[EdgeType], | ||
channels: int, | ||
aggr: str = "mean", | ||
num_layers: int = 2, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.convs = torch.nn.ModuleList() | ||
for _ in range(num_layers): | ||
conv = HeteroConv( | ||
{ | ||
edge_type: SAGEConv( | ||
(channels, channels), channels, aggr=aggr) | ||
for edge_type in edge_types | ||
}, | ||
aggr="sum", | ||
) | ||
self.convs.append(conv) | ||
|
||
self.norms = torch.nn.ModuleList() | ||
for _ in range(num_layers): | ||
norm_dict = torch.nn.ModuleDict() | ||
for node_type in node_types: | ||
norm_dict[node_type] = LayerNorm(channels, mode="node") | ||
self.norms.append(norm_dict) | ||
|
||
def reset_parameters(self) -> None: | ||
for conv in self.convs: | ||
conv.reset_parameters() | ||
for norm_dict in self.norms: | ||
for norm in norm_dict.values(): | ||
norm.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
x_dict: Dict[NodeType, Tensor], | ||
edge_index_dict: Dict[NodeType, Tensor], | ||
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None, | ||
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None, | ||
) -> Dict[NodeType, Tensor]: | ||
for i, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)): | ||
x_dict = conv(x_dict, edge_index_dict) | ||
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} | ||
x_dict = {key: x.relu() for key, x in x_dict.items()} | ||
|
||
return x_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch_frame.data.stats import StatType | ||
from torch_geometric.data import HeteroData | ||
from torch_geometric.nn import MLP | ||
from torch_geometric.typing import NodeType | ||
|
||
from hybridgnn.nn.encoder import ( | ||
DEFAULT_STYPE_ENCODER_DICT, | ||
HeteroEncoder, | ||
HeteroTemporalEncoder, | ||
) | ||
from hybridgnn.nn.models import HeteroGraphSAGE | ||
|
||
|
||
class IDGNN(torch.nn.Module): | ||
def __init__( | ||
self, | ||
data: HeteroData, | ||
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], | ||
num_layers: int, | ||
channels: int, | ||
out_channels: int, | ||
aggr: str, | ||
norm: str, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.encoder = HeteroEncoder( | ||
channels=channels, | ||
node_to_col_names_dict={ | ||
node_type: data[node_type].tf.col_names_dict | ||
for node_type in data.node_types | ||
}, | ||
node_to_col_stats=col_stats_dict, | ||
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, | ||
) | ||
self.temporal_encoder = HeteroTemporalEncoder( | ||
node_types=[ | ||
node_type for node_type in data.node_types | ||
if "time" in data[node_type] | ||
], | ||
channels=channels, | ||
) | ||
self.gnn = HeteroGraphSAGE( | ||
node_types=data.node_types, | ||
edge_types=data.edge_types, | ||
channels=channels, | ||
aggr=aggr, | ||
num_layers=num_layers, | ||
) | ||
self.head = MLP( | ||
channels, | ||
out_channels=out_channels, | ||
norm=norm, | ||
num_layers=1, | ||
) | ||
|
||
self.id_awareness_emb = torch.nn.Embedding(1, channels) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self) -> None: | ||
self.encoder.reset_parameters() | ||
self.temporal_encoder.reset_parameters() | ||
self.gnn.reset_parameters() | ||
self.head.reset_parameters() | ||
self.id_awareness_emb.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
batch: HeteroData, | ||
entity_table: NodeType, | ||
dst_table: NodeType, | ||
) -> Tensor: | ||
seed_time = batch[entity_table].seed_time | ||
x_dict = self.encoder(batch.tf_dict) | ||
# Add ID-awareness to the root node | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so standard GNN is basically just IDGNN without this |
||
x_dict[entity_table][:seed_time.size(0 | ||
)] += self.id_awareness_emb.weight | ||
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, | ||
batch.batch_dict) | ||
|
||
for node_type, rel_time in rel_time_dict.items(): | ||
x_dict[node_type] = x_dict[node_type] + rel_time | ||
|
||
x_dict = self.gnn( | ||
x_dict, | ||
batch.edge_index_dict, | ||
) | ||
|
||
return self.head(x_dict[dst_table]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from relbench.datasets.fake import FakeDataset | ||
from relbench.modeling.graph import make_pkey_fkey_graph | ||
from relbench.modeling.utils import get_stype_proposal | ||
from torch_frame.config.text_embedder import TextEmbedderConfig | ||
from torch_frame.testing.text_embedder import HashTextEmbedder | ||
|
||
from hybridgnn.nn.encoder import DEFAULT_STYPE_ENCODER_DICT, HeteroEncoder | ||
|
||
|
||
def test_encoder(tmp_path): | ||
dataset = FakeDataset() | ||
|
||
db = dataset.get_db() | ||
data, col_stats_dict = make_pkey_fkey_graph( | ||
db, | ||
get_stype_proposal(db), | ||
text_embedder_cfg=TextEmbedderConfig(text_embedder=HashTextEmbedder(8), | ||
batch_size=None), | ||
cache_dir=tmp_path, | ||
) | ||
node_to_col_names_dict = { | ||
node_type: data[node_type].tf.col_names_dict | ||
for node_type in data.node_types | ||
} | ||
|
||
# Ensure that full-batch model works as expected ########################## | ||
|
||
encoder = HeteroEncoder( | ||
64, node_to_col_names_dict, col_stats_dict, | ||
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT) | ||
|
||
x_dict = encoder(data.tf_dict) | ||
assert 'product' in x_dict.keys() | ||
assert 'customer' in x_dict.keys() | ||
assert 'review' in x_dict.keys() | ||
assert 'relations' in x_dict.keys() | ||
assert x_dict['relations'].shape == torch.Size([20, 64]) | ||
assert x_dict['product'].shape == torch.Size([30, 64]) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the input argument has
aggr="mean"
but hereaggr
is hard codedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's intended to use
sum
here? cc @zechengzThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use
sum
here for now. Theaggr = "mean"
is used for theSAGEConv
aggregation. Here theaggr
seems to have a different meaning, which aggregates embeddings for the same node type together (if I remembered correctly)