Skip to content

Commit

Permalink
Add node reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Jun 8, 2024
1 parent c79c7ce commit c62c3e3
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 52 deletions.
6 changes: 6 additions & 0 deletions inference_scripts/mt_infer/ml_nc_ec_er_lp_only_infer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,9 @@ gsf:
reverse_edge_types_map:
- user,rating,rating-rev,movie
batch_size: 128 # will overwrite the global batch_size
- reconstruct_node_feat:
reconstruct_nfeat_name: "title"
target_ntype: "movie"
batch_size: 128
eval_metric:
- "mse"
12 changes: 11 additions & 1 deletion inference_scripts/mt_infer/ml_nc_ec_er_lp_with_mask_infer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,14 @@ gsf:
mask_fields:
- "train_mask_field_l"
- null # empty means there is no validation mask
- "test_mask_field_l"
- "test_mask_field_l"
- reconstruct_node_feat:
reconstruct_nfeat_name: "title"
target_ntype: "movie"
batch_size: 128
mask_fields:
- "train_mask_c0" # node classification mask 0
- "val_mask_c0"
- "test_mask_c0"
eval_metric:
- "mse"
112 changes: 91 additions & 21 deletions python/graphstorm/inference/mt_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
"""
import os
import time
import logging
import torch as th

from ..config import (BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION)
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
from .graphstorm_infer import GSInferrer
from ..model.utils import save_full_node_embeddings as save_gsgnn_embeddings
from ..model.utils import (save_node_prediction_results,
Expand All @@ -32,7 +33,9 @@
from ..model.utils import NodeIDShuffler
from ..model import do_full_graph_inference, do_mini_batch_inference
from ..model.multitask_gnn import multi_task_mini_batch_predict
from ..model.node_gnn import run_node_mini_batch_predict
from ..model.lp_gnn import run_lp_mini_batch_predict
from ..model.gnn_encoder_base import GSgnnGNNEncoderInterface

from ..model.edge_decoder import LinkPredictDistMultDecoder

Expand All @@ -51,8 +54,8 @@ class GSgnnMultiTaskLearningInferer(GSInferrer):
"""

