Skip to content

Commit

Permalink
add idgnn
Browse files Browse the repository at this point in the history
  • Loading branch information
yiweny committed Jul 20, 2024
1 parent 30968e4 commit 7fed12d
Show file tree
Hide file tree
Showing 8 changed files with 370 additions and 40 deletions.
40 changes: 0 additions & 40 deletions examples/gnn_link.py

This file was deleted.

5 changes: 5 additions & 0 deletions hybridgnn/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .encoder import HeteroStypeWiseEncoder

__all__ = classes = [
'HeteroStypeWiseEncoder',
]
119 changes: 119 additions & 0 deletions hybridgnn/nn/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Any, Dict, List

import torch
import torch_frame
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder
from torch_geometric.nn import PositionalEncoding
from torch_geometric.typing import NodeType

DEFAULT_STYPE_ENCODER_DICT: Dict[torch_frame.stype, Any] = {
torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}),
torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}),
torch_frame.multicategorical: (
torch_frame.nn.MultiCategoricalEmbeddingEncoder,
{},
),
torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}),
torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}),
}


class HeteroStypeWiseEncoder(torch.nn.Module):
r"""StypeWiseEncoder based on PyTorch Frame.
Args:
channels (int): The output channels for each node type.
node_to_col_names_dict (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): # noqa
A dictionary mapping from node type to column names dictionary
compatible to PyTorch Frame.
node_to_col_stats (Dict[NodeType, Dict[str, Dict[StatType, Any]]]):
A dictionary mapping from node type to column statistics dictionary
compatible to PyTorch Frame.
stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]):
A dictionary mapping from :obj:`torch_frame.stype` object into a
tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its
keyword arguments :obj:`kwargs`.
"""
def __init__(
self,
channels: int,
node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype,
List[str]]],
node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]],
stype_encoder_cls_kwargs: Dict[torch_frame.stype,
Any] = DEFAULT_STYPE_ENCODER_DICT,
):
super().__init__()

self.encoders = torch.nn.ModuleDict()

for node_type in node_to_col_names_dict.keys():
stype_encoder_dict = {
stype:
stype_encoder_cls_kwargs[stype][0](
**stype_encoder_cls_kwargs[stype][1])
for stype in node_to_col_names_dict[node_type].keys()
}

self.encoders[node_type] = StypeWiseFeatureEncoder(
out_channels=channels,
col_stats=node_to_col_stats[node_type],
col_names_dict=node_to_col_names_dict[node_type],
stype_encoder_dict=stype_encoder_dict,
)

def reset_parameters(self):
for encoder in self.encoders.values():
encoder.reset_parameters()

def forward(
self,
tf_dict: Dict[NodeType, torch_frame.TensorFrame],
) -> Dict[NodeType, Tensor]:
x_dict = {
node_type: self.encoders[node_type](tf)[0].sum(axis=1)
for node_type, tf in tf_dict.items()
}
return x_dict


class HeteroTemporalEncoder(torch.nn.Module):
def __init__(self, node_types: List[NodeType], channels: int):
super().__init__()

self.encoder_dict = torch.nn.ModuleDict({
node_type:
PositionalEncoding(channels)
for node_type in node_types
})
self.lin_dict = torch.nn.ModuleDict({
node_type:
torch.nn.Linear(channels, channels)
for node_type in node_types
})

def reset_parameters(self):
for encoder in self.encoder_dict.values():
encoder.reset_parameters()
for lin in self.lin_dict.values():
lin.reset_parameters()

def forward(
self,
seed_time: Tensor,
time_dict: Dict[NodeType, Tensor],
batch_dict: Dict[NodeType, Tensor],
) -> Dict[NodeType, Tensor]:
out_dict: Dict[NodeType, Tensor] = {}

for node_type, time in time_dict.items():
rel_time = seed_time[batch_dict[node_type]] - time
rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days.

x = self.encoder_dict[node_type](rel_time)
x = self.lin_dict[node_type](x)
out_dict[node_type] = x

return out_dict
7 changes: 7 additions & 0 deletions hybridgnn/nn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .graphsage import HeteroGraphSAGE
from .idgnn import IDGNN

