From 7fed12dea108b30736fe7ec14d4da15879d7475e Mon Sep 17 00:00:00 2001 From: yiweny Date: Sat, 20 Jul 2024 21:39:38 +0000 Subject: [PATCH 1/3] add idgnn --- examples/gnn_link.py | 40 ----------- hybridgnn/nn/__init__.py | 5 ++ hybridgnn/nn/encoder.py | 119 +++++++++++++++++++++++++++++++ hybridgnn/nn/models/__init__.py | 7 ++ hybridgnn/nn/models/graphsage.py | 58 +++++++++++++++ hybridgnn/nn/models/idgnn.py | 88 +++++++++++++++++++++++ test/nn/test_encoder.py | 38 ++++++++++ test/nn/test_model.py | 55 ++++++++++++++ 8 files changed, 370 insertions(+), 40 deletions(-) delete mode 100644 examples/gnn_link.py create mode 100644 hybridgnn/nn/encoder.py create mode 100644 hybridgnn/nn/models/__init__.py create mode 100644 hybridgnn/nn/models/graphsage.py create mode 100644 hybridgnn/nn/models/idgnn.py create mode 100644 test/nn/test_encoder.py create mode 100644 test/nn/test_model.py diff --git a/examples/gnn_link.py b/examples/gnn_link.py deleted file mode 100644 index 1cf1ef3..0000000 --- a/examples/gnn_link.py +++ /dev/null @@ -1,40 +0,0 @@ -import argparse -import os - -import torch -from torch_geometric.seed import seed_everything - -from relbench.datasets import get_dataset -from relbench.tasks import get_task - - -parser = argparse.ArgumentParser() -parser.add_argument("--dataset", type=str, default="rel-hm") -parser.add_argument("--task", type=str, default="user-item-purchase") -parser.add_argument("--lr", type=float, default=0.001) -parser.add_argument("--epochs", type=int, default=20) -parser.add_argument("--eval_epochs_interval", type=int, default=1) -parser.add_argument("--batch_size", type=int, default=512) -parser.add_argument("--channels", type=int, default=128) -parser.add_argument("--aggr", type=str, default="sum") -parser.add_argument("--num_layers", type=int, default=2) -parser.add_argument("--num_neighbors", type=int, default=128) -parser.add_argument("--temporal_strategy", type=str, default="last") -parser.add_argument("--max_steps_per_epoch", type=int, default=2000) -parser.add_argument("--num_workers", type=int, default=0) -parser.add_argument("--seed", type=int, default=42) -parser.add_argument( - "--cache_dir", type=str, default=os.path.expanduser("~/.cache/relbench_examples") -) -args = parser.parse_args() - - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -if torch.cuda.is_available(): - torch.set_num_threads(1) -seed_everything(args.seed) - -dataset = get_dataset(args.dataset) -task = get_task(args.dataset, args.task, download=True) -tune_metric = "link_prediction_map" -print(task.task_type) \ No newline at end of file diff --git a/hybridgnn/nn/__init__.py b/hybridgnn/nn/__init__.py index e69de29..21b5c4c 100644 --- a/hybridgnn/nn/__init__.py +++ b/hybridgnn/nn/__init__.py @@ -0,0 +1,5 @@ +from .encoder import HeteroStypeWiseEncoder + +__all__ = classes = [ + 'HeteroStypeWiseEncoder', +] diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py new file mode 100644 index 0000000..9222747 --- /dev/null +++ b/hybridgnn/nn/encoder.py @@ -0,0 +1,119 @@ +from typing import Any, Dict, List + +import torch +import torch_frame +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder +from torch_geometric.nn import PositionalEncoding +from torch_geometric.typing import NodeType + +DEFAULT_STYPE_ENCODER_DICT: Dict[torch_frame.stype, Any] = { + torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}), + torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}), + torch_frame.multicategorical: ( + torch_frame.nn.MultiCategoricalEmbeddingEncoder, + {}, + ), + torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}), + torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), +} + + +class HeteroStypeWiseEncoder(torch.nn.Module): + r"""StypeWiseEncoder based on PyTorch Frame. + + Args: + channels (int): The output channels for each node type. + node_to_col_names_dict (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): # noqa + A dictionary mapping from node type to column names dictionary + compatible to PyTorch Frame. + node_to_col_stats (Dict[NodeType, Dict[str, Dict[StatType, Any]]]): + A dictionary mapping from node type to column statistics dictionary + compatible to PyTorch Frame. + stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]): + A dictionary mapping from :obj:`torch_frame.stype` object into a + tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its + keyword arguments :obj:`kwargs`. + """ + def __init__( + self, + channels: int, + node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype, + List[str]]], + node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]], + stype_encoder_cls_kwargs: Dict[torch_frame.stype, + Any] = DEFAULT_STYPE_ENCODER_DICT, + ): + super().__init__() + + self.encoders = torch.nn.ModuleDict() + + for node_type in node_to_col_names_dict.keys(): + stype_encoder_dict = { + stype: + stype_encoder_cls_kwargs[stype][0]( + **stype_encoder_cls_kwargs[stype][1]) + for stype in node_to_col_names_dict[node_type].keys() + } + + self.encoders[node_type] = StypeWiseFeatureEncoder( + out_channels=channels, + col_stats=node_to_col_stats[node_type], + col_names_dict=node_to_col_names_dict[node_type], + stype_encoder_dict=stype_encoder_dict, + ) + + def reset_parameters(self): + for encoder in self.encoders.values(): + encoder.reset_parameters() + + def forward( + self, + tf_dict: Dict[NodeType, torch_frame.TensorFrame], + ) -> Dict[NodeType, Tensor]: + x_dict = { + node_type: self.encoders[node_type](tf)[0].sum(axis=1) + for node_type, tf in tf_dict.items() + } + return x_dict + + +class HeteroTemporalEncoder(torch.nn.Module): + def __init__(self, node_types: List[NodeType], channels: int): + super().__init__() + + self.encoder_dict = torch.nn.ModuleDict({ + node_type: + PositionalEncoding(channels) + for node_type in node_types + }) + self.lin_dict = torch.nn.ModuleDict({ + node_type: + torch.nn.Linear(channels, channels) + for node_type in node_types + }) + + def reset_parameters(self): + for encoder in self.encoder_dict.values(): + encoder.reset_parameters() + for lin in self.lin_dict.values(): + lin.reset_parameters() + + def forward( + self, + seed_time: Tensor, + time_dict: Dict[NodeType, Tensor], + batch_dict: Dict[NodeType, Tensor], + ) -> Dict[NodeType, Tensor]: + out_dict: Dict[NodeType, Tensor] = {} + + for node_type, time in time_dict.items(): + rel_time = seed_time[batch_dict[node_type]] - time + rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days. + + x = self.encoder_dict[node_type](rel_time) + x = self.lin_dict[node_type](x) + out_dict[node_type] = x + + return out_dict diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py new file mode 100644 index 0000000..bdb0fe7 --- /dev/null +++ b/hybridgnn/nn/models/__init__.py @@ -0,0 +1,7 @@ +from .graphsage import HeteroGraphSAGE +from .idgnn import IDGNN + +__all__ = classes = [ + 'HeteroGraphSAGE', + 'IDGNN', +] diff --git a/hybridgnn/nn/models/graphsage.py b/hybridgnn/nn/models/graphsage.py new file mode 100644 index 0000000..a17b66d --- /dev/null +++ b/hybridgnn/nn/models/graphsage.py @@ -0,0 +1,58 @@ +from typing import Dict, List, Optional + +import torch +from torch import Tensor +from torch_geometric.nn import HeteroConv, LayerNorm, SAGEConv +from torch_geometric.typing import EdgeType, NodeType + + +class HeteroGraphSAGE(torch.nn.Module): + def __init__( + self, + node_types: List[NodeType], + edge_types: List[EdgeType], + channels: int, + aggr: str = "mean", + num_layers: int = 2, + ): + super().__init__() + + self.convs = torch.nn.ModuleList() + for _ in range(num_layers): + conv = HeteroConv( + { + edge_type: SAGEConv( + (channels, channels), channels, aggr=aggr) + for edge_type in edge_types + }, + aggr="sum", + ) + self.convs.append(conv) + + self.norms = torch.nn.ModuleList() + for _ in range(num_layers): + norm_dict = torch.nn.ModuleDict() + for node_type in node_types: + norm_dict[node_type] = LayerNorm(channels, mode="node") + self.norms.append(norm_dict) + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for norm_dict in self.norms: + for norm in norm_dict.values(): + norm.reset_parameters() + + def forward( + self, + x_dict: Dict[NodeType, Tensor], + edge_index_dict: Dict[NodeType, Tensor], + num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None, + num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None, + ) -> Dict[NodeType, Tensor]: + for i, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)): + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} + x_dict = {key: x.relu() for key, x in x_dict.items()} + + return x_dict diff --git a/hybridgnn/nn/models/idgnn.py b/hybridgnn/nn/models/idgnn.py new file mode 100644 index 0000000..fe8c952 --- /dev/null +++ b/hybridgnn/nn/models/idgnn.py @@ -0,0 +1,88 @@ +from typing import Any, Dict + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +from hybridgnn.nn.encoder import HeteroStypeWiseEncoder, HeteroTemporalEncoder +from hybridgnn.nn.models import HeteroGraphSAGE + + +class IDGNN(torch.nn.Module): + def __init__( + self, + data: HeteroData, + col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], + num_layers: int, + channels: int, + out_channels: int, + aggr: str, + norm: str, + ): + super().__init__() + + self.encoder = HeteroStypeWiseEncoder( + channels=channels, + node_to_col_names_dict={ + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + }, + node_to_col_stats=col_stats_dict, + ) + self.temporal_encoder = HeteroTemporalEncoder( + node_types=[ + node_type for node_type in data.node_types + if "time" in data[node_type] + ], + channels=channels, + ) + self.gnn = HeteroGraphSAGE( + node_types=data.node_types, + edge_types=data.edge_types, + channels=channels, + aggr=aggr, + num_layers=num_layers, + ) + self.head = MLP( + channels, + out_channels=out_channels, + norm=norm, + num_layers=1, + ) + + self.id_awareness_emb = torch.nn.Embedding(1, channels) + self.reset_parameters() + + def reset_parameters(self): + self.encoder.reset_parameters() + self.temporal_encoder.reset_parameters() + self.gnn.reset_parameters() + self.head.reset_parameters() + self.id_awareness_emb.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + dst_table: NodeType, + ) -> Tensor: + seed_time = batch[entity_table].seed_time + x_dict = self.encoder(batch.tf_dict) + # Add ID-awareness to the root node + x_dict[entity_table][:seed_time.size(0 + )] += self.id_awareness_emb.weight + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + ) + + return self.head(x_dict[dst_table]) diff --git a/test/nn/test_encoder.py b/test/nn/test_encoder.py new file mode 100644 index 0000000..a9ff847 --- /dev/null +++ b/test/nn/test_encoder.py @@ -0,0 +1,38 @@ +import torch +from relbench.datasets.fake import FakeDataset +from relbench.modeling.graph import make_pkey_fkey_graph +from relbench.modeling.utils import get_stype_proposal +from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_frame.testing.text_embedder import HashTextEmbedder + +from hybridgnn.nn.encoder import HeteroStypeWiseEncoder + + +def test_encoder(tmp_path): + dataset = FakeDataset() + + db = dataset.get_db() + data, col_stats_dict = make_pkey_fkey_graph( + db, + get_stype_proposal(db), + text_embedder_cfg=TextEmbedderConfig(text_embedder=HashTextEmbedder(8), + batch_size=None), + cache_dir=tmp_path, + ) + node_to_col_names_dict = { + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + } + + # Ensure that full-batch model works as expected ########################## + + encoder = HeteroStypeWiseEncoder(64, node_to_col_names_dict, + col_stats_dict) + + x_dict = encoder(data.tf_dict) + assert 'product' in x_dict.keys() + assert 'customer' in x_dict.keys() + assert 'review' in x_dict.keys() + assert 'relations' in x_dict.keys() + assert x_dict['relations'].shape == torch.Size([20, 64]) + assert x_dict['product'].shape == torch.Size([30, 64]) diff --git a/test/nn/test_model.py b/test/nn/test_model.py new file mode 100644 index 0000000..a566a91 --- /dev/null +++ b/test/nn/test_model.py @@ -0,0 +1,55 @@ +from relbench.base.task_base import TaskType +from relbench.datasets.fake import FakeDataset +from relbench.modeling.graph import ( + get_link_train_table_input, + make_pkey_fkey_graph, +) +from relbench.modeling.utils import get_stype_proposal +from relbench.tasks.amazon import UserItemPurchaseTask +from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_frame.testing.text_embedder import HashTextEmbedder +from torch_geometric.loader import NeighborLoader + +from hybridgnn.nn.models import IDGNN + + +def test_idgnn(tmp_path): + dataset = FakeDataset() + + db = dataset.get_db() + data, col_stats_dict = make_pkey_fkey_graph( + db, + get_stype_proposal(db), + text_embedder_cfg=TextEmbedderConfig(text_embedder=HashTextEmbedder(8), + batch_size=None), + cache_dir=tmp_path, + ) + task = UserItemPurchaseTask(dataset) + assert task.task_type == TaskType.LINK_PREDICTION + + train_table = task.get_table("train") + + train_table = task.get_table("train") + train_table_input = get_link_train_table_input(train_table, task) + batch_size = 16 + train_loader = NeighborLoader( + data, + num_neighbors=[128, 128], + time_attr="time", + input_nodes=train_table_input.src_nodes, + input_time=train_table_input.src_time, + subgraph_type="bidirectional", + batch_size=batch_size, + temporal_strategy='last', + shuffle=True, + ) + + batch = next(iter(train_loader)) + + assert len(batch[task.dst_entity_table].batch) > 0 + model = IDGNN(data=data, col_stats_dict=col_stats_dict, num_layers=2, + channels=64, out_channels=1, aggr="sum", norm="layer_norm") + model.train() + + out = model(batch, task.src_entity_table, task.dst_entity_table).flatten() + assert len(out) == len(batch[task.dst_entity_table].n_id) From a01f3bd47230a93454151e4ffdf342b4227485ec Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 22 Jul 2024 20:37:30 +0000 Subject: [PATCH 2/3] nit --- hybridgnn/nn/encoder.py | 10 +++++----- hybridgnn/nn/models/graphsage.py | 4 ++-- hybridgnn/nn/models/idgnn.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py index 9222747..69ca67d 100644 --- a/hybridgnn/nn/encoder.py +++ b/hybridgnn/nn/encoder.py @@ -25,7 +25,7 @@ class HeteroStypeWiseEncoder(torch.nn.Module): Args: channels (int): The output channels for each node type. - node_to_col_names_dict (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): # noqa + node_to_col_names_dict (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): # noqa: E501 A dictionary mapping from node type to column names dictionary compatible to PyTorch Frame. node_to_col_stats (Dict[NodeType, Dict[str, Dict[StatType, Any]]]): @@ -44,7 +44,7 @@ def __init__( node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]], stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any] = DEFAULT_STYPE_ENCODER_DICT, - ): + ) -> None: super().__init__() self.encoders = torch.nn.ModuleDict() @@ -64,7 +64,7 @@ def __init__( stype_encoder_dict=stype_encoder_dict, ) - def reset_parameters(self): + def reset_parameters(self) -> None: for encoder in self.encoders.values(): encoder.reset_parameters() @@ -80,7 +80,7 @@ def forward( class HeteroTemporalEncoder(torch.nn.Module): - def __init__(self, node_types: List[NodeType], channels: int): + def __init__(self, node_types: List[NodeType], channels: int) -> None: super().__init__() self.encoder_dict = torch.nn.ModuleDict({ @@ -94,7 +94,7 @@ def __init__(self, node_types: List[NodeType], channels: int): for node_type in node_types }) - def reset_parameters(self): + def reset_parameters(self) -> None: for encoder in self.encoder_dict.values(): encoder.reset_parameters() for lin in self.lin_dict.values(): diff --git a/hybridgnn/nn/models/graphsage.py b/hybridgnn/nn/models/graphsage.py index a17b66d..c1bee59 100644 --- a/hybridgnn/nn/models/graphsage.py +++ b/hybridgnn/nn/models/graphsage.py @@ -14,7 +14,7 @@ def __init__( channels: int, aggr: str = "mean", num_layers: int = 2, - ): + ) -> None: super().__init__() self.convs = torch.nn.ModuleList() @@ -36,7 +36,7 @@ def __init__( norm_dict[node_type] = LayerNorm(channels, mode="node") self.norms.append(norm_dict) - def reset_parameters(self): + def reset_parameters(self) -> None: for conv in self.convs: conv.reset_parameters() for norm_dict in self.norms: diff --git a/hybridgnn/nn/models/idgnn.py b/hybridgnn/nn/models/idgnn.py index fe8c952..50be2fd 100644 --- a/hybridgnn/nn/models/idgnn.py +++ b/hybridgnn/nn/models/idgnn.py @@ -21,7 +21,7 @@ def __init__( out_channels: int, aggr: str, norm: str, - ): + ) -> None: super().__init__() self.encoder = HeteroStypeWiseEncoder( @@ -56,7 +56,7 @@ def __init__( self.id_awareness_emb = torch.nn.Embedding(1, channels) self.reset_parameters() - def reset_parameters(self): + def reset_parameters(self) -> None: self.encoder.reset_parameters() self.temporal_encoder.reset_parameters() self.gnn.reset_parameters() From a2c6a27e752d3144b70a6c3f2df35a76bb472af9 Mon Sep 17 00:00:00 2001 From: yiweny Date: Mon, 22 Jul 2024 23:02:35 +0000 Subject: [PATCH 3/3] fix code based on review comments --- hybridgnn/nn/__init__.py | 4 ++-- hybridgnn/nn/encoder.py | 31 +++++++++++++++++++++++-------- hybridgnn/nn/models/idgnn.py | 9 +++++++-- test/nn/test_encoder.py | 7 ++++--- test/nn/test_model.py | 2 -- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/hybridgnn/nn/__init__.py b/hybridgnn/nn/__init__.py index 21b5c4c..e07b8ed 100644 --- a/hybridgnn/nn/__init__.py +++ b/hybridgnn/nn/__init__.py @@ -1,5 +1,5 @@ -from .encoder import HeteroStypeWiseEncoder +from .encoder import HeteroEncoder __all__ = classes = [ - 'HeteroStypeWiseEncoder', + 'HeteroEncoder', ] diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py index 69ca67d..022b30a 100644 --- a/hybridgnn/nn/encoder.py +++ b/hybridgnn/nn/encoder.py @@ -4,7 +4,7 @@ import torch_frame from torch import Tensor from torch_frame.data.stats import StatType -from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder +from torch_frame.nn.models import ResNet from torch_geometric.nn import PositionalEncoding from torch_geometric.typing import NodeType @@ -18,10 +18,12 @@ torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}), torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), } +SECONDS_IN_A_DAY = 60 * 60 * 24 -class HeteroStypeWiseEncoder(torch.nn.Module): - r"""StypeWiseEncoder based on PyTorch Frame. +class HeteroEncoder(torch.nn.Module): + r"""HeteroStypeWiseEncoder is a simple encoder to encode multi-modal + data from different node types. Args: channels (int): The output channels for each node type. @@ -35,6 +37,14 @@ class HeteroStypeWiseEncoder(torch.nn.Module): A dictionary mapping from :obj:`torch_frame.stype` object into a tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its keyword arguments :obj:`kwargs`. + torch_frame_model_cls: Model class for PyTorch Frame. The class object + takes :class:`TensorFrame` object as input and outputs + :obj:`channels`-dimensional embeddings. Default to + :class:`torch_frame.nn.ResNet`. + torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for + :class:`torch_frame_model_cls` class. Default keyword argument is + set specific for :class:`torch_frame.nn.ResNet`. Expect it to + be changed for different :class:`torch_frame_model_cls`. """ def __init__( self, @@ -42,8 +52,12 @@ def __init__( node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype, List[str]]], node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]], - stype_encoder_cls_kwargs: Dict[torch_frame.stype, - Any] = DEFAULT_STYPE_ENCODER_DICT, + stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any], + torch_frame_model_cls=ResNet, + torch_frame_model_kwargs: Dict[str, Any] = { + "channels": 128, + "num_layers": 4, + }, ) -> None: super().__init__() @@ -57,7 +71,8 @@ def __init__( for stype in node_to_col_names_dict[node_type].keys() } - self.encoders[node_type] = StypeWiseFeatureEncoder( + self.encoders[node_type] = torch_frame_model_cls( + **torch_frame_model_kwargs, out_channels=channels, col_stats=node_to_col_stats[node_type], col_names_dict=node_to_col_names_dict[node_type], @@ -73,7 +88,7 @@ def forward( tf_dict: Dict[NodeType, torch_frame.TensorFrame], ) -> Dict[NodeType, Tensor]: x_dict = { - node_type: self.encoders[node_type](tf)[0].sum(axis=1) + node_type: self.encoders[node_type](tf) for node_type, tf in tf_dict.items() } return x_dict @@ -110,7 +125,7 @@ def forward( for node_type, time in time_dict.items(): rel_time = seed_time[batch_dict[node_type]] - time - rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days. + rel_time = rel_time / SECONDS_IN_A_DAY x = self.encoder_dict[node_type](rel_time) x = self.lin_dict[node_type](x) diff --git a/hybridgnn/nn/models/idgnn.py b/hybridgnn/nn/models/idgnn.py index 50be2fd..98aced6 100644 --- a/hybridgnn/nn/models/idgnn.py +++ b/hybridgnn/nn/models/idgnn.py @@ -7,7 +7,11 @@ from torch_geometric.nn import MLP from torch_geometric.typing import NodeType -from hybridgnn.nn.encoder import HeteroStypeWiseEncoder, HeteroTemporalEncoder +from hybridgnn.nn.encoder import ( + DEFAULT_STYPE_ENCODER_DICT, + HeteroEncoder, + HeteroTemporalEncoder, +) from hybridgnn.nn.models import HeteroGraphSAGE @@ -24,13 +28,14 @@ def __init__( ) -> None: super().__init__() - self.encoder = HeteroStypeWiseEncoder( + self.encoder = HeteroEncoder( channels=channels, node_to_col_names_dict={ node_type: data[node_type].tf.col_names_dict for node_type in data.node_types }, node_to_col_stats=col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, ) self.temporal_encoder = HeteroTemporalEncoder( node_types=[ diff --git a/test/nn/test_encoder.py b/test/nn/test_encoder.py index a9ff847..28609d8 100644 --- a/test/nn/test_encoder.py +++ b/test/nn/test_encoder.py @@ -5,7 +5,7 @@ from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.testing.text_embedder import HashTextEmbedder -from hybridgnn.nn.encoder import HeteroStypeWiseEncoder +from hybridgnn.nn.encoder import DEFAULT_STYPE_ENCODER_DICT, HeteroEncoder def test_encoder(tmp_path): @@ -26,8 +26,9 @@ def test_encoder(tmp_path): # Ensure that full-batch model works as expected ########################## - encoder = HeteroStypeWiseEncoder(64, node_to_col_names_dict, - col_stats_dict) + encoder = HeteroEncoder( + 64, node_to_col_names_dict, col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT) x_dict = encoder(data.tf_dict) assert 'product' in x_dict.keys() diff --git a/test/nn/test_model.py b/test/nn/test_model.py index a566a91..a2a82cb 100644 --- a/test/nn/test_model.py +++ b/test/nn/test_model.py @@ -27,8 +27,6 @@ def test_idgnn(tmp_path): task = UserItemPurchaseTask(dataset) assert task.task_type == TaskType.LINK_PREDICTION - train_table = task.get_table("train") - train_table = task.get_table("train") train_table_input = get_link_train_table_input(train_table, task) batch_size = 16