# pylint: disable=unused-argument
def infer(self, data, mt_loader,
lp_test_loader=None,
def infer(self, data,
mt_test_loader,
save_embed_path=None,
save_prediction_path=None,
use_mini_batch_infer=False,
Expand All @@ -73,8 +76,16 @@ def infer(self, data, mt_loader,
----------
data: GSgnnData
Graph data.
mt_loader: GSgnnMultiTaskDataLoader
The mini-batch sampler for inference.
mt_test_loader: tuple of GSgnnMultiTaskDataLoaders
A tuple of mini-batch samplers for inference.
In format of (test_dataloader, lp_test_dataloader,
recon_nfeat_test_dataloader). The second dataloader
contains test dataloaders for link predicction tasks.
The third dataloader contains test dataloaders for
node feature reconstruction tasks. When evaluating
these tasks, different message passing strategies
will be applied. The first dataloader contains
all other dataloaders.
save_embed_path: str
The path to save the node embeddings.
save_prediction_path: str
Expand All @@ -98,19 +109,26 @@ def infer(self, data, mt_loader,
do_eval = self.evaluator is not None
sys_tracker.check('start inferencing')
self._model.eval()
mt_loader, lp_test_loader, recon_nfeat_test_loader = mt_test_loader

fanout = None
for task_fanout in mt_loader.fanout:
if task_fanout is not None:
fanout = task_fanout
break
if use_mini_batch_infer:
embs = do_mini_batch_inference(self._model, data, batch_size=infer_batch_size,
fanout=fanout,
task_tracker=self.task_tracker)
else:
embs = do_full_graph_inference(self._model, data, fanout=fanout,
task_tracker=self.task_tracker)

def gen_embs():
# Generate node embeddings.
if use_mini_batch_infer:
embs = do_mini_batch_inference(
self._model, data, batch_size=infer_batch_size,
fanout=fanout, task_tracker=self.task_tracker)
else:
embs = do_full_graph_inference(
self._model, data, fanout=fanout,
task_tracker=self.task_tracker)
return embs
embs = gen_embs()
sys_tracker.check('compute embeddings')
device = self.device

Expand All @@ -123,28 +141,79 @@ def infer(self, data, mt_loader,
return_label=do_eval)

if lp_test_loader is not None:
# We also need to add test metrics for link prediction tasks
# We also need to compute test scores for link prediction tasks.
dataloaders = lp_test_loader.dataloaders
task_infos = lp_test_loader.task_infos

with th.no_grad():
for dataloader, task_info in zip(dataloaders, task_infos):
print(task_info.task_id)
if dataloader is None:
pre_results[task_info.task_id] = None

if use_mini_batch_infer:
lp_test_embs = do_mini_batch_inference(self._model, data, batch_size=infer_batch_size,
fanout=fanout,
edge_mask=task_info.task_config.train_mask,
task_tracker=self.task_tracker)
lp_test_embs = do_mini_batch_inference(
self._model, data, batch_size=infer_batch_size,
fanout=fanout,
edge_mask=task_info.task_config.train_mask,
task_tracker=self.task_tracker)
else:
lp_test_embs = do_full_graph_inference(self._model, data, fanout=fanout,
edge_mask=task_info.task_config.train_mask,
task_tracker=self.task_tracker)
lp_test_embs = do_full_graph_inference(
self._model, data, fanout=fanout,
edge_mask=task_info.task_config.train_mask,
task_tracker=self.task_tracker)
decoder = self._model.task_decoders[task_info.task_id]
ranking = run_lp_mini_batch_predict(decoder, lp_test_embs, dataloader, device)
pre_results[task_info.task_id] = ranking
if recon_nfeat_test_loader is not None:
# We also need to compute test scores for node feature reconstruction tasks.
dataloaders = lp_test_loader.dataloaders
task_infos = lp_test_loader.task_infos

with th.no_grad():
for dataloader, task_info in zip(dataloaders, task_infos):
if dataloader is None:
pre_results[task_info.task_id] = (None, None)

if isinstance(self.gnn_encoder, GSgnnGNNEncoderInterface):
if self.has_sparse_params():
# When there are learnable embeddings, we can not
# just simply skip the last layer self-loop.
# Keep the self-loop and print a warning
# we will use the computed embs directly
logging.warning("When doing %s inference, we need to "
"avoid adding self loop in the last GNN layer "
"to avoid the potential node "
"feature leakage issue. "
"When there are learnable embeddings on "
"nodes, GraphStorm can not automatically"
"skip the last layer self-loop"
"Please set use_self_loop to False",
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
else:
# skip the selfloop of the last layer to
# avoid information leakage.
self._model.gnn_encoder.skip_last_selfloop()
embs = gen_embs()
self._model.gnn_encoder.reset_last_selfloop()
else:
# we will use the computed embs directly
logging.warning("The gnn encoder %s does not support skip "
"the last self-loop operation"
"(skip_last_selfloop). There is a potential "
"node feature leakage risk when doing %s training.",
type(self._model.gnn_encoder),
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
decoder = self._model.task_decoders[task_info.task_id]
preds, labels = \
run_node_mini_batch_predict(decoder,
embs,
dataloader,
device=device,
return_proba=return_proba,
return_label=do_eval)
ntype = list(preds.keys())[0]
pre_results[task_info.task_id] = (preds[ntype], labels[ntype] \
if labels is not None else None)

if do_eval:
test_start = time.time()
Expand Down Expand Up @@ -229,6 +298,7 @@ def infer(self, data, mt_loader,

else:
# There is no prediction results for link prediction
# and feature reconstruction
continue

# save relation embedding if any
Expand Down
21 changes: 5 additions & 16 deletions python/graphstorm/run/gsgnn_mt/gsgnn_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,13 @@ def create_task_val_dataloader(task, config, train_data):
return None
# All tasks share the same input encoder, so the node feats must be same.
node_feats = config.node_feat_name
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]:
eval_ntype = task_config.eval_target_ntype \
if task_config.eval_target_ntype is not None \
else task_config.target_ntype
val_idxs = train_data.get_node_val_set(eval_ntype, mask=task_config.val_mask)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(val_idxs) > 0:
# TODO(xiangsx): Support construct feat
return GSgnnNodeDataLoader(train_data,
Expand All @@ -170,8 +170,6 @@ def create_task_val_dataloader(task, config, train_data):
label_field=task_config.label_field)
elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]:
val_idxs = train_data.get_edge_val_set(task_config.target_etype, mask=task_config.val_mask)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(val_idxs) > 0:
# TODO(xiangsx): Support construct feat
return GSgnnEdgeDataLoader(train_data,
Expand All @@ -187,8 +185,6 @@ def create_task_val_dataloader(task, config, train_data):
elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
val_idxs = train_data.get_edge_val_set(task_config.eval_etype, mask=task_config.val_mask)
dataloader_cls = gs.get_builtin_lp_eval_dataloader_class(task_config)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(val_idxs) > 0:
# TODO(xiangsx): Support construct feat
if task_config.eval_etypes_negative_dstnode is not None:
Expand All @@ -212,8 +208,6 @@ def create_task_val_dataloader(task, config, train_data):
if task_config.eval_target_ntype is not None \
else task_config.target_ntype
val_idxs = train_data.get_node_val_set(eval_ntype, mask=task_config.val_mask)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(val_idxs) > 0:
# TODO(xiangsx): Support construct feat
return GSgnnNodeDataLoader(train_data,
Expand Down Expand Up @@ -248,13 +242,14 @@ def create_task_test_dataloader(task, config, train_data):
return None
# All tasks share the same input encoder, so the node feats must be same.
node_feats = config.node_feat_name
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []

if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]:
eval_ntype = task_config.eval_target_ntype \
if task_config.eval_target_ntype is not None \
else task_config.target_ntype
test_idxs = train_data.get_node_test_set(eval_ntype, mask=task_config.test_mask)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(test_idxs) > 0:
# TODO(xiangsx): Support construct feat
return GSgnnNodeDataLoader(train_data,
Expand All @@ -269,8 +264,6 @@ def create_task_test_dataloader(task, config, train_data):
test_idxs = train_data.get_edge_test_set(
task_config.target_etype,
mask=task_config.test_mask)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(test_idxs) > 0:
# TODO(xiangsx): Support construct feat
return GSgnnEdgeDataLoader(train_data,
Expand All @@ -286,8 +279,6 @@ def create_task_test_dataloader(task, config, train_data):
elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
test_idxs = train_data.get_edge_test_set(task_config.eval_etype, mask=task_config.val_mask)
dataloader_cls = gs.get_builtin_lp_eval_dataloader_class(task_config)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(test_idxs) > 0:
# TODO(xiangsx): Support construct feat
if task_config.eval_etypes_negative_dstnode is not None:
Expand All @@ -311,8 +302,6 @@ def create_task_test_dataloader(task, config, train_data):
if task_config.eval_target_ntype is not None \
else task_config.target_ntype
test_idxs = train_data.get_node_test_set(eval_ntype, mask=task_config.test_mask)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(test_idxs) > 0:
# TODO(xiangsx): Support construct feat
return GSgnnNodeDataLoader(train_data,
Expand Down
Loading

0 comments on commit c62c3e3

Please sign in to comment.