Skip to content

Commit

Permalink
Merge branch 'fix-multi-task-eval' into multi-task-infer
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Jun 13, 2024
2 parents 75dbd74 + ac333d8 commit a7b1872
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 39 deletions.
22 changes: 15 additions & 7 deletions python/graphstorm/inference/mt_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,26 @@ def gen_embs(edge_mask=None):
task_infos = recon_nfeat_test_loader.task_infos

with th.no_grad():
def nfrecon_gen_embs(last_self_loop=False):
def nfrecon_gen_embs(skip_last_self_loop=False):
""" Generate node embeddings for node feature reconstruction
"""
if last_self_loop is False:
self._model.gnn_encoder.skip_last_selfloop()
if skip_last_self_loop is True:
# Turn off the last layer GNN's self-loop
# to compute node embeddings.
model.gnn_encoder.skip_last_selfloop()
new_embs = gen_embs()
self._model.gnn_encoder.reset_last_selfloop()
model.gnn_encoder.reset_last_selfloop()
return new_embs
else:
# if lask_self_loop is True
# we can reuse the computed embs if any
return embs if embs is not None else gen_embs()
# If skip_last_self_loop is False
# we will not change the way we compute
# node embeddings.
if embs is not None:
# The embeddings have been computed
# when handling predict_tasks in L608
return embs
else:
return gen_embs()

nfeat_embs = gen_emb_for_nfeat_reconstruct(self._model, nfrecon_gen_embs)

