Skip to content

Commit

Permalink
[Multi-task learning] Support reconstruct node feature supervision (#863
Browse files Browse the repository at this point in the history
)

*Issue #, if available:*
#789
#862 

*Description of changes:*
Add a new task type `BUILTIN_TASK_RECONSTRUCT_NODE_FEAT =
"reconstruct_node_feat"`. User can use node feature reconstruction to
supervise model training in multi-task learning.

User can define a reconstruct_node_feat task as following:

```
...
  multi_task_learning:
    - reconstruct_node_feat:
        reconstruct_nfeat_name: "title"
        target_ntype: "movie"
        batch_size: 128
        mask_fields:
          - "train_mask_c0" # node classification mask 0
          - "val_mask_c0"
          - "test_mask_c0"
        task_weight: 1.0
        eval_metric:
          - "mse"
```
`reconstruct_node_feat` is the task name, `target_ntype` defines which
node type, the reconstruct node feature learning will be applied.
`reconstruct_nfeat_name` defines the name of the feature to be
re-construct. The other configs are same as node regression tasks.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Jun 5, 2024
1 parent aff7275 commit 5e189df
Show file tree
Hide file tree
Showing 21 changed files with 761 additions and 33 deletions.
3 changes: 2 additions & 1 deletion python/graphstorm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from .gsf import (create_builtin_node_decoder,
create_builtin_edge_decoder,
create_builtin_lp_decoder)
create_builtin_lp_decoder,
create_builtin_reconstruct_nfeat_decoder)
from .gsf import (get_builtin_lp_train_dataloader_class,
get_builtin_lp_eval_dataloader_class)
3 changes: 2 additions & 1 deletion python/graphstorm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION,
BUILTIN_TASK_COMPUTE_EMB)
BUILTIN_TASK_COMPUTE_EMB,
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
from .config import SUPPORTED_TASKS

from .config import BUILTIN_LP_DOT_DECODER
Expand Down
61 changes: 59 additions & 2 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .config import BUILTIN_TASK_EDGE_REGRESSION
from .config import (BUILTIN_TASK_LINK_PREDICTION,
LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL)
from .config import BUILTIN_TASK_RECONSTRUCT_NODE_FEAT
from .config import BUILTIN_GNN_NORM
from .config import EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY
from .config import EARLY_STOP_AVERAGE_INCREASE_STRATEGY
Expand Down Expand Up @@ -439,6 +440,39 @@ def _parse_link_prediction_task(self, task_config):
task_id=task_id,
task_config=task_info)

def _parse_reconstruct_node_feat(self, task_config):
""" Parse the reconstruct node feature task info
Parameters
----------
task_config: dict
Reconstruct node feature task config
"""
task_type = BUILTIN_TASK_RECONSTRUCT_NODE_FEAT
mask_fields, task_weight, batch_size = \
self._parse_general_task_config(task_config)
task_config["batch_size"] = batch_size

task_info = GSConfig.__new__(GSConfig)
task_info.set_task_attributes(task_config)
setattr(task_info, "_task_type", task_type)
task_info.verify_node_feat_reconstruct_arguments()

target_ntype = task_info.target_ntype
label_field = task_info.reconstruct_nfeat_name

task_id = get_mttask_id(task_type=task_type,
ntype=target_ntype,
label=label_field)
setattr(task_info, "train_mask", mask_fields[0])
setattr(task_info, "val_mask", mask_fields[1])
setattr(task_info, "test_mask", mask_fields[2])
setattr(task_info, "task_weight", task_weight)

return TaskInfo(task_type=task_type,
task_id=task_id,
task_config=task_info)