__all__ = classes = [
'HeteroGraphSAGE',
'IDGNN',
]
58 changes: 58 additions & 0 deletions hybridgnn/nn/models/graphsage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Dict, List, Optional

import torch
from torch import Tensor
from torch_geometric.nn import HeteroConv, LayerNorm, SAGEConv
from torch_geometric.typing import EdgeType, NodeType


class HeteroGraphSAGE(torch.nn.Module):
def __init__(
self,
node_types: List[NodeType],
edge_types: List[EdgeType],
channels: int,
aggr: str = "mean",
num_layers: int = 2,
):
super().__init__()

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
edge_type: SAGEConv(
(channels, channels), channels, aggr=aggr)
for edge_type in edge_types
},
aggr="sum",
)
self.convs.append(conv)

self.norms = torch.nn.ModuleList()
for _ in range(num_layers):
norm_dict = torch.nn.ModuleDict()
for node_type in node_types:
norm_dict[node_type] = LayerNorm(channels, mode="node")
self.norms.append(norm_dict)

def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for norm_dict in self.norms:
for norm in norm_dict.values():
norm.reset_parameters()

def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[NodeType, Tensor],
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
) -> Dict[NodeType, Tensor]:
for i, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
x_dict = {key: x.relu() for key, x in x_dict.items()}

return x_dict
88 changes: 88 additions & 0 deletions hybridgnn/nn/models/idgnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Any, Dict

import torch
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from hybridgnn.nn.encoder import HeteroStypeWiseEncoder, HeteroTemporalEncoder
from hybridgnn.nn.models import HeteroGraphSAGE


class IDGNN(torch.nn.Module):
def __init__(
self,
data: HeteroData,
col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
num_layers: int,
channels: int,
out_channels: int,
aggr: str,
norm: str,
):
super().__init__()

self.encoder = HeteroStypeWiseEncoder(
channels=channels,
node_to_col_names_dict={
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
},
node_to_col_stats=col_stats_dict,
)
self.temporal_encoder = HeteroTemporalEncoder(
node_types=[
node_type for node_type in data.node_types
if "time" in data[node_type]
],
channels=channels,
)
self.gnn = HeteroGraphSAGE(
node_types=data.node_types,
edge_types=data.edge_types,
channels=channels,
aggr=aggr,
num_layers=num_layers,
)
self.head = MLP(
channels,
out_channels=out_channels,
norm=norm,
num_layers=1,
)

self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.reset_parameters()

def reset_parameters(self):
self.encoder.reset_parameters()
self.temporal_encoder.reset_parameters()
self.gnn.reset_parameters()
self.head.reset_parameters()
self.id_awareness_emb.reset_parameters()

def forward(
self,
batch: HeteroData,
entity_table: NodeType,
dst_table: NodeType,
) -> Tensor:
seed_time = batch[entity_table].seed_time
x_dict = self.encoder(batch.tf_dict)
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
batch.edge_index_dict,
)

return self.head(x_dict[dst_table])
38 changes: 38 additions & 0 deletions test/nn/test_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from relbench.datasets.fake import FakeDataset
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.modeling.utils import get_stype_proposal
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.testing.text_embedder import HashTextEmbedder

from hybridgnn.nn.encoder import HeteroStypeWiseEncoder


def test_encoder(tmp_path):
dataset = FakeDataset()

db = dataset.get_db()
data, col_stats_dict = make_pkey_fkey_graph(
db,
get_stype_proposal(db),
text_embedder_cfg=TextEmbedderConfig(text_embedder=HashTextEmbedder(8),
batch_size=None),
cache_dir=tmp_path,
)
node_to_col_names_dict = {
node_type: data[node_type].tf.col_names_dict
for node_type in data.node_types
}

# Ensure that full-batch model works as expected ##########################

encoder = HeteroStypeWiseEncoder(64, node_to_col_names_dict,
col_stats_dict)

x_dict = encoder(data.tf_dict)
assert 'product' in x_dict.keys()
assert 'customer' in x_dict.keys()
assert 'review' in x_dict.keys()
assert 'relations' in x_dict.keys()
assert x_dict['relations'].shape == torch.Size([20, 64])
assert x_dict['product'].shape == torch.Size([30, 64])
Loading

0 comments on commit 7fed12d

Please sign in to comment.