diff --git a/python/graphstorm/__init__.py b/python/graphstorm/__init__.py index 29b747104a..ce94f8de9d 100644 --- a/python/graphstorm/__init__.py +++ b/python/graphstorm/__init__.py @@ -32,6 +32,7 @@ from .gsf import (create_builtin_node_decoder, create_builtin_edge_decoder, - create_builtin_lp_decoder) + create_builtin_lp_decoder, + create_builtin_reconstruct_nfeat_decoder) from .gsf import (get_builtin_lp_train_dataloader_class, get_builtin_lp_eval_dataloader_class) diff --git a/python/graphstorm/config/__init__.py b/python/graphstorm/config/__init__.py index f14a3cd1b5..d04796afed 100644 --- a/python/graphstorm/config/__init__.py +++ b/python/graphstorm/config/__init__.py @@ -24,7 +24,8 @@ BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, BUILTIN_TASK_LINK_PREDICTION, - BUILTIN_TASK_COMPUTE_EMB) + BUILTIN_TASK_COMPUTE_EMB, + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) from .config import SUPPORTED_TASKS from .config import BUILTIN_LP_DOT_DECODER diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 3203220d84..95be1a935f 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -40,6 +40,7 @@ from .config import BUILTIN_TASK_EDGE_REGRESSION from .config import (BUILTIN_TASK_LINK_PREDICTION, LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL) +from .config import BUILTIN_TASK_RECONSTRUCT_NODE_FEAT from .config import BUILTIN_GNN_NORM from .config import EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY from .config import EARLY_STOP_AVERAGE_INCREASE_STRATEGY @@ -171,6 +172,8 @@ def __init__(self, cmd_args): # parse multi task learning config and save it into self._multi_tasks if multi_task_config is not None: self._parse_multi_tasks(multi_task_config) + else: + self._multi_tasks = None def set_attributes(self, configuration): """Set class attributes from 2nd level arguments in yaml config""" @@ -439,6 +442,39 @@ def _parse_link_prediction_task(self, task_config): task_id=task_id, task_config=task_info) + def _parse_reconstruct_node_feat(self, task_config): + """ Parse the reconstruct node feature task info + + Parameters + ---------- + task_config: dict + Reconstruct node feature task config + """ + task_type = BUILTIN_TASK_RECONSTRUCT_NODE_FEAT + mask_fields, task_weight, batch_size = \ + self._parse_general_task_config(task_config) + task_config["batch_size"] = batch_size + + task_info = GSConfig.__new__(GSConfig) + task_info.set_task_attributes(task_config) + setattr(task_info, "_task_type", task_type) + task_info.verify_node_feat_reconstruct_arguments() + + target_ntype = task_info.target_ntype + label_field = task_info.reconstruct_nfeat_name + + task_id = get_mttask_id(task_type=task_type, + ntype=target_ntype, + label=label_field) + setattr(task_info, "train_mask", mask_fields[0]) + setattr(task_info, "val_mask", mask_fields[1]) + setattr(task_info, "test_mask", mask_fields[2]) + setattr(task_info, "task_weight", task_weight) + + return TaskInfo(task_type=task_type, + task_id=task_id, + task_config=task_info) + def _parse_multi_tasks(self, multi_task_config): """ Parse multi-task configuration @@ -500,6 +536,9 @@ def _parse_multi_tasks(self, multi_task_config): elif "link_prediction" in task_config: task = self._parse_link_prediction_task( task_config["link_prediction"]) + elif "reconstruct_node_feat" in task_config: + task = self._parse_reconstruct_node_feat( + task_config["reconstruct_node_feat"]) else: raise ValueError(f"Invalid task type in multi-task learning {task_config}.") tasks.append(task) @@ -530,6 +569,14 @@ def override_arguments(self, cmd_args): # for basic attributes setattr(self, f"_{arg_key}", arg_val) + def verify_node_feat_reconstruct_arguments(self): + """Verify the correctness of arguments for node feature reconstruction tasks. + """ + _ = self.target_ntype + _ = self.batch_size + _ = self.eval_metric + _ = self.reconstruct_nfeat_name + def verify_node_class_arguments(self): """ Verify the correctness of arguments for node classification tasks. """ @@ -2545,7 +2592,7 @@ def eval_metric(self): else: eval_metric = ["accuracy"] elif self.task_type in [BUILTIN_TASK_NODE_REGRESSION, \ - BUILTIN_TASK_EDGE_REGRESSION]: + BUILTIN_TASK_EDGE_REGRESSION, BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: if hasattr(self, "_eval_metric"): if isinstance(self._eval_metric, str): eval_metric = self._eval_metric.lower() @@ -2568,7 +2615,10 @@ def eval_metric(self): "should be a string or a list of string" # no eval_metric else: - eval_metric = ["rmse"] + if self.task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: + eval_metric = ["mse"] + else: + eval_metric = ["rmse"] elif self.task_type == BUILTIN_TASK_LINK_PREDICTION: if hasattr(self, "_eval_metric"): if isinstance(self._eval_metric, str): @@ -2650,6 +2700,15 @@ def num_ffn_layers_in_decoder(self): # Set default mlp layer number between gnn layer to 0 return 0 + ################## Reconstruct node feats ############### + @property + def reconstruct_nfeat_name(self): + """ node feature name for reconstruction + """ + assert hasattr(self, "_reconstruct_nfeat_name"), \ + "reconstruct_nfeat_name must be provided under reconstruct_node_feat task " + return self._reconstruct_nfeat_name + ################## Multi task learning ################## @property def multi_tasks(self): diff --git a/python/graphstorm/config/config.py b/python/graphstorm/config/config.py index 90d713e471..a35fbfa311 100644 --- a/python/graphstorm/config/config.py +++ b/python/graphstorm/config/config.py @@ -53,6 +53,7 @@ BUILTIN_TASK_EDGE_REGRESSION = "edge_regression" BUILTIN_TASK_LINK_PREDICTION = "link_prediction" BUILTIN_TASK_COMPUTE_EMB = "compute_emb" +BUILTIN_TASK_RECONSTRUCT_NODE_FEAT = "reconstruct_node_feat" LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL = "ALL" @@ -60,7 +61,8 @@ BUILTIN_TASK_NODE_REGRESSION, \ BUILTIN_TASK_EDGE_CLASSIFICATION, \ BUILTIN_TASK_LINK_PREDICTION, \ - BUILTIN_TASK_EDGE_REGRESSION] + BUILTIN_TASK_EDGE_REGRESSION, \ + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT] EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY = "consecutive_increase" EARLY_STOP_AVERAGE_INCREASE_STRATEGY = "average_increase" diff --git a/python/graphstorm/eval/__init__.py b/python/graphstorm/eval/__init__.py index 0f3c0aae1c..a2507661e4 100644 --- a/python/graphstorm/eval/__init__.py +++ b/python/graphstorm/eval/__init__.py @@ -27,4 +27,5 @@ GSgnnPerEtypeMrrLPEvaluator, GSgnnClassificationEvaluator, GSgnnRegressionEvaluator, + GSgnnRconstructFeatRegScoreEvaluator, GSgnnMultiTaskEvaluator) diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index b7774706ba..fe3bb8bf08 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -575,7 +575,6 @@ def multilabel(self): """ return self._multilabel - class GSgnnRegressionEvaluator(GSgnnBaseEvaluator, GSgnnPredictionEvalInterface): """ Regression Evaluator. @@ -706,6 +705,84 @@ def compute_score(self, pred, labels, train=True): return scores +class GSgnnRconstructFeatRegScoreEvaluator(GSgnnRegressionEvaluator): + """ Evaluator for feature reconstruction using regression scores. + + We treat the prediction results as a 2D float tensor and + the label is also a 2D float tensor. + + We compute mse or rmse for it. + + Parameters + ---------- + eval_frequency: int + The frequency (number of iterations) of doing evaluation. + eval_metric_list: list of string + Evaluation metric used during evaluation. Default: ["mse"]. + use_early_stop: bool + Set true to use early stop. + early_stop_burnin_rounds: int + Burn-in rounds before start checking for the early stop condition. + early_stop_rounds: int + The number of rounds for validation scores used to decide early stop. + early_stop_strategy: str + The early stop strategy. GraphStorm supports two strategies: + 1) consecutive_increase and 2) average_increase. + """ + def __init__(self, eval_frequency, + eval_metric_list=None, + use_early_stop=False, + early_stop_burnin_rounds=0, + early_stop_rounds=3, + early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): + # set default metric list + if eval_metric_list is None: + eval_metric_list = ["mse"] + + super(GSgnnRconstructFeatRegScoreEvaluator, self).__init__( + eval_frequency, + eval_metric_list, + use_early_stop, + early_stop_burnin_rounds, + early_stop_rounds, + early_stop_strategy) + + def compute_score(self, pred, labels, train=True): + """ Compute evaluation score + + Parameters + ---------- + pred: + Rediction result + labels: + Label + train: boolean + If in model training. + + Returns + ------- + Evaluation metric values: dict + """ + scores = {} + for metric in self.metric_list: + if pred is not None and labels is not None: + pred = pred.to(th.float32) + labels = labels.to(th.float32) + + if train: + # training expects always a single number to be + # returned and has a different (potentially) evluation function + scores[metric] = self.metrics_obj.metric_function[metric](pred, labels) + else: + # validation or testing may have a different + # evaluation function, in our case the evaluation code + # may return a dictionary with the metric values for each metric + scores[metric] = self.metrics_obj.metric_eval_function[metric](pred, labels) + else: + # if the pred is None or the labels is None the metric can not me computed + scores[metric] = "N/A" + + return scores class GSgnnMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): """ Link Prediction Evaluator using "mrr" as metric. diff --git a/python/graphstorm/gconstruct/construct_graph.py b/python/graphstorm/gconstruct/construct_graph.py index 9ed376baeb..6ecbeed54a 100644 --- a/python/graphstorm/gconstruct/construct_graph.py +++ b/python/graphstorm/gconstruct/construct_graph.py @@ -710,6 +710,8 @@ def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats logging.info("Train/val/test on %s with mask %s, %s, %s: %d, %d, %d", ntype, train_mask, val_mask, test_mask, num_train, num_val, num_test) + logging.info("Note: Custom train, validate, test mask " + "information for nodes are not collected.") for etype in edge_data: feat_names = list(edge_data[etype].keys()) logging.info("Edge type %s has features: %s.", str(etype), str(feat_names)) @@ -726,6 +728,8 @@ def print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats logging.info("Train/val/test on %s with mask %s, %s, %s: %d, %d, %d", str(etype), train_mask, val_mask, test_mask, num_train, num_val, num_test) + logging.info("Note: Custom train, validate, test mask " + "information for edges are not collected.") for ntype in node_label_stats: for label_name, stats in node_label_stats[ntype].items(): diff --git a/python/graphstorm/gconstruct/transform.py b/python/graphstorm/gconstruct/transform.py index 2ea7ac01c9..01089a7eb1 100644 --- a/python/graphstorm/gconstruct/transform.py +++ b/python/graphstorm/gconstruct/transform.py @@ -1817,8 +1817,19 @@ def parse_label_conf(label_conf): mask_names.append(ops.train_mask_name) mask_names.append(ops.val_mask_name) mask_names.append(ops.test_mask_name) - assert len(mask_names) == len(set(mask_names)), \ - f"Some train/val/test mask field names are duplicated, please check: {mask_names}." + if len(mask_names) == len(set(mask_names)): + # In multi-task learning, we expect each task has + # its own train, validation and test mask fields. + # But there can be exceptions as users want to + # provide masks through node features or + # some tasks are sharing the same mask. + logging.warning("Some train/val/test mask field " + "names are duplicated, please check: %s." + "If you provide masks as node/edge features," + "please ignore this warning." + "If you share train/val/test mask fields " + "across different tasks, please ignore this warning.", + mask_names) return label_ops diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index 9b9d0136e6..6d82102b62 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -32,7 +32,8 @@ BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) + BUILTIN_TASK_LINK_PREDICTION, + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) from .config import BUILTIN_LP_DOT_DECODER from .config import BUILTIN_LP_DISTMULT_DECODER from .config import (BUILTIN_LP_LOSS_CROSS_ENTROPY, @@ -243,6 +244,41 @@ def create_builtin_node_gnn_model(g, config, train_task): """ return create_builtin_node_model(g, config, train_task) +# pylint: disable=unused-argument +def create_builtin_reconstruct_nfeat_decoder(g, decoder_input_dim, config, train_task): + """ create builtin node feature reconstruction decoder + according to task config + + Parameters + ---------- + g: DGLGraph + The graph data. + Note(xiang): Make it consistent with create_builtin_edge_decoder. + Reserved for future. + decoder_input_dim: int + Input dimension size of the decoder + config: GSConfig + Configurations + train_task : bool + Whether this model is used for training. + + Returns + ------- + decoder: The node task decoder(s) + loss_func: The loss function(s) + """ + dropout = config.dropout if train_task else 0 + target_ntype = config.target_ntype + reconstruct_feat = config.reconstruct_nfeat_name + feat_dim = g.nodes[target_ntype].data[reconstruct_feat].shape[1] + + decoder = EntityRegression(decoder_input_dim, + dropout=dropout, + out_dim=feat_dim) + + loss_func = RegressionLossFunc() + return decoder, loss_func + # pylint: disable=unused-argument def create_builtin_node_decoder(g, decoder_input_dim, config, train_task): """ create builtin node decoder according to task config @@ -869,5 +905,7 @@ def create_task_decoder(task_info, g, decoder_input_dim, train_task): return create_builtin_edge_decoder(g, decoder_input_dim, task_info.task_config, train_task) elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: return create_builtin_lp_decoder(g, decoder_input_dim, task_info.task_config, train_task) + elif task_info.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: + return create_builtin_reconstruct_nfeat_decoder(g, decoder_input_dim, task_info.task_config, train_task) else: raise TypeError(f"Unknown task type {task_info.task_type}") diff --git a/python/graphstorm/model/gnn.py b/python/graphstorm/model/gnn.py index e6155d0897..8d2d7da8e8 100644 --- a/python/graphstorm/model/gnn.py +++ b/python/graphstorm/model/gnn.py @@ -543,6 +543,17 @@ def get_lm_params(self): return params + def has_sparse_params(self): + """ Return whether there are sparse parameters (learnable embeddings) + in the model. + + Return + ------ + bool: True for there are sparse parameters + """ + return len(self._optimizer.sparse_opts) > 0 + + def get_sparse_params(self): """ get the sparse parameters of the model. diff --git a/python/graphstorm/model/gnn_encoder_base.py b/python/graphstorm/model/gnn_encoder_base.py index 91b51bead2..a020528b88 100644 --- a/python/graphstorm/model/gnn_encoder_base.py +++ b/python/graphstorm/model/gnn_encoder_base.py @@ -19,6 +19,7 @@ from functools import partial import logging +import abc import dgl import torch as th from torch import nn @@ -28,6 +29,25 @@ from ..utils import get_rank, barrier, is_distributed, create_dist_tensor, is_wholegraph from ..distributed import flush_data +class GSgnnGNNEncoderInterface: + """ The interface for builtin GraphStorm gnn encoder layer. + + The interface defines two functions that are useful in multi-task learning. + Any GNN encoder that implements these two functions can work with + GraphStorm multi-task learning pipeline. + + Note: We can define more functions when necessary. + """ + @abc.abstractmethod + def skip_last_selfloop(self): + """ Skip the self-loop of the last GNN layer. + """ + + @abc.abstractmethod + def reset_last_selfloop(self): + """ Reset the self-loop setting of the last GNN layer. + """ + class GraphConvEncoder(GSLayer): # pylint: disable=abstract-method r"""General encoder for graph data. diff --git a/python/graphstorm/model/hgt_encoder.py b/python/graphstorm/model/hgt_encoder.py index de203335fb..b544f51561 100644 --- a/python/graphstorm/model/hgt_encoder.py +++ b/python/graphstorm/model/hgt_encoder.py @@ -25,7 +25,8 @@ from dgl.nn.functional import edge_softmax from ..config import BUILDIN_GNN_BATCH_NORM, BUILDIN_GNN_LAYER_NORM, BUILTIN_GNN_NORM from .ngnn_mlp import NGNNMLP -from .gnn_encoder_base import GraphConvEncoder +from .gnn_encoder_base import (GraphConvEncoder, + GSgnnGNNEncoderInterface) class HGTLayer(nn.Module): @@ -280,7 +281,7 @@ def forward(self, g, h): return new_h -class HGTEncoder(GraphConvEncoder): +class HGTEncoder(GraphConvEncoder, GSgnnGNNEncoderInterface): r"""Heterogenous graph transformer (HGT) encoder The HGTEncoder employs several HGTLayers as its encoding mechanism. @@ -375,6 +376,14 @@ def __init__(self, dropout=dropout, norm=norm)) + def skip_last_selfloop(self): + # HGT does not have explicit self-loop + pass + + def reset_last_selfloop(self): + # HGT does not have explicit self-loop + pass + def forward(self, blocks, h): """Forward computation diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index cd8d29bf99..58e28064b6 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -24,9 +24,10 @@ BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) + BUILTIN_TASK_LINK_PREDICTION, + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) from .gnn import GSgnnModel - +from .gnn_encoder_base import GSgnnGNNEncoderInterface from .node_gnn import run_node_mini_batch_predict from .edge_gnn import run_edge_mini_batch_predict @@ -91,6 +92,7 @@ def __init__(self, alpha_l2norm): self._alpha_l2norm = alpha_l2norm self._task_pool = {} self._decoder = nn.ModuleDict() + self._warn_printed = False def add_task(self, task_id, task_type, decoder, loss_func): @@ -162,6 +164,12 @@ def _run_mini_batch(self, task_info, mini_batch): loss = self._forward(task_info.task_id, (blocks, node_feats, edge_feats, input_nodes), (pos_graph, neg_graph, pos_edge_feats, neg_edge_feats)) + elif task_info.task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: + # Order follow GSgnnNodeModelInterface.forward + blocks, input_feats, edge_feats, lbl, input_nodes = mini_batch + loss = self._forward(task_info.task_id, + (blocks, input_feats, edge_feats, input_nodes), + lbl) else: raise TypeError(f"Unknown task {task_info}") @@ -212,19 +220,62 @@ def _forward(self, task_id, encoder_data, decoder_data): # message passing graph, node features, edge features, seed nodes blocks, node_feats, _, input_nodes = encoder_data + task_type, loss_func = self.task_pool[task_id] + task_decoder = self.decoder[task_id] + if blocks is None or len(blocks) == 0: # no GNN message passing + if task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: + logging.warning("Reconstruct node feature with only " \ + "input embedding layer may not work.") encode_embs = self.comput_input_embed(input_nodes, node_feats) else: # GNN message passing - encode_embs = self.compute_embed_step(blocks, node_feats, input_nodes) + if task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: + if isinstance(self.gnn_encoder, GSgnnGNNEncoderInterface): + if self.has_sparse_params(): + # When there are learnable embeddings, we can not + # just simply skip the last layer self-loop. + # It may break the sparse optimizer backward code logic + # keep the self-loop and print a warning insetead + encode_embs = self.compute_embed_step( + blocks, node_feats, input_nodes) + if self._warn_printed is False: + logging.warning("When doing %s training, we need to " + "avoid adding self loop in the last GNN layer " + "to avoid the potential node " + "feature leakage issue. " + "When there are learnable embeddings on " + "nodes, GraphStorm can not automatically" + "skip the last layer self-loop" + "Please set use_self_loop to False", + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) + self._warn_printed = True + else: + # skip the selfloop of the last layer to + # avoid information leakage. + self.gnn_encoder.skip_last_selfloop() + encode_embs = self.compute_embed_step( + blocks, node_feats, input_nodes) + self.gnn_encoder.reset_last_selfloop() + else: + if self._warn_printed is False: + # Only print warning once to avoid overwhelming the log. + logging.warning("The gnn encoder %s does not support skip " + "the last self-loop operation" + "(skip_last_selfloop). There is a potential " + "node feature leakage risk when doing %s training.", + type(self.gnn_encoder), + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) + self._warn_printed = True + encode_embs = self.compute_embed_step( + blocks, node_feats, input_nodes) + else: + encode_embs = self.compute_embed_step(blocks, node_feats, input_nodes) # Call emb normalization. encode_embs = self.normalize_node_embs(encode_embs) - task_type, loss_func = self.task_pool[task_id] - task_decoder = self.decoder[task_id] - if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: labels = decoder_data assert len(labels) == 1, \ @@ -261,6 +312,22 @@ def _forward(self, task_id, encoder_data, decoder_data): "Positive scores and Negative scores must have edges of same" \ f"edge types, but get {pos_score.keys()} and {neg_score.keys()}" pred_loss = loss_func(pos_score, neg_score) + return pred_loss + elif task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: + labels = decoder_data + assert len(labels) == 1, \ + "In multi-task learning, only support do prediction " \ + "on one node type for a single node task." + pred_loss = 0 + target_ntype = list(labels.keys())[0] + + assert target_ntype in encode_embs, f"Node type {target_ntype} not in encode_embs" + assert target_ntype in labels, f"Node type {target_ntype} not in labels" + emb = encode_embs[target_ntype] + ntype_labels = labels[target_ntype] + ntype_logits = task_decoder(emb) + pred_loss = loss_func(ntype_logits, ntype_labels) + return pred_loss else: raise TypeError(f"Unknow task type {task_type}") @@ -307,6 +374,9 @@ def predict(self, task_id, mini_batch, return_proba=False): elif task_type == BUILTIN_TASK_LINK_PREDICTION: logging.warning("Prediction for link prediction is not implemented") return None + elif task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: + logging.warning("Prediction for node feature reconstruction is not supported") + return None else: raise TypeError(f"Unknow task type {task_type}") @@ -340,7 +410,9 @@ def multi_task_mini_batch_predict( with th.no_grad(): for dataloader, task_info in zip(dataloaders, task_infos): if task_info.task_type in \ - [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + [BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: if dataloader is None: # In cases when there is no validation or test set. # set pred and labels to None diff --git a/python/graphstorm/model/node_decoder.py b/python/graphstorm/model/node_decoder.py index 1d3ecc5f4d..d5405b88fa 100644 --- a/python/graphstorm/model/node_decoder.py +++ b/python/graphstorm/model/node_decoder.py @@ -114,13 +114,16 @@ class EntityRegression(GSLayer): The hidden dimensions dropout : float The dropout + out_dim: int + The output dimension size ''' def __init__(self, h_dim, - dropout=0): + dropout=0, + out_dim=1): super(EntityRegression, self).__init__() self.h_dim = h_dim - self.decoder = nn.Parameter(th.Tensor(h_dim, 1)) + self.decoder = nn.Parameter(th.Tensor(h_dim, out_dim)) nn.init.xavier_uniform_(self.decoder) # TODO(zhengda): The dropout is not used. self.dropout = nn.Dropout(dropout) diff --git a/python/graphstorm/model/rgat_encoder.py b/python/graphstorm/model/rgat_encoder.py index 14c1ecc395..bc05ad1c14 100644 --- a/python/graphstorm/model/rgat_encoder.py +++ b/python/graphstorm/model/rgat_encoder.py @@ -23,7 +23,8 @@ import dgl.nn as dglnn from .ngnn_mlp import NGNNMLP -from .gnn_encoder_base import GraphConvEncoder +from .gnn_encoder_base import (GraphConvEncoder, + GSgnnGNNEncoderInterface) class RelationalAttLayer(nn.Module): @@ -207,7 +208,7 @@ def _apply(ntype, h): return {ntype : _apply(ntype, h) for ntype, h in hs.items()} -class RelationalGATEncoder(GraphConvEncoder): +class RelationalGATEncoder(GraphConvEncoder, GSgnnGNNEncoderInterface): r"""Relational graph attention encoder The RelationalGATEncoder employs several RelationalAttLayers as its encoding mechanism. @@ -293,6 +294,13 @@ def __init__(self, self.num_heads, activation=F.relu if last_layer_act else None, self_loop=use_self_loop, norm=norm if last_layer_act else None)) + def skip_last_selfloop(self): + self.last_selfloop = self.layers[-1].self_loop + self.layers[-1].self_loop = False + + def reset_last_selfloop(self): + self.layers[-1].self_loop = self.last_selfloop + def forward(self, blocks, h): """Forward computation diff --git a/python/graphstorm/model/rgcn_encoder.py b/python/graphstorm/model/rgcn_encoder.py index a708dfe289..040658670b 100644 --- a/python/graphstorm/model/rgcn_encoder.py +++ b/python/graphstorm/model/rgcn_encoder.py @@ -24,7 +24,8 @@ from dgl.nn.pytorch.hetero import get_aggregate_fn from .ngnn_mlp import NGNNMLP -from .gnn_encoder_base import GraphConvEncoder +from .gnn_encoder_base import (GraphConvEncoder, + GSgnnGNNEncoderInterface) class RelGraphConvLayer(nn.Module): @@ -254,7 +255,7 @@ def _apply(ntype, h): return {ntype : _apply(ntype, h) for ntype, h in hs.items()} -class RelationalGCNEncoder(GraphConvEncoder): +class RelationalGCNEncoder(GraphConvEncoder, GSgnnGNNEncoderInterface): r""" Relational graph conv encoder. The RelationalGCNEncoder employs several RelGraphConvLayer as its encoding mechanism. @@ -345,6 +346,13 @@ def __init__(self, self.num_bases, activation=F.relu if last_layer_act else None, self_loop=use_self_loop, norm=norm if last_layer_act else None)) + def skip_last_selfloop(self): + self.last_selfloop = self.layers[-1].self_loop + self.layers[-1].self_loop = False + + def reset_last_selfloop(self): + self.layers[-1].self_loop = self.last_selfloop + # TODO(zhengda) refactor this to support edge features. def forward(self, blocks, h): """Forward computation diff --git a/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py b/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py index 98735739f4..3d2e8938d5 100644 --- a/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py +++ b/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py @@ -20,13 +20,18 @@ from graphstorm.config import GSConfig from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph from graphstorm.dataloading import GSgnnData -from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, - BUILTIN_TASK_NODE_REGRESSION, - BUILTIN_TASK_EDGE_CLASSIFICATION, - BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION, + GRAPHSTORM_MODEL_ALL_LAYERS, + GRAPHSTORM_MODEL_EMBED_LAYER, + GRAPHSTORM_MODEL_GNN_LAYER, + GRAPHSTORM_MODEL_DECODER_LAYER) from graphstorm.inference import GSgnnEmbGenInferer from graphstorm.utils import get_lm_ntypes +from graphstorm.model.multitask_gnn import GSgnnMultiTaskSharedEncoderModel def main(config_args): """ main function @@ -44,12 +49,14 @@ def main(config_args): if gs.get_rank() == 0: tracker.log_params(config.__dict__) - assert config.task_type in [BUILTIN_TASK_LINK_PREDICTION, - BUILTIN_TASK_NODE_REGRESSION, - BUILTIN_TASK_NODE_CLASSIFICATION, - BUILTIN_TASK_EDGE_CLASSIFICATION, - BUILTIN_TASK_EDGE_REGRESSION], \ - f"Not supported for task type: {config.task_type}" + if config.multi_tasks is None: + # if not multi-task, check task type + assert config.task_type in [BUILTIN_TASK_LINK_PREDICTION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION], \ + f"Not supported for task type: {config.task_type}" input_data = GSgnnData(config.part_config, node_feat_field=config.node_feat_name, @@ -63,14 +70,25 @@ def main(config_args): "restore model path cannot be none for gs_gen_node_embeddings" # load the model - if config.task_type == BUILTIN_TASK_LINK_PREDICTION: - model = gs.create_builtin_lp_gnn_model(input_data.g, config, train_task=False) - elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: - model = gs.create_builtin_node_gnn_model(input_data.g, config, train_task=False) - elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: - model = gs.create_builtin_edge_gnn_model(input_data.g, config, train_task=False) + if config.multi_tasks: + # Only support multi-task shared encoder model. + model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm) + gs.gsf.set_encoder(model, input_data.g, config, train_task=False) + assert config.restore_model_layers is not GRAPHSTORM_MODEL_ALL_LAYERS, \ + "When computing node embeddings with GSgnnMultiTaskSharedEncoderModel, " \ + "please set --restore-model-layers to " \ + f"{GRAPHSTORM_MODEL_EMBED_LAYER}, {GRAPHSTORM_MODEL_GNN_LAYER}." \ + f"Please do not include {GRAPHSTORM_MODEL_DECODER_LAYER}, " \ + f"but we get {config.restore_model_layers}" else: - raise TypeError("Not supported for task type: ", config.task_type) + if config.task_type == BUILTIN_TASK_LINK_PREDICTION: + model = gs.create_builtin_lp_gnn_model(input_data.g, config, train_task=False) + elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: + model = gs.create_builtin_node_gnn_model(input_data.g, config, train_task=False) + elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: + model = gs.create_builtin_edge_gnn_model(input_data.g, config, train_task=False) + else: + raise TypeError("Not supported for task type: ", config.task_type) model.restore_model(config.restore_model_path, model_layer_to_load=config.restore_model_layers) @@ -78,21 +96,24 @@ def main(config_args): emb_generator = GSgnnEmbGenInferer(model) emb_generator.setup_device(device=get_device()) - task_type = config.task_type - # infer ntypes must be sorted for node embedding saving - if task_type == BUILTIN_TASK_LINK_PREDICTION: + if config.multi_tasks: + # infer_ntypes = None means all node types. infer_ntypes = None - elif task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: - # TODO(xiangsx): Support multi-task on multiple node types. - infer_ntypes = [config.target_ntype] - elif task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: - infer_ntypes = set() - for etype in config.target_etype: - infer_ntypes.add(etype[0]) - infer_ntypes.add(etype[2]) - infer_ntypes = sorted(list(infer_ntypes)) else: - raise TypeError("Not supported for task type: ", task_type) + task_type = config.task_type + # infer ntypes must be sorted for node embedding saving + if task_type == BUILTIN_TASK_LINK_PREDICTION: + infer_ntypes = None + elif task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: + infer_ntypes = [config.target_ntype] + elif task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: + infer_ntypes = set() + for etype in config.target_etype: + infer_ntypes.add(etype[0]) + infer_ntypes.add(etype[2]) + infer_ntypes = sorted(list(infer_ntypes)) + else: + raise TypeError("Not supported for task type: ", task_type) emb_generator.infer(input_data, infer_ntypes, save_embed_path=config.save_embed_path, diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 958ef1b85b..0d8c6d5c6d 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -25,13 +25,15 @@ BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) + BUILTIN_TASK_LINK_PREDICTION, + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import (GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnMultiTaskDataLoader) from graphstorm.eval import (GSgnnClassificationEvaluator, GSgnnRegressionEvaluator, + GSgnnRconstructFeatRegScoreEvaluator, GSgnnPerEtypeMrrLPEvaluator, GSgnnMrrLPEvaluator, GSgnnMultiTaskEvaluator) @@ -113,6 +115,18 @@ def create_task_train_dataloader(task, config, train_data): exclude_training_targets=task_config.exclude_training_targets, edge_dst_negative_field=task_config.train_etypes_negative_dstnode, num_hard_negs=task_config.num_train_hard_negatives) + elif task.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: + train_idxs = train_data.get_node_train_set( + task_config.target_ntype, + mask=task_config.train_mask) + # TODO(xiangsx): Support construct feat + return GSgnnNodeDataLoader(train_data, + train_idxs, + fanout=fanout, + batch_size=task_config.batch_size, + train_task=True, + node_feats=node_feats, + label_field=task_config.reconstruct_nfeat_name) return None @@ -191,6 +205,22 @@ def create_task_val_dataloader(task, config, train_data): fixed_test_size=task_config.fixed_test_size, node_feats=node_feats, pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) + elif task.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: + eval_ntype = task_config.eval_target_ntype \ + if task_config.eval_target_ntype is not None \ + else task_config.target_ntype + val_idxs = train_data.get_node_val_set(eval_ntype, mask=task_config.val_mask) + # All tasks share the same GNN model, so the fanout should be the global fanout + fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] + if len(val_idxs) > 0: + # TODO(xiangsx): Support construct feat + return GSgnnNodeDataLoader(train_data, + val_idxs, + fanout=fanout, + batch_size=task_config.eval_batch_size, + train_task=False, + node_feats=node_feats, + label_field=task_config.reconstruct_nfeat_name) return None @@ -274,6 +304,22 @@ def create_task_test_dataloader(task, config, train_data): fixed_test_size=task_config.fixed_test_size, node_feats=node_feats, pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) + elif task.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: + eval_ntype = task_config.eval_target_ntype \ + if task_config.eval_target_ntype is not None \ + else task_config.target_ntype + test_idxs = train_data.get_node_test_set(eval_ntype, mask=task_config.test_mask) + # All tasks share the same GNN model, so the fanout should be the global fanout + fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] + if len(test_idxs) > 0: + # TODO(xiangsx): Support construct feat + return GSgnnNodeDataLoader(train_data, + test_idxs, + fanout=fanout, + batch_size=task_config.eval_batch_size, + train_task=False, + node_feats=node_feats, + label_field=task_config.reconstruct_nfeat_name) return None def create_evaluator(task): @@ -340,6 +386,14 @@ def create_evaluator(task): early_stop_burnin_rounds=config.early_stop_burnin_rounds, early_stop_rounds=config.early_stop_rounds, early_stop_strategy=config.early_stop_strategy) + elif task.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: + return GSgnnRconstructFeatRegScoreEvaluator( + config.eval_frequency, + config.eval_metric, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) return None def main(config_args): @@ -362,6 +416,8 @@ def main(config_args): gs.gsf.set_encoder(model, train_data.g, config, train_task=True) tasks = config.multi_tasks + assert tasks is not None, \ + "The multi_task_learning configure block should not be empty." train_dataloaders = [] val_dataloaders = [] test_dataloaders = [] diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index b4a3df7354..e6244fd38e 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -27,7 +27,8 @@ BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) + BUILTIN_TASK_LINK_PREDICTION, + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) from ..model import (do_full_graph_inference, do_mini_batch_inference, GSgnnModelBase, GSgnnModel, @@ -195,6 +196,33 @@ def prepare_link_predict_mini_batch(data, task_info, mini_batch, device): return (blocks, pos_graph, neg_graph, node_feats, None, \ pos_graph_feats, None, input_nodes) +def prepare_reconstruct_node_feat(data, task_info, mini_batch, device): + """ Prepare mini-batch for node feature reconstruction. + + The input is a mini-batch sampled by a node sampler. + The output ia a prepared input following the + input arguments of GSgnnNodeModelInterface.forward. + + Parameters + ---------- + data: GSgnnData + Graph data + task_info: TaskInfo + Task meta information + mini_batch: tuple + Mini-batch info + device: torch.device + Device + + Return + ------ + tuple: mini-batch + """ + # same are preparing node regression data + # Note: We may add some argumentation in the future + # So keep a different prepare func for node feature reconstruction. + return prepare_node_mini_batch(data, task_info, mini_batch, device) + class GSgnnMultiTaskLearningTrainer(GSgnnTrainer): r""" A trainer for multi-task learning @@ -258,6 +286,11 @@ def _prepare_mini_batch(self, data, task_info, mini_batch, device): task_info, mini_batch, device) + elif task_info.task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: + return prepare_reconstruct_node_feat(data, + task_info, + mini_batch, + device) else: raise TypeError(f"Unknown task {task_info}", ) diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index 345048e544..b53b9948cc 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -146,6 +146,34 @@ then exit -1 fi +bst_cnt=$(grep "Best Test reconstruct_node_feat" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test reconstruct_node_feat" + exit -1 +fi + +cnt=$(grep "Test reconstruct_node_feat" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test reconstruct_node_feat" + exit -1 +fi + +bst_cnt=$(grep "Best Validation reconstruct_node_feat" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation reconstruct_node_feat" + exit -1 +fi + +cnt=$(grep "Validation reconstruct_node_feat" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation reconstruct_node_feat" + exit -1 +fi + cnt=$(ls -l /data/gsgnn_mt/ | grep epoch | wc -l) if test $cnt != 3 then @@ -153,7 +181,167 @@ then exit -1 fi -echo "**************[Multi-task with learnable embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" +rm -fr /data/gsgnn_mt/ +rm /tmp/train_log.txt + +echo "**************[Multi-task] dataset: Movielens, RGAT layer 2, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" +python3 -m graphstorm.run.gs_multi_task_learning --workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --num-layers 2 --fanout "4,4" --model-encoder-type rgat + +error_and_exit $? + +# check prints + +bst_cnt=$(grep "Best Test node_classification" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test node_classification" + exit -1 +fi + +cnt=$(grep "Test node_classification" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test node_classification" + exit -1 +fi + +bst_cnt=$(grep "Best Validation node_classification" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation accuracy node_classification" + exit -1 +fi + +cnt=$(grep "Validation node_classification" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation node_classification" + exit -1 +fi + +bst_cnt=$(grep "Best Test edge_classification" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test edge_classification" + exit -1 +fi + +cnt=$(grep "Test edge_classification" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test edge_classification" + exit -1 +fi + +bst_cnt=$(grep "Best Validation edge_classification" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation edge_classification" + exit -1 +fi + +cnt=$(grep "Validation edge_classification" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation edge_classification" + exit -1 +fi + +bst_cnt=$(grep "Best Test edge_regression" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test edge_regression" + exit -1 +fi + +cnt=$(grep "Test edge_regression" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test edge_regression" + exit -1 +fi + +bst_cnt=$(grep "Best Validation edge_regression" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation edge_regression" + exit -1 +fi + +cnt=$(grep "Validation edge_regression" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation edge_regression" + exit -1 +fi + +bst_cnt=$(grep "Best Test link_prediction" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test link_prediction" + exit -1 +fi + +cnt=$(grep "Test link_prediction" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test link_prediction" + exit -1 +fi + +bst_cnt=$(grep "Best Validation link_prediction" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation link_prediction" + exit -1 +fi + +cnt=$(grep "Validation link_prediction" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation link_prediction" + exit -1 +fi + +bst_cnt=$(grep "Best Test reconstruct_node_feat" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test reconstruct_node_feat" + exit -1 +fi + +cnt=$(grep "Test reconstruct_node_feat" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test reconstruct_node_feat" + exit -1 +fi + +bst_cnt=$(grep "Best Validation reconstruct_node_feat" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation reconstruct_node_feat" + exit -1 +fi + +cnt=$(grep "Validation reconstruct_node_feat" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation reconstruct_node_feat" + exit -1 +fi + +cnt=$(ls -l /data/gsgnn_mt/ | grep epoch | wc -l) +if test $cnt != 3 +then + echo "The number of save models $cnt is not equal to the specified topk 3" + exit -1 +fi + +rm -fr /data/gsgnn_mt/ +rm /tmp/train_log.txt + +echo "**************[Multi-task with learnable embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, with learnable node embedding, inference: full-graph, save model" python3 -m graphstorm.run.gs_multi_task_learning --workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True error_and_exit $? @@ -192,8 +380,20 @@ then exit -1 fi +echo "**************[Multi-task gen embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, load from saved model" +python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_mt/ --num-trainers $NUM_TRAINERS --use-mini-batch-infer false --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-embed-path /data/gsgnn_mt/save-emb/ --restore-model-path /data/gsgnn_mt/epoch-2/ --restore-model-layers embed,gnn --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True -rm -fr /data/gsgnn_mt/infer-emb/ +error_and_exit $? + +cnt=$(ls -l /data/gsgnn_mt/save-emb/ | wc -l) +cnt=$[cnt - 1] +if test $cnt != 2 +then + echo "The number of saved embs $cnt is not equal to 2 (for movie and user)." +fi + +# Multi-task will save node embeddings of all the nodes. +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train-embout /data/gsgnn_mt/emb/ --infer-embout /data/gsgnn_mt/save-emb/ --link-prediction echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference only" python3 -m graphstorm.run.gs_multi_task_learning --inference --workspace $GS_HOME/inference_scripts/mt_infer --num-trainers $NUM_INFERs --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp_only_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_mt/infer-emb/ --restore-model-path /data/gsgnn_mt/epoch-2 --save-prediction-path /data/gsgnn_mt/prediction/ --logging-file /tmp/log.txt --preserve-input True --backend nccl diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index cdbbbfb2da..1421393e6f 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -34,7 +34,8 @@ BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) + BUILTIN_TASK_LINK_PREDICTION, + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) from graphstorm.config.config import GRAPHSTORM_LP_EMB_L2_NORMALIZATION from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER @@ -1671,6 +1672,22 @@ def create_dummy_lp_config2(): "exclude_training_targets": False } +def create_dummy_nfr_config(): + return { + "target_ntype": "a", + "reconstruct_nfeat_name": "rfeat", + "task_weight": 0.5, + "mask_fields": ["nfr_train_mask", "nfr_eval_mask", "nfr_test_mask"] + } + +def create_dummy_nfr_config2(): + return { + "target_ntype": "a", + "reconstruct_nfeat_name": "rfeat", + "mask_fields": ["nfr_train_mask", "nfr_eval_mask", "nfr_test_mask"], + "eval_metric": "rmse" + } + def create_multi_task_config(tmp_path, file_name): yaml_object = create_dummpy_config_obj() yaml_object["gsf"]["basic"] = { @@ -1699,6 +1716,12 @@ def create_multi_task_config(tmp_path, file_name): }, { BUILTIN_TASK_LINK_PREDICTION : create_dummy_lp_config2() + }, + { + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: create_dummy_nfr_config() + }, + { + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT: create_dummy_nfr_config2() } ] @@ -1712,7 +1735,7 @@ def test_multi_task_config(): args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'multi_task_test_default.yaml'), local_rank=0) config = GSConfig(args) - assert len(config.multi_tasks) == 6 + assert len(config.multi_tasks) == 8 nc_config = config.multi_tasks[0] assert nc_config.task_type == BUILTIN_TASK_NODE_CLASSIFICATION assert nc_config.task_id == f"{BUILTIN_TASK_NODE_CLASSIFICATION}-a-label_c" @@ -1815,7 +1838,6 @@ def test_multi_task_config(): assert lp_config.eval_metric[0] == "mrr" assert lp_config.lp_edge_weight_for_loss == "weight" - lp_config = config.multi_tasks[5] assert lp_config.task_type == BUILTIN_TASK_LINK_PREDICTION assert lp_config.task_id == f"{BUILTIN_TASK_LINK_PREDICTION}-ALL_ETYPE" @@ -1839,6 +1861,34 @@ def test_multi_task_config(): assert config.lp_edge_weight_for_loss == None assert config.model_select_etype == LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL + nfr_config = config.multi_tasks[6] + assert nfr_config.task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT + assert nfr_config.task_id == f"{BUILTIN_TASK_RECONSTRUCT_NODE_FEAT}-a-rfeat" + nfr_config = nfr_config.task_config + assert nfr_config.task_weight == 0.5 + assert nfr_config.train_mask == "nfr_train_mask" + assert nfr_config.val_mask == "nfr_eval_mask" + assert nfr_config.test_mask == "nfr_test_mask" + assert nfr_config.target_ntype == "a" + assert nfr_config.reconstruct_nfeat_name == "rfeat" + assert len(nfr_config.eval_metric) == 1 + assert nfr_config.eval_metric[0] == "mse" + assert nfr_config.batch_size == 64 + + nfr_config = config.multi_tasks[7] + assert nfr_config.task_type == BUILTIN_TASK_RECONSTRUCT_NODE_FEAT + assert nfr_config.task_id == f"{BUILTIN_TASK_RECONSTRUCT_NODE_FEAT}-a-rfeat" + nfr_config = nfr_config.task_config + assert nfr_config.task_weight == 1.0 + assert nfr_config.train_mask == "nfr_train_mask" + assert nfr_config.val_mask == "nfr_eval_mask" + assert nfr_config.test_mask == "nfr_test_mask" + assert nfr_config.target_ntype == "a" + assert nfr_config.reconstruct_nfeat_name == "rfeat" + assert len(nfr_config.eval_metric) == 1 + assert nfr_config.eval_metric[0] == "rmse" + assert nfr_config.batch_size == 64 + if __name__ == '__main__': test_multi_task_config() test_id_mapping_file() diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index d0225802cc..ca2e8b7ee8 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -57,11 +57,14 @@ LinkPredictWeightedDotDecoder, LinkPredictWeightedDistMultDecoder) from graphstorm.model.node_decoder import EntityRegression, EntityClassifier +from graphstorm.model.loss_func import RegressionLossFunc from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnMultiTaskDataLoader from graphstorm.dataloading.dataset import prepare_batch_input -from graphstorm import create_builtin_edge_gnn_model, create_builtin_node_gnn_model -from graphstorm import create_builtin_lp_gnn_model +from graphstorm import (create_builtin_edge_gnn_model, + create_builtin_node_gnn_model, + create_builtin_lp_gnn_model, + create_builtin_reconstruct_nfeat_decoder) from graphstorm import get_node_feat_size from graphstorm.gsf import get_rel_names_for_reconstruct from graphstorm.model import do_full_graph_inference, do_mini_batch_inference @@ -1398,6 +1401,32 @@ def test_node_regression(): assert model.gnn_encoder.out_dims == 4 assert isinstance(model.gnn_encoder, RelationalGATEncoder) assert isinstance(model.decoder, EntityRegression) + # It is single float regression by default + assert model.decoder.decoder.shape[1] == 1 + th.distributed.destroy_process_group() + dgl.distributed.kvstore.close_kvstore() + +def test_node_feat_reconstruct(): + """ Test logic of building a node regression model + """ + # initialize the torch distributed environment + th.distributed.init_process_group(backend='gloo', + init_method='tcp://127.0.0.1:23456', + rank=0, + world_size=1) + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + g, _ = generate_dummy_dist_graph(tmpdirname) + create_nr_config(Path(tmpdirname), 'gnn_nr.yaml') + args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'gnn_nr.yaml'), + local_rank=0) + config = GSConfig(args) + setattr(config, "_reconstruct_nfeat_name", "feat") + decoder, loss_func = create_builtin_reconstruct_nfeat_decoder( + g, decoder_input_dim=32, config=config, train_task=True) + assert isinstance(decoder, EntityRegression) + assert decoder.decoder.shape[1] == 2 + assert isinstance(loss_func, RegressionLossFunc) th.distributed.destroy_process_group() dgl.distributed.kvstore.close_kvstore() @@ -2271,6 +2300,8 @@ def check_forward(mock_run_lp_mini_batch_predict, if __name__ == '__main__': + test_node_feat_reconstruct() + test_multi_task_forward() test_multi_task_predict() test_multi_task_mini_batch_predict() diff --git a/tests/unit-tests/test_trainer.py b/tests/unit-tests/test_trainer.py index d6191c5fb0..598ab44269 100644 --- a/tests/unit-tests/test_trainer.py +++ b/tests/unit-tests/test_trainer.py @@ -25,7 +25,8 @@ from graphstorm.config import (GSConfig, TaskInfo) from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) + BUILTIN_TASK_LINK_PREDICTION, + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) from graphstorm.dataloading import GSgnnData from graphstorm.tracker import GSSageMakerTaskTracker from graphstorm import create_builtin_node_gnn_model @@ -34,7 +35,8 @@ from graphstorm.utils import setup_device, get_device from graphstorm.trainer.mt_trainer import (prepare_node_mini_batch, prepare_edge_mini_batch, - prepare_link_predict_mini_batch) + prepare_link_predict_mini_batch, + prepare_reconstruct_node_feat) from graphstorm.dataloading import (GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnLinkPredictionDataLoader) @@ -195,6 +197,56 @@ def forward(self, task_id, mini_batch): def predict(self, task_id, mini_batch, return_proba=False): pass +def test_mtask_prepare_reconstruct_node_feat(): + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) + np_data = GSgnnData(part_config=part_config) + + setup_device(0) + device = get_device() + # Without shuffling, the seed nodes should have the same order as the target nodes. + target_idx = {'n1': th.arange(np_data.g.number_of_nodes('n1'))} + task_id = "test_node_feat_reconstruct" + + # label is same as node feat + dataloader = GSgnnNodeDataLoader(np_data, target_idx, [10], 10, + label_field='feat', + node_feats='feat', + train_task=False) + task_config = GSConfig.__new__(GSConfig) + setattr(task_config, "task_weight", 0.75) + task_info = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, + task_id=task_id, + task_config=task_config, + dataloader=dataloader) + node_feats = np_data.get_node_feats(target_idx, 'feat') + labels = np_data.get_node_feats(target_idx, 'feat') + mini_batch = (target_idx, target_idx, None) + + blocks, input_feats, _, lbl, input_nodes = \ + prepare_reconstruct_node_feat(np_data, task_info, mini_batch, device) + assert blocks is None + assert_equal(input_nodes["n1"].numpy(), target_idx["n1"].numpy()) + assert_equal(input_feats["n1"].cpu().numpy(), node_feats["n1"].numpy()) + assert_equal(lbl["n1"].cpu().numpy(), labels["n1"].numpy()) + assert_equal(node_feats["n1"].cpu().numpy(), lbl["n1"].cpu().numpy()) + + # there is no node feat + dataloader = GSgnnNodeDataLoader(np_data, target_idx, [10], 10, + label_field='feat', + train_task=False) + task_info = TaskInfo(task_type=BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, + task_id=task_id, + task_config=task_config, + dataloader=dataloader) + _, input_feats, _, lbl, input_nodes = \ + prepare_reconstruct_node_feat(np_data, task_info, mini_batch, device) + assert_equal(input_nodes["n1"].numpy(), target_idx["n1"].numpy()) + assert len(input_feats) == 0 + assert_equal(lbl["n1"].cpu().numpy(), labels["n1"].numpy()) + assert_equal(node_feats["n1"].cpu().numpy(), lbl["n1"].cpu().numpy()) + def test_mtask_prepare_node_mini_batch(): with tempfile.TemporaryDirectory() as tmpdirname: # get the test dummy distributed graph @@ -371,3 +423,4 @@ def test_mtask_prepare_lp_mini_batch(): test_mtask_prepare_node_mini_batch() test_mtask_prepare_edge_mini_batch() test_mtask_prepare_lp_mini_batch() + test_mtask_prepare_reconstruct_node_feat() diff --git a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml index fcf208c0f8..169d1a4758 100644 --- a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml +++ b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml @@ -100,4 +100,15 @@ gsf: - "train_mask_field_lp" - null # empty means there is no validation mask - null # empty means there is no test mask - task_weight: 1.0 \ No newline at end of file + task_weight: 1.0 + - reconstruct_node_feat: + reconstruct_nfeat_name: "title" + target_ntype: "movie" + batch_size: 128 + mask_fields: + - "train_mask_c0" # node classification mask 0 + - "val_mask_c0" + - "test_mask_c0" + task_weight: 1.0 + eval_metric: + - "mse" \ No newline at end of file