def _parse_multi_tasks(self, multi_task_config):
""" Parse multi-task configuration
Expand Down Expand Up @@ -500,6 +534,9 @@ def _parse_multi_tasks(self, multi_task_config):
elif "link_prediction" in task_config:
task = self._parse_link_prediction_task(
task_config["link_prediction"])
elif "reconstruct_node_feat" in task_config:
task = self._parse_reconstruct_node_feat(
task_config["reconstruct_node_feat"])
else:
raise ValueError(f"Invalid task type in multi-task learning {task_config}.")
tasks.append(task)
Expand Down Expand Up @@ -530,6 +567,14 @@ def override_arguments(self, cmd_args):
# for basic attributes
setattr(self, f"_{arg_key}", arg_val)

def verify_node_feat_reconstruct_arguments(self):
"""Verify the correctness of arguments for node feature reconstruction tasks.
"""
_ = self.target_ntype
_ = self.batch_size
_ = self.eval_metric
_ = self.reconstruct_nfeat_name

def verify_node_class_arguments(self):
""" Verify the correctness of arguments for node classification tasks.
"""
Expand Down Expand Up @@ -2545,7 +2590,7 @@ def eval_metric(self):
else:
eval_metric = ["accuracy"]
elif self.task_type in [BUILTIN_TASK_NODE_REGRESSION, \
BUILTIN_TASK_EDGE_REGRESSION]:
BUILTIN_TASK_EDGE_REGRESSION, BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]:
if hasattr(self, "_eval_metric"):
if isinstance(self._eval_metric, str):
eval_metric = self._eval_metric.lower()
Expand All @@ -2568,7 +2613,10 @@ def eval_metric(self):
"should be a string or a list of string"
# no eval_metric
else:
eval_metric = ["rmse"]
if self.task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT:
eval_metric = ["mse"]
else:
eval_metric = ["rmse"]
elif self.task_type == BUILTIN_TASK_LINK_PREDICTION:
if hasattr(self, "_eval_metric"):
if isinstance(self._eval_metric, str):
Expand Down Expand Up @@ -2650,6 +2698,15 @@ def num_ffn_layers_in_decoder(self):
# Set default mlp layer number between gnn layer to 0
return 0

################## Reconstruct node feats ###############
@property
def reconstruct_nfeat_name(self):
""" node feature name for reconstruction
"""
assert hasattr(self, "_reconstruct_nfeat_name"), \
"reconstruct_nfeat_name must be provided under reconstruct_node_feat task "
return self._reconstruct_nfeat_name

################## Multi task learning ##################
@property
def multi_tasks(self):
Expand Down
4 changes: 3 additions & 1 deletion python/graphstorm/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@
BUILTIN_TASK_EDGE_REGRESSION = "edge_regression"
BUILTIN_TASK_LINK_PREDICTION = "link_prediction"
BUILTIN_TASK_COMPUTE_EMB = "compute_emb"
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT = "reconstruct_node_feat"

LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL = "ALL"

SUPPORTED_TASKS = [BUILTIN_TASK_NODE_CLASSIFICATION, \
BUILTIN_TASK_NODE_REGRESSION, \
BUILTIN_TASK_EDGE_CLASSIFICATION, \
BUILTIN_TASK_LINK_PREDICTION, \
BUILTIN_TASK_EDGE_REGRESSION]
BUILTIN_TASK_EDGE_REGRESSION, \
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]

EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY = "consecutive_increase"
EARLY_STOP_AVERAGE_INCREASE_STRATEGY = "average_increase"
Expand Down
1 change: 1 addition & 0 deletions python/graphstorm/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@
GSgnnPerEtypeMrrLPEvaluator,
GSgnnClassificationEvaluator,
GSgnnRegressionEvaluator,
GSgnnRconstructFeatRegScoreEvaluator,
GSgnnMultiTaskEvaluator)
79 changes: 78 additions & 1 deletion python/graphstorm/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ def multilabel(self):
"""
return self._multilabel


class GSgnnRegressionEvaluator(GSgnnBaseEvaluator, GSgnnPredictionEvalInterface):
""" Regression Evaluator.
Expand Down Expand Up @@ -706,6 +705,84 @@ def compute_score(self, pred, labels, train=True):

return scores

class GSgnnRconstructFeatRegScoreEvaluator(GSgnnRegressionEvaluator):
""" Evaluator for feature reconstruction using regression scores.
We treat the prediction results as a 2D float tensor and
the label is also a 2D float tensor.
We compute mse or rmse for it.
Parameters
----------
eval_frequency: int
The frequency (number of iterations) of doing evaluation.
eval_metric_list: list of string
Evaluation metric used during evaluation. Default: ["mse"].
use_early_stop: bool
Set true to use early stop.
early_stop_burnin_rounds: int
Burn-in rounds before start checking for the early stop condition.
early_stop_rounds: int
The number of rounds for validation scores used to decide early stop.
early_stop_strategy: str
The early stop strategy. GraphStorm supports two strategies:
1) consecutive_increase and 2) average_increase.
"""
def __init__(self, eval_frequency,
eval_metric_list=None,
use_early_stop=False,
early_stop_burnin_rounds=0,
early_stop_rounds=3,
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY):
# set default metric list
if eval_metric_list is None:
eval_metric_list = ["mse"]

super(GSgnnRconstructFeatRegScoreEvaluator, self).__init__(
eval_frequency,
eval_metric_list,
use_early_stop,
early_stop_burnin_rounds,
early_stop_rounds,
early_stop_strategy)

def compute_score(self, pred, labels, train=True):
""" Compute evaluation score
Parameters
----------
pred:
Rediction result
labels:
Label
train: boolean
If in model training.
Returns
-------
Evaluation metric values: dict
"""
scores = {}
for metric in self.metric_list:
if pred is not None and labels is not None:
pred = pred.to(th.float32)
labels = labels.to(th.float32)

if train:
# training expects always a single number to be
# returned and has a different (potentially) evluation function
scores[metric] = self.metrics_obj.metric_function[metric](pred, labels)
else:
# validation or testing may have a different
# evaluation function, in our case the evaluation code
# may return a dictionary with the metric values for each metric
scores[metric] = self.metrics_obj.metric_eval_function[metric](pred, labels)
else:
# if the pred is None or the labels is None the metric can not me computed
scores[metric] = "N/A"

return scores

class GSgnnMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface):
""" Link Prediction Evaluator using "mrr" as metric.
Expand Down
40 changes: 39 additions & 1 deletion python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION)
BUILTIN_TASK_LINK_PREDICTION,
BUILTIN_TASK_RECONSTRUCT_NODE_FEAT)
from .config import BUILTIN_LP_DOT_DECODER
from .config import BUILTIN_LP_DISTMULT_DECODER
from .config import (BUILTIN_LP_LOSS_CROSS_ENTROPY,
Expand Down Expand Up @@ -243,6 +244,41 @@ def create_builtin_node_gnn_model(g, config, train_task):
"""
return create_builtin_node_model(g, config, train_task)

# pylint: disable=unused-argument
def create_builtin_reconstruct_nfeat_decoder(g, decoder_input_dim, config, train_task):
""" create builtin node feature reconstruction decoder
according to task config
Parameters
----------
g: DGLGraph
The graph data.
Note(xiang): Make it consistent with create_builtin_edge_decoder.
Reserved for future.
decoder_input_dim: int
Input dimension size of the decoder
config: GSConfig
Configurations
train_task : bool
Whether this model is used for training.
Returns
-------
decoder: The node task decoder(s)
loss_func: The loss function(s)
"""
dropout = config.dropout if train_task else 0
target_ntype = config.target_ntype
reconstruct_feat = config.reconstruct_nfeat_name
feat_dim = g.nodes[target_ntype].data[reconstruct_feat].shape[1]

decoder = EntityRegression(decoder_input_dim,
dropout=dropout,
out_dim=feat_dim)

loss_func = RegressionLossFunc()
return decoder, loss_func

# pylint: disable=unused-argument
def create_builtin_node_decoder(g, decoder_input_dim, config, train_task):
""" create builtin node decoder according to task config
Expand Down Expand Up @@ -869,5 +905,7 @@ def create_task_decoder(task_info, g, decoder_input_dim, train_task):
return create_builtin_edge_decoder(g, decoder_input_dim, task_info.task_config, train_task)
elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
return create_builtin_lp_decoder(g, decoder_input_dim, task_info.task_config, train_task)
elif task_info.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]:
return create_builtin_reconstruct_nfeat_decoder(g, decoder_input_dim, task_info.task_config, train_task)
else:
raise TypeError(f"Unknown task type {task_info.task_type}")
11 changes: 11 additions & 0 deletions python/graphstorm/model/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,17 @@ def get_lm_params(self):

