Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Jun 12, 2024
1 parent 14f6cfc commit cbb0596
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 72 deletions.
18 changes: 14 additions & 4 deletions python/graphstorm/model/multitask_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from .node_gnn import run_node_mini_batch_predict
from .edge_gnn import run_edge_mini_batch_predict
from .lp_gnn import run_lp_mini_batch_predict
from ..utils import is_distributed


class GSgnnMultiTaskModelInterface:
Expand Down Expand Up @@ -383,16 +382,19 @@ def predict(self, task_id, mini_batch, return_proba=False):

def multi_task_mini_batch_predict(
model, emb, dataloaders, task_infos, device, return_proba=True, return_label=False):
""" conduct mini batch prediction on multiple tasks
""" conduct mini batch prediction on multiple tasks.
The task infos are passed in as task_infos.
The task dataloaders are passed in as dataloaders.
Parameters
----------
model: GSgnnMultiTaskModelInterface, GSgnnModel
Multi-task learning model
emb : dict of Tensor
The GNN embeddings
loader: GSgnnMultiTaskDataLoader
The mini-batch dataloader.
dataloaders: list
List of val or test dataloaders.
task_infos: list
List of task info
device: th.device
Expand Down Expand Up @@ -458,6 +460,14 @@ def multi_task_mini_batch_predict(
etype = list(preds.keys())[0]
res[task_info.task_id] = (preds[etype], labels[etype] \
if labels is not None else None)
elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
if dataloader is None:
# In cases when there is no validation or test set.
res[task_info.task_id] = None
else:
decoder = task_decoders[task_info.task_id]
ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device)
res[task_info.task_id] = ranking
else:
raise TypeError(f"Unsupported task {task_info}")

Expand Down
8 changes: 6 additions & 2 deletions python/graphstorm/run/gsgnn_mt/gsgnn_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,12 @@ def main(config_args):
logging.warning("The training data do not have validation set.")
if test_loader is None:
logging.warning("The training data do not have test set.")
task_evaluators[task.task_id] = \
create_evaluator(task)

if val_loader is None and test_loader is None:
logging.warning("Task %s does not have validation and test sets.", task.task_id)
else:
task_evaluators[task.task_id] = \
create_evaluator(task)

