Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Jun 11, 2024
1 parent 7eba2e3 commit 14f6cfc
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 98 deletions.
3 changes: 1 addition & 2 deletions python/graphstorm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
run_lp_mini_batch_predict)
from .multitask_gnn import (GSgnnMultiTaskModelInterface,
GSgnnMultiTaskSharedEncoderModel)
from .multitask_gnn import (multi_prediction_task_mini_batch_predict,
multi_nfeat_recon_task_mini_batch_predict,
from .multitask_gnn import (multi_task_mini_batch_predict,
gen_emb_for_nfeat_reconstruct)
from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer
from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer
Expand Down
81 changes: 5 additions & 76 deletions python/graphstorm/model/multitask_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ def predict(self, task_id, mini_batch, return_proba=False):
else:
raise TypeError(f"Unknow task type {task_type}")

def multi_prediction_task_mini_batch_predict(
model, emb, loader, device, return_proba=True, return_label=False):
def multi_task_mini_batch_predict(
model, emb, dataloaders, task_infos, device, return_proba=True, return_label=False):
""" conduct mini batch prediction on multiple tasks
Parameters
Expand All @@ -393,6 +393,8 @@ def multi_prediction_task_mini_batch_predict(
The GNN embeddings
loader: GSgnnMultiTaskDataLoader
The mini-batch dataloader.
task_infos: list
List of task info
device: th.device
Device used to compute test scores.
return_proba: bool
Expand All @@ -404,8 +406,6 @@ def multi_prediction_task_mini_batch_predict(
-------
dict: prediction results of each task
"""
dataloaders = loader.dataloaders
task_infos = loader.task_infos
task_decoders = model.task_decoders
res = {}
with th.no_grad():
Expand Down Expand Up @@ -458,16 +458,8 @@ def multi_prediction_task_mini_batch_predict(
etype = list(preds.keys())[0]
res[task_info.task_id] = (preds[etype], labels[etype] \
if labels is not None else None)
elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION:
if dataloader is None:
# In cases when there is no validation or test set.
res[task_info.task_id] = None
else:
decoder = task_decoders[task_info.task_id]
ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device)
res[task_info.task_id] = ranking
else:
raise TypeError(f"Unknown task {task_info}")
raise TypeError(f"Unsupported task {task_info}")

return res

Expand Down Expand Up @@ -520,66 +512,3 @@ def gen_emb_for_nfeat_reconstruct(model, gen_embs):
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
embs = gen_embs(last_self_loop=True)
return embs

def multi_nfeat_recon_task_mini_batch_predict(
model, embs,
nfeat_recon_val_loaders,
nfeat_recon_test_loaders,
task_infos,
device,
return_label=False):
""" conduct mini batch prediction on node feature
reconstruction tasks
Parameters
----------
model: GSgnnMultiTaskModelInterface, GSgnnModel
Multi-task learning model
embs : dict of Tensor
The GNN embeddings
nfeat_recon_val_loaders: list
List of validation datalaoders
nfeat_recon_test_loaders: list
List of test dataloaders
task_infos: list
List of task info
device: th.device
Device used to compute test scores.
return_label : bool
Whether or not to return labels.
Return
------
dict: Validatoin results
dict: test results
"""
val_results = {}
test_results = {}
for val_loader, test_loader, task_info in \
zip(nfeat_recon_val_loaders, nfeat_recon_test_loaders, task_infos):
decoder = model.task_decoders[task_info.task_id]
if val_loader is None:
val_results[task_info.task_id] = (None, None)
else:
val_preds, val_labels = \
run_node_mini_batch_predict(decoder,
embs,
val_loader,
device=device,
return_proba=False,
return_label=return_label)
val_results[task_info.task_id] = (val_preds, val_labels)

if test_loader is None:
test_results[task_info.task_id] = (None, None)
else:
test_preds, test_labels = \
run_node_mini_batch_predict(decoder,
embs,
test_loader,
device=device,
return_proba=False,
return_label=return_label)
test_results[task_info.task_id] = (test_preds, test_labels)

return val_results, test_results
57 changes: 37 additions & 20 deletions python/graphstorm/trainer/mt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
do_mini_batch_inference,
GSgnnModelBase, GSgnnModel,
GSgnnMultiTaskModelInterface,
multi_prediction_task_mini_batch_predict,
multi_nfeat_recon_task_mini_batch_predict,
multi_task_mini_batch_predict,
gen_emb_for_nfeat_reconstruct)
from ..model.lp_gnn import run_lp_mini_batch_predict
from .gsgnn_trainer import GSgnnTrainer
Expand Down Expand Up @@ -603,24 +602,26 @@ def gen_embs(edge_mask=None):
sys_tracker.check('compute embeddings')
embs = gen_embs()
val_results = \
multi_prediction_task_mini_batch_predict(
multi_task_mini_batch_predict(
model,
emb=embs,
loader=val_loader,
loader=predict_val_loaders,
task_infos=predict_tasks,
device=self.device,
return_proba=return_proba,
return_label=True) \
if val_loader is not None else None
if len(predict_val_loaders) > 0 else None

test_results = \
multi_prediction_task_mini_batch_predict(
multi_task_mini_batch_predict(
model,
emb=embs,
loader=test_loader,
loader=predict_test_loaders,
task_infos=predict_tasks,
device=self.device,
return_proba=return_proba,
return_label=True) \
if test_loader is not None else None
if len(predict_test_loaders) > 0 else None

if len(lp_tasks) > 0:
for lp_val_loader, lp_test_loader, task_info \
Expand Down Expand Up @@ -658,24 +659,40 @@ def nfrecon_gen_embs(model, last_self_loop=False):

nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs)

nfeat_recon_val_results, nfeat_recon_test_results = \
multi_nfeat_recon_task_mini_batch_predict(
nfeat_recon_val_results = \
multi_task_mini_batch_predict(
model,
nfeat_embs,
nfeat_recon_val_loaders,
nfeat_recon_test_loaders,
task_infos,
emb=nfeat_embs,
loader=nfeat_recon_val_loaders,
task_infos=predict_tasks,
device=self.device,
return_label=True)
return_proba=return_proba,
return_label=True) \
if len(nfeat_recon_val_loaders) > 0 else None

if val_results is not None:
val_results.update(nfeat_recon_val_results)
else:
nfeat_recon_test_results = \
multi_task_mini_batch_predict(
model,
emb=nfeat_embs,
loader=nfeat_recon_test_loaders,
task_infos=predict_tasks,
device=self.device,
return_proba=return_proba,
return_label=True) \
if len(nfeat_recon_test_loaders) > 0 else None

if val_results is None:
val_results = nfeat_recon_val_results
if test_results is not None:
test_results.update(nfeat_recon_val_results)
else:
if nfeat_recon_val_results is not None:
val_results.update(nfeat_recon_val_results)

if test_results is None:
test_results = nfeat_recon_test_results
else:
if nfeat_recon_test_results is not None:
test_results.update(nfeat_recon_test_results)


sys_tracker.check('after_test_score')
val_score, test_score = self.evaluator.evaluate(
Expand Down

0 comments on commit 14f6cfc

Please sign in to comment.