Expand Down
8 changes: 4 additions & 4 deletions python/graphstorm/model/multitask_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs):
Return
------
embs: node embedings
embs: node embeddings
"""
if isinstance(model.gnn_encoder, GSgnnGNNEncoderInterface):
if model.has_sparse_params():
Expand All @@ -507,11 +507,11 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs):
"skip the last layer self-loop"
"Please set use_self_loop to False",
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
embs = gen_embs(last_self_loop=True)
embs = gen_embs(skip_last_self_loop=False)
else:
# skip the selfloop of the last layer to
# avoid information leakage.
embs = gen_embs(last_self_loop=False)
embs = gen_embs(skip_last_self_loop=True)
else:
# we will use the computed embs directly
logging.warning("The gnn encoder %s does not support skip "
Expand All @@ -520,5 +520,5 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs):
"node feature leakage risk when doing %s training.",
type(model.gnn_encoder),
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
embs = gen_embs(last_self_loop=True)
embs = gen_embs(skip_last_self_loop=False)
return embs
24 changes: 16 additions & 8 deletions python/graphstorm/trainer/mt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def fit(self, train_loader,
# TODO(xiangsx): Add early stop support

# Every n iterations, save the model and keep
# the lask k models.
# the last k models.
# TODO(xiangsx): support saving the best top k model.
if save_model_frequency > 0 and \
total_steps % save_model_frequency == 0 and \
Expand Down Expand Up @@ -489,9 +489,9 @@ def eval(self, model, data, mt_val_loader, mt_test_loader, total_steps,
The GNN model.
data : GSgnnData
The training dataset
mt_val_loader: GSNodeDataLoader
mt_val_loader: GSgnnMultiTaskDataLoader
The dataloader for validation data
mt_test_loader : GSNodeDataLoader
mt_test_loader : GSgnnMultiTaskDataLoader
The dataloader for test data.
total_steps: int
Total number of iterations.
Expand Down Expand Up @@ -657,18 +657,26 @@ def gen_embs(edge_mask=None):
test_results = {task_info.task_id: test_scores}

if len(nfeat_recon_tasks) > 0:
def nfrecon_gen_embs(last_self_loop=False):
def nfrecon_gen_embs(skip_last_self_loop=False):
""" Generate node embeddings for node feature reconstruction
"""
if last_self_loop is False:
if skip_last_self_loop is True:
# Turn off the last layer GNN's self-loop
# to compute node embeddings.
model.gnn_encoder.skip_last_selfloop()
new_embs = gen_embs()
model.gnn_encoder.reset_last_selfloop()
return new_embs
else:
# if lask_self_loop is True
# we can reuse the computed embs if any
return embs if embs is not None else gen_embs()
# If skip_last_self_loop is False
# we will not change the way we compute
# node embeddings.
if embs is not None:
# The embeddings have been computed
# when handling predict_tasks in L608
return embs
else:
return gen_embs()

nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs)

Expand Down
15 changes: 7 additions & 8 deletions tests/unit-tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2301,21 +2301,20 @@ def check_forward(mock_run_lp_mini_batch_predict,

def test_gen_emb_for_nfeat_recon():
encoder_model = DummyGSgnnEncoderModel()
model = DummyGSgnnModel(encoder_model, has_sparse=True)
call_self_loop = True
def check_call_gen_embs(last_self_loop):
assert last_self_loop == call_self_loop
def check_call_gen_embs(skip_last_self_loop):
assert skip_last_self_loop == skip_self_loop

model = DummyGSgnnModel(encoder_model, has_sparse=True)
skip_self_loop = False
gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs)

call_self_loop = False
skip_self_loop = True
model = DummyGSgnnModel(encoder_model, has_sparse=False)
gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs)

model = DummyGSgnnModel(None)
call_self_loop = True
def check_call_gen_embs(last_self_loop):
assert last_self_loop == call_self_loop
skip_self_loop = False
gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs)


if __name__ == '__main__':
Expand Down
21 changes: 9 additions & 12 deletions tests/unit-tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,41 +460,41 @@ def task_tracker(self):
return "dummy tracker"

def test_mtask_eval():
tast_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION,
task_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION,
task_id='nc_task',
task_config=None)
nc_dataloader = DummyGSgnnNodeDataLoader()
tast_info_nr = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION,
task_info_nr = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION,
task_id='nr_task',
task_config=None)
nr_dataloader = DummyGSgnnNodeDataLoader()
tast_info_ec = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION,
task_info_ec = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION,
task_id='ec_task',
task_config=None)
ec_dataloader = DummyGSgnnEdgeDataLoader()
tast_info_er = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION,
task_info_er = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION,
task_id='er_task',
task_config=None)
er_dataloader = DummyGSgnnEdgeDataLoader()
task_config = GSConfig.__new__(GSConfig)
setattr(task_config, "train_mask", "train_mask")
tast_info_lp = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION,
task_info_lp = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION,
task_id='lp_task',
task_config=task_config)

encoder_model = DummyGSgnnEncoderModel()
model = DummyGSgnnMTModel(encoder_model, decoders={tast_info_lp.task_id: "dummy"}, has_sparse=True)
model = DummyGSgnnMTModel(encoder_model, decoders={task_info_lp.task_id: "dummy"}, has_sparse=True)

mt_trainer = GSgnnMultiTaskLearningTrainer(model)
mt_trainer._device = 'cpu'

lp_dataloader = DummyGSgnnLinkPredictionDataLoader()
tast_info_nfr = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT,
task_info_nfr = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT,
task_id='nfr_task',
task_config=None)
nfr_dataloader = DummyGSgnnNodeDataLoader()
task_infos = [tast_info_nc, tast_info_nr, tast_info_ec,
tast_info_er, tast_info_lp, tast_info_nfr]
task_infos = [task_info_nc, task_info_nr, task_info_ec,
task_info_er, task_info_lp, task_info_nfr]

data = None
res = mt_trainer.eval(model, data, None, None, 100)
Expand Down Expand Up @@ -595,7 +595,6 @@ def check_eval(mock_do_mini_batch_inference,
}
evaluator = MTaskCheckerEvaluator(target_res, target_res, 100)
mt_trainer.setup_evaluator(evaluator)
# test when val_loader is None
mt_trainer.eval(model, data, val_loader, test_loader, 100)

# lp tasks are empty
Expand All @@ -614,7 +613,6 @@ def check_eval(mock_do_mini_batch_inference,
}
evaluator = MTaskCheckerEvaluator(target_res, target_res, 200)
mt_trainer.setup_evaluator(evaluator)
# test when val_loader is None
mt_trainer.eval(model, data, val_loader, test_loader, 200)

# node feature reconstruct tasks are empty
Expand All @@ -633,7 +631,6 @@ def check_eval(mock_do_mini_batch_inference,
}
evaluator = MTaskCheckerEvaluator(target_res, target_res, 200)
mt_trainer.setup_evaluator(evaluator)
# test when val_loader is None
mt_trainer.eval(model, data, val_loader, test_loader, 200)

check_eval()
Expand Down

0 comments on commit a7b1872

Please sign in to comment.