diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index bfc2c1a7d3..0edf31b599 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -37,8 +37,7 @@ run_lp_mini_batch_predict) from .multitask_gnn import (GSgnnMultiTaskModelInterface, GSgnnMultiTaskSharedEncoderModel) -from .multitask_gnn import (multi_prediction_task_mini_batch_predict, - multi_nfeat_recon_task_mini_batch_predict, +from .multitask_gnn import (multi_task_mini_batch_predict, gen_emb_for_nfeat_reconstruct) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 9ab425dead..fb62c7e1d7 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -381,8 +381,8 @@ def predict(self, task_id, mini_batch, return_proba=False): else: raise TypeError(f"Unknow task type {task_type}") -def multi_prediction_task_mini_batch_predict( - model, emb, loader, device, return_proba=True, return_label=False): +def multi_task_mini_batch_predict( + model, emb, dataloaders, task_infos, device, return_proba=True, return_label=False): """ conduct mini batch prediction on multiple tasks Parameters @@ -393,6 +393,8 @@ def multi_prediction_task_mini_batch_predict( The GNN embeddings loader: GSgnnMultiTaskDataLoader The mini-batch dataloader. + task_infos: list + List of task info device: th.device Device used to compute test scores. return_proba: bool @@ -404,8 +406,6 @@ def multi_prediction_task_mini_batch_predict( ------- dict: prediction results of each task """ - dataloaders = loader.dataloaders - task_infos = loader.task_infos task_decoders = model.task_decoders res = {} with th.no_grad(): @@ -458,16 +458,8 @@ def multi_prediction_task_mini_batch_predict( etype = list(preds.keys())[0] res[task_info.task_id] = (preds[etype], labels[etype] \ if labels is not None else None) - elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: - if dataloader is None: - # In cases when there is no validation or test set. - res[task_info.task_id] = None - else: - decoder = task_decoders[task_info.task_id] - ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device) - res[task_info.task_id] = ranking else: - raise TypeError(f"Unknown task {task_info}") + raise TypeError(f"Unsupported task {task_info}") return res @@ -520,66 +512,3 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs): BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) embs = gen_embs(last_self_loop=True) return embs - -def multi_nfeat_recon_task_mini_batch_predict( - model, embs, - nfeat_recon_val_loaders, - nfeat_recon_test_loaders, - task_infos, - device, - return_label=False): - """ conduct mini batch prediction on node feature - reconstruction tasks - - Parameters - ---------- - model: GSgnnMultiTaskModelInterface, GSgnnModel - Multi-task learning model - embs : dict of Tensor - The GNN embeddings - nfeat_recon_val_loaders: list - List of validation datalaoders - nfeat_recon_test_loaders: list - List of test dataloaders - task_infos: list - List of task info - device: th.device - Device used to compute test scores. - return_label : bool - Whether or not to return labels. - - Return - ------ - dict: Validatoin results - dict: test results - """ - val_results = {} - test_results = {} - for val_loader, test_loader, task_info in \ - zip(nfeat_recon_val_loaders, nfeat_recon_test_loaders, task_infos): - decoder = model.task_decoders[task_info.task_id] - if val_loader is None: - val_results[task_info.task_id] = (None, None) - else: - val_preds, val_labels = \ - run_node_mini_batch_predict(decoder, - embs, - val_loader, - device=device, - return_proba=False, - return_label=return_label) - val_results[task_info.task_id] = (val_preds, val_labels) - - if test_loader is None: - test_results[task_info.task_id] = (None, None) - else: - test_preds, test_labels = \ - run_node_mini_batch_predict(decoder, - embs, - test_loader, - device=device, - return_proba=False, - return_label=return_label) - test_results[task_info.task_id] = (test_preds, test_labels) - - return val_results, test_results diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 9950cc348e..af609297ca 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -33,8 +33,7 @@ do_mini_batch_inference, GSgnnModelBase, GSgnnModel, GSgnnMultiTaskModelInterface, - multi_prediction_task_mini_batch_predict, - multi_nfeat_recon_task_mini_batch_predict, + multi_task_mini_batch_predict, gen_emb_for_nfeat_reconstruct) from ..model.lp_gnn import run_lp_mini_batch_predict from .gsgnn_trainer import GSgnnTrainer @@ -603,24 +602,26 @@ def gen_embs(edge_mask=None): sys_tracker.check('compute embeddings') embs = gen_embs() val_results = \ - multi_prediction_task_mini_batch_predict( + multi_task_mini_batch_predict( model, emb=embs, - loader=val_loader, + loader=predict_val_loaders, + task_infos=predict_tasks, device=self.device, return_proba=return_proba, return_label=True) \ - if val_loader is not None else None + if len(predict_val_loaders) > 0 else None test_results = \ - multi_prediction_task_mini_batch_predict( + multi_task_mini_batch_predict( model, emb=embs, - loader=test_loader, + loader=predict_test_loaders, + task_infos=predict_tasks, device=self.device, return_proba=return_proba, return_label=True) \ - if test_loader is not None else None + if len(predict_test_loaders) > 0 else None if len(lp_tasks) > 0: for lp_val_loader, lp_test_loader, task_info \ @@ -658,24 +659,40 @@ def nfrecon_gen_embs(model, last_self_loop=False): nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs) - nfeat_recon_val_results, nfeat_recon_test_results = \ - multi_nfeat_recon_task_mini_batch_predict( + nfeat_recon_val_results = \ + multi_task_mini_batch_predict( model, - nfeat_embs, - nfeat_recon_val_loaders, - nfeat_recon_test_loaders, - task_infos, + emb=nfeat_embs, + loader=nfeat_recon_val_loaders, + task_infos=predict_tasks, device=self.device, - return_label=True) + return_proba=return_proba, + return_label=True) \ + if len(nfeat_recon_val_loaders) > 0 else None - if val_results is not None: - val_results.update(nfeat_recon_val_results) - else: + nfeat_recon_test_results = \ + multi_task_mini_batch_predict( + model, + emb=nfeat_embs, + loader=nfeat_recon_test_loaders, + task_infos=predict_tasks, + device=self.device, + return_proba=return_proba, + return_label=True) \ + if len(nfeat_recon_test_loaders) > 0 else None + + if val_results is None: val_results = nfeat_recon_val_results - if test_results is not None: - test_results.update(nfeat_recon_val_results) else: + if nfeat_recon_val_results is not None: + val_results.update(nfeat_recon_val_results) + + if test_results is None: test_results = nfeat_recon_test_results + else: + if nfeat_recon_test_results is not None: + test_results.update(nfeat_recon_test_results) + sys_tracker.check('after_test_score') val_score, test_score = self.evaluator.evaluate(