diff --git a/python/graphstorm/inference/mt_infer.py b/python/graphstorm/inference/mt_infer.py index 5cb1597808..301a1374c7 100644 --- a/python/graphstorm/inference/mt_infer.py +++ b/python/graphstorm/inference/mt_infer.py @@ -182,18 +182,26 @@ def gen_embs(edge_mask=None): task_infos = recon_nfeat_test_loader.task_infos with th.no_grad(): - def nfrecon_gen_embs(last_self_loop=False): + def nfrecon_gen_embs(skip_last_self_loop=False): """ Generate node embeddings for node feature reconstruction """ - if last_self_loop is False: - self._model.gnn_encoder.skip_last_selfloop() + if skip_last_self_loop is True: + # Turn off the last layer GNN's self-loop + # to compute node embeddings. + model.gnn_encoder.skip_last_selfloop() new_embs = gen_embs() - self._model.gnn_encoder.reset_last_selfloop() + model.gnn_encoder.reset_last_selfloop() return new_embs else: - # if lask_self_loop is True - # we can reuse the computed embs if any - return embs if embs is not None else gen_embs() + # If skip_last_self_loop is False + # we will not change the way we compute + # node embeddings. + if embs is not None: + # The embeddings have been computed + # when handling predict_tasks in L608 + return embs + else: + return gen_embs() nfeat_embs = gen_emb_for_nfeat_reconstruct(self._model, nfrecon_gen_embs) diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index e891dd2be0..5b6b9f7b5c 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -490,7 +490,7 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs): Return ------ - embs: node embedings + embs: node embeddings """ if isinstance(model.gnn_encoder, GSgnnGNNEncoderInterface): if model.has_sparse_params(): @@ -507,11 +507,11 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs): "skip the last layer self-loop" "Please set use_self_loop to False", BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) - embs = gen_embs(last_self_loop=True) + embs = gen_embs(skip_last_self_loop=False) else: # skip the selfloop of the last layer to # avoid information leakage. - embs = gen_embs(last_self_loop=False) + embs = gen_embs(skip_last_self_loop=True) else: # we will use the computed embs directly logging.warning("The gnn encoder %s does not support skip " @@ -520,5 +520,5 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs): "node feature leakage risk when doing %s training.", type(model.gnn_encoder), BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) - embs = gen_embs(last_self_loop=True) + embs = gen_embs(skip_last_self_loop=False) return embs diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 7353d87bb8..ef97392cf0 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -429,7 +429,7 @@ def fit(self, train_loader, # TODO(xiangsx): Add early stop support # Every n iterations, save the model and keep - # the lask k models. + # the last k models. # TODO(xiangsx): support saving the best top k model. if save_model_frequency > 0 and \ total_steps % save_model_frequency == 0 and \ @@ -489,9 +489,9 @@ def eval(self, model, data, mt_val_loader, mt_test_loader, total_steps, The GNN model. data : GSgnnData The training dataset - mt_val_loader: GSNodeDataLoader + mt_val_loader: GSgnnMultiTaskDataLoader The dataloader for validation data - mt_test_loader : GSNodeDataLoader + mt_test_loader : GSgnnMultiTaskDataLoader The dataloader for test data. total_steps: int Total number of iterations. @@ -657,18 +657,26 @@ def gen_embs(edge_mask=None): test_results = {task_info.task_id: test_scores} if len(nfeat_recon_tasks) > 0: - def nfrecon_gen_embs(last_self_loop=False): + def nfrecon_gen_embs(skip_last_self_loop=False): """ Generate node embeddings for node feature reconstruction """ - if last_self_loop is False: + if skip_last_self_loop is True: + # Turn off the last layer GNN's self-loop + # to compute node embeddings. model.gnn_encoder.skip_last_selfloop() new_embs = gen_embs() model.gnn_encoder.reset_last_selfloop() return new_embs else: - # if lask_self_loop is True - # we can reuse the computed embs if any - return embs if embs is not None else gen_embs() + # If skip_last_self_loop is False + # we will not change the way we compute + # node embeddings. + if embs is not None: + # The embeddings have been computed + # when handling predict_tasks in L608 + return embs + else: + return gen_embs() nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 7277cd3908..27015796c0 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -2301,21 +2301,20 @@ def check_forward(mock_run_lp_mini_batch_predict, def test_gen_emb_for_nfeat_recon(): encoder_model = DummyGSgnnEncoderModel() - model = DummyGSgnnModel(encoder_model, has_sparse=True) - call_self_loop = True - def check_call_gen_embs(last_self_loop): - assert last_self_loop == call_self_loop + def check_call_gen_embs(skip_last_self_loop): + assert skip_last_self_loop == skip_self_loop + model = DummyGSgnnModel(encoder_model, has_sparse=True) + skip_self_loop = False gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs) - call_self_loop = False + skip_self_loop = True model = DummyGSgnnModel(encoder_model, has_sparse=False) gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs) model = DummyGSgnnModel(None) - call_self_loop = True - def check_call_gen_embs(last_self_loop): - assert last_self_loop == call_self_loop + skip_self_loop = False + gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs) if __name__ == '__main__': diff --git a/tests/unit-tests/test_trainer.py b/tests/unit-tests/test_trainer.py index f66a83f586..dbe8770096 100644 --- a/tests/unit-tests/test_trainer.py +++ b/tests/unit-tests/test_trainer.py @@ -460,41 +460,41 @@ def task_tracker(self): return "dummy tracker" def test_mtask_eval(): - tast_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, task_id='nc_task', task_config=None) nc_dataloader = DummyGSgnnNodeDataLoader() - tast_info_nr = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION, + task_info_nr = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION, task_id='nr_task', task_config=None) nr_dataloader = DummyGSgnnNodeDataLoader() - tast_info_ec = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION, + task_info_ec = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION, task_id='ec_task', task_config=None) ec_dataloader = DummyGSgnnEdgeDataLoader() - tast_info_er = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, + task_info_er = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, task_id='er_task', task_config=None) er_dataloader = DummyGSgnnEdgeDataLoader() task_config = GSConfig.__new__(GSConfig) setattr(task_config, "train_mask", "train_mask") - tast_info_lp = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, + task_info_lp = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, task_id='lp_task', task_config=task_config) encoder_model = DummyGSgnnEncoderModel() - model = DummyGSgnnMTModel(encoder_model, decoders={tast_info_lp.task_id: "dummy"}, has_sparse=True) + model = DummyGSgnnMTModel(encoder_model, decoders={task_info_lp.task_id: "dummy"}, has_sparse=True) mt_trainer = GSgnnMultiTaskLearningTrainer(model) mt_trainer._device = 'cpu' lp_dataloader = DummyGSgnnLinkPredictionDataLoader() - tast_info_nfr = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, + task_info_nfr = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, task_id='nfr_task', task_config=None) nfr_dataloader = DummyGSgnnNodeDataLoader() - task_infos = [tast_info_nc, tast_info_nr, tast_info_ec, - tast_info_er, tast_info_lp, tast_info_nfr] + task_infos = [task_info_nc, task_info_nr, task_info_ec, + task_info_er, task_info_lp, task_info_nfr] data = None res = mt_trainer.eval(model, data, None, None, 100) @@ -595,7 +595,6 @@ def check_eval(mock_do_mini_batch_inference, } evaluator = MTaskCheckerEvaluator(target_res, target_res, 100) mt_trainer.setup_evaluator(evaluator) - # test when val_loader is None mt_trainer.eval(model, data, val_loader, test_loader, 100) # lp tasks are empty @@ -614,7 +613,6 @@ def check_eval(mock_do_mini_batch_inference, } evaluator = MTaskCheckerEvaluator(target_res, target_res, 200) mt_trainer.setup_evaluator(evaluator) - # test when val_loader is None mt_trainer.eval(model, data, val_loader, test_loader, 200) # node feature reconstruct tasks are empty @@ -633,7 +631,6 @@ def check_eval(mock_do_mini_batch_inference, } evaluator = MTaskCheckerEvaluator(target_res, target_res, 200) mt_trainer.setup_evaluator(evaluator) - # test when val_loader is None mt_trainer.eval(model, data, val_loader, test_loader, 200) check_eval()