Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed May 20, 2024
1 parent d0b37b4 commit 46da6ca
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 5 deletions.
5 changes: 4 additions & 1 deletion python/graphstorm/model/edge_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,15 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label=

def run_edge_mini_batch_predict(decoder, emb, loader, device,
return_proba=True, return_label=False):
""" Perform mini-batch prediction using edge decoder
""" Perform mini-batch prediction with the given decoder.
This function usually follows full-grain GNN embedding inference. After having
the GNN embeddings, we need to perform mini-batch computation to make predictions
on the GNN embeddings.
Note: caller should call model.eval() before calling this function
and call model.train() after when doing training.
Parameters
----------
decoder : GSEdgeDecoder
Expand Down
11 changes: 9 additions & 2 deletions python/graphstorm/model/lp_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self, blocks, pos_graph,
def lp_mini_batch_predict(model, emb, loader, device):
""" Perform mini-batch prediction.
This function follows full-grain GNN embedding inference.
This function follows full-graph GNN embedding inference.
After having the GNN embeddings, we need to perform mini-batch
computation to make predictions on the GNN embeddings.
Expand All @@ -160,7 +160,14 @@ def lp_mini_batch_predict(model, emb, loader, device):
device)

def run_lp_mini_batch_predict(decoder, emb, loader, device):
""" Perform mini-batch link prediction.
""" Perform mini-batch link prediction with the given decoder.
This function follows full-graph GNN embedding inference.
After having the GNN embeddings, we need to perform mini-batch
computation to make predictions on the GNN embeddings.
Note: caller should call model.eval() before calling this function
and call model.train() after when doing training.
Parameters
----------
Expand Down
43 changes: 41 additions & 2 deletions python/graphstorm/model/node_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,47 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
Labels if return_labels is True
"""
device = model.device
decoder = model.decoder
model.eval()
preds, labels = \
run_node_mini_batch_predict(decoder,
emb,
loader,
device,
return_proba,
return_label)
model.train()
return preds, labels

def run_node_mini_batch_predict(decoder, emb, loader, device,
return_proba=True, return_label=False):
""" Perform mini-batch prediction with the given decoder.
Note: caller should call model.eval() before calling this function
and call model.train() after when doing training.
Parameters
----------
decoder : GSNodeDecoder
The GraphStorm node decoder
emb : dict of Tensor
The GNN embeddings
loader : GSgnnNodeDataLoader
The GraphStorm dataloader
device: th.device
Device used to compute prediction result
return_proba : bool
Whether or not to return all the predictions or the maximum prediction
return_label : bool
Whether or not to return labels.
Returns
-------
dict of Tensor :
Prediction results.
dict of Tensor :
Labels if return_labels is True
"""
data = loader.data

if return_label:
Expand All @@ -321,7 +362,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
preds = {}
labels = {}
# TODO(zhengda) I need to check if the data loader only returns target nodes.
model.eval()
with th.no_grad():
for _, seeds, _ in loader: # seeds are target nodes
for ntype, seed_nodes in seeds.items():
Expand All @@ -345,7 +385,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
labels[ntype].append(lbl[ntype])
else:
labels[ntype] = [lbl[ntype]]
model.train()

for ntype, ntype_pred in preds.items():
preds[ntype] = th.cat(ntype_pred)
Expand Down

0 comments on commit 46da6ca

Please sign in to comment.