Skip to content

Commit

Permalink
Update multi-task evaluation logic to avoid information leakage issue…
Browse files Browse the repository at this point in the history
… in lp and nfeat reconstruct task evaluation.

Previously, in the eval() function of GSgnnMultiTaskLearningTrainer, both link prediction and
node feature reconstruction tasks use the node embeddings computed with the entire graph.
This will cause test edge leakage for link prediction tasks and target node node feature leakage
for node feature reconstruction tasks. This PR fixes this issue.
  • Loading branch information
Xiang Song committed Jun 11, 2024
1 parent e1d128e commit 7eba2e3
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 35 deletions.
4 changes: 3 additions & 1 deletion python/graphstorm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
run_lp_mini_batch_predict)
from .multitask_gnn import (GSgnnMultiTaskModelInterface,
GSgnnMultiTaskSharedEncoderModel)
from .multitask_gnn import multi_task_mini_batch_predict
from .multitask_gnn import (multi_prediction_task_mini_batch_predict,
multi_nfeat_recon_task_mini_batch_predict,
gen_emb_for_nfeat_reconstruct)
from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer
from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer
from .sage_encoder import SAGEEncoder, SAGEConv
Expand Down
116 changes: 115 additions & 1 deletion python/graphstorm/model/multitask_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .node_gnn import run_node_mini_batch_predict
from .edge_gnn import run_edge_mini_batch_predict
from .lp_gnn import run_lp_mini_batch_predict
from ..utils import is_distributed


class GSgnnMultiTaskModelInterface:
Expand Down Expand Up @@ -380,7 +381,7 @@ def predict(self, task_id, mini_batch, return_proba=False):
else:
raise TypeError(f"Unknow task type {task_type}")

