diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 08a8391f56..18a741e200 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -24,12 +24,17 @@ from .gnn import do_full_graph_inference from .gnn import do_mini_batch_inference from .node_gnn import GSgnnNodeModel, GSgnnNodeModelBase, GSgnnNodeModelInterface -from .node_gnn import node_mini_batch_gnn_predict, node_mini_batch_predict +from .node_gnn import (node_mini_batch_gnn_predict, + node_mini_batch_predict, + run_node_mini_batch_predict) from .edge_gnn import GSgnnEdgeModel, GSgnnEdgeModelBase, GSgnnEdgeModelInterface -from .edge_gnn import edge_mini_batch_gnn_predict, edge_mini_batch_predict +from .edge_gnn import (edge_mini_batch_gnn_predict, + edge_mini_batch_predict, + run_edge_mini_batch_predict) from .lp_gnn import (GSgnnLinkPredictionModel, GSgnnLinkPredictionModelBase, - GSgnnLinkPredictionModelInterface) + GSgnnLinkPredictionModelInterface, + run_lp_mini_batch_predict) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 536e61f311..75eea6791c 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -311,6 +311,48 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= model.eval() decoder = model.decoder device = model.device + + preds, labels = run_edge_mini_batch_predict(decoder, + emb, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_edge_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch edge prediction with the given decoder. + + This function usually follows full-graph GNN embedding inference. After having + the GNN embeddings, we need to perform mini-batch computation to make predictions + on the GNN embeddings. + + Note: callers should call model.eval() before calling this function + and call model.train() after when doing training. + + Parameters + ---------- + decoder : GSEdgeDecoder + The GraphStorm edge decoder + emb : dict of Tensor + The GNN embeddings + loader : GSgnnEdgeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba: bool + Whether to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels + + Returns + ------- + dict of Tensor : GNN prediction results. Return all the results when return_proba is true + otherwise return the maximum result. + dict of Tensor : labels if return_labels is True + """ data = loader.data g = data.g preds = {} @@ -379,7 +421,6 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= append_to_dict(lbl, labels) barrier() - model.train() for target_etype, pred in preds.items(): preds[target_etype] = th.cat(pred) if return_label: diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index 91c2c3317c..2ba2f9322d 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -133,10 +133,13 @@ def forward(self, blocks, pos_graph, def lp_mini_batch_predict(model, emb, loader, device): """ Perform mini-batch prediction. - This function follows full-grain GNN embedding inference. + This function follows full-graph GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. + Note: callers should call model.eval() before calling this function + and call model.train() after when doing training. + Parameters ---------- model : GSgnnModel @@ -154,6 +157,37 @@ def lp_mini_batch_predict(model, emb, loader, device): Rankings of positive scores in format of {etype: ranking} """ decoder = model.decoder + return run_lp_mini_batch_predict(decoder, + emb, + loader, + device) + +def run_lp_mini_batch_predict(decoder, emb, loader, device): + """ Perform mini-batch link prediction with the given decoder. + + This function follows full-graph GNN embedding inference. + After having the GNN embeddings, we need to perform mini-batch + computation to make predictions on the GNN embeddings. + + Note: callers should call model.eval() before calling this function + and call model.train() after when doing training. + + Parameters + ---------- + decoder : LinkPredictNoParamDecoder or LinkPredictLearnableDecoder + The GraphStorm link prediction decoder model + emb : dict of Tensor + The GNN embeddings + loader : GSgnnEdgeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute test scores + + Returns + ------- + rankings: dict of tensors + Rankings of positive scores in format of {etype: ranking} + """ with th.no_grad(): ranking = {} for pos_neg_tuple, neg_sample_type in loader: diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index fca05ada24..f1c79cd968 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -311,6 +311,48 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= Labels if return_labels is True """ device = model.device + decoder = model.decoder + model.eval() + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_node_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch node prediction with the given decoder. + + Note: callers should call model.eval() before calling this function + and call model.train() after when doing training. + + Parameters + ---------- + decoder : GSNodeDecoder or th.nn.ModuleDict + The GraphStorm node decoder. + It can be a GSNodeDecoder or a dict of GSNodeDecoders + emb : dict of Tensor + The GNN embeddings + loader : GSgnnNodeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba : bool + Whether or not to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels. + + Returns + ------- + dict of Tensor : + Prediction results. + dict of Tensor : + Labels if return_labels is True + """ data = loader.data if return_label: @@ -321,15 +363,13 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= preds = {} labels = {} # TODO(zhengda) I need to check if the data loader only returns target nodes. - model.eval() with th.no_grad(): for _, seeds, _ in loader: # seeds are target nodes for ntype, seed_nodes in seeds.items(): - if isinstance(model.decoder, th.nn.ModuleDict): - assert ntype in model.decoder, f"Node type {ntype} not in decoder" - decoder = model.decoder[ntype] - else: - decoder = model.decoder + if isinstance(decoder, th.nn.ModuleDict): + assert ntype in decoder, f"Node type {ntype} not in decoder" + decoder = decoder[ntype] + if return_proba: pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device)) else: @@ -345,7 +385,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= labels[ntype].append(lbl[ntype]) else: labels[ntype] = [lbl[ntype]] - model.train() for ntype, ntype_pred in preds.items(): preds[ntype] = th.cat(ntype_pred) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 399015a2ae..4ac8a479bc 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -60,9 +60,13 @@ from graphstorm import get_node_feat_size from graphstorm.gsf import get_rel_names_for_reconstruct from graphstorm.model import do_full_graph_inference, do_mini_batch_inference -from graphstorm.model.node_gnn import node_mini_batch_predict, node_mini_batch_gnn_predict +from graphstorm.model.node_gnn import (node_mini_batch_predict, + run_node_mini_batch_predict, + node_mini_batch_gnn_predict) from graphstorm.model.node_gnn import GSgnnNodeModelInterface -from graphstorm.model.edge_gnn import edge_mini_batch_predict, edge_mini_batch_gnn_predict +from graphstorm.model.edge_gnn import (edge_mini_batch_predict, + run_edge_mini_batch_predict, + edge_mini_batch_gnn_predict) from graphstorm.model.gnn_with_reconstruct import construct_node_feat, get_input_embeds_combined from graphstorm.model.utils import load_model, save_model @@ -279,9 +283,13 @@ def require_cache_embed(self): pred2_gnn_pred, _, labels2_gnn_pred, = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) # Call last layer mini-batch inference with the GNN dataloader pred2_pred, labels2_pred = node_mini_batch_predict(model, embs, dataloader2, return_label=True) + + pred2_d_pred, labels2_d_pred = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + if isinstance(pred1,dict): assert len(pred1) == len(pred2_gnn_pred) and len(labels1) == len(labels2_gnn_pred) assert len(pred1) == len(pred2_pred) and len(labels1) == len(labels2_pred) + assert len(pred1) == len(pred2_d_pred) and len(labels1) == len(labels2_gnn_pred) for ntype in pred1: assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2_gnn_pred[ntype][0:len(pred2_gnn_pred)].numpy(), decimal=5) @@ -289,6 +297,9 @@ def require_cache_embed(self): assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2_pred[ntype][0:len(pred2_pred)].numpy(), decimal=5) assert_equal(labels1[ntype].numpy(), labels2_pred[ntype].numpy()) + assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), + pred2_d_pred[ntype][0:len(pred2_d_pred)].numpy()) + assert_equal(labels1[ntype].numpy(), labels2_d_pred[ntype].numpy()) else: assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2_gnn_pred[0:len(pred2_gnn_pred)].numpy(), decimal=5) @@ -296,24 +307,42 @@ def require_cache_embed(self): assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2_pred[0:len(pred2_pred)].numpy(), decimal=5) assert_equal(labels1.numpy(), labels2_pred.numpy()) + assert_almost_equal(pred1[0:len(pred1)].numpy(), + labels2_d_pred[0:len(labels2_d_pred)].numpy()) + assert_equal(labels1.numpy(), labels2_d_pred.numpy()) # Test the return_proba argument. pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3_d, labels3_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred4, labels4 = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred4_d, labels4_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) and len(labels3) == len(labels4) + assert len(pred3) == len(pred3_d) and len(labels3) == len(labels3_d) + assert len(pred4) == len(pred4_d) and len(labels4) == len(labels4_d) for key in pred3: assert pred3[key].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[key])) + assert pred3_d[key].dim() == 2 + assert(th.is_floating_point(pred3_d[key])) assert(pred4[key].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[key])) assert(th.equal(pred3[key].argmax(dim=1), pred4[key])) + assert(pred4_d[key].dim() == 1) + assert(is_int(pred4_d[key])) + assert(th.equal(pred3[key].argmax(dim=1), pred4_d[key])) else: assert pred3.dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3)) + assert pred3_d.dim() == 2 + assert(th.is_floating_point(pred3_d)) assert(pred4.dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4)) assert(th.equal(pred3.argmax(dim=1), pred4)) + assert(labels4_d.dim() == 1) + assert(is_int(labels4_d)) + assert(th.equal(pred3.argmax(dim=1), labels4_d)) def check_node_prediction_with_reconstruct(model, data, construct_feat_ntype, train_ntypes, node_feat_field=None): """ Check whether full graph inference and mini batch inference generate the same @@ -416,32 +445,51 @@ def check_mlp_node_prediction(model, data): batch_size=10, label_field='label', node_feats='feat', train_task=False) pred2, _, labels2 = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) + pred1_d, labels1_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) if isinstance(pred1, dict): assert len(pred1) == len(pred2) and len(labels1) == len(labels2) + assert len(pred1) == len(pred1_d) and len(labels1) == len(labels1_d) for ntype in pred1: assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2[ntype][0:len(pred2)].numpy(), decimal=5) assert_equal(labels1[ntype].numpy(), labels2[ntype].numpy()) + assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred1_d[ntype][0:len(pred1_d)].numpy()) + assert_equal(labels1[ntype].numpy(), labels1_d[ntype].numpy()) else: assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2[0:len(pred2)].numpy(), decimal=5) assert_equal(labels1.numpy(), labels2.numpy()) + assert_almost_equal(pred1[0:len(pred1)].numpy(), pred1_d[0:len(pred1_d)].numpy()) + assert_equal(labels1.numpy(), labels1_d.numpy()) # Test the return_proba argument. - pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) - pred4, labels4 = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred3, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred4, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred3_d, _ = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred4_d, _ = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) + assert len(pred3) == len(pred3_d) + assert len(pred4) == len(pred4_d) for ntype in pred3: assert pred3[ntype].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[ntype])) + assert pred3_d[ntype].dim() == 2 + assert(th.is_floating_point(pred3_d[ntype])) assert(pred4[ntype].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[ntype])) assert(th.equal(pred3[ntype].argmax(dim=1), pred4[ntype])) + assert(is_int(pred4_d[ntype])) + assert(th.equal(pred3[ntype].argmax(dim=1), pred4_d[ntype])) else: assert pred3.dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3)) + assert pred3_d.dim() == 2 + assert(th.is_floating_point(pred3_d)) assert(pred4.dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4)) assert(th.equal(pred3.argmax(dim=1), pred4)) + assert(pred4_d.dim() == 1) + assert(is_int(pred4_d)) + assert(th.equal(pred3.argmax(dim=1), pred4_d)) @pytest.mark.parametrize("norm", [None, 'batch', 'layer']) def test_rgcn_node_prediction(norm): @@ -752,15 +800,31 @@ def check_edge_prediction(model, data): pred2[("n0", "r1", "n1")][0:len(pred2[("n0", "r1", "n1")])].numpy(), decimal=5) assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels2[("n0", "r1", "n1")].numpy()) + pred1_d, labels1_d = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + assert_almost_equal(pred1[("n0", "r1", "n1")][0:len(pred1[("n0", "r1", "n1")])].numpy(), + pred1_d[("n0", "r1", "n1")][0:len(pred1_d[("n0", "r1", "n1")])].numpy()) + assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels1_d[("n0", "r1", "n1")].numpy()) + + # Test the return_proba argument. - pred3, labels3 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) assert(th.is_floating_point(pred3[("n0", "r1", "n1")])) assert pred3[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true - pred4, labels4 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + + pred3_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + assert(th.is_floating_point(pred3_d[("n0", "r1", "n1")])) + assert pred3_d[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true + + pred4, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[("n0", "r1", "n1")])) assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4[("n0", "r1", "n1")])) + pred4_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False + assert(is_int(pred4_d[("n0", "r1", "n1")])) + assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4_d[("n0", "r1", "n1")])) + def check_mlp_edge_prediction(model, data): """ Check whether full graph inference and mini batch inference generate the same prediction result for GSgnnEdgeModel without GNN layers. @@ -793,15 +857,30 @@ def check_mlp_edge_prediction(model, data): pred2[("n0", "r1", "n1")][0:len(pred2[("n0", "r1", "n1")])].numpy(), decimal=5) assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels2[("n0", "r1", "n1")].numpy()) + pred1_d, labels1_d = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + assert_almost_equal(pred1[("n0", "r1", "n1")][0:len(pred1[("n0", "r1", "n1")])].numpy(), + pred1_d[("n0", "r1", "n1")][0:len(pred1_d[("n0", "r1", "n1")])].numpy()) + assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels1_d[("n0", "r1", "n1")].numpy()) + # Test the return_proba argument. - pred3, labels3 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) assert pred3[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[("n0", "r1", "n1")])) - pred4, labels4 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + + pred3_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + assert(th.is_floating_point(pred3_d[("n0", "r1", "n1")])) + assert pred3_d[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true + + pred4, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[("n0", "r1", "n1")])) assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4[("n0", "r1", "n1")])) + pred4_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False + assert(is_int(pred4_d[("n0", "r1", "n1")])) + assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4_d[("n0", "r1", "n1")])) + @pytest.mark.parametrize("num_ffn_layers", [0, 2]) def test_rgcn_edge_prediction(num_ffn_layers): """ Test edge prediction logic correctness with a edge prediction model