diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 0a4b6a38e5..7525e2428f 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -322,12 +322,15 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= def run_edge_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): - """ Perform mini-batch prediction using edge decoder + """ Perform mini-batch prediction with the given decoder. This function usually follows full-grain GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. + Parameters ---------- decoder : GSEdgeDecoder diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index ebf8449a43..1e08443755 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -133,7 +133,7 @@ def forward(self, blocks, pos_graph, def lp_mini_batch_predict(model, emb, loader, device): """ Perform mini-batch prediction. - This function follows full-grain GNN embedding inference. + This function follows full-graph GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. @@ -160,7 +160,14 @@ def lp_mini_batch_predict(model, emb, loader, device): device) def run_lp_mini_batch_predict(decoder, emb, loader, device): - """ Perform mini-batch link prediction. + """ Perform mini-batch link prediction with the given decoder. + + This function follows full-graph GNN embedding inference. + After having the GNN embeddings, we need to perform mini-batch + computation to make predictions on the GNN embeddings. + + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. Parameters ---------- diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index fca05ada24..442582bf4f 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -311,6 +311,47 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= Labels if return_labels is True """ device = model.device + decoder = model.decoder + model.eval() + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_node_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch prediction with the given decoder. + + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. + + Parameters + ---------- + decoder : GSNodeDecoder + The GraphStorm node decoder + emb : dict of Tensor + The GNN embeddings + loader : GSgnnNodeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba : bool + Whether or not to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels. + + Returns + ------- + dict of Tensor : + Prediction results. + dict of Tensor : + Labels if return_labels is True + """ data = loader.data if return_label: @@ -321,7 +362,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= preds = {} labels = {} # TODO(zhengda) I need to check if the data loader only returns target nodes. - model.eval() with th.no_grad(): for _, seeds, _ in loader: # seeds are target nodes for ntype, seed_nodes in seeds.items(): @@ -345,7 +385,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= labels[ntype].append(lbl[ntype]) else: labels[ntype] = [lbl[ntype]] - model.train() for ntype, ntype_pred in preds.items(): preds[ntype] = th.cat(ntype_pred)