Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add IDGNN #2

Merged
merged 3 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions examples/gnn_link.py

This file was deleted.

5 changes: 5 additions & 0 deletions hybridgnn/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .encoder import HeteroEncoder

__all__ = classes = [
'HeteroEncoder',
]
134 changes: 134 additions & 0 deletions hybridgnn/nn/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import Any, Dict, List

import torch
import torch_frame
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_frame.nn.models import ResNet
from torch_geometric.nn import PositionalEncoding
from torch_geometric.typing import NodeType

DEFAULT_STYPE_ENCODER_DICT: Dict[torch_frame.stype, Any] = {
torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}),
torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}),
torch_frame.multicategorical: (
torch_frame.nn.MultiCategoricalEmbeddingEncoder,
{},
),
torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}),
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}),
}
SECONDS_IN_A_DAY = 60 * 60 * 24


class HeteroEncoder(torch.nn.Module):
r"""HeteroStypeWiseEncoder is a simple encoder to encode multi-modal
data from different node types.

Args:
channels (int): The output channels for each node type.
node_to_col_names_dict (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): # noqa: E501
A dictionary mapping from node type to column names dictionary
compatible to PyTorch Frame.
node_to_col_stats (Dict[NodeType, Dict[str, Dict[StatType, Any]]]):
A dictionary mapping from node type to column statistics dictionary
compatible to PyTorch Frame.
stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]):
A dictionary mapping from :obj:`torch_frame.stype` object into a
tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its
keyword arguments :obj:`kwargs`.
torch_frame_model_cls: Model class for PyTorch Frame. The class object
takes :class:`TensorFrame` object as input and outputs
:obj:`channels`-dimensional embeddings. Default to
:class:`torch_frame.nn.ResNet`.
torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for
:class:`torch_frame_model_cls` class. Default keyword argument is
set specific for :class:`torch_frame.nn.ResNet`. Expect it to
be changed for different :class:`torch_frame_model_cls`.
"""
def __init__(
self,
channels: int,
node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype,
List[str]]],
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]],
stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any],
torch_frame_model_cls=ResNet,
torch_frame_model_kwargs: Dict[str, Any] = {
"channels": 128,
"num_layers": 4,
},
) -> None:
super().__init__()

self.encoders = torch.nn.ModuleDict()

for node_type in node_to_col_names_dict.keys():
stype_encoder_dict = {
stype:
stype_encoder_cls_kwargs[stype][0](
**stype_encoder_cls_kwargs[stype][1])
for stype in node_to_col_names_dict[node_type].keys()
}

self.encoders[node_type] = torch_frame_model_cls(
**torch_frame_model_kwargs,
out_channels=channels,
col_stats=node_to_col_stats[node_type],
col_names_dict=node_to_col_names_dict[node_type],
stype_encoder_dict=stype_encoder_dict,
)

def reset_parameters(self) -> None:
for encoder in self.encoders.values():
encoder.reset_parameters()

def forward(
self,
tf_dict: Dict[NodeType, torch_frame.TensorFrame],
) -> Dict[NodeType, Tensor]:
x_dict = {
node_type: self.encoders[node_type](tf)
for node_type, tf in tf_dict.items()
}
return x_dict


class HeteroTemporalEncoder(torch.nn.Module):
def __init__(self, node_types: List[NodeType], channels: int) -> None:
super().__init__()

self.encoder_dict = torch.nn.ModuleDict({
node_type:
PositionalEncoding(channels)
for node_type in node_types
})
self.lin_dict = torch.nn.ModuleDict({
node_type:
torch.nn.Linear(channels, channels)
for node_type in node_types
})

def reset_parameters(self) -> None:
for encoder in self.encoder_dict.values():
encoder.reset_parameters()
for lin in self.lin_dict.values():
lin.reset_parameters()

def forward(
self,
seed_time: Tensor,
time_dict: Dict[NodeType, Tensor],
batch_dict: Dict[NodeType, Tensor],
) -> Dict[NodeType, Tensor]:
out_dict: Dict[NodeType, Tensor] = {}

for node_type, time in time_dict.items():
rel_time = seed_time[batch_dict[node_type]] - time
rel_time = rel_time / SECONDS_IN_A_DAY

x = self.encoder_dict[node_type](rel_time)
x = self.lin_dict[node_type](x)
out_dict[node_type] = x

return out_dict
7 changes: 7 additions & 0 deletions hybridgnn/nn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .graphsage import HeteroGraphSAGE
from .idgnn import IDGNN