def multi_task_mini_batch_predict(
def multi_prediction_task_mini_batch_predict(
model, emb, loader, device, return_proba=True, return_label=False):
""" conduct mini batch prediction on multiple tasks
Expand Down Expand Up @@ -469,3 +470,116 @@ def multi_task_mini_batch_predict(
raise TypeError(f"Unknown task {task_info}")

return res

def gen_emb_for_nfeat_reconstruct(model, gen_embs):
""" Generate node embeddings for node feature reconstruction.
In theory, we should skip the self-loop of the last GNN layer.
However, there are some exceptions. This function handles
those exceptions.
Parameters
----------
model: GSgnnMultiTaskSharedEncoderModel
Multi-task model
gen_embs: func
The function used to generate node embeddings.
It should accept a bool flag indicating whether
the last GNN layer self-loop should be removed.
Return
------
embs: node embedings
"""
if isinstance(model.gnn_encoder, GSgnnGNNEncoderInterface):
if model.has_sparse_params():
# When there are learnable embeddings, we can not
# just simply skip the last layer self-loop.
# Keep the self-loop and print a warning
# we will use the computed embs directly
logging.warning("When doing %s inference, we need to "
"avoid adding self loop in the last GNN layer "
"to avoid the potential node "
"feature leakage issue. "
"When there are learnable embeddings on "
"nodes, GraphStorm can not automatically"
"skip the last layer self-loop"
"Please set use_self_loop to False",
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
embs = gen_embs(last_self_loop=True)
else:
# skip the selfloop of the last layer to
# avoid information leakage.
embs = gen_embs(last_self_loop=False)
else:
# we will use the computed embs directly
logging.warning("The gnn encoder %s does not support skip "
"the last self-loop operation"
"(skip_last_selfloop). There is a potential "
"node feature leakage risk when doing %s training.",
type(model.gnn_encoder),
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
embs = gen_embs(last_self_loop=True)
return embs

def multi_nfeat_recon_task_mini_batch_predict(
model, embs,
nfeat_recon_val_loaders,
nfeat_recon_test_loaders,
task_infos,
device,
return_label=False):
""" conduct mini batch prediction on node feature
reconstruction tasks
Parameters
----------
model: GSgnnMultiTaskModelInterface, GSgnnModel
Multi-task learning model
embs : dict of Tensor
The GNN embeddings
nfeat_recon_val_loaders: list
List of validation datalaoders
nfeat_recon_test_loaders: list
List of test dataloaders
task_infos: list
List of task info
device: th.device
Device used to compute test scores.
return_label : bool
Whether or not to return labels.
Return
------
dict: Validatoin results
dict: test results
"""
val_results = {}
test_results = {}
for val_loader, test_loader, task_info in \
zip(nfeat_recon_val_loaders, nfeat_recon_test_loaders, task_infos):
decoder = model.task_decoders[task_info.task_id]
if val_loader is None:
val_results[task_info.task_id] = (None, None)
else:
val_preds, val_labels = \
run_node_mini_batch_predict(decoder,
embs,
val_loader,
device=device,
return_proba=False,
return_label=return_label)
val_results[task_info.task_id] = (val_preds, val_labels)

if test_loader is None:
test_results[task_info.task_id] = (None, None)
else:
test_preds, test_labels = \
run_node_mini_batch_predict(decoder,
embs,
test_loader,
device=device,
return_proba=False,
return_label=return_label)
test_results[task_info.task_id] = (test_preds, test_labels)

return val_results, test_results
200 changes: 167 additions & 33 deletions python/graphstorm/trainer/mt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
do_mini_batch_inference,
GSgnnModelBase, GSgnnModel,
GSgnnMultiTaskModelInterface,
multi_task_mini_batch_predict)
multi_prediction_task_mini_batch_predict,
multi_nfeat_recon_task_mini_batch_predict,
gen_emb_for_nfeat_reconstruct)
from ..model.lp_gnn import run_lp_mini_batch_predict
from .gsgnn_trainer import GSgnnTrainer

from ..utils import sys_tracker, rt_profiler, print_mem, get_rank
Expand Down Expand Up @@ -506,42 +509,173 @@ def eval(self, model, data, val_loader, test_loader, total_steps,
sys_tracker.check('before prediction')
model.eval()

if val_loader is None and test_loader is None:
# no need to do validation and test
# do nothing.
return None

val_dataloaders = val_loader.dataloaders \
if val_loader is not None else None
test_dataloaders = test_loader.dataloaders \
if test_loader is not None else None
task_infos = val_loader.task_infos \
if val_loader is not None else test_loader.task_infos

# All the tasks share the same GNN encoder so the fanouts are same
# for different tasks.
fanout = None
for task_fanout in val_loader.fanout:
if task_fanout is not None:
fanout = task_fanout
break
assert fanout is not None, \
"There is no validation dataloader. eval() function should not be called"
if use_mini_batch_infer:
emb = do_mini_batch_inference(model, data,
fanout=fanout,
task_tracker=self.task_tracker)
if val_loader is not None:
for task_fanout in val_loader.fanout:
if task_fanout is not None:
fanout = task_fanout
break
else:
emb = do_full_graph_inference(model, data,
fanout=fanout,
task_tracker=self.task_tracker)
sys_tracker.check('compute embeddings')

val_results = \
multi_task_mini_batch_predict(model,
emb=emb,
loader=val_loader,
device=self.device,
return_proba=return_proba,
return_label=True) \
if val_loader is not None else None

test_results = \
multi_task_mini_batch_predict(model,
emb=emb,
loader=test_loader,
device=self.device,
return_proba=return_proba,
return_label=True) \
if test_loader is not None else None
for task_fanout in test_loader.fanout:
if task_fanout is not None:
fanout = task_fanout
break
assert fanout is not None, \
"There is no validation dataloader.eval() function should not be called"

# Node prediction and edge prediction
# do not have information leakage problem
predict_tasks = []
predict_val_loaders = []
predict_test_loaders = []
# For link prediction tasks, we need to
# exclude valid and test edges during message
# passk
lp_tasks = []
lp_val_loaders = []
lp_test_loaders = []
# For node feature reconstruction tasks,
# we need to avoid self-loop in the last
# GNN layer
nfeat_recon_tasks = []
nfeat_recon_val_loaders = []
nfeat_recon_test_loaders = []

for val_loader, test_loader, task_info \
in zip(val_dataloaders, test_dataloaders, task_infos):
if val_loader is None and test_loader is None:
# For this task, these is no need to do compute test or val score
# skip this task
continue

if task_info.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION]:
predict_tasks.append(task_info)
predict_val_loaders.append(val_loader)
predict_test_loaders.append(test_loader)

if task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
lp_tasks.append(task_info)
lp_val_loaders.append(val_loader)
lp_test_loaders.append(test_loader)

if task_info.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]:
nfeat_recon_tasks.append(task_info)
nfeat_recon_val_loaders.append(val_loader)
nfeat_recon_test_loaders.append(test_loader)

def gen_embs(edge_mask=None):
""" Compute node embeddings
"""
if use_mini_batch_infer:
emb = do_mini_batch_inference(model, data,
fanout=fanout,
edge_mask=edge_mask,
task_tracker=self.task_tracker)
else:
emb = do_full_graph_inference(model, data,
fanout=fanout,
edge_mask=edge_mask,
task_tracker=self.task_tracker)
return emb

embs = None
val_results = None
test_results = None
if len(predict_tasks) > 0:
# do validation and test for prediciton tasks.
sys_tracker.check('compute embeddings')
embs = gen_embs()
val_results = \
multi_prediction_task_mini_batch_predict(
model,
emb=embs,
loader=val_loader,
device=self.device,
return_proba=return_proba,
return_label=True) \
if val_loader is not None else None

test_results = \
multi_prediction_task_mini_batch_predict(
model,
emb=embs,
loader=test_loader,
device=self.device,
return_proba=return_proba,
return_label=True) \
if test_loader is not None else None

if len(lp_tasks) > 0:
for lp_val_loader, lp_test_loader, task_info \
in zip(lp_val_loaders, lp_test_loaders, task_infos):

lp_test_embs = gen_embs(edge_mask=task_info.task_config.train_mask)

decoder = model.task_decoders[task_info.task_id]
val_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_val_loader, self.device) \
if val_loader is not None else None
test_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_test_loader, self.device) \
if val_loader is not None else None
if val_results is not None:
val_results[task_info.task_id] = val_scores
else:
val_results = {task_info.task_id: val_scores}
if test_results is not None:
test_results[task_info.task_id] = test_scores
else:
test_results = {task_info.task_id: test_scores}

if len(nfeat_recon_tasks) > 0:
def nfrecon_gen_embs(model, last_self_loop=False):
""" Generate node embeddings for node feature reconstruction
"""
if last_self_loop is False:
model.gnn_encoder.skip_last_selfloop()
new_embs = gen_embs()
model.gnn_encoder.reset_last_selfloop()
return new_embs
else:
# if lask_self_loop is True
# we can reuse the computed embs if any
return embs if embs is not None else gen_embs()

nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs)

nfeat_recon_val_results, nfeat_recon_test_results = \
multi_nfeat_recon_task_mini_batch_predict(
model,
nfeat_embs,
nfeat_recon_val_loaders,
nfeat_recon_test_loaders,
task_infos,
device=self.device,
return_label=True)

if val_results is not None:
val_results.update(nfeat_recon_val_results)
else:
val_results = nfeat_recon_val_results
if test_results is not None:
test_results.update(nfeat_recon_val_results)
else:
test_results = nfeat_recon_test_results

sys_tracker.check('after_test_score')
val_score, test_score = self.evaluator.evaluate(
Expand Down

0 comments on commit 7eba2e3

Please sign in to comment.