Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Multi-task Learning] Refactor graphstorm.model for multi-task learning. #852

Merged
merged 8 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/graphstorm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@
from .gnn import do_full_graph_inference
from .gnn import do_mini_batch_inference
from .node_gnn import GSgnnNodeModel, GSgnnNodeModelBase, GSgnnNodeModelInterface
from .node_gnn import node_mini_batch_gnn_predict, node_mini_batch_predict
from .node_gnn import (node_mini_batch_gnn_predict,
node_mini_batch_predict,
run_node_mini_batch_predict)
from .edge_gnn import GSgnnEdgeModel, GSgnnEdgeModelBase, GSgnnEdgeModelInterface
from .edge_gnn import edge_mini_batch_gnn_predict, edge_mini_batch_predict
from .edge_gnn import (edge_mini_batch_gnn_predict,
edge_mini_batch_predict,
run_edge_mini_batch_predict)
from .lp_gnn import (GSgnnLinkPredictionModel,
GSgnnLinkPredictionModelBase,
GSgnnLinkPredictionModelInterface)
GSgnnLinkPredictionModelInterface,
run_lp_mini_batch_predict)
from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer
from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer
from .sage_encoder import SAGEEncoder, SAGEConv
Expand Down
43 changes: 42 additions & 1 deletion python/graphstorm/model/edge_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,48 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
model.eval()
decoder = model.decoder
device = model.device

preds, labels = run_edge_mini_batch_predict(decoder,
emb,
loader,
device,
return_proba,
return_label)
model.train()
return preds, labels

def run_edge_mini_batch_predict(decoder, emb, loader, device,
return_proba=True, return_label=False):
""" Perform mini-batch prediction with the given decoder.
classicsong marked this conversation as resolved.
Show resolved Hide resolved

This function usually follows full-grain GNN embedding inference. After having
classicsong marked this conversation as resolved.
Show resolved Hide resolved
the GNN embeddings, we need to perform mini-batch computation to make predictions
on the GNN embeddings.

Note: caller should call model.eval() before calling this function
classicsong marked this conversation as resolved.
Show resolved Hide resolved
and call model.train() after when doing training.

Parameters
----------
decoder : GSEdgeDecoder
The GraphStorm edge decoder
emb : dict of Tensor
The GNN embeddings
loader : GSgnnEdgeDataLoader
The GraphStorm dataloader
device: th.device
Device used to compute prediction result
return_proba: bool
Whether to return all the predictions or the maximum prediction
return_label : bool
Whether or not to return labels

Returns
-------
dict of Tensor : GNN prediction results. Return all the results when return_proba is true
otherwise return the maximum result.
dict of Tensor : labels if return_labels is True
"""
data = loader.data
g = data.g
preds = {}
Expand Down Expand Up @@ -379,7 +421,6 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
append_to_dict(lbl, labels)
barrier()

model.train()
for target_etype, pred in preds.items():
preds[target_etype] = th.cat(pred)
if return_label:
Expand Down
33 changes: 32 additions & 1 deletion python/graphstorm/model/lp_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self, blocks, pos_graph,
def lp_mini_batch_predict(model, emb, loader, device):
""" Perform mini-batch prediction.

This function follows full-grain GNN embedding inference.
This function follows full-graph GNN embedding inference.
After having the GNN embeddings, we need to perform mini-batch
computation to make predictions on the GNN embeddings.

Expand All @@ -154,6 +154,37 @@ def lp_mini_batch_predict(model, emb, loader, device):
Rankings of positive scores in format of {etype: ranking}
"""
decoder = model.decoder
return run_lp_mini_batch_predict(decoder,
emb,
loader,
device)

def run_lp_mini_batch_predict(decoder, emb, loader, device):
""" Perform mini-batch link prediction with the given decoder.

This function follows full-graph GNN embedding inference.
After having the GNN embeddings, we need to perform mini-batch
computation to make predictions on the GNN embeddings.

Note: caller should call model.eval() before calling this function
zhjwy9343 marked this conversation as resolved.
Show resolved Hide resolved
classicsong marked this conversation as resolved.
Show resolved Hide resolved
and call model.train() after when doing training.

Parameters
----------
decoder : LinkPredictNoParamDecoder or LinkPredictLearnableDecoder
The GraphStorm link prediction decoder model
emb : dict of Tensor
The GNN embeddings
loader : GSgnnEdgeDataLoader
The GraphStorm dataloader
device: th.device
Device used to compute test scores

Returns
-------
rankings: dict of tensors
Rankings of positive scores in format of {etype: ranking}
"""
with th.no_grad():
ranking = {}
for pos_neg_tuple, neg_sample_type in loader:
Expand Down
53 changes: 46 additions & 7 deletions python/graphstorm/model/node_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,48 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
Labels if return_labels is True
"""
device = model.device
decoder = model.decoder
model.eval()
preds, labels = \
run_node_mini_batch_predict(decoder,
emb,
loader,
device,
return_proba,
return_label)
model.train()
return preds, labels

def run_node_mini_batch_predict(decoder, emb, loader, device,
return_proba=True, return_label=False):
""" Perform mini-batch prediction with the given decoder.
classicsong marked this conversation as resolved.
Show resolved Hide resolved

Note: caller should call model.eval() before calling this function
classicsong marked this conversation as resolved.
Show resolved Hide resolved
and call model.train() after when doing training.

Parameters
----------
decoder : GSNodeDecoder or th.nn.ModuleDict
The GraphStorm node decoder.
It can be a GSNodeDecoder or a dict of GSNodeDecoders
emb : dict of Tensor
The GNN embeddings
loader : GSgnnNodeDataLoader
The GraphStorm dataloader
device: th.device
Device used to compute prediction result
return_proba : bool
Whether or not to return all the predictions or the maximum prediction
return_label : bool
Whether or not to return labels.

Returns
-------
dict of Tensor :
Prediction results.
dict of Tensor :
Labels if return_labels is True
"""
data = loader.data

if return_label:
Expand All @@ -321,15 +363,13 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
preds = {}
labels = {}
# TODO(zhengda) I need to check if the data loader only returns target nodes.
model.eval()
with th.no_grad():
for _, seeds, _ in loader: # seeds are target nodes
for ntype, seed_nodes in seeds.items():
if isinstance(model.decoder, th.nn.ModuleDict):
assert ntype in model.decoder, f"Node type {ntype} not in decoder"
decoder = model.decoder[ntype]
else:
decoder = model.decoder
if isinstance(decoder, th.nn.ModuleDict):
assert ntype in decoder, f"Node type {ntype} not in decoder"
decoder = decoder[ntype]

if return_proba:
pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device))
else:
Expand All @@ -345,7 +385,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label=
labels[ntype].append(lbl[ntype])
else:
labels[ntype] = [lbl[ntype]]
model.train()

for ntype, ntype_pred in preds.items():
preds[ntype] = th.cat(ntype_pred)
Expand Down
Loading
Loading