__all__ = classes = [
'HeteroGraphSAGE',
'IDGNN',
]
58 changes: 58 additions & 0 deletions hybridgnn/nn/models/graphsage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Dict, List, Optional

import torch
from torch import Tensor
from torch_geometric.nn import HeteroConv, LayerNorm, SAGEConv
from torch_geometric.typing import EdgeType, NodeType


class HeteroGraphSAGE(torch.nn.Module):
def __init__(
self,
node_types: List[NodeType],
edge_types: List[EdgeType],
channels: int,
aggr: str = "mean",
num_layers: int = 2,
) -> None:
super().__init__()

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
edge_type: SAGEConv(
(channels, channels), channels, aggr=aggr)
for edge_type in edge_types
},
aggr="sum",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the input argument has aggr="mean" but here aggr is hard coded

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's intended to use sum here? cc @zechengz

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use sum here for now. The aggr = "mean" is used for the SAGEConv aggregation. Here the aggr seems to have a different meaning, which aggregates embeddings for the same node type together (if I remembered correctly)

)
self.convs.append(conv)

self.norms = torch.nn.ModuleList()
for _ in range(num_layers):
norm_dict = torch.nn.ModuleDict()
for node_type in node_types:
norm_dict[node_type] = LayerNorm(channels, mode="node")
self.norms.append(norm_dict)

def reset_parameters(self) -> None:
for conv in self.convs:
conv.reset_parameters()
for norm_dict in self.norms:
for norm in norm_dict.values():
norm.reset_parameters()

def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[NodeType, Tensor],
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
) -> Dict[NodeType, Tensor]:
for i, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
x_dict = {key: x.relu() for key, x in x_dict.items()}

return x_dict
93 changes: 93 additions & 0 deletions hybridgnn/nn/models/idgnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Any, Dict

import torch
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from hybridgnn.nn.encoder import (
DEFAULT_STYPE_ENCODER_DICT,
HeteroEncoder,
HeteroTemporalEncoder,
)
from hybridgnn.nn.models import HeteroGraphSAGE


class IDGNN(torch.nn.Module):
def __init__(
self,
data: HeteroData,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_layers: int,
channels: int,
out_channels: int,
aggr: str,
norm: str,
) -> None:
super().__init__()

self.encoder = HeteroEncoder(
channels=channels,
node_to_col_names_dict={
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
},
node_to_col_stats=col_stats_dict,
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT,
)
self.temporal_encoder = HeteroTemporalEncoder(
node_types=[
node_type for node_type in data.node_types
if "time" in data[node_type]
],
channels=channels,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
edge_types=data.edge_types,
channels=channels,
aggr=aggr,
num_layers=num_layers,
)
self.head = MLP(
channels,
out_channels=out_channels,
norm=norm,
num_layers=1,
)

self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.reset_parameters()

def reset_parameters(self) -> None:
self.encoder.reset_parameters()
self.temporal_encoder.reset_parameters()
self.gnn.reset_parameters()
self.head.reset_parameters()
self.id_awareness_emb.reset_parameters()

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
# Add ID-awareness to the root node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so standard GNN is basically just IDGNN without this id_awareness_emb ? These class can be reused to include standard GNN without ID awareness then just by making this optional, maybe as an argument?

x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)

return self.head(x_dict[dst_table])
39 changes: 39 additions & 0 deletions test/nn/test_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from relbench.datasets.fake import FakeDataset
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.utils import get_stype_proposal
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.testing.text_embedder import HashTextEmbedder

from hybridgnn.nn.encoder import DEFAULT_STYPE_ENCODER_DICT, HeteroEncoder


def test_encoder(tmp_path):
dataset = FakeDataset()

db = dataset.get_db()
data, col_stats_dict = make_pkey_fkey_graph(
db,
get_stype_proposal(db),
text_embedder_cfg=TextEmbedderConfig(text_embedder=HashTextEmbedder(8),
batch_size=None),
cache_dir=tmp_path,
)
node_to_col_names_dict = {
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
}

# Ensure that full-batch model works as expected ##########################

encoder = HeteroEncoder(
64, node_to_col_names_dict, col_stats_dict,
stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT)

x_dict = encoder(data.tf_dict)
assert 'product' in x_dict.keys()
assert 'customer' in x_dict.keys()
assert 'review' in x_dict.keys()
assert 'relations' in x_dict.keys()
assert x_dict['relations'].shape == torch.Size([20, 64])
assert x_dict['product'].shape == torch.Size([30, 64])
Loading
Loading