diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 34cce011ce..bfc2c1a7d3 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -37,7 +37,9 @@ run_lp_mini_batch_predict) from .multitask_gnn import (GSgnnMultiTaskModelInterface, GSgnnMultiTaskSharedEncoderModel) -from .multitask_gnn import multi_task_mini_batch_predict +from .multitask_gnn import (multi_prediction_task_mini_batch_predict, + multi_nfeat_recon_task_mini_batch_predict, + gen_emb_for_nfeat_reconstruct) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 58e28064b6..9ab425dead 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -32,6 +32,7 @@ from .node_gnn import run_node_mini_batch_predict from .edge_gnn import run_edge_mini_batch_predict from .lp_gnn import run_lp_mini_batch_predict +from ..utils import is_distributed class GSgnnMultiTaskModelInterface: @@ -380,7 +381,7 @@ def predict(self, task_id, mini_batch, return_proba=False): else: raise TypeError(f"Unknow task type {task_type}") -def multi_task_mini_batch_predict( +def multi_prediction_task_mini_batch_predict( model, emb, loader, device, return_proba=True, return_label=False): """ conduct mini batch prediction on multiple tasks @@ -469,3 +470,116 @@ def multi_task_mini_batch_predict( raise TypeError(f"Unknown task {task_info}") return res + +def gen_emb_for_nfeat_reconstruct(model, gen_embs): + """ Generate node embeddings for node feature reconstruction. + In theory, we should skip the self-loop of the last GNN layer. + However, there are some exceptions. This function handles + those exceptions. + + Parameters + ---------- + model: GSgnnMultiTaskSharedEncoderModel + Multi-task model + gen_embs: func + The function used to generate node embeddings. + It should accept a bool flag indicating whether + the last GNN layer self-loop should be removed. + + Return + ------ + embs: node embedings + """ + if isinstance(model.gnn_encoder, GSgnnGNNEncoderInterface): + if model.has_sparse_params(): + # When there are learnable embeddings, we can not + # just simply skip the last layer self-loop. + # Keep the self-loop and print a warning + # we will use the computed embs directly + logging.warning("When doing %s inference, we need to " + "avoid adding self loop in the last GNN layer " + "to avoid the potential node " + "feature leakage issue. " + "When there are learnable embeddings on " + "nodes, GraphStorm can not automatically" + "skip the last layer self-loop" + "Please set use_self_loop to False", + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) + embs = gen_embs(last_self_loop=True) + else: + # skip the selfloop of the last layer to + # avoid information leakage. + embs = gen_embs(last_self_loop=False) + else: + # we will use the computed embs directly + logging.warning("The gnn encoder %s does not support skip " + "the last self-loop operation" + "(skip_last_selfloop). There is a potential " + "node feature leakage risk when doing %s training.", + type(model.gnn_encoder), + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) + embs = gen_embs(last_self_loop=True) + return embs + +def multi_nfeat_recon_task_mini_batch_predict( + model, embs, + nfeat_recon_val_loaders, + nfeat_recon_test_loaders, + task_infos, + device, + return_label=False): + """ conduct mini batch prediction on node feature + reconstruction tasks + + Parameters + ---------- + model: GSgnnMultiTaskModelInterface, GSgnnModel + Multi-task learning model + embs : dict of Tensor + The GNN embeddings + nfeat_recon_val_loaders: list + List of validation datalaoders + nfeat_recon_test_loaders: list + List of test dataloaders + task_infos: list + List of task info + device: th.device + Device used to compute test scores. + return_label : bool + Whether or not to return labels. + + Return + ------ + dict: Validatoin results + dict: test results + """ + val_results = {} + test_results = {} + for val_loader, test_loader, task_info in \ + zip(nfeat_recon_val_loaders, nfeat_recon_test_loaders, task_infos): + decoder = model.task_decoders[task_info.task_id] + if val_loader is None: + val_results[task_info.task_id] = (None, None) + else: + val_preds, val_labels = \ + run_node_mini_batch_predict(decoder, + embs, + val_loader, + device=device, + return_proba=False, + return_label=return_label) + val_results[task_info.task_id] = (val_preds, val_labels) + + if test_loader is None: + test_results[task_info.task_id] = (None, None) + else: + test_preds, test_labels = \ + run_node_mini_batch_predict(decoder, + embs, + test_loader, + device=device, + return_proba=False, + return_label=return_label) + test_results[task_info.task_id] = (test_preds, test_labels) + + return val_results, test_results diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index e6244fd38e..9950cc348e 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -33,7 +33,10 @@ do_mini_batch_inference, GSgnnModelBase, GSgnnModel, GSgnnMultiTaskModelInterface, - multi_task_mini_batch_predict) + multi_prediction_task_mini_batch_predict, + multi_nfeat_recon_task_mini_batch_predict, + gen_emb_for_nfeat_reconstruct) +from ..model.lp_gnn import run_lp_mini_batch_predict from .gsgnn_trainer import GSgnnTrainer from ..utils import sys_tracker, rt_profiler, print_mem, get_rank @@ -506,42 +509,173 @@ def eval(self, model, data, val_loader, test_loader, total_steps, sys_tracker.check('before prediction') model.eval() + if val_loader is None and test_loader is None: + # no need to do validation and test + # do nothing. + return None + + val_dataloaders = val_loader.dataloaders \ + if val_loader is not None else None + test_dataloaders = test_loader.dataloaders \ + if test_loader is not None else None + task_infos = val_loader.task_infos \ + if val_loader is not None else test_loader.task_infos + # All the tasks share the same GNN encoder so the fanouts are same # for different tasks. fanout = None - for task_fanout in val_loader.fanout: - if task_fanout is not None: - fanout = task_fanout - break - assert fanout is not None, \ - "There is no validation dataloader. eval() function should not be called" - if use_mini_batch_infer: - emb = do_mini_batch_inference(model, data, - fanout=fanout, - task_tracker=self.task_tracker) + if val_loader is not None: + for task_fanout in val_loader.fanout: + if task_fanout is not None: + fanout = task_fanout + break else: - emb = do_full_graph_inference(model, data, - fanout=fanout, - task_tracker=self.task_tracker) - sys_tracker.check('compute embeddings') - - val_results = \ - multi_task_mini_batch_predict(model, - emb=emb, - loader=val_loader, - device=self.device, - return_proba=return_proba, - return_label=True) \ - if val_loader is not None else None - - test_results = \ - multi_task_mini_batch_predict(model, - emb=emb, - loader=test_loader, - device=self.device, - return_proba=return_proba, - return_label=True) \ - if test_loader is not None else None + for task_fanout in test_loader.fanout: + if task_fanout is not None: + fanout = task_fanout + break + assert fanout is not None, \ + "There is no validation dataloader.eval() function should not be called" + + # Node prediction and edge prediction + # do not have information leakage problem + predict_tasks = [] + predict_val_loaders = [] + predict_test_loaders = [] + # For link prediction tasks, we need to + # exclude valid and test edges during message + # passk + lp_tasks = [] + lp_val_loaders = [] + lp_test_loaders = [] + # For node feature reconstruction tasks, + # we need to avoid self-loop in the last + # GNN layer + nfeat_recon_tasks = [] + nfeat_recon_val_loaders = [] + nfeat_recon_test_loaders = [] + + for val_loader, test_loader, task_info \ + in zip(val_dataloaders, test_dataloaders, task_infos): + if val_loader is None and test_loader is None: + # For this task, these is no need to do compute test or val score + # skip this task + continue + + if task_info.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION]: + predict_tasks.append(task_info) + predict_val_loaders.append(val_loader) + predict_test_loaders.append(test_loader) + + if task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + lp_tasks.append(task_info) + lp_val_loaders.append(val_loader) + lp_test_loaders.append(test_loader) + + if task_info.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: + nfeat_recon_tasks.append(task_info) + nfeat_recon_val_loaders.append(val_loader) + nfeat_recon_test_loaders.append(test_loader) + + def gen_embs(edge_mask=None): + """ Compute node embeddings + """ + if use_mini_batch_infer: + emb = do_mini_batch_inference(model, data, + fanout=fanout, + edge_mask=edge_mask, + task_tracker=self.task_tracker) + else: + emb = do_full_graph_inference(model, data, + fanout=fanout, + edge_mask=edge_mask, + task_tracker=self.task_tracker) + return emb + + embs = None + val_results = None + test_results = None + if len(predict_tasks) > 0: + # do validation and test for prediciton tasks. + sys_tracker.check('compute embeddings') + embs = gen_embs() + val_results = \ + multi_prediction_task_mini_batch_predict( + model, + emb=embs, + loader=val_loader, + device=self.device, + return_proba=return_proba, + return_label=True) \ + if val_loader is not None else None + + test_results = \ + multi_prediction_task_mini_batch_predict( + model, + emb=embs, + loader=test_loader, + device=self.device, + return_proba=return_proba, + return_label=True) \ + if test_loader is not None else None + + if len(lp_tasks) > 0: + for lp_val_loader, lp_test_loader, task_info \ + in zip(lp_val_loaders, lp_test_loaders, task_infos): + + lp_test_embs = gen_embs(edge_mask=task_info.task_config.train_mask) + + decoder = model.task_decoders[task_info.task_id] + val_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_val_loader, self.device) \ + if val_loader is not None else None + test_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_test_loader, self.device) \ + if val_loader is not None else None + if val_results is not None: + val_results[task_info.task_id] = val_scores + else: + val_results = {task_info.task_id: val_scores} + if test_results is not None: + test_results[task_info.task_id] = test_scores + else: + test_results = {task_info.task_id: test_scores} + + if len(nfeat_recon_tasks) > 0: + def nfrecon_gen_embs(model, last_self_loop=False): + """ Generate node embeddings for node feature reconstruction + """ + if last_self_loop is False: + model.gnn_encoder.skip_last_selfloop() + new_embs = gen_embs() + model.gnn_encoder.reset_last_selfloop() + return new_embs + else: + # if lask_self_loop is True + # we can reuse the computed embs if any + return embs if embs is not None else gen_embs() + + nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs) + + nfeat_recon_val_results, nfeat_recon_test_results = \ + multi_nfeat_recon_task_mini_batch_predict( + model, + nfeat_embs, + nfeat_recon_val_loaders, + nfeat_recon_test_loaders, + task_infos, + device=self.device, + return_label=True) + + if val_results is not None: + val_results.update(nfeat_recon_val_results) + else: + val_results = nfeat_recon_val_results + if test_results is not None: + test_results.update(nfeat_recon_val_results) + else: + test_results = nfeat_recon_test_results sys_tracker.check('after_test_score') val_score, test_score = self.evaluator.evaluate(