From 47f7e88ee87712357deb4d2fbd847af6870b0bad Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 12 Jun 2024 09:19:03 -0700 Subject: [PATCH 1/3] Fix lint --- python/graphstorm/trainer/mt_trainer.py | 36 ++++++++++++++----------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index a398cf356f..afd03bd55a 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -479,7 +479,7 @@ def fit(self, train_loader, self.get_best_model_path() if save_model_path is not None else None} self.log_params(output) - def eval(self, model, data, val_loader, test_loader, total_steps, +def eval(self, model, data, mt_val_loader, mt_test_loader, total_steps, use_mini_batch_infer=False, return_proba=True): """ do the model evaluation using validation and test sets @@ -489,9 +489,9 @@ def eval(self, model, data, val_loader, test_loader, total_steps, The GNN model. data : GSgnnData The training dataset - val_loader: GSNodeDataLoader + mt_val_loader: GSNodeDataLoader The dataloader for validation data - test_loader : GSNodeDataLoader + mt_test_loader : GSNodeDataLoader The dataloader for test data. total_steps: int Total number of iterations. @@ -508,17 +508,17 @@ def eval(self, model, data, val_loader, test_loader, total_steps, sys_tracker.check('before prediction') model.eval() - if val_loader is None and test_loader is None: + if mt_val_loader is None and mt_test_loader is None: # no need to do validation and test # do nothing. return None - val_dataloaders = val_loader.dataloaders \ - if val_loader is not None else None - test_dataloaders = test_loader.dataloaders \ - if test_loader is not None else None - task_infos = val_loader.task_infos \ - if val_loader is not None else test_loader.task_infos + val_dataloaders = mt_val_loader.dataloaders \ + if mt_val_loader is not None else None + test_dataloaders = mt_test_loader.dataloaders \ + if mt_test_loader is not None else None + task_infos = mt_val_loader.task_infos \ + if mt_val_loader is not None else mt_test_loader.task_infos if val_dataloaders is None: val_dataloaders = [None] * len(task_infos) if test_dataloaders is None: @@ -527,13 +527,13 @@ def eval(self, model, data, val_loader, test_loader, total_steps, # All the tasks share the same GNN encoder so the fanouts are same # for different tasks. fanout = None - if val_loader is not None: - for task_fanout in val_loader.fanout: + if mt_val_loader is not None: + for task_fanout in mt_val_loader.fanout: if task_fanout is not None: fanout = task_fanout break else: - for task_fanout in test_loader.fanout: + for task_fanout in mt_test_loader.fanout: if task_fanout is not None: fanout = task_fanout break @@ -636,9 +636,15 @@ def gen_embs(edge_mask=None): lp_test_embs = gen_embs(edge_mask=task_info.task_config.train_mask) decoder = model.task_decoders[task_info.task_id] - val_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_val_loader, self.device) \ + val_scores = run_lp_mini_batch_predict(decoder, + lp_test_embs, + lp_val_loader, + self.device) \ if lp_val_loader is not None else None - test_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_test_loader, self.device) \ + test_scores = run_lp_mini_batch_predict(decoder, + lp_test_embs, + lp_test_loader, + self.device) \ if lp_test_loader is not None else None if val_results is not None: From 7d91049560897adbe84bd1cc3bf608d083ff4e89 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 12 Jun 2024 10:14:35 -0700 Subject: [PATCH 2/3] Fix --- python/graphstorm/trainer/mt_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index afd03bd55a..7353d87bb8 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -479,7 +479,7 @@ def fit(self, train_loader, self.get_best_model_path() if save_model_path is not None else None} self.log_params(output) -def eval(self, model, data, mt_val_loader, mt_test_loader, total_steps, + def eval(self, model, data, mt_val_loader, mt_test_loader, total_steps, use_mini_batch_infer=False, return_proba=True): """ do the model evaluation using validation and test sets From ac333d844c51976ab646b14135c224c4c08bb9e2 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 13 Jun 2024 00:28:33 -0700 Subject: [PATCH 3/3] resolve comments --- python/graphstorm/model/multitask_gnn.py | 8 ++++---- python/graphstorm/trainer/mt_trainer.py | 24 ++++++++++++++++-------- tests/unit-tests/test_gnn.py | 15 +++++++-------- tests/unit-tests/test_trainer.py | 21 +++++++++------------ 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index e891dd2be0..5b6b9f7b5c 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -490,7 +490,7 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs): Return ------ - embs: node embedings + embs: node embeddings """ if isinstance(model.gnn_encoder, GSgnnGNNEncoderInterface): if model.has_sparse_params(): @@ -507,11 +507,11 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs): "skip the last layer self-loop" "Please set use_self_loop to False", BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) - embs = gen_embs(last_self_loop=True) + embs = gen_embs(skip_last_self_loop=False) else: # skip the selfloop of the last layer to # avoid information leakage. - embs = gen_embs(last_self_loop=False) + embs = gen_embs(skip_last_self_loop=True) else: # we will use the computed embs directly logging.warning("The gnn encoder %s does not support skip " @@ -520,5 +520,5 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs): "node feature leakage risk when doing %s training.", type(model.gnn_encoder), BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) - embs = gen_embs(last_self_loop=True) + embs = gen_embs(skip_last_self_loop=False) return embs diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 7353d87bb8..ef97392cf0 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -429,7 +429,7 @@ def fit(self, train_loader, # TODO(xiangsx): Add early stop support # Every n iterations, save the model and keep - # the lask k models. + # the last k models. # TODO(xiangsx): support saving the best top k model. if save_model_frequency > 0 and \ total_steps % save_model_frequency == 0 and \ @@ -489,9 +489,9 @@ def eval(self, model, data, mt_val_loader, mt_test_loader, total_steps, The GNN model. data : GSgnnData The training dataset - mt_val_loader: GSNodeDataLoader + mt_val_loader: GSgnnMultiTaskDataLoader The dataloader for validation data - mt_test_loader : GSNodeDataLoader + mt_test_loader : GSgnnMultiTaskDataLoader The dataloader for test data. total_steps: int Total number of iterations. @@ -657,18 +657,26 @@ def gen_embs(edge_mask=None): test_results = {task_info.task_id: test_scores} if len(nfeat_recon_tasks) > 0: - def nfrecon_gen_embs(last_self_loop=False): + def nfrecon_gen_embs(skip_last_self_loop=False): """ Generate node embeddings for node feature reconstruction """ - if last_self_loop is False: + if skip_last_self_loop is True: + # Turn off the last layer GNN's self-loop + # to compute node embeddings. model.gnn_encoder.skip_last_selfloop() new_embs = gen_embs() model.gnn_encoder.reset_last_selfloop() return new_embs else: - # if lask_self_loop is True - # we can reuse the computed embs if any - return embs if embs is not None else gen_embs() + # If skip_last_self_loop is False + # we will not change the way we compute + # node embeddings. + if embs is not None: + # The embeddings have been computed + # when handling predict_tasks in L608 + return embs + else: + return gen_embs() nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 7277cd3908..27015796c0 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -2301,21 +2301,20 @@ def check_forward(mock_run_lp_mini_batch_predict, def test_gen_emb_for_nfeat_recon(): encoder_model = DummyGSgnnEncoderModel() - model = DummyGSgnnModel(encoder_model, has_sparse=True) - call_self_loop = True - def check_call_gen_embs(last_self_loop): - assert last_self_loop == call_self_loop + def check_call_gen_embs(skip_last_self_loop): + assert skip_last_self_loop == skip_self_loop + model = DummyGSgnnModel(encoder_model, has_sparse=True) + skip_self_loop = False gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs) - call_self_loop = False + skip_self_loop = True model = DummyGSgnnModel(encoder_model, has_sparse=False) gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs) model = DummyGSgnnModel(None) - call_self_loop = True - def check_call_gen_embs(last_self_loop): - assert last_self_loop == call_self_loop + skip_self_loop = False + gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs) if __name__ == '__main__': diff --git a/tests/unit-tests/test_trainer.py b/tests/unit-tests/test_trainer.py index f66a83f586..dbe8770096 100644 --- a/tests/unit-tests/test_trainer.py +++ b/tests/unit-tests/test_trainer.py @@ -460,41 +460,41 @@ def task_tracker(self): return "dummy tracker" def test_mtask_eval(): - tast_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, task_id='nc_task', task_config=None) nc_dataloader = DummyGSgnnNodeDataLoader() - tast_info_nr = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION, + task_info_nr = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION, task_id='nr_task', task_config=None) nr_dataloader = DummyGSgnnNodeDataLoader() - tast_info_ec = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION, + task_info_ec = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION, task_id='ec_task', task_config=None) ec_dataloader = DummyGSgnnEdgeDataLoader() - tast_info_er = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, + task_info_er = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, task_id='er_task', task_config=None) er_dataloader = DummyGSgnnEdgeDataLoader() task_config = GSConfig.__new__(GSConfig) setattr(task_config, "train_mask", "train_mask") - tast_info_lp = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, + task_info_lp = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, task_id='lp_task', task_config=task_config) encoder_model = DummyGSgnnEncoderModel() - model = DummyGSgnnMTModel(encoder_model, decoders={tast_info_lp.task_id: "dummy"}, has_sparse=True) + model = DummyGSgnnMTModel(encoder_model, decoders={task_info_lp.task_id: "dummy"}, has_sparse=True) mt_trainer = GSgnnMultiTaskLearningTrainer(model) mt_trainer._device = 'cpu' lp_dataloader = DummyGSgnnLinkPredictionDataLoader() - tast_info_nfr = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, + task_info_nfr = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, task_id='nfr_task', task_config=None) nfr_dataloader = DummyGSgnnNodeDataLoader() - task_infos = [tast_info_nc, tast_info_nr, tast_info_ec, - tast_info_er, tast_info_lp, tast_info_nfr] + task_infos = [task_info_nc, task_info_nr, task_info_ec, + task_info_er, task_info_lp, task_info_nfr] data = None res = mt_trainer.eval(model, data, None, None, 100) @@ -595,7 +595,6 @@ def check_eval(mock_do_mini_batch_inference, } evaluator = MTaskCheckerEvaluator(target_res, target_res, 100) mt_trainer.setup_evaluator(evaluator) - # test when val_loader is None mt_trainer.eval(model, data, val_loader, test_loader, 100) # lp tasks are empty @@ -614,7 +613,6 @@ def check_eval(mock_do_mini_batch_inference, } evaluator = MTaskCheckerEvaluator(target_res, target_res, 200) mt_trainer.setup_evaluator(evaluator) - # test when val_loader is None mt_trainer.eval(model, data, val_loader, test_loader, 200) # node feature reconstruct tasks are empty @@ -633,7 +631,6 @@ def check_eval(mock_do_mini_batch_inference, } evaluator = MTaskCheckerEvaluator(target_res, target_res, 200) mt_trainer.setup_evaluator(evaluator) - # test when val_loader is None mt_trainer.eval(model, data, val_loader, test_loader, 200) check_eval()