From 7eba2e30b3372277a4c206bed8fdfc852e3ae59f Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 10 Jun 2024 23:17:54 -0700 Subject: [PATCH] Update multi-task evaluation logic to avoid information leakage issue in lp and nfeat reconstruct task evaluation. Previously, in the eval() function of GSgnnMultiTaskLearningTrainer, both link prediction and node feature reconstruction tasks use the node embeddings computed with the entire graph. This will cause test edge leakage for link prediction tasks and target node node feature leakage for node feature reconstruction tasks. This PR fixes this issue. --- python/graphstorm/model/__init__.py | 4 +- python/graphstorm/model/multitask_gnn.py | 116 ++++++++++++- python/graphstorm/trainer/mt_trainer.py | 200 +++++++++++++++++++---- 3 files changed, 285 insertions(+), 35 deletions(-) diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 34cce011ce..bfc2c1a7d3 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -37,7 +37,9 @@ run_lp_mini_batch_predict) from .multitask_gnn import (GSgnnMultiTaskModelInterface, GSgnnMultiTaskSharedEncoderModel) -from .multitask_gnn import multi_task_mini_batch_predict +from .multitask_gnn import (multi_prediction_task_mini_batch_predict, + multi_nfeat_recon_task_mini_batch_predict, + gen_emb_for_nfeat_reconstruct) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 58e28064b6..9ab425dead 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -32,6 +32,7 @@ from .node_gnn import run_node_mini_batch_predict from .edge_gnn import run_edge_mini_batch_predict from .lp_gnn import run_lp_mini_batch_predict +from ..utils import is_distributed class GSgnnMultiTaskModelInterface: @@ -380,7 +381,7 @@ def predict(self, task_id, mini_batch, return_proba=False): else: raise TypeError(f"Unknow task type {task_type}") -def multi_task_mini_batch_predict( +def multi_prediction_task_mini_batch_predict( model, emb, loader, device, return_proba=True, return_label=False): """ conduct mini batch prediction on multiple tasks @@ -469,3 +470,116 @@ def multi_task_mini_batch_predict( raise TypeError(f"Unknown task {task_info}") return res + +def gen_emb_for_nfeat_reconstruct(model, gen_embs): + """ Generate node embeddings for node feature reconstruction. + In theory, we should skip the self-loop of the last GNN layer. + However, there are some exceptions. This function handles + those exceptions. + + Parameters + ---------- + model: GSgnnMultiTaskSharedEncoderModel + Multi-task model + gen_embs: func + The function used to generate node embeddings. + It should accept a bool flag indicating whether + the last GNN layer self-loop should be removed. + + Return + ------ + embs: node embedings + """ + if isinstance(model.gnn_encoder, GSgnnGNNEncoderInterface): + if model.has_sparse_params(): + # When there are learnable embeddings, we can not + # just simply skip the last layer self-loop. + # Keep the self-loop and print a warning + # we will use the computed embs directly + logging.warning("When doing %s inference, we need to " + "avoid adding self loop in the last GNN layer " + "to avoid the potential node " + "feature leakage issue. " + "When there are learnable embeddings on " + "nodes, GraphStorm can not automatically" + "skip the last layer self-loop" + "Please set use_self_loop to False", + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) + embs = gen_embs(last_self_loop=True) + else: + # skip the selfloop of the last layer to + # avoid information leakage. + embs = gen_embs(last_self_loop=False) + else: + # we will use the computed embs directly + logging.warning("The gnn encoder %s does not support skip " + "the last self-loop operation" + "(skip_last_selfloop). There is a potential " + "node feature leakage risk when doing %s training.", + type(model.gnn_encoder), + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) + embs = gen_embs(last_self_loop=True) + return embs + +def multi_nfeat_recon_task_mini_batch_predict( + model, embs, + nfeat_recon_val_loaders, + nfeat_recon_test_loaders, + task_infos, + device, + return_label=False): + """ conduct mini batch prediction on node feature + reconstruction tasks + + Parameters + ---------- + model: GSgnnMultiTaskModelInterface, GSgnnModel + Multi-task learning model + embs : dict of Tensor + The GNN embeddings + nfeat_recon_val_loaders: list + List of validation datalaoders + nfeat_recon_test_loaders: list + List of test dataloaders + task_infos: list + List of task info + device: th.device + Device used to compute test scores. + return_label : bool + Whether or not to return labels. + + Return + ------ + dict: Validatoin results + dict: test results + """ + val_results = {} + test_results = {} + for val_loader, test_loader, task_info in \ + zip(nfeat_recon_val_loaders, nfeat_recon_test_loaders, task_infos): + decoder = model.task_decoders[task_info.task_id] + if val_loader is None: + val_results[task_info.task_id] = (None, None) + else: + val_preds, val_labels = \ + run_node_mini_batch_predict(decoder, + embs, + val_loader, + device=device, + return_proba=False, + return_label=return_label) + val_results[task_info.task_id] = (val_preds, val_labels) + + if test_loader is None: + test_results[task_info.task_id] = (None, None) + else: + test_preds, test_labels = \ + run_node_mini_batch_predict(decoder, + embs, + test_loader, + device=device, + return_proba=False, + return_label=return_label) + test_results[task_info.task_id] = (test_preds, test_labels) + + return val_results, test_results diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index e6244fd38e..9950cc348e 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -33,7 +33,10 @@ do_mini_batch_inference, GSgnnModelBase, GSgnnModel, GSgnnMultiTaskModelInterface, - multi_task_mini_batch_predict) + multi_prediction_task_mini_batch_predict, + multi_nfeat_recon_task_mini_batch_predict, + gen_emb_for_nfeat_reconstruct) +from ..model.lp_gnn import run_lp_mini_batch_predict from .gsgnn_trainer import GSgnnTrainer from ..utils import sys_tracker, rt_profiler, print_mem, get_rank @@ -506,42 +509,173 @@ def eval(self, model, data, val_loader, test_loader, total_steps, sys_tracker.check('before prediction') model.eval() + if val_loader is None and test_loader is None: + # no need to do validation and test + # do nothing. + return None + + val_dataloaders = val_loader.dataloaders \ + if val_loader is not None else None + test_dataloaders = test_loader.dataloaders \ + if test_loader is not None else None + task_infos = val_loader.task_infos \ + if val_loader is not None else test_loader.task_infos + # All the tasks share the same GNN encoder so the fanouts are same # for different tasks. fanout = None - for task_fanout in val_loader.fanout: - if task_fanout is not None: - fanout = task_fanout - break - assert fanout is not None, \ - "There is no validation dataloader. eval() function should not be called" - if use_mini_batch_infer: - emb = do_mini_batch_inference(model, data, - fanout=fanout, - task_tracker=self.task_tracker) + if val_loader is not None: + for task_fanout in val_loader.fanout: + if task_fanout is not None: + fanout = task_fanout + break else: - emb = do_full_graph_inference(model, data, - fanout=fanout, - task_tracker=self.task_tracker) - sys_tracker.check('compute embeddings') - - val_results = \ - multi_task_mini_batch_predict(model, - emb=emb, - loader=val_loader, - device=self.device, - return_proba=return_proba, - return_label=True) \ - if val_loader is not None else None - - test_results = \ - multi_task_mini_batch_predict(model, - emb=emb, - loader=test_loader, - device=self.device, - return_proba=return_proba, - return_label=True) \ - if test_loader is not None else None + for task_fanout in test_loader.fanout: + if task_fanout is not None: + fanout = task_fanout + break + assert fanout is not None, \ + "There is no validation dataloader.eval() function should not be called" + + # Node prediction and edge prediction + # do not have information leakage problem + predict_tasks = [] + predict_val_loaders = [] + predict_test_loaders = [] + # For link prediction tasks, we need to + # exclude valid and test edges during message + # passk + lp_tasks = [] + lp_val_loaders = [] + lp_test_loaders = [] + # For node feature reconstruction tasks, + # we need to avoid self-loop in the last + # GNN layer + nfeat_recon_tasks = [] + nfeat_recon_val_loaders = [] + nfeat_recon_test_loaders = [] + + for val_loader, test_loader, task_info \ + in zip(val_dataloaders, test_dataloaders, task_infos): + if val_loader is None and test_loader is None: + # For this task, these is no need to do compute test or val score + # skip this task + continue + + if task_info.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION]: + predict_tasks.append(task_info) + predict_val_loaders.append(val_loader) + predict_test_loaders.append(test_loader) + + if task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + lp_tasks.append(task_info) + lp_val_loaders.append(val_loader) + lp_test_loaders.append(test_loader) + + if task_info.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: + nfeat_recon_tasks.append(task_info) + nfeat_recon_val_loaders.append(val_loader) + nfeat_recon_test_loaders.append(test_loader) + + def gen_embs(edge_mask=None): + """ Compute node embeddings + """ + if use_mini_batch_infer: + emb = do_mini_batch_inference(model, data, + fanout=fanout, + edge_mask=edge_mask, + task_tracker=self.task_tracker) + else: + emb = do_full_graph_inference(model, data, + fanout=fanout, + edge_mask=edge_mask, + task_tracker=self.task_tracker) + return emb + + embs = None + val_results = None + test_results = None + if len(predict_tasks) > 0: + # do validation and test for prediciton tasks. + sys_tracker.check('compute embeddings') + embs = gen_embs() + val_results = \ + multi_prediction_task_mini_batch_predict( + model, + emb=embs, + loader=val_loader, + device=self.device, + return_proba=return_proba, + return_label=True) \ + if val_loader is not None else None + + test_results = \ + multi_prediction_task_mini_batch_predict( + model, + emb=embs, + loader=test_loader, + device=self.device, + return_proba=return_proba, + return_label=True) \ + if test_loader is not None else None + + if len(lp_tasks) > 0: + for lp_val_loader, lp_test_loader, task_info \ + in zip(lp_val_loaders, lp_test_loaders, task_infos): + + lp_test_embs = gen_embs(edge_mask=task_info.task_config.train_mask) + + decoder = model.task_decoders[task_info.task_id] + val_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_val_loader, self.device) \ + if val_loader is not None else None + test_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_test_loader, self.device) \ + if val_loader is not None else None + if val_results is not None: + val_results[task_info.task_id] = val_scores + else: + val_results = {task_info.task_id: val_scores} + if test_results is not None: + test_results[task_info.task_id] = test_scores + else: + test_results = {task_info.task_id: test_scores} + + if len(nfeat_recon_tasks) > 0: + def nfrecon_gen_embs(model, last_self_loop=False): + """ Generate node embeddings for node feature reconstruction + """ + if last_self_loop is False: + model.gnn_encoder.skip_last_selfloop() + new_embs = gen_embs() + model.gnn_encoder.reset_last_selfloop() + return new_embs + else: + # if lask_self_loop is True + # we can reuse the computed embs if any + return embs if embs is not None else gen_embs() + + nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs) + + nfeat_recon_val_results, nfeat_recon_test_results = \ + multi_nfeat_recon_task_mini_batch_predict( + model, + nfeat_embs, + nfeat_recon_val_loaders, + nfeat_recon_test_loaders, + task_infos, + device=self.device, + return_label=True) + + if val_results is not None: + val_results.update(nfeat_recon_val_results) + else: + val_results = nfeat_recon_val_results + if test_results is not None: + test_results.update(nfeat_recon_val_results) + else: + test_results = nfeat_recon_test_results sys_tracker.check('after_test_score') val_score, test_score = self.evaluator.evaluate(