train_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, train_dataloaders)
val_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, val_dataloaders)
Expand Down
29 changes: 18 additions & 11 deletions python/graphstorm/trainer/mt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,10 @@ def eval(self, model, data, val_loader, test_loader, total_steps,
if test_loader is not None else None
task_infos = val_loader.task_infos \
if val_loader is not None else test_loader.task_infos
if val_dataloaders is None:
val_dataloaders = [None] * len(task_infos)
if test_dataloaders is None:
test_dataloaders = [None] * len(task_infos)

# All the tasks share the same GNN encoder so the fanouts are same
# for different tasks.
Expand Down Expand Up @@ -556,6 +560,7 @@ def eval(self, model, data, val_loader, test_loader, total_steps,

for val_loader, test_loader, task_info \
in zip(val_dataloaders, test_dataloaders, task_infos):

if val_loader is None and test_loader is None:
# For this task, these is no need to do compute test or val score
# skip this task
Expand Down Expand Up @@ -605,7 +610,7 @@ def gen_embs(edge_mask=None):
multi_task_mini_batch_predict(
model,
emb=embs,
loader=predict_val_loaders,
dataloaders=predict_val_loaders,
task_infos=predict_tasks,
device=self.device,
return_proba=return_proba,
Expand All @@ -616,7 +621,7 @@ def gen_embs(edge_mask=None):
multi_task_mini_batch_predict(
model,
emb=embs,
loader=predict_test_loaders,
dataloaders=predict_test_loaders,
task_infos=predict_tasks,
device=self.device,
return_proba=return_proba,
Expand All @@ -625,15 +630,17 @@ def gen_embs(edge_mask=None):

if len(lp_tasks) > 0:
for lp_val_loader, lp_test_loader, task_info \
in zip(lp_val_loaders, lp_test_loaders, task_infos):

in zip(lp_val_loaders, lp_test_loaders, lp_tasks):
# For link prediction, do evaluation task
# by task.
lp_test_embs = gen_embs(edge_mask=task_info.task_config.train_mask)

decoder = model.task_decoders[task_info.task_id]
val_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_val_loader, self.device) \
if val_loader is not None else None
if lp_val_loader is not None else None
test_scores = run_lp_mini_batch_predict(decoder, lp_test_embs, lp_test_loader, self.device) \
if val_loader is not None else None
if lp_test_loader is not None else None

if val_results is not None:
val_results[task_info.task_id] = val_scores
else:
Expand All @@ -644,7 +651,7 @@ def gen_embs(edge_mask=None):
test_results = {task_info.task_id: test_scores}

if len(nfeat_recon_tasks) > 0:
def nfrecon_gen_embs(model, last_self_loop=False):
def nfrecon_gen_embs(last_self_loop=False):
""" Generate node embeddings for node feature reconstruction
"""
if last_self_loop is False:
Expand All @@ -663,8 +670,8 @@ def nfrecon_gen_embs(model, last_self_loop=False):
multi_task_mini_batch_predict(
model,
emb=nfeat_embs,
loader=nfeat_recon_val_loaders,
task_infos=predict_tasks,
dataloaders=nfeat_recon_val_loaders,
task_infos=nfeat_recon_tasks,
device=self.device,
return_proba=return_proba,
return_label=True) \
Expand All @@ -674,8 +681,8 @@ def nfrecon_gen_embs(model, last_self_loop=False):
multi_task_mini_batch_predict(
model,
emb=nfeat_embs,
loader=nfeat_recon_test_loaders,
task_infos=predict_tasks,
dataloaders=nfeat_recon_test_loaders,
task_infos=nfeat_recon_tasks,
device=self.device,
return_proba=return_proba,
return_label=True) \
Expand Down
114 changes: 67 additions & 47 deletions tests/unit-tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION)
BUILTIN_TASK_LINK_PREDICTION,
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
from graphstorm.model import GSNodeEncoderInputLayer, RelationalGCNEncoder
from graphstorm.model import GSgnnNodeModel, GSgnnEdgeModel
from graphstorm.model import GSLMNodeEncoderInputLayer, GSPureLMNodeInputLayer
Expand All @@ -59,7 +60,7 @@
from graphstorm.model.node_decoder import EntityRegression, EntityClassifier
from graphstorm.model.loss_func import RegressionLossFunc
from graphstorm.dataloading import GSgnnData
from graphstorm.dataloading import GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnMultiTaskDataLoader
from graphstorm.dataloading import GSgnnNodeDataLoader, GSgnnEdgeDataLoader
from graphstorm.dataloading.dataset import prepare_batch_input
from graphstorm import (create_builtin_edge_gnn_model,
create_builtin_node_gnn_model,
Expand All @@ -75,13 +76,17 @@
from graphstorm.model.edge_gnn import (edge_mini_batch_predict,
run_edge_mini_batch_predict,
edge_mini_batch_gnn_predict)
from graphstorm.model.multitask_gnn import multi_task_mini_batch_predict
from graphstorm.model.multitask_gnn import (multi_task_mini_batch_predict,
gen_emb_for_nfeat_reconstruct)
from graphstorm.model.gnn_with_reconstruct import construct_node_feat, get_input_embeds_combined
from graphstorm.model.utils import load_model, save_model
from graphstorm.model import GSgnnMultiTaskSharedEncoderModel
from graphstorm.dataloading import (GSgnnEdgeDataLoaderBase,
GSgnnLinkPredictionDataLoaderBase,
GSgnnNodeDataLoaderBase)

from util import (DummyGSgnnNodeDataLoader,
DummyGSgnnEdgeDataLoader,
DummyGSgnnLinkPredictionDataLoader,
DummyGSgnnModel,
DummyGSgnnEncoderModel)

from data_utils import generate_dummy_dist_graph, generate_dummy_dist_graph_multi_target_ntypes
from data_utils import generate_dummy_dist_graph_reconstruct
Expand Down Expand Up @@ -2082,36 +2087,6 @@ def input_embed_side_effect_func(input_nodes, node_feats):

check_forward()

class DummyGSgnnNodeDataLoader(GSgnnNodeDataLoaderBase):
def __init__(self):
pass # do nothing

def __len__(self):
return 10

def __iter__(self):
return self

class DummyGSgnnEdgeDataLoader(GSgnnEdgeDataLoaderBase):
def __init__(self):
pass # do nothing

def __len__(self):
return 10

def __iter__(self):
return self

class DummyGSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase):
def __init__(self):
pass # do nothing

def __len__(self):
return 10

def __iter__(self):
return self

def test_multi_task_mini_batch_predict():
mt_model = GSgnnMultiTaskSharedEncoderModel(0.1)

Expand Down Expand Up @@ -2144,6 +2119,11 @@ def pred_lp_loss_func(pos_score, neg_score):
DummyLPDecoder(),
pred_lp_loss_func)

mt_model.add_task("nfr_task",
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT,
DummyLPDecoder(),
pred_lp_loss_func)

tast_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION,
task_id='nc_task',
task_config=None)
Expand All @@ -2164,8 +2144,14 @@ def pred_lp_loss_func(pos_score, neg_score):
task_id='lp_task',
task_config=None)
lp_dataloader = DummyGSgnnLinkPredictionDataLoader()
task_infos = [tast_info_nc, tast_info_nr, tast_info_ec, tast_info_er, tast_info_lp]
dataloaders = [nc_dataloader, nr_dataloader, ec_dataloader, er_dataloader, lp_dataloader]
tast_info_nfr = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT,
task_id='nfr_task',
task_config=None)
nfr_dataloader = DummyGSgnnNodeDataLoader()
task_infos = [tast_info_nc, tast_info_nr, tast_info_ec,
tast_info_er, tast_info_lp, tast_info_nfr]
dataloaders = [nc_dataloader, nr_dataloader, ec_dataloader,
er_dataloader, lp_dataloader, nfr_dataloader]

node_pred = {"n0": th.arange(10)}
node_prob = {"n0": th.arange(10)/10}
Expand Down Expand Up @@ -2206,10 +2192,10 @@ def check_forward(mock_run_lp_mini_batch_predict,
mock_run_edge_mini_batch_predict,
mock_run_node_mini_batch_predict):

mt_dataloader = GSgnnMultiTaskDataLoader(None, task_infos, dataloaders)
res = multi_task_mini_batch_predict(mt_model,
None,
mt_dataloader,
dataloaders,
task_infos,
device=th.device('cpu'),
return_proba=False,
return_label=False)
Expand All @@ -2227,10 +2213,15 @@ def check_forward(mock_run_lp_mini_batch_predict,
assert res["er_task"][1] is None
assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy())
assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy())
# node feature reconstruction also calls
# run_node_mini_batch_predict
assert len(res["nfr_task"]) == 2
assert_equal(res["nfr_task"][0].numpy(), node_pred["n0"].numpy())

res = multi_task_mini_batch_predict(mt_model,
None,
mt_dataloader,
dataloaders,
task_infos,
device=th.device('cpu'),
return_proba=True,
return_label=False)
Expand All @@ -2248,10 +2239,15 @@ def check_forward(mock_run_lp_mini_batch_predict,
assert res["er_task"][1] is None
assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy())
assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy())
# node feature reconstruction also calls
# run_node_mini_batch_predict
assert len(res["nfr_task"]) == 2
assert_equal(res["nfr_task"][0].numpy(), node_prob["n0"].numpy())

res = multi_task_mini_batch_predict(mt_model,
None,
mt_dataloader,
dataloaders,
task_infos,
device=th.device('cpu'),
return_proba=False,
return_label=True)
Expand All @@ -2269,14 +2265,18 @@ def check_forward(mock_run_lp_mini_batch_predict,
assert_equal(res["ec_task"][0].numpy(), edge_label[("n0", "r0", "n1")].numpy())
assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy())
assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy())
# node feature reconstruction also calls
# run_node_mini_batch_predict
assert len(res["nfr_task"]) == 2
assert_equal(res["nfr_task"][0].numpy(), node_pred["n0"].numpy())
assert_equal(res["nfr_task"][1].numpy(), node_label["n0"].numpy())


new_dataloaders = [nc_dataloader, None, ec_dataloader, None, None]
mt_dataloader = GSgnnMultiTaskDataLoader(None, task_infos, new_dataloaders)

new_dataloaders = [nc_dataloader, None, ec_dataloader, None, None, None]
res = multi_task_mini_batch_predict(mt_model,
None,
mt_dataloader,
new_dataloaders,
task_infos,
device=th.device('cpu'),
return_proba=False,
return_label=False)
Expand All @@ -2293,10 +2293,29 @@ def check_forward(mock_run_lp_mini_batch_predict,
assert res["er_task"][0] is None
assert res["er_task"][1] is None
assert res["lp_task"] is None
assert len(res["nfr_task"]) == 2
assert res["nfr_task"][0] is None
assert res["nfr_task"][1] is None

check_forward()

def test_gen_emb_for_nfeat_recon():
encoder_model = DummyGSgnnEncoderModel()
model = DummyGSgnnModel(encoder_model, has_sparse=True)
call_self_loop = True
def check_call_gen_embs(last_self_loop):
assert last_self_loop == call_self_loop

check_forward()
gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs)

call_self_loop = False
model = DummyGSgnnModel(encoder_model, has_sparse=False)
gen_emb_for_nfeat_reconstruct(model, check_call_gen_embs)

model = DummyGSgnnModel(None)
call_self_loop = True
def check_call_gen_embs(last_self_loop):
assert last_self_loop == call_self_loop


if __name__ == '__main__':
Expand All @@ -2305,6 +2324,7 @@ def check_forward(mock_run_lp_mini_batch_predict,
test_multi_task_forward()
test_multi_task_predict()
test_multi_task_mini_batch_predict()
test_gen_emb_for_nfeat_recon()

test_lm_rgcn_node_prediction_with_reconstruct()
test_rgcn_node_prediction_with_reconstruct(True)
Expand Down
Loading

0 comments on commit cbb0596

Please sign in to comment.