return params

def has_sparse_params(self):
""" Return whether there are sparse parameters (learnable embeddings)
in the model.
Return
------
bool: True for there are sparse parameters
"""
return len(self._optimizer.sparse_opts) > 0


def get_sparse_params(self):
""" get the sparse parameters of the model.
Expand Down
20 changes: 20 additions & 0 deletions python/graphstorm/model/gnn_encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from functools import partial
import logging

import abc
import dgl
import torch as th
from torch import nn
Expand All @@ -28,6 +29,25 @@
from ..utils import get_rank, barrier, is_distributed, create_dist_tensor, is_wholegraph
from ..distributed import flush_data

class GSgnnGNNEncoderInterface:
""" The interface for builtin GraphStorm gnn encoder layer.
The interface defines two functions that are useful in multi-task learning.
Any GNN encoder that implements these two functions can work with
GraphStorm multi-task learning pipeline.
Note: We can define more functions when necessary.
"""
@abc.abstractmethod
def skip_last_selfloop(self):
""" Skip the self-loop of the last GNN layer.
"""

@abc.abstractmethod
def reset_last_selfloop(self):
""" Reset the self-loop setting of the last GNN layer.
"""

class GraphConvEncoder(GSLayer): # pylint: disable=abstract-method
r"""General encoder for graph data.
Expand Down
13 changes: 11 additions & 2 deletions python/graphstorm/model/hgt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from dgl.nn.functional import edge_softmax
from ..config import BUILDIN_GNN_BATCH_NORM, BUILDIN_GNN_LAYER_NORM, BUILTIN_GNN_NORM
from .ngnn_mlp import NGNNMLP
from .gnn_encoder_base import GraphConvEncoder
from .gnn_encoder_base import (GraphConvEncoder,
GSgnnGNNEncoderInterface)


class HGTLayer(nn.Module):
Expand Down Expand Up @@ -280,7 +281,7 @@ def forward(self, g, h):
return new_h


class HGTEncoder(GraphConvEncoder):
class HGTEncoder(GraphConvEncoder, GSgnnGNNEncoderInterface):
r"""Heterogenous graph transformer (HGT) encoder
The HGTEncoder employs several HGTLayers as its encoding mechanism.
Expand Down Expand Up @@ -375,6 +376,14 @@ def __init__(self,
dropout=dropout,
norm=norm))

def skip_last_selfloop(self):
# HGT does not have explicit self-loop
pass

def reset_last_selfloop(self):
# HGT does not have explicit self-loop
pass

def forward(self, blocks, h):
"""Forward computation
Expand Down
Loading

0 comments on commit 5e189df

Please sign in to comment.