Skip to content

Commit

Permalink
[Multi-task Learning] Refactor graphstorm.model for multi-task learni…
Browse files Browse the repository at this point in the history
…ng. (#852)

*Issue #, if available:*
#789 

*Description of changes:*
As multi-task learning trainer will invoke edge_mini_batch_predict,
lp_mini_batch_predict and node_mini_batch_predict when conducting
evaluation or testing, refactor the code to allow the functions to work
with different decoders.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored May 21, 2024
1 parent 5db0f74 commit 78f3458
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 20 deletions.
11 changes: 8 additions & 3 deletions python/graphstorm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@
from .gnn import do_full_graph_inference
from .gnn import do_mini_batch_inference
from .node_gnn import GSgnnNodeModel, GSgnnNodeModelBase, GSgnnNodeModelInterface
from .node_gnn import node_mini_batch_gnn_predict, node_mini_batch_predict
from .node_gnn import (node_mini_batch_gnn_predict,
node_mini_batch_predict,
run_node_mini_batch_predict)
from .edge_gnn import GSgnnEdgeModel, GSgnnEdgeModelBase, GSgnnEdgeModelInterface
from .edge_gnn import edge_mini_batch_gnn_predict, edge_mini_batch_predict
from .edge_gnn import (edge_mini_batch_gnn_predict,
edge_mini_batch_predict,
run_edge_mini_batch_predict)
from .lp_gnn import (GSgnnLinkPredictionModel,
GSgnnLinkPredictionModelBase,
GSgnnLinkPredictionModelInterface)
GSgnnLinkPredictionModelInterface,
run_lp_mini_batch_predict)
from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer
from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer
from .sage_encoder import SAGEEncoder, SAGEConv
Expand Down
43 changes: 42 additions & 1 deletion python/graphstorm/model/edge_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,48 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
model.eval()
decoder = model.decoder
device = model.device

preds, labels = run_edge_mini_batch_predict(decoder,
emb,
loader,
device,
return_proba,
return_label)
model.train()
return preds, labels

def run_edge_mini_batch_predict(decoder, emb, loader, device,
return_proba=True, return_label=False):
""" Perform mini-batch edge prediction with the given decoder.
This function usually follows full-graph GNN embedding inference. After having
the GNN embeddings, we need to perform mini-batch computation to make predictions
on the GNN embeddings.
Note: callers should call model.eval() before calling this function
and call model.train() after when doing training.
Parameters
----------
decoder : GSEdgeDecoder
The GraphStorm edge decoder
emb : dict of Tensor
The GNN embeddings
loader : GSgnnEdgeDataLoader
The GraphStorm dataloader
device: th.device
Device used to compute prediction result
return_proba: bool
Whether to return all the predictions or the maximum prediction
return_label : bool
Whether or not to return labels
Returns
-------
dict of Tensor : GNN prediction results. Return all the results when return_proba is true
otherwise return the maximum result.
dict of Tensor : labels if return_labels is True
"""
data = loader.data
g = data.g
preds = {}
Expand Down Expand Up @@ -379,7 +421,6 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
append_to_dict(lbl, labels)
barrier()

model.train()
for target_etype, pred in preds.items():
preds[target_etype] = th.cat(pred)
if return_label:
Expand Down
36 changes: 35 additions & 1 deletion python/graphstorm/model/lp_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,13 @@ def forward(self, blocks, pos_graph,
def lp_mini_batch_predict(model, emb, loader, device):
""" Perform mini-batch prediction.
This function follows full-grain GNN embedding inference.
This function follows full-graph GNN embedding inference.
After having the GNN embeddings, we need to perform mini-batch
computation to make predictions on the GNN embeddings.
Note: callers should call model.eval() before calling this function
and call model.train() after when doing training.
Parameters
----------
model : GSgnnModel
Expand All @@ -154,6 +157,37 @@ def lp_mini_batch_predict(model, emb, loader, device):
Rankings of positive scores in format of {etype: ranking}
"""
decoder = model.decoder
return run_lp_mini_batch_predict(decoder,
emb,
loader,
device)

def run_lp_mini_batch_predict(decoder, emb, loader, device):
""" Perform mini-batch link prediction with the given decoder.
This function follows full-graph GNN embedding inference.
After having the GNN embeddings, we need to perform mini-batch
computation to make predictions on the GNN embeddings.
Note: callers should call model.eval() before calling this function
and call model.train() after when doing training.
Parameters
----------
decoder : LinkPredictNoParamDecoder or LinkPredictLearnableDecoder
The GraphStorm link prediction decoder model
emb : dict of Tensor
The GNN embeddings
loader : GSgnnEdgeDataLoader
The GraphStorm dataloader
device: th.device
Device used to compute test scores
Returns
-------
rankings: dict of tensors
Rankings of positive scores in format of {etype: ranking}
"""
with th.no_grad():
ranking = {}
for pos_neg_tuple, neg_sample_type in loader:
Expand Down
53 changes: 46 additions & 7 deletions python/graphstorm/model/node_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,48 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
Labels if return_labels is True
"""
device = model.device
decoder = model.decoder
model.eval()
preds, labels = \
run_node_mini_batch_predict(decoder,
emb,
loader,
device,
return_proba,
return_label)
model.train()
return preds, labels

def run_node_mini_batch_predict(decoder, emb, loader, device,
return_proba=True, return_label=False):
""" Perform mini-batch node prediction with the given decoder.
Note: callers should call model.eval() before calling this function
and call model.train() after when doing training.
Parameters
----------
decoder : GSNodeDecoder or th.nn.ModuleDict
The GraphStorm node decoder.
It can be a GSNodeDecoder or a dict of GSNodeDecoders
emb : dict of Tensor
The GNN embeddings
loader : GSgnnNodeDataLoader
The GraphStorm dataloader
device: th.device
Device used to compute prediction result
return_proba : bool
Whether or not to return all the predictions or the maximum prediction
return_label : bool
Whether or not to return labels.
Returns
-------
dict of Tensor :
Prediction results.
dict of Tensor :
Labels if return_labels is True
"""
data = loader.data

if return_label:
Expand All @@ -321,15 +363,13 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
preds = {}
labels = {}
# TODO(zhengda) I need to check if the data loader only returns target nodes.
model.eval()
with th.no_grad():
for _, seeds, _ in loader: # seeds are target nodes
for ntype, seed_nodes in seeds.items():
if isinstance(model.decoder, th.nn.ModuleDict):
assert ntype in model.decoder, f"Node type {ntype} not in decoder"
decoder = model.decoder[ntype]
else:
decoder = model.decoder
if isinstance(decoder, th.nn.ModuleDict):
assert ntype in decoder, f"Node type {ntype} not in decoder"
decoder = decoder[ntype]

if return_proba:
pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device))
else:
Expand All @@ -345,7 +385,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
labels[ntype].append(lbl[ntype])
else:
labels[ntype] = [lbl[ntype]]
model.train()

for ntype, ntype_pred in preds.items():
preds[ntype] = th.cat(ntype_pred)
Expand Down
Loading

0 comments on commit 78f3458

Please sign in to comment.