From 08e3fe6b7db79dcf14da8047fa886a5514d16494 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:06:52 -0700 Subject: [PATCH 1/8] update init --- python/graphstorm/model/__init__.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 08a8391f56..18a741e200 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -24,12 +24,17 @@ from .gnn import do_full_graph_inference from .gnn import do_mini_batch_inference from .node_gnn import GSgnnNodeModel, GSgnnNodeModelBase, GSgnnNodeModelInterface -from .node_gnn import node_mini_batch_gnn_predict, node_mini_batch_predict +from .node_gnn import (node_mini_batch_gnn_predict, + node_mini_batch_predict, + run_node_mini_batch_predict) from .edge_gnn import GSgnnEdgeModel, GSgnnEdgeModelBase, GSgnnEdgeModelInterface -from .edge_gnn import edge_mini_batch_gnn_predict, edge_mini_batch_predict +from .edge_gnn import (edge_mini_batch_gnn_predict, + edge_mini_batch_predict, + run_edge_mini_batch_predict) from .lp_gnn import (GSgnnLinkPredictionModel, GSgnnLinkPredictionModelBase, - GSgnnLinkPredictionModelInterface) + GSgnnLinkPredictionModelInterface, + run_lp_mini_batch_predict) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv From 4945c2c22e31225ce252699014b41841b4453bb2 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:08:09 -0700 Subject: [PATCH 2/8] update ep_gnn.py --- python/graphstorm/model/edge_gnn.py | 39 ++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 536e61f311..0a4b6a38e5 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -311,6 +311,44 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= model.eval() decoder = model.decoder device = model.device + + preds, labels = run_edge_mini_batch_predict(decoder, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_edge_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch prediction using edge decoder + + This function usually follows full-grain GNN embedding inference. After having + the GNN embeddings, we need to perform mini-batch computation to make predictions + on the GNN embeddings. + + Parameters + ---------- + decoder : GSEdgeDecoder + The GraphStorm edge decoder + emb : dict of Tensor + The GNN embeddings + loader : GSgnnEdgeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba: bool + Whether to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels + + Returns + ------- + dict of Tensor : GNN prediction results. Return all the results when return_proba is true + otherwise return the maximum result. + dict of Tensor : labels if return_labels is True + """ data = loader.data g = data.g preds = {} @@ -379,7 +417,6 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= append_to_dict(lbl, labels) barrier() - model.train() for target_etype, pred in preds.items(): preds[target_etype] = th.cat(pred) if return_label: From d0b37b4d3891c2ab1e8f2ea3e9c9cbd896d56ce0 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:09:06 -0700 Subject: [PATCH 3/8] update lp_gnn.py --- python/graphstorm/model/lp_gnn.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index 91c2c3317c..ebf8449a43 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -154,6 +154,30 @@ def lp_mini_batch_predict(model, emb, loader, device): Rankings of positive scores in format of {etype: ranking} """ decoder = model.decoder + return run_lp_mini_batch_predict(decoder, + emb, + loader, + device) + +def run_lp_mini_batch_predict(decoder, emb, loader, device): + """ Perform mini-batch link prediction. + + Parameters + ---------- + decoder : LinkPredictNoParamDecoder or LinkPredictLearnableDecoder + The GraphStorm link prediction decoder model + emb : dict of Tensor + The GNN embeddings + loader : GSgnnEdgeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute test scores + + Returns + ------- + rankings: dict of tensors + Rankings of positive scores in format of {etype: ranking} + """ with th.no_grad(): ranking = {} for pos_neg_tuple, neg_sample_type in loader: From 46da6cabea66a834de0d8c1a746a3710608e10fd Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:12:42 -0700 Subject: [PATCH 4/8] update --- python/graphstorm/model/edge_gnn.py | 5 +++- python/graphstorm/model/lp_gnn.py | 11 ++++++-- python/graphstorm/model/node_gnn.py | 43 +++++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 0a4b6a38e5..7525e2428f 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -322,12 +322,15 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= def run_edge_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): - """ Perform mini-batch prediction using edge decoder + """ Perform mini-batch prediction with the given decoder. This function usually follows full-grain GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. + Parameters ---------- decoder : GSEdgeDecoder diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index ebf8449a43..1e08443755 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -133,7 +133,7 @@ def forward(self, blocks, pos_graph, def lp_mini_batch_predict(model, emb, loader, device): """ Perform mini-batch prediction. - This function follows full-grain GNN embedding inference. + This function follows full-graph GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. @@ -160,7 +160,14 @@ def lp_mini_batch_predict(model, emb, loader, device): device) def run_lp_mini_batch_predict(decoder, emb, loader, device): - """ Perform mini-batch link prediction. + """ Perform mini-batch link prediction with the given decoder. + + This function follows full-graph GNN embedding inference. + After having the GNN embeddings, we need to perform mini-batch + computation to make predictions on the GNN embeddings. + + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. Parameters ---------- diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index fca05ada24..442582bf4f 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -311,6 +311,47 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= Labels if return_labels is True """ device = model.device + decoder = model.decoder + model.eval() + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_node_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch prediction with the given decoder. + + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. + + Parameters + ---------- + decoder : GSNodeDecoder + The GraphStorm node decoder + emb : dict of Tensor + The GNN embeddings + loader : GSgnnNodeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba : bool + Whether or not to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels. + + Returns + ------- + dict of Tensor : + Prediction results. + dict of Tensor : + Labels if return_labels is True + """ data = loader.data if return_label: @@ -321,7 +362,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= preds = {} labels = {} # TODO(zhengda) I need to check if the data loader only returns target nodes. - model.eval() with th.no_grad(): for _, seeds, _ in loader: # seeds are target nodes for ntype, seed_nodes in seeds.items(): @@ -345,7 +385,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= labels[ntype].append(lbl[ntype]) else: labels[ntype] = [lbl[ntype]] - model.train() for ntype, ntype_pred in preds.items(): preds[ntype] = th.cat(ntype_pred) From 28ec04cd4bb7f5bec72aaa923d602c43533adc80 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 17:07:02 -0700 Subject: [PATCH 5/8] Add unitests --- python/graphstorm/model/edge_gnn.py | 1 + python/graphstorm/model/node_gnn.py | 9 ++- tests/unit-tests/test_gnn.py | 95 ++++++++++++++++++++++++++--- 3 files changed, 92 insertions(+), 13 deletions(-) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 7525e2428f..a2e023ba81 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -313,6 +313,7 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= device = model.device preds, labels = run_edge_mini_batch_predict(decoder, + emb, loader, device, return_proba, diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index 442582bf4f..c432f07f5d 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -365,11 +365,10 @@ def run_node_mini_batch_predict(decoder, emb, loader, device, with th.no_grad(): for _, seeds, _ in loader: # seeds are target nodes for ntype, seed_nodes in seeds.items(): - if isinstance(model.decoder, th.nn.ModuleDict): - assert ntype in model.decoder, f"Node type {ntype} not in decoder" - decoder = model.decoder[ntype] - else: - decoder = model.decoder + if isinstance(decoder, th.nn.ModuleDict): + assert ntype in decoder, f"Node type {ntype} not in decoder" + decoder = decoder[ntype] + if return_proba: pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device)) else: diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 399015a2ae..d2746fd264 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -60,9 +60,13 @@ from graphstorm import get_node_feat_size from graphstorm.gsf import get_rel_names_for_reconstruct from graphstorm.model import do_full_graph_inference, do_mini_batch_inference -from graphstorm.model.node_gnn import node_mini_batch_predict, node_mini_batch_gnn_predict +from graphstorm.model.node_gnn import (node_mini_batch_predict, + run_node_mini_batch_predict, + node_mini_batch_gnn_predict) from graphstorm.model.node_gnn import GSgnnNodeModelInterface -from graphstorm.model.edge_gnn import edge_mini_batch_predict, edge_mini_batch_gnn_predict +from graphstorm.model.edge_gnn import (edge_mini_batch_predict, + run_edge_mini_batch_predict, + edge_mini_batch_gnn_predict) from graphstorm.model.gnn_with_reconstruct import construct_node_feat, get_input_embeds_combined from graphstorm.model.utils import load_model, save_model @@ -279,9 +283,13 @@ def require_cache_embed(self): pred2_gnn_pred, _, labels2_gnn_pred, = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) # Call last layer mini-batch inference with the GNN dataloader pred2_pred, labels2_pred = node_mini_batch_predict(model, embs, dataloader2, return_label=True) + + pred2_d_pred, labels2_d_pred = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + if isinstance(pred1,dict): assert len(pred1) == len(pred2_gnn_pred) and len(labels1) == len(labels2_gnn_pred) assert len(pred1) == len(pred2_pred) and len(labels1) == len(labels2_pred) + assert len(pred1) == len(pred2_d_pred) and len(labels1) == len(labels2_gnn_pred) for ntype in pred1: assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2_gnn_pred[ntype][0:len(pred2_gnn_pred)].numpy(), decimal=5) @@ -289,6 +297,9 @@ def require_cache_embed(self): assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2_pred[ntype][0:len(pred2_pred)].numpy(), decimal=5) assert_equal(labels1[ntype].numpy(), labels2_pred[ntype].numpy()) + assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), + pred2_d_pred[ntype][0:len(pred2_d_pred)].numpy()) + assert_equal(labels1[ntype].numpy(), labels2_d_pred[ntype].numpy()) else: assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2_gnn_pred[0:len(pred2_gnn_pred)].numpy(), decimal=5) @@ -296,24 +307,42 @@ def require_cache_embed(self): assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2_pred[0:len(pred2_pred)].numpy(), decimal=5) assert_equal(labels1.numpy(), labels2_pred.numpy()) + assert_almost_equal(pred1[0:len(pred1)].numpy(), + labels2_d_pred[0:len(labels2_d_pred)].numpy()) + assert_equal(labels1.numpy(), labels2_d_pred.numpy()) # Test the return_proba argument. pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3_d, labels3_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred4, labels4 = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred4_d, labels4_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) and len(labels3) == len(labels4) + assert len(pred3) == len(pred3_d) and len(labels3) == len(labels3_d) + assert len(pred4) == len(pred4_d) and len(labels4) == len(labels4_d) for key in pred3: assert pred3[key].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[key])) + assert pred3_d[key].dim() == 2 + assert(th.is_floating_point(pred3_d[key])) assert(pred4[key].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[key])) assert(th.equal(pred3[key].argmax(dim=1), pred4[key])) + assert(pred4_d[key].dim() == 1) + assert(is_int(pred4_d[key])) + assert(th.equal(pred3[key].argmax(dim=1), pred4_d[key])) else: assert pred3.dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3)) + assert pred3_d.dim() == 2 + assert(th.is_floating_point(pred3_d)) assert(pred4.dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4)) assert(th.equal(pred3.argmax(dim=1), pred4)) + assert(labels4_d.dim() == 1) + assert(is_int(labels4_d)) + assert(th.equal(pred3.argmax(dim=1), labels4_d)) def check_node_prediction_with_reconstruct(model, data, construct_feat_ntype, train_ntypes, node_feat_field=None): """ Check whether full graph inference and mini batch inference generate the same @@ -416,32 +445,51 @@ def check_mlp_node_prediction(model, data): batch_size=10, label_field='label', node_feats='feat', train_task=False) pred2, _, labels2 = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) + pred1_d, labels1_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) if isinstance(pred1, dict): assert len(pred1) == len(pred2) and len(labels1) == len(labels2) + assert len(pred1) == len(pred1_d) and len(labels1) == len(labels1_d) for ntype in pred1: assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2[ntype][0:len(pred2)].numpy(), decimal=5) assert_equal(labels1[ntype].numpy(), labels2[ntype].numpy()) + assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred1_d[ntype][0:len(pred1_d)].numpy()) + assert_equal(labels1[ntype].numpy(), labels1_d[ntype].numpy()) else: assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2[0:len(pred2)].numpy(), decimal=5) assert_equal(labels1.numpy(), labels2.numpy()) + assert_almost_equal(pred1[0:len(pred1)].numpy(), pred1_d[0:len(pred1_d)].numpy()) + assert_equal(labels1.numpy(), labels1_d.numpy()) # Test the return_proba argument. - pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) - pred4, labels4 = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred3, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred4, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred3_d, _ = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred4_d, _ = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) + assert len(pred3) == len(pred3_d) + assert len(pred4) == len(pred4_d) for ntype in pred3: assert pred3[ntype].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[ntype])) + assert pred3_d[ntype].dim() == 2 + assert(th.is_floating_point(pred3_d[ntype])) assert(pred4[ntype].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[ntype])) assert(th.equal(pred3[ntype].argmax(dim=1), pred4[ntype])) + assert(is_int(pred4_d[ntype])) + assert(th.equal(pred3[ntype].argmax(dim=1), pred4_d[ntype])) else: assert pred3.dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3)) + assert pred3_d.dim() == 2 + assert(th.is_floating_point(pred3_d)) assert(pred4.dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4)) assert(th.equal(pred3.argmax(dim=1), pred4)) + assert(pred4_d.dim() == 1) + assert(is_int(pred4_d)) + assert(th.equal(pred3.argmax(dim=1), pred4_d)) @pytest.mark.parametrize("norm", [None, 'batch', 'layer']) def test_rgcn_node_prediction(norm): @@ -752,15 +800,31 @@ def check_edge_prediction(model, data): pred2[("n0", "r1", "n1")][0:len(pred2[("n0", "r1", "n1")])].numpy(), decimal=5) assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels2[("n0", "r1", "n1")].numpy()) + pred1_d, labels1_d = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + assert_almost_equal(pred1[("n0", "r1", "n1")][0:len(pred1[("n0", "r1", "n1")])].numpy(), + pred1_d[("n0", "r1", "n1")][0:len(pred1_d[("n0", "r1", "n1")])].numpy()) + assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels1_d[("n0", "r1", "n1")].numpy()) + + # Test the return_proba argument. - pred3, labels3 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) assert(th.is_floating_point(pred3[("n0", "r1", "n1")])) assert pred3[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true - pred4, labels4 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + + pred3_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + assert(th.is_floating_point(pred3_d[("n0", "r1", "n1")])) + assert pred3_d[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true + + pred4, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[("n0", "r1", "n1")])) assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4[("n0", "r1", "n1")])) + pred4_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False + assert(is_int(pred4_d[("n0", "r1", "n1")])) + assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4_d[("n0", "r1", "n1")])) + def check_mlp_edge_prediction(model, data): """ Check whether full graph inference and mini batch inference generate the same prediction result for GSgnnEdgeModel without GNN layers. @@ -793,15 +857,30 @@ def check_mlp_edge_prediction(model, data): pred2[("n0", "r1", "n1")][0:len(pred2[("n0", "r1", "n1")])].numpy(), decimal=5) assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels2[("n0", "r1", "n1")].numpy()) + pred1_d, labels1_d = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + assert_almost_equal(pred1[("n0", "r1", "n1")][0:len(pred1[("n0", "r1", "n1")])].numpy(), + pred1_d[("n0", "r1", "n1")][0:len(pred1_d[("n0", "r1", "n1")])].numpy()) + assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels1_d[("n0", "r1", "n1")].numpy()) + # Test the return_proba argument. - pred3, labels3 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) assert pred3[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[("n0", "r1", "n1")])) - pred4, labels4 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + + pred3_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + assert(th.is_floating_point(pred3_d[("n0", "r1", "n1")])) + assert pred3_d[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true + + pred4, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[("n0", "r1", "n1")])) assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4[("n0", "r1", "n1")])) + pred4_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False + assert(is_int(pred4_d[("n0", "r1", "n1")])) + assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4_d[("n0", "r1", "n1")])) + @pytest.mark.parametrize("num_ffn_layers", [0, 2]) def test_rgcn_edge_prediction(num_ffn_layers): """ Test edge prediction logic correctness with a edge prediction model From 2675b6dfa190f6a330f9888217ba962518c59cc4 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 17:12:37 -0700 Subject: [PATCH 6/8] Update docstr --- python/graphstorm/model/node_gnn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index c432f07f5d..526be5b4b6 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -332,8 +332,9 @@ def run_node_mini_batch_predict(decoder, emb, loader, device, Parameters ---------- - decoder : GSNodeDecoder - The GraphStorm node decoder + decoder : GSNodeDecoder or th.nn.ModuleDict + The GraphStorm node decoder. + It can be a GSNodeDecoder or a dict of GSNodeDecoders emb : dict of Tensor The GNN embeddings loader : GSgnnNodeDataLoader From 44c5357940bdbe3cb2523e43ce32c26083b37d25 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 17:36:37 -0700 Subject: [PATCH 7/8] update --- tests/unit-tests/test_gnn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index d2746fd264..4ac8a479bc 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -313,10 +313,10 @@ def require_cache_embed(self): # Test the return_proba argument. pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) - pred3_d, labels3_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred3_d, labels3_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) pred4, labels4 = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) - pred4_d, labels4_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + pred4_d, labels4_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) and len(labels3) == len(labels4) assert len(pred3) == len(pred3_d) and len(labels3) == len(labels3_d) @@ -445,7 +445,7 @@ def check_mlp_node_prediction(model, data): batch_size=10, label_field='label', node_feats='feat', train_task=False) pred2, _, labels2 = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) - pred1_d, labels1_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + pred1_d, labels1_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) if isinstance(pred1, dict): assert len(pred1) == len(pred2) and len(labels1) == len(labels2) assert len(pred1) == len(pred1_d) and len(labels1) == len(labels1_d) @@ -463,8 +463,8 @@ def check_mlp_node_prediction(model, data): # Test the return_proba argument. pred3, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) pred4, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) - pred3_d, _ = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) - pred4_d, _ = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + pred3_d, _ = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred4_d, _ = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) assert len(pred3) == len(pred3_d) From 958e2b9638a297b6259a86913afaa48fdbb63e3e Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Tue, 21 May 2024 11:25:49 -0700 Subject: [PATCH 8/8] Update --- python/graphstorm/model/edge_gnn.py | 6 +++--- python/graphstorm/model/lp_gnn.py | 5 ++++- python/graphstorm/model/node_gnn.py | 4 ++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index a2e023ba81..75eea6791c 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -323,13 +323,13 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= def run_edge_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): - """ Perform mini-batch prediction with the given decoder. + """ Perform mini-batch edge prediction with the given decoder. - This function usually follows full-grain GNN embedding inference. After having + This function usually follows full-graph GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. - Note: caller should call model.eval() before calling this function + Note: callers should call model.eval() before calling this function and call model.train() after when doing training. Parameters diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index 1e08443755..2ba2f9322d 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -137,6 +137,9 @@ def lp_mini_batch_predict(model, emb, loader, device): After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. + Note: callers should call model.eval() before calling this function + and call model.train() after when doing training. + Parameters ---------- model : GSgnnModel @@ -166,7 +169,7 @@ def run_lp_mini_batch_predict(decoder, emb, loader, device): After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. - Note: caller should call model.eval() before calling this function + Note: callers should call model.eval() before calling this function and call model.train() after when doing training. Parameters diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index 526be5b4b6..f1c79cd968 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -325,9 +325,9 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= def run_node_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): - """ Perform mini-batch prediction with the given decoder. + """ Perform mini-batch node prediction with the given decoder. - Note: caller should call model.eval() before calling this function + Note: callers should call model.eval() before calling this function and call model.train() after when doing training. Parameters