Skip to content

Commit

Permalink
Update multi-task evaluation logic to avoid information leakage issue…
Browse files Browse the repository at this point in the history
… in lp and nfeat reconstruct task evaluation. (#871)

*Issue #, if available:*
#789 

*Description of changes:*
Previously, in the eval() function of GSgnnMultiTaskLearningTrainer,
both link prediction and node feature reconstruction tasks use the node
embeddings computed with the entire graph. This will cause test edge
leakage for link prediction tasks and target node node feature leakage
for node feature reconstruction tasks. This PR fixes this issue.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Jun 13, 2024
1 parent 71bc404 commit 0af2213
Show file tree
Hide file tree
Showing 8 changed files with 655 additions and 103 deletions.
3 changes: 2 additions & 1 deletion python/graphstorm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
run_lp_mini_batch_predict)
from .multitask_gnn import (GSgnnMultiTaskModelInterface,
GSgnnMultiTaskSharedEncoderModel)
from .multitask_gnn import multi_task_mini_batch_predict
from .multitask_gnn import (multi_task_mini_batch_predict,
gen_emb_for_nfeat_reconstruct)
from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer
from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer
from .sage_encoder import SAGEEncoder, SAGEConv
Expand Down
69 changes: 61 additions & 8 deletions python/graphstorm/model/multitask_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,22 @@ def predict(self, task_id, mini_batch, return_proba=False):
raise TypeError(f"Unknow task type {task_type}")

def multi_task_mini_batch_predict(
model, emb, loader, device, return_proba=True, return_label=False):
""" conduct mini batch prediction on multiple tasks
model, emb, dataloaders, task_infos, device, return_proba=True, return_label=False):
""" conduct mini batch prediction on multiple tasks.
The task infos are passed in as task_infos.
The task dataloaders are passed in as dataloaders.
Parameters
----------
model: GSgnnMultiTaskModelInterface, GSgnnModel
Multi-task learning model
emb : dict of Tensor
The GNN embeddings
loader: GSgnnMultiTaskDataLoader
The mini-batch dataloader.
dataloaders: list
List of val or test dataloaders.
task_infos: list
List of task info
device: th.device
Device used to compute test scores.
return_proba: bool
Expand All @@ -403,8 +408,6 @@ def multi_task_mini_batch_predict(
-------
dict: prediction results of each task
"""
dataloaders = loader.dataloaders
task_infos = loader.task_infos
task_decoders = model.task_decoders
res = {}
with th.no_grad():
Expand Down Expand Up @@ -457,7 +460,7 @@ def multi_task_mini_batch_predict(
etype = list(preds.keys())[0]
res[task_info.task_id] = (preds[etype], labels[etype] \
if labels is not None else None)
elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION:
elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
if dataloader is None:
# In cases when there is no validation or test set.
res[task_info.task_id] = None
Expand All @@ -466,6 +469,56 @@ def multi_task_mini_batch_predict(
ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device)
res[task_info.task_id] = ranking
else:
raise TypeError(f"Unknown task {task_info}")
raise TypeError(f"Unsupported task {task_info}")

return res

def gen_emb_for_nfeat_reconstruct(model, gen_embs):
""" Generate node embeddings for node feature reconstruction.
In theory, we should skip the self-loop of the last GNN layer.
However, there are some exceptions. This function handles
those exceptions.
Parameters
----------
model: GSgnnMultiTaskSharedEncoderModel
Multi-task model
gen_embs: func
The function used to generate node embeddings.
It should accept a bool flag indicating whether
the last GNN layer self-loop should be removed.
Return
------
embs: node embeddings
"""
if isinstance(model.gnn_encoder, GSgnnGNNEncoderInterface):
if model.has_sparse_params():
# When there are learnable embeddings, we can not
# just simply skip the last layer self-loop.
# Keep the self-loop and print a warning
# we will use the computed embs directly
logging.warning("When doing %s inference, we need to "
"avoid adding self loop in the last GNN layer "
"to avoid the potential node "
"feature leakage issue. "
"When there are learnable embeddings on "
"nodes, GraphStorm can not automatically"
"skip the last layer self-loop"
"Please set use_self_loop to False",
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
embs = gen_embs(skip_last_self_loop=False)
else:
# skip the selfloop of the last layer to
# avoid information leakage.
embs = gen_embs(skip_last_self_loop=True)
else:
# we will use the computed embs directly
logging.warning("The gnn encoder %s does not support skip "
"the last self-loop operation"
"(skip_last_selfloop). There is a potential "
"node feature leakage risk when doing %s training.",
type(model.gnn_encoder),
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
embs = gen_embs(skip_last_self_loop=False)
return embs
8 changes: 6 additions & 2 deletions python/graphstorm/run/gsgnn_mt/gsgnn_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,12 @@ def main(config_args):
logging.warning("The training data do not have validation set.")
if test_loader is None:
logging.warning("The training data do not have test set.")
task_evaluators[task.task_id] = \
create_evaluator(task)

if val_loader is None and test_loader is None:
logging.warning("Task %s does not have validation and test sets.", task.task_id)
else:
task_evaluators[task.task_id] = \
create_evaluator(task)

train_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, train_dataloaders)
val_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, val_dataloaders)
Expand Down
Loading

0 comments on commit 0af2213

Please sign in to comment.