From a59b0387c946d4a2209f2a44720a02f10953e9f7 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 2 May 2024 22:29:22 -0700 Subject: [PATCH 01/79] Drafting multi-task learning --- python/graphstorm/config/__init__.py | 2 + python/graphstorm/config/argument.py | 52 ++ python/graphstorm/config/config.py | 26 + python/graphstorm/dataloading/dataloading.py | 101 ++++ python/graphstorm/model/__init__.py | 11 +- python/graphstorm/model/edge_gnn.py | 39 +- python/graphstorm/model/lp_gnn.py | 25 + python/graphstorm/model/multitask_gnn.py | 220 +++++++++ python/graphstorm/model/node_gnn.py | 50 +- python/graphstorm/run/gsgnn_mt/__init__.py | 0 python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 88 ++++ python/graphstorm/trainer/mt_trainer.py | 491 +++++++++++++++++++ training_scripts/gsgnn_mt/README.md | 18 + training_scripts/gsgnn_mt/ml_ncr_lp.json | 84 ++++ training_scripts/gsgnn_mt/ml_ncr_lp_yaml | 74 +++ 15 files changed, 1269 insertions(+), 12 deletions(-) create mode 100644 python/graphstorm/model/multitask_gnn.py create mode 100644 python/graphstorm/run/gsgnn_mt/__init__.py create mode 100644 python/graphstorm/run/gsgnn_mt/gsgnn_mt.py create mode 100644 python/graphstorm/trainer/mt_trainer.py create mode 100644 training_scripts/gsgnn_mt/README.md create mode 100644 training_scripts/gsgnn_mt/ml_ncr_lp.json create mode 100644 training_scripts/gsgnn_mt/ml_ncr_lp_yaml diff --git a/python/graphstorm/config/__init__.py b/python/graphstorm/config/__init__.py index ae2b1831cc..33d9539b68 100644 --- a/python/graphstorm/config/__init__.py +++ b/python/graphstorm/config/__init__.py @@ -46,3 +46,5 @@ BUILTIN_LP_LOSS_CONTRASTIVELOSS) from .config import (GRAPHSTORM_LP_EMB_L2_NORMALIZATION, GRAPHSTORM_LP_EMB_NORMALIZATION_METHODS) + +from .config import TaskInfo \ No newline at end of file diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 7955815d39..37c5474f98 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -146,6 +146,12 @@ def __init__(self, cmd_args): self.yaml_paths = cmd_args.yaml_config_file # Load all arguments from yaml config configuration = self.load_yaml_config(cmd_args.yaml_config_file) + + if 'multi_task_learning' in configuration: + # parse multi task learning config and save it into self._multi_tasks + self._parse_multi_tasks(configuration['multi_task_learning']) + del configuration['multi_task_learning'] + self.set_attributes(configuration) # Override class attributes using command-line arguments self.override_arguments(cmd_args) @@ -219,6 +225,38 @@ def set_attributes(self, configuration): for key, val in udf_family.items(): setattr(self, key, val) + def _parse_multi_tasks(self, multi_task_config): + """ Parse multi-task configuration + """ + assert len(multi_task_config) > 1, \ + "There must be at least two tasks" + + tasks = [] + for task_config in multi_task_config: + assert isinstance(task_config, dict) and len(task_config) == 1, \ + "When defining multiple tasks for " \ + "training, define one task each time." + if "node_classification" in task_config: + task = self._parse_node_classification_task( + task_config["node_classification"]) + elif "node_regression" in task_config: + task = self._parse_node_regression_task( + task_config["node_regression"]) + elif "edge_classification" in task_config: + task = self._parse_edge_classification_task( + task_config["edge_classification"]) + elif "edge_regression" in task_config: + task = self._parse_edge_regression_task( + task_config["edge_regression"]) + elif "link_prediction" in task_config: + task = self._parse_link_prediction_task( + task_config["link_prediction"]) + else: + raise ValueError(f"Invalid task type in multi-task learning {task_config}.") + tasks.append(task) + logging.debug("Multi-task learning with %d tasks", len(tasks)) + self._multi_tasks = tasks + def load_yaml_config(self, yaml_path): """Helper function to load a yaml config file""" with open(yaml_path, "r", encoding='utf-8') as stream: @@ -2250,6 +2288,20 @@ def model_select_etype(self): # Per edge type lp evaluation is disabled. return LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL + ###Multi task support #### + @property + def multi_tasks(self): + """ Definition of tasks in multi-task learning. + + Return: list of Tasks + """ + # pylint: disable=no-member + if hasattr(self, "_multi_tasks"): + assert len(self._multi_tasks) > 1, \ + "There must be at least two tasks for multi-task learning" + return self._multi_tasks + return None + @property def num_ffn_layers_in_input(self): """ Number of extra feedforward neural network layers in the input layer diff --git a/python/graphstorm/config/config.py b/python/graphstorm/config/config.py index 18b8a16d17..417392071d 100644 --- a/python/graphstorm/config/config.py +++ b/python/graphstorm/config/config.py @@ -15,6 +15,7 @@ Builtin configs """ +import dataclasses BUILTIN_GNN_ENCODER = ["gat", "rgat", "rgcn", "sage", "hgt", "gatv2"] BUILTIN_ENCODER = ["lm", "mlp"] + BUILTIN_GNN_ENCODER @@ -73,3 +74,28 @@ BUILTIN_LP_DISTMULT_DECODER = "distmult" SUPPORTED_LP_DECODER = [BUILTIN_LP_DOT_DECODER, BUILTIN_LP_DISTMULT_DECODER] + +################ Task info data classes ############################ +@dataclasses.dataclass +class TaskInfo: + """Information of a training task in multi-task learning + + Parameters + ---------- + task_type: str + Task type + node_type: str + Node type of the task, if it is a node task + edge_type: tuple of strs + Edge type of the task, if it is a edge task + node_label_field: str + Node label field + edge_label_field: str + Edge label field + """ + task_type : str + node_type : str = None + edge_type : tuple = None + node_label_field : str = None + edge_label_field : str = None + dataloader = None # dataloder diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index 175bb83c05..28a36a75c7 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -18,6 +18,7 @@ import math import inspect import logging + import dgl import torch as th from torch.utils.data import DataLoader @@ -1686,6 +1687,106 @@ def __len__(self): return min(self.dataloader.expected_idxs, self.unlabeled_dataloader.expected_idxs) +####################### Multi-task Dataloader #################### +class GSgnnMultiTaskDataLoader: + r""" DataLoader designed for multi-task learning + + Parameters + ---------- + dataset: GSgnnData + The GraphStorm dataset + task_infos: list of TaskInfo + Task meta information + task_dataloaders: list of GsgnnDataLoader + A list of task dataloaders + """ + def __init__(self, dataset, task_infos, task_dataloaders): + assert len(task_infos) == len(task_dataloaders), \ + "Number of task_info should match number of task dataloaders" + # check dataloaders + lens = [] + for task_info, dataloader in zip(task_infos, task_dataloaders): + assert isinstance(dataloader, GSgnnEdgeDataLoaderBase) or \ + isinstance(dataloader, GSgnnLinkPredictionDataLoaderBase) or \ + isinstance(dataloader, GSgnnNodeDataLoaderBase), \ + "The task data loader should be a GSgnnEdgeDataLoaderBase " \ + " or a GSgnnLinkPredictionDataLoaderBase or a GSgnnNodeDataLoaderBase" + num_iters = len(dataloader) + lens.append(num_iters) + logging.debug("Task %s has number of iterations of %d", + task_info, num_iters) + + self._len = max(lens) + logging.info("Set the number of iterations to %d, which is the length " \ + "of the largest task in the multi-task learning.", self._len) + self._data = dataset + self._task_infos = task_infos + self._dataloaders = task_dataloaders # one dataloader for each task + self._reset_loader() + + def _reset_loader(self): + """ reset the dataloaders + """ + for dataloader in self._dataloaders: + dataloader.__iter__() + self._num_iters = 0 + + + def __iter__(self): + self._reset_loader() + return self + + def __len__(self): + return self._len + + def __next__(self): + self._num_iters += 1 + # End of iterating all the dataloaders + if self._num_iters == self._len: + raise StopIteration + + # call __next__ of each dataloader + mini_batches = [] + for task_info, dataloader in zip(self._task_infos, self._dataloaders): + try: + mini_batch = next(dataloader) + except StopIteration: + dataloader.__iter__() + mini_batch = next(dataloader) + if task_info.dataloader is None: + task_info.dataloader = dataloader + else: + assert task_info.dataloader is dataloader, \ + "Each task in multi-task learning should have a fixed dataloader." + mini_batches.append((task_info, mini_batch)) + return mini_batches + + @property + def data(self): + """ The dataset of this dataloader. + + Returns + ------- + GSgnnData : The dataset of the dataloader. + """ + return self._data + + @property + def dataloaders(self): + """Get the list of dataloaders + """ + # useful for conducting validation scores and test scores. + return self._dataloaders + + @property + def task_infos(self): + """Get the list of task_infos + """ + # useful for conducting validation scores and test scores. + return self._task_infos + + + ####################### Distillation ############################# class DistillDataManager: diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 08a8391f56..18a741e200 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -24,12 +24,17 @@ from .gnn import do_full_graph_inference from .gnn import do_mini_batch_inference from .node_gnn import GSgnnNodeModel, GSgnnNodeModelBase, GSgnnNodeModelInterface -from .node_gnn import node_mini_batch_gnn_predict, node_mini_batch_predict +from .node_gnn import (node_mini_batch_gnn_predict, + node_mini_batch_predict, + run_node_mini_batch_predict) from .edge_gnn import GSgnnEdgeModel, GSgnnEdgeModelBase, GSgnnEdgeModelInterface -from .edge_gnn import edge_mini_batch_gnn_predict, edge_mini_batch_predict +from .edge_gnn import (edge_mini_batch_gnn_predict, + edge_mini_batch_predict, + run_edge_mini_batch_predict) from .lp_gnn import (GSgnnLinkPredictionModel, GSgnnLinkPredictionModelBase, - GSgnnLinkPredictionModelInterface) + GSgnnLinkPredictionModelInterface, + run_lp_mini_batch_predict) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 536e61f311..0a4b6a38e5 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -311,6 +311,44 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= model.eval() decoder = model.decoder device = model.device + + preds, labels = run_edge_mini_batch_predict(decoder, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_edge_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch prediction using edge decoder + + This function usually follows full-grain GNN embedding inference. After having + the GNN embeddings, we need to perform mini-batch computation to make predictions + on the GNN embeddings. + + Parameters + ---------- + decoder : GSEdgeDecoder + The GraphStorm edge decoder + emb : dict of Tensor + The GNN embeddings + loader : GSgnnEdgeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba: bool + Whether to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels + + Returns + ------- + dict of Tensor : GNN prediction results. Return all the results when return_proba is true + otherwise return the maximum result. + dict of Tensor : labels if return_labels is True + """ data = loader.data g = data.g preds = {} @@ -379,7 +417,6 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= append_to_dict(lbl, labels) barrier() - model.train() for target_etype, pred in preds.items(): preds[target_etype] = th.cat(pred) if return_label: diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index 91c2c3317c..f73b36d34a 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -154,6 +154,31 @@ def lp_mini_batch_predict(model, emb, loader, device): Rankings of positive scores in format of {etype: ranking} """ decoder = model.decoder + return run_lp_mini_batch_predict(decoder, + emb, + loader, + device) + +def run_lp_mini_batch_predict(decoder, emb, loader, device): + """ Perform mini-batch link prediction. + + Parameters + ---------- + decoder : LinkPredictNoParamDecoder or LinkPredictLearnableDecoder + The GraphStorm link prediction decoder model + emb : dict of Tensor + The GNN embeddings + loader : GSgnnEdgeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute test scores + + Returns + ------- + rankings: dict of tensors + Rankings of positive scores in format of {etype: ranking} + """ + with th.no_grad(): ranking = {} for pos_neg_tuple, neg_sample_type in loader: diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py new file mode 100644 index 0000000000..5568143464 --- /dev/null +++ b/python/graphstorm/model/multitask_gnn.py @@ -0,0 +1,220 @@ +""" + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + GNN model for multi-task learning in GraphStorm +""" +import abc +import logging +import time +import torch as th +import dgl + +from ..config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) +from .gnn import GSgnnModel, GSgnnModelBase + + +class GSgnnMultiTaskModelInterface: + """ The interface for GraphStorm multi-task learning. + + This interface defines two main methods for training and inference. + """ + @abc.abstractmethod + def forward(self, task_id, mini_batch): + """ The forward function for multi-task learning + + This method is used for training, It runs model forword + on a mini-batch for one task at a time. + The loss of the model in the mini-batch is returned. + + Parameters + ---------- + task_id: str + ID of the task. + mini_batch: tuple + Mini-batch info + + + Return + ------ + The loss of prediction. + """ + + @abc.abstractmethod + def predict(self, task_info, mini_batch): + """ The forward function for multi-task prediction. + + This method is used for inference, It runs model forword + on a mini-batch for one task at a time. + The prediction result is returned. + + Parameters + ---------- + task_info: TaskInfo + task meta information + mini_batch: tuple + mini-batch info + + Returns + ------- + Tensor or dict of Tensor: + the prediction results. + """ + +class GSgnnMultiTaskSharedEncoderModel(GSgnnModel, GSgnnMultiTaskModelInterface): + """ GraphStorm GNN model for multi-task learning + with a shared encoder model and separate decoder models. + + Parameters + ---------- + alpha_l2norm : float + The alpha for L2 normalization. + """ + def __init__(self, alpha_l2norm): + super(GSgnnMultiTaskSharedEncoderModel, self).__init__() + self._alpha_l2norm = alpha_l2norm + self._task_pool = {} + + def add_task(self, task_id, task_type, + decoder, loss_func, weight): + """ Add a task into the multi-task pool + """ + assert task_id not in self._task_pool, \ + f"Task {task_id} already exists" + logging.info("Setup task %s", task_id) + self._task_pool[task_id] = (task_type, decoder, loss_func, weight) + + @property + def alpha_l2norm(self): + """Get parameter norm params + """ + return self._alpha_l2norm + + @property + def task_pool(self): + """ Get task pool + """ + return self._task_pool + + # pylint: disable=unused-argument + def forward(self, task_id, mini_batch): + """ The forward function for multi-task learning + """ + assert task_id in self.task_pool, \ + f"Unknown task: {task_id} in multi-task learning." \ + f"Existing tasks are {self.task_pool.keys()}" + + encoder_data, decoder_data = mini_batch + # message passing graph, node features, edge features, seed nodes + blocks, node_feats, _, input_nodes = encoder_data + if blocks is None or len(blocks) == 0: + # no GNN message passing + encode_embs = self.comput_input_embed(input_nodes, node_feats) + else: + # GNN message passing + encode_embs = self.compute_embed_step(blocks, node_feats, input_nodes) + + # Call emb normalization. + encode_embs = self.normalize_node_embs(encode_embs) + + task_type, decoder, loss_func, weight = self.task_pool[task_id] + + if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + labels = decoder_data + assert len(labels) == 1, \ + "In multi-task learning, only support do prediction " \ + "on one node type for a single node task." + pred_loss = 0 + target_ntype = list(labels.keys())[0] + + assert target_ntype in encode_embs, f"Node type {target_ntype} not in encode_embs" + assert target_ntype in labels, f"Node type {target_ntype} not in labels" + emb = encode_embs[target_ntype] + ntype_labels = labels[target_ntype] + ntype_logits = decoder(emb) + pred_loss = loss_func(ntype_logits, ntype_labels) + + return pred_loss, weight + elif task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + target_edges, target_edge_feats, labels = decoder_data + assert len(labels) == 1, \ + "In multi-task learning, only support do prediction " \ + "on one edge type for a single edge task." + pred_loss = 0 + target_etype = list(labels.keys())[0] + logits = decoder(target_edges, encode_embs, target_edge_feats) + pred_loss = loss_func(logits, labels[target_etype]) + + return pred_loss, weight + elif task_type == BUILTIN_TASK_LINK_PREDICTION: + pos_graph, neg_graph, pos_edge_feats, neg_edge_feats = decoder_data + + pos_score = decoder(pos_graph, encode_embs, pos_edge_feats) + neg_score = decoder(neg_graph, encode_embs, neg_edge_feats) + assert pos_score.keys() == neg_score.keys(), \ + "Positive scores and Negative scores must have edges of same" \ + f"edge types, but get {pos_score.keys()} and {neg_score.keys()}" + pred_loss = loss_func(pos_score, neg_score) + return pred_loss, weight + else: + raise TypeError("Unknow task type %s", task_type) + + + def predict(self, task_id, mini_batch, return_proba=False): + """ The forward function for multi-task inference + """ + assert task_id in self.task_pool, \ + f"Unknown task: {task_id} in multi-task learning." \ + f"Existing tasks are {self.task_pool.keys()}" + + encoder_data, decoder_data = mini_batch + # message passing graph, node features, edge features, seed nodes + blocks, node_feats, _, input_nodes = encoder_data + if blocks is None or len(blocks) == 0: + # no GNN message passing + encode_embs = self.comput_input_embed(input_nodes, node_feats) + else: + # GNN message passing + encode_embs = self.compute_embed_step(blocks, node_feats, input_nodes) + + # Call emb normalization. + encode_embs = self.normalize_node_embs(encode_embs) + + task_type, decoder, _, _ = self.task_pool[task_id] + + if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + assert len(encode_embs) == 1, \ + "In multi-task learning, only support do prediction " \ + "on one node type for a single node task." + target_ntype = list(encode_embs.keys())[0] + predicts = {} + if return_proba: + predicts[target_ntype] = decoder.predict_proba(encode_embs[target_ntype]) + else: + predicts[target_ntype] = decoder.predict(encode_embs[target_ntype]) + return predicts + elif task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + target_edges, target_edge_feats, _ = decoder_data + if return_proba: + return decoder.predict_proba(target_edges, encode_embs, target_edge_feats) + return decoder.predict(target_edges, encode_embs, target_edge_feats) + elif task_type == BUILTIN_TASK_LINK_PREDICTION: + logging.warning("Prediction for link prediction is not implemented") + return None + else: + raise TypeError("Unknow task type %s", task_type) diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index adbb2027c4..5e30f8f7fe 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -311,6 +311,44 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= Labels if return_labels is True """ device = model.device + decoder = model.decoder + model.eval() + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_node_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch prediction. + + Parameters + ---------- + decoder : GSNodeDecoder + The GraphStorm node decoder + emb : dict of Tensor + The GNN embeddings + loader : GSgnnNodeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba : bool + Whether or not to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels. + + Returns + ------- + dict of Tensor : + Prediction results. + dict of Tensor : + Labels if return_labels is True + """ data = loader.data if return_label: @@ -321,15 +359,12 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= preds = {} labels = {} # TODO(zhengda) I need to check if the data loader only returns target nodes. - model.eval() with th.no_grad(): for input_nodes, seeds, _ in loader: for ntype, in_nodes in input_nodes.items(): - if isinstance(model.decoder, th.nn.ModuleDict): - assert ntype in model.decoder, f"Node type {ntype} not in decoder" - decoder = model.decoder[ntype] - else: - decoder = model.decoder + if isinstance(decoder, th.nn.ModuleDict): + assert ntype in decoder, f"Node type {ntype} not in decoder" + decoder = decoder[ntype] if return_proba: pred = decoder.predict_proba(emb[ntype][in_nodes].to(device)) else: @@ -344,8 +379,7 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= if ntype in labels: labels[ntype].append(lbl[ntype]) else: - labels[ntype] = [lbl[ntype]] - model.train() + labels[ntype] = lbl[ntype] for ntype, ntype_pred in preds.items(): preds[ntype] = th.cat(ntype_pred) diff --git a/python/graphstorm/run/gsgnn_mt/__init__.py b/python/graphstorm/run/gsgnn_mt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py new file mode 100644 index 0000000000..acf31eec57 --- /dev/null +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -0,0 +1,88 @@ +""" + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + GSgnn multi-task learning +""" +import os + +import graphstorm as gs +from graphstorm.config import get_argument_parser +from graphstorm.config import GSConfig +from graphstorm.dataloading import GSgnnData +from graphstorm.trainer import GSgnnMultiTaskLearningTrainer + +from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph +from graphstorm.utils import get_lm_ntypes + +def main(config_args): + """ main function + """ + config = GSConfig(config_args) + config.verify_arguments(True) + + use_wg_feats = use_wholegraph(config.part_config) + gs.initialize(ip_config=config.ip_config, backend=config.backend, + local_rank=config.local_rank, + use_wholegraph=config.use_wholegraph_embed or use_wg_feats) + rt_profiler.init(config.profile_path, rank=gs.get_rank()) + sys_tracker.init(config.verbose, rank=gs.get_rank()) + train_data = GSgnnData(config.part_config, + node_feat_field=config.node_feat_name, + edge_feat_field=config.edge_feat_name, + lm_feat_ntypes=get_lm_ntypes(config.node_lm_configs)) + + tasks = config.multi_tasks + train_dataloaders = [] + val_dataloaders = [] + test_dataloaders = [] + decoders = [] + for task in tasks: + train_loader = create_task_train_dataloader(task, config, train_task=True) + decoder = create_task_decoder(task, config) + val_loader = create_task_val_dataloader(task, config, train_task=False) + test_loader = create_task_test_dataloader(task, config, train_task=False) + + train_dataloader = GSgnnMultiTaskDataLoader(train_dataloaders) + val_dataloader = GSgnnMultiTaskDataLoader(val_dataloaders) + test_dataloader = GSgnnMultiTaskDataLoader(test_dataloaders) + + trainer = GSgnnMultiTaskLearningTrainer(model, topk_model_to_save=config.topk_model_to_save) + + # Preparing input layer for training or inference. + # The input layer can pre-compute node features in the preparing step if needed. + # For example pre-compute all BERT embeddings + model.prepare_input_encoder(train_data) + if config.save_model_path is not None: + save_model_path = config.save_model_path + elif config.save_embed_path is not None: + # If we need to save embeddings, we need to save the model somewhere. + save_model_path = os.path.join(config.save_embed_path, "model") + else: + save_model_path = None + + trainer.fit(train_loader=train_dataloader, + val_loader=val_dataloader, + test_loader=test_dataloader, + num_epochs=config.num_epochs, + save_model_path=save_model_path, + use_mini_batch_infer=config.use_mini_batch_infer, + save_model_frequency=config.save_model_frequency, + save_perf_results_path=config.save_perf_results_path, + freeze_input_layer_epochs=config.freeze_lm_encoder_epochs, + max_grad_norm=config.max_grad_norm, + grad_norm_type=config.grad_norm_type) + + if config.save_embed_path is not None: + # Save node embeddings diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py new file mode 100644 index 0000000000..06b206866d --- /dev/null +++ b/python/graphstorm/trainer/mt_trainer.py @@ -0,0 +1,491 @@ +""" + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + GraphStorm trainer for multi-task learning. +""" + +import time +import resource +import logging +import torch as th +from torch.nn.parallel import DistributedDataParallel +import dgl + +from ..config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) +from ..model import (do_full_graph_inference, + do_mini_batch_inference,GSgnnModelBase, GSgnnModel) +from .gsgnn_trainer import GSgnnTrainer +from ..model import (run_node_mini_batch_predict, + run_edge_mini_batch_predict, + run_lp_mini_batch_predict) + +from ..utils import sys_tracker, rt_profiler, print_mem, get_rank +from ..utils import barrier, is_distributed, get_backend + +def run_node_predict_mini_batch(model, data, task_info, mini_batch, device): + g = data.g + input_nodes, seeds, blocks = mini_batch + if not isinstance(input_nodes, dict): + assert len(g.ntypes) == 1 + input_nodes = {g.ntypes[0]: input_nodes} + nfeat_fields = task_info.dataloader.node_feat_fields + label_field = task_info.dataloader.label_field + input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) + lbl = data.get_node_feats(seeds, label_field, device) + blocks = [block.to(device) for block in blocks] + # TODO: we don't support edge features for now. + loss = model(task_info.task_id, ((blocks, input_feats, None, input_nodes), lbl)) + + return loss + +def run_edge_predict_mini_batch(model, data, task_info, mini_batch, device): + input_nodes, batch_graph, blocks = mini_batch + if not isinstance(input_nodes, dict): + assert len(batch_graph.ntypes) == 1 + input_nodes = {batch_graph.ntypes[0]: input_nodes} + nfeat_fields = task_info.dataloader.node_feat_fields + input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) + + if task_info.dataloader.decoder_edge_feat_fields is not None: + input_edges = {etype: batch_graph.edges[etype].data[dgl.EID] \ + for etype in batch_graph.canonical_etypes} + edge_decoder_feats = \ + data.get_edge_feats(input_edges, + task_info.dataloader.decoder_edge_feat_fields, + device) + edge_decoder_feats = {etype: feat.to(th.float32) \ + for etype, feat in edge_decoder_feats.items()} + else: + edge_decoder_feats = None + + # retrieving seed edge id from the graph to find labels + assert len(batch_graph.etypes) == 1 + target_etype = batch_graph.canonical_etypes[0] + # TODO(zhengda) the data loader should return labels directly. + seeds = batch_graph.edges[target_etype[1]].data[dgl.EID] + + label_field = task_info.dataloader.label_field + lbl = data.get_edge_feats({target_etype: seeds}, label_field, device) + blocks = [block.to(device) for block in blocks] + batch_graph = batch_graph.to(device) + rt_profiler.record('train_graph2GPU') + + # TODO(zhengda) we don't support edge features for now. + loss = model(task_info.task_id, + ((blocks, input_feats, None, input_nodes), + (batch_graph, edge_decoder_feats, lbl))) + return loss + +def run_link_predict_mini_batch(model, data, task_info, mini_batch, device): + input_nodes, pos_graph, neg_graph, blocks = mini_batch + + if not isinstance(input_nodes, dict): + assert len(pos_graph.ntypes) == 1 + input_nodes = {pos_graph.ntypes[0]: input_nodes} + + nfeat_fields = task_info.dataloader.node_feat_fields + input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) + + if task_info.dataloader.pos_graph_feat_fields is not None: + input_edges = {etype: pos_graph.edges[etype].data[dgl.EID] \ + for etype in pos_graph.canonical_etypes} + pos_graph_feats = data.get_edge_feats(input_edges, + task_info.dataloader.pos_graph_feat_fields, + device) + else: + pos_graph_feats = None + + pos_graph = pos_graph.to(device) + neg_graph = neg_graph.to(device) + blocks = [blk.to(device) for blk in blocks] + + # TODO: we don't support edge features for now. + loss = model(task_info.task_id, + ((blocks, input_feats, None, input_nodes), + (pos_graph, neg_graph,pos_graph_feats, None))) + return loss + +def multi_task_mini_batch_predict( + model, emb, loader, device, return_proba=True, return_label=False): + """ conduct mini batch prediction on multiple tasks + + Parameters + ---------- + model: GSgnnMultiTaskModelInterface, GSgnnModel + Multi-task learning model + emb : dict of Tensor + The GNN embeddings + loader: GSgnnMultiTaskDataLoader + The mini-batch dataloader. + device: th.device + Device used to compute test scores. + return_proba: bool + Whether to return all the predictions or the maximum prediction. + + Returns + ------- + list: prediction results of each task + """ + dataloaders = loader.dataloaders + task_infos = loader.task_infos + task_pool = model.task_pool + res = {} + with th.no_grad(): + for dataloader, task_info in zip(dataloaders, task_infos): + if task_info.task_type in \ + [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + task_type, decoder, _, _ = task_pool[task_info.task_id] + assert task_info.task_type == task_type + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + dataloader, + device, + return_proba, + return_label) + res[task_info.task_id] = (preds, labels) + elif task_info.task_type in \ + [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + task_type, decoder, _, _ = task_pool[task_info.task_id] + assert task_info.task_type == task_type + preds, labels = \ + run_edge_mini_batch_predict(decoder, + emb, + loader, + device, + return_proba, + return_label) + res[task_info.task_id] = (preds, labels) + elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: + task_type, decoder, _, _ = task_pool[task_info.task_id] + assert task_info.task_type == task_type + ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device) + res[task_info.task_id] = ranking + else: + raise TypeError("Unknown task %s", task_info) + + return res + +class GSgnnMultiTaskLearningTrainer(GSgnnTrainer): + r""" A trainer for multi-task learning + + This class is used to train models for multi task learning. + + It makes use of the functions provided by `GSgnnTrainer` + to define two main functions: `fit` that performs the training + for the model that is provided when the object is created, + and `eval` that evaluates a provided model against test and + validation data. + + Parameters + ---------- + model : GSgnnMultiTaskModel + The GNN model for node prediction. + topk_model_to_save : int + The top K model to save. + """ + def __init__(self, model, topk_model_to_save=1): + super(GSgnnMultiTaskLearningTrainer, self).__init__(model, topk_model_to_save) + assert isinstance(model) and isinstance(model, GSgnnModelBase), \ + "The input model is not a GSgnnModel model. Please implement GSgnnModelBase." + + def _run_mini_batch(self, data, model, task_info, mini_batch, device): + """ run mini batch for a single task + + Parameters + ---------- + data: GSgnnData + Graph data + task_info: TaskInfo + task meta information + mini_batch: tuple + mini-batch info + + Return + ------ + loss + """ + if task_info.task_type in \ + [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + return run_node_predict_mini_batch(model, + data, + task_info, + mini_batch, + device) + elif task_info.task_type in \ + [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + return run_edge_predict_mini_batch(model, + data, + task_info, + mini_batch, + device) + elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: + return run_link_predict_mini_batch(model, + data, + task_info, + mini_batch, + device) + else: + raise TypeError("Unknown task %s", task_info) + + def fit(self, train_loader, + num_epochs, + val_loader=None, + test_loader=None, + use_mini_batch_infer=True, + save_model_path=None, + save_model_frequency=-1, + save_perf_results_path=None, + freeze_input_layer_epochs=0, + max_grad_norm=None, + grad_norm_type=2.0): + """ The fit function for multi-task learning. + + Performs the training for `self.model`. Iterates over all the tasks + and run one mini-batch for each task in an iteration. The loss will be + accumulated. Performs the backwards step using `self.optimizer`. + If an evaluator has been assigned to the trainer, it will run evaluation + at the end of every epoch. + + Parameters + ---------- + train_loader : GSgnnMultiTaskDataLoader + The mini-batch sampler for training. + num_epochs : int + The max number of epochs to train the model. + val_loader : GSgnnMultiTaskDataLoader + The mini-batch sampler for computing validation scores. The validation scores + are used for selecting models. + test_loader : GSgnnMultiTaskDataLoader + The mini-batch sampler for computing test scores. + use_mini_batch_infer : bool + Whether or not to use mini-batch inference. + save_model_path : str + The path where the model is saved. + save_model_frequency : int + The number of iteration to train the model before saving the model. + save_perf_results_path : str + The path of the file where the performance results are saved. + freeze_input_layer_epochs: int + Freeze the input layer for N epochs. This is commonly used when + the input layer contains language models. + Default: 0, no freeze. + max_grad_norm: float + Clip the gradient by the max_grad_norm to ensure stability. + Default: None, no clip. + grad_norm_type: float + Norm type for the gradient clip + Default: 2.0 + """ + # Check the correctness of configurations. + if self.evaluator is not None: + assert val_loader is not None, \ + "The evaluator is provided but validation set is not provided." + if not use_mini_batch_infer: + assert isinstance(self._model, GSgnnModel), \ + "Only GSgnnModel supports full-graph inference." + + # with freeze_input_layer_epochs is 0, computation graph will not be changed. + static_graph = freeze_input_layer_epochs == 0 + on_cpu = self.device == th.device('cpu') + if is_distributed(): + model = DistributedDataParallel(self._model, + device_ids=None if on_cpu else [self.device], + output_device=None if on_cpu else self.device, + find_unused_parameters=True, + static_graph=static_graph) + else: + model = self._model + device = model.device + data = train_loader.data + + # Preparing input layer for training or inference. + # The input layer can pre-compute node features in the preparing step if needed. + # For example pre-compute all BERT embeddings + if freeze_input_layer_epochs > 0: + self._model.freeze_input_encoder(data) + # TODO(xiangsx) Support freezing gnn encoder and decoder + + # training loop + total_steps = 0 + sys_tracker.check('start training') + g = data.g + for epoch in range(num_epochs): + model.train() + epoch_start = time.time() + if freeze_input_layer_epochs <= epoch: + self._model.unfreeze_input_encoder() + # TODO(xiangsx) Support unfreezing gnn encoder and decoder + + rt_profiler.start_record() + batch_tic = time.time() + for i, task_mini_batches in enumerate(train_loader): + rt_profiler.record('train_sample') + total_steps += 1 + + losses = [] + for (task_info, mini_batch) in task_mini_batches: + loss, weight = self._run_mini_batch(data, task_info, mini_batch) + losses.append((loss, weight)) + + reg_loss = th.tensor(0.).to(device) + for d_para in model.get_dense_params(): + reg_loss += d_para.square().sum() + alpha_l2norm = model.alpha_l2norm + + mt_loss = reg_loss * alpha_l2norm + mt_loss += loss * weight + rt_profiler.record('train_forward') + self.optimizer.zero_grad() + loss.backward() + rt_profiler.record('train_backward') + self.optimizer.step() + rt_profiler.record('train_step') + + if max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm, grad_norm_type) + self.log_metric("Train loss", loss.item(), total_steps) + + if i % 20 == 0 and get_rank() == 0: + rt_profiler.print_stats() + logging.info("Epoch %05d | Batch %03d | Train Loss: %.4f | Time: %.4f", + epoch, i, loss.item(), time.time() - batch_tic) + + val_score = None + if self.evaluator is not None and \ + self.evaluator.do_eval(total_steps, epoch_end=False): + + val_score = self.eval(model.module if is_distributed() else model, + data, val_loader, test_loader, total_steps) + # TODO(xiangsx): Add early stop support + + # Every n iterations, check to save the top k models. Will save + # the last k model or all models depends on the setting of top k + # TODO(xiangsx): support saving the best top k model. + if save_model_frequency > 0 and \ + total_steps % save_model_frequency == 0 and \ + total_steps != 0: + + if self.evaluator is None or val_score is not None: + # We will save the best model when + # 1. There is no evaluation, we will keep the + # latest K models. + # 2. (TODO) There is evaluaiton, we need to follow the + # guidance of validation score. + self.save_topk_models(model, epoch, i, None, save_model_path) + + batch_tic = time.time() + rt_profiler.record('train_eval') + + # ------- end of an epoch ------- + barrier() + epoch_time = time.time() - epoch_start + if get_rank() == 0: + logging.info("Epoch %d take %.3f seconds", epoch, epoch_time) + + val_score = None + if self.evaluator is not None and self.evaluator.do_eval(total_steps, epoch_end=True): + val_score = self.eval(model.module if is_distributed() else model, + data, val_loader, test_loader, total_steps) + + # After each epoch, check to save the top k models. + # Will either save the last k model or all models + # depends on the setting of top k. + self.save_topk_models(model, epoch, None, None, save_model_path) + rt_profiler.print_stats() + barrier() + + + + rt_profiler.save_profile() + print_mem(device) + if get_rank() == 0 and self.evaluator is not None: + # final evaluation + output = {'best_test_score': self.evaluator.best_test_score, + 'best_val_score':self.evaluator.best_val_score, + 'last_test_score': self.evaluator.last_test_score, + 'last_val_score':self.evaluator.last_val_score, + 'peak_GPU_mem_alloc_MB': th.cuda.max_memory_allocated(device) / 1024 / 1024, + 'peak_RAM_mem_alloc_MB': \ + resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, + 'best validation iteration': \ + self.evaluator.best_iter_num, + 'best model path': \ + self.get_best_model_path() if save_model_path is not None else None} + self.log_params(output) + + def eval(self, model, data, val_loader, test_loader, total_steps, + use_mini_batch_infer=False, return_proba=True): + """ do the model evaluation using validation and test sets + + Parameters + ---------- + model : Pytorch model + The GNN model. + data : GSgnnData + The training dataset + val_loader: GSNodeDataLoader + The dataloader for validation data + test_loader : GSNodeDataLoader + The dataloader for test data. + total_steps: int + Total number of iterations. + use_mini_batch_infer: bool + Whether do mini-batch inference + return_proba: bool + Whether to return all the predictions or the maximum prediction. + + Returns + ------- + dict: validation score + """ + test_start = time.time() + sys_tracker.check('before prediction') + model.eval() + + if use_mini_batch_infer: + emb = do_mini_batch_inference(model, data, + fanout=val_loader.fanout, + task_tracker=self.task_tracker) + else: + emb = do_full_graph_inference(model, data, + fanout=val_loader.fanout, + task_tracker=self.task_tracker) + sys_tracker.check('compute embeddings') + + val_scores = \ + multi_task_mini_batch_predict(model, emb, val_loader, self.device, return_proba) \ + if val_loader is not None else None + + test_scores = \ + multi_task_mini_batch_predict(model, emb, test_loader, self.device, return_proba) \ + if test_loader is not None else None + + sys_tracker.check('after_test_score') + val_score, test_score = self.evaluator.evaluate( + val_scores, test_scores, total_steps) + sys_tracker.check('evaluate validation/test') + model.train() + + if get_rank() == 0: + self.log_print_metrics(val_score=val_score, + test_score=test_score, + dur_eval=time.time() - test_start, + total_steps=total_steps) + return val_score diff --git a/training_scripts/gsgnn_mt/README.md b/training_scripts/gsgnn_mt/README.md new file mode 100644 index 0000000000..9a9010ed8f --- /dev/null +++ b/training_scripts/gsgnn_mt/README.md @@ -0,0 +1,18 @@ +# Multi-task Learning Example Yaml Files +This folder presents example yaml files for multi-task learning with Movielens datasets. + +## Build a graph for multi-task learning on Movielens dataset +``` +python3 $GS_HOME/tests/end2end-tests/data_gen/process_movielens.py + +python3 -m graphstorm.gconstruct.construct_graph \ + --conf-file $GS_HOME/training_scripts/gsgnn_mt/ml_ncr_lp.json \ + --num-processes 1 \ + --output-dir movielen_100k_multitask_1p_4t \ + --graph-name movie-lens-100k \ + --add-reverse-edges +``` + +## Run the example +``` +``` \ No newline at end of file diff --git a/training_scripts/gsgnn_mt/ml_ncr_lp.json b/training_scripts/gsgnn_mt/ml_ncr_lp.json new file mode 100644 index 0000000000..85990db309 --- /dev/null +++ b/training_scripts/gsgnn_mt/ml_ncr_lp.json @@ -0,0 +1,84 @@ +{ + "version": "gconstruct-v0.1", + "nodes": [ + { + "node_id_col": "id", + "node_type": "user", + "format": {"name": "hdf5"}, + "files": "/data/ml-100k/user.hdf5", + "features": [ + { + "feature_col": "feat" + } + ] + }, + { + "node_id_col": "id", + "node_type": "movie", + "format": {"name": "parquet"}, + "files": "/data/ml-100k/movie.parquet", + "features": [ + { + "feature_col": "title", + "transform": { + "name": "bert_hf", + "bert_model": "bert-base-uncased", + "max_seq_length": 16 + } + } + ], + "labels": [ + { + "label_col": "label", + "task_type": "classification", + "split_pct": [0.8, 0.1, 0.1], + "mask_field_names": ["train_mask_c0", + "val_mask_c0", + "test_mask_c0"] + }, + { + "label_col": "label", + "task_type": "classification", + "split_pct": [0.7, 0.2, 0.1], + "mask_field_names": ["train_mask_c1", + "val_mask_c1", + "test_mask_c1"] + } + ] + } + ], + "edges": [ + { + "source_id_col": "src_id", + "dest_id_col": "dst_id", + "relation": ["user", "rating", "movie"], + "format": {"name": "parquet"}, + "files": "/data/ml-100k/edges.parquet", + "labels": [ + { + "label_col": "rate", + "task_type": "classification", + "split_pct": [0.1, 0.1, 0.1], + "mask_field_names": ["train_mask_field_c", + "val_mask_field_c", + "test_mask_field_c"] + }, + { + "label_col": "rate", + "task_type": "regression", + "split_pct": [0.1, 0.1, 0.1], + "mask_field_names": ["train_mask_field_r", + "val_mask_field_r", + "test_mask_field_r"] + }, + { + "task_type": "link_prediction", + "split_pct": [0.1, 0.1, 0.1], + "mask_field_names": ["train_mask_field_l", + "val_mask_field_l", + "test_mask_field_l"] + } + ] + } + ] +} \ No newline at end of file diff --git a/training_scripts/gsgnn_mt/ml_ncr_lp_yaml b/training_scripts/gsgnn_mt/ml_ncr_lp_yaml new file mode 100644 index 0000000000..66f56da955 --- /dev/null +++ b/training_scripts/gsgnn_mt/ml_ncr_lp_yaml @@ -0,0 +1,74 @@ +--- +version: 1.0 +gsf: + basic: + backend: gloo + verbose: false + save_perf_results_path: null + gnn: + model_encoder_type: rgcn + fanout: "4" + num_layers: 1 + hidden_size: 128 + use_mini_batch_infer: true + input: + restore_model_path: null + output: + save_model_path: null + save_embed_path: null + hyperparam: + dropout: 0. + lr: 0.001 + lm_tune_lr: 0.0001 + num_epochs: 3 + wd_l2norm: 0 + no_validation: false + rgcn: + num_bases: -1 + use_self_loop: true + sparse_optimizer_lr: 1e-2 + use_node_embeddings: false + multi_task_learning: + - node_classification: + target_ntype: "movie" + label_field: "label" + multilabel: false + num_classes: 19 + batch_size: 16 # will overwrite the global batch_size + mask_fields: + - "train_mask_field_nc" + - "val_mask_field_nc" + - "test_mask_field_nc" + task_weight: 1.0 + - edge_classification: + target_etype: + - "user,rating,movie" + reverse_edge_types_map: + - "user,rating,rating-rev,movie" + label_field: "rate" + multilabel: false + num_classes: 5 + num_decoder_basis: 32 + exclude_training_targets: false + batch_size: 10 # will overwrite the global batch_size + mask_fields: + - "train_mask_field_ec" + - "val_mask_field_ec" + - "test_mask_field_ec" + task_weight: 0.5 # weight of the task + - link_prediction: + num_negative_edges: 4 + num_negative_edges_eval: 100 + train_negative_sampler: joint + eval_etype: + - "user,rating,movie" + train_etype: + - "user,rating,movie" + exclude_training_targets: false + reverse_edge_types_map: [] + batch_size: 10 # will overwrite the global batch_size + mask_fields: + - "train_mask_field_lp" + - "" # empty means there is no validation mask + - "" # empty means there is no test mask + task_weight: 1.0 \ No newline at end of file From 8825cb186c47128dc33635e943fe283f269444d4 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 2 May 2024 23:01:51 -0700 Subject: [PATCH 02/79] Update --- python/graphstorm/__init__.py | 2 + python/graphstorm/dataloading/__init__.py | 1 + python/graphstorm/gsf.py | 68 +++++++++ python/graphstorm/run/gsgnn_lp/gsgnn_lp.py | 64 +------- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 162 ++++++++++++++++++++- 5 files changed, 235 insertions(+), 62 deletions(-) diff --git a/python/graphstorm/__init__.py b/python/graphstorm/__init__.py index e6b34f4f06..71469ec8ee 100644 --- a/python/graphstorm/__init__.py +++ b/python/graphstorm/__init__.py @@ -28,3 +28,5 @@ from .gsf import create_builtin_lp_model from .gsf import create_builtin_edge_model from .gsf import create_builtin_node_model +from .gsf import (get_lp_train_sampler, + get_lp_eval_sampler) diff --git a/python/graphstorm/dataloading/__init__.py b/python/graphstorm/dataloading/__init__.py index 77b4e653be..c90525f2e9 100644 --- a/python/graphstorm/dataloading/__init__.py +++ b/python/graphstorm/dataloading/__init__.py @@ -34,6 +34,7 @@ from .dataloading import (GSgnnEdgeDataLoaderBase, GSgnnLinkPredictionDataLoaderBase, GSgnnNodeDataLoaderBase) +from .dataloading import GSgnnMultiTaskDataLoader from .dataset import GSgnnData diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index c8df8406be..e0d4106737 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -64,6 +64,31 @@ LinkPredictContrastiveDistMultDecoder, LinkPredictWeightedDotDecoder, LinkPredictWeightedDistMultDecoder) +from .dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_JOINT_NEG_SAMPLER,BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, + BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER, + BUILTIN_LP_LOCALJOINT_NEG_SAMPLER, + BUILTIN_LP_ALL_ETYPE_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_ALL_ETYPE_JOINT_NEG_SAMPLER, + BUILTIN_FAST_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_FAST_LP_JOINT_NEG_SAMPLER, + BUILTIN_FAST_LP_LOCALUNIFORM_NEG_SAMPLER, + BUILTIN_FAST_LP_LOCALJOINT_NEG_SAMPLER) +from .dataloading import (FastGSgnnLinkPredictionDataLoader, + FastGSgnnLPJointNegDataLoader, + FastGSgnnLPLocalUniformNegDataLoader, + FastGSgnnLPLocalJointNegDataLoader, + GSgnnLinkPredictionDataLoader, + GSgnnLPJointNegDataLoader, + GSgnnLPLocalUniformNegDataLoader, + GSgnnLPLocalJointNegDataLoader, + GSgnnLPInBatchJointNegDataLoader, + GSgnnAllEtypeLPJointNegDataLoader, + GSgnnAllEtypeLinkPredictionDataLoader) +from .dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) + from .tracker import get_task_tracker_class def initialize(ip_config=None, backend='gloo', local_rank=0, use_wholegraph=False): @@ -674,3 +699,46 @@ def check_homo(g): def create_builtin_task_tracker(config): tracker_class = get_task_tracker_class(config.task_tracker) return tracker_class(config.eval_frequency) + +def get_lp_eval_sampler(config): + test_dataloader_cls = None + if config.eval_etypes_negative_dstnode is not None: + test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + test_dataloader_cls = GSgnnLinkPredictionTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: + test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader + else: + raise ValueError('Unknown test negative sampler.' + 'Supported test negative samplers include ' + f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') + return test_dataloader_cls + +def get_lp_train_sampler(config): + dataloader_cls = None + if config.train_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + dataloader_cls = GSgnnLinkPredictionDataLoader + elif config.train_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: + dataloader_cls = GSgnnLPJointNegDataLoader + elif config.train_negative_sampler == BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER: + dataloader_cls = GSgnnLPInBatchJointNegDataLoader + elif config.train_negative_sampler == BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER: + dataloader_cls = GSgnnLPLocalUniformNegDataLoader + elif config.train_negative_sampler == BUILTIN_LP_LOCALJOINT_NEG_SAMPLER: + dataloader_cls = GSgnnLPLocalJointNegDataLoader + elif config.train_negative_sampler == BUILTIN_LP_ALL_ETYPE_UNIFORM_NEG_SAMPLER: + dataloader_cls = GSgnnAllEtypeLinkPredictionDataLoader + elif config.train_negative_sampler == BUILTIN_LP_ALL_ETYPE_JOINT_NEG_SAMPLER: + dataloader_cls = GSgnnAllEtypeLPJointNegDataLoader + elif config.train_negative_sampler == BUILTIN_FAST_LP_UNIFORM_NEG_SAMPLER: + dataloader_cls = FastGSgnnLinkPredictionDataLoader + elif config.train_negative_sampler == BUILTIN_FAST_LP_JOINT_NEG_SAMPLER: + dataloader_cls = FastGSgnnLPJointNegDataLoader + elif config.train_negative_sampler == BUILTIN_FAST_LP_LOCALUNIFORM_NEG_SAMPLER: + dataloader_cls = FastGSgnnLPLocalUniformNegDataLoader + elif config.train_negative_sampler == BUILTIN_FAST_LP_LOCALJOINT_NEG_SAMPLER: + dataloader_cls = FastGSgnnLPLocalJointNegDataLoader + else: + raise ValueError('Unknown negative sampler') + + return dataloader_cls diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index 5ea7af5457..96990f3123 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -23,31 +23,7 @@ from graphstorm.config import GSConfig from graphstorm.trainer import GSgnnLinkPredictionTrainer from graphstorm.dataloading import GSgnnData -from graphstorm.dataloading import GSgnnLinkPredictionDataLoader -from graphstorm.dataloading import (GSgnnLPJointNegDataLoader, - GSgnnLPLocalUniformNegDataLoader, - GSgnnLPLocalJointNegDataLoader, - GSgnnLPInBatchJointNegDataLoader) -from graphstorm.dataloading import GSgnnAllEtypeLPJointNegDataLoader -from graphstorm.dataloading import GSgnnAllEtypeLinkPredictionDataLoader -from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, - GSgnnLinkPredictionJointTestDataLoader, - GSgnnLinkPredictionPredefinedTestDataLoader) -from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, - BUILTIN_LP_JOINT_NEG_SAMPLER, - BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, - BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER, - BUILTIN_LP_LOCALJOINT_NEG_SAMPLER) -from graphstorm.dataloading import BUILTIN_LP_ALL_ETYPE_UNIFORM_NEG_SAMPLER -from graphstorm.dataloading import BUILTIN_LP_ALL_ETYPE_JOINT_NEG_SAMPLER -from graphstorm.dataloading import (BUILTIN_FAST_LP_UNIFORM_NEG_SAMPLER, - BUILTIN_FAST_LP_JOINT_NEG_SAMPLER, - BUILTIN_FAST_LP_LOCALUNIFORM_NEG_SAMPLER, - BUILTIN_FAST_LP_LOCALJOINT_NEG_SAMPLER) -from graphstorm.dataloading import (FastGSgnnLinkPredictionDataLoader, - FastGSgnnLPJointNegDataLoader, - FastGSgnnLPLocalUniformNegDataLoader, - FastGSgnnLPLocalJointNegDataLoader) + from graphstorm.eval import GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator from graphstorm.model.utils import save_full_node_embeddings from graphstorm.model import do_full_graph_inference @@ -119,31 +95,9 @@ def main(config_args): tracker.log_params(config.__dict__) trainer.setup_task_tracker(tracker) - if config.train_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: - dataloader_cls = GSgnnLinkPredictionDataLoader - elif config.train_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: - dataloader_cls = GSgnnLPJointNegDataLoader - elif config.train_negative_sampler == BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER: - dataloader_cls = GSgnnLPInBatchJointNegDataLoader - elif config.train_negative_sampler == BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER: - dataloader_cls = GSgnnLPLocalUniformNegDataLoader - elif config.train_negative_sampler == BUILTIN_LP_LOCALJOINT_NEG_SAMPLER: - dataloader_cls = GSgnnLPLocalJointNegDataLoader - elif config.train_negative_sampler == BUILTIN_LP_ALL_ETYPE_UNIFORM_NEG_SAMPLER: - dataloader_cls = GSgnnAllEtypeLinkPredictionDataLoader - elif config.train_negative_sampler == BUILTIN_LP_ALL_ETYPE_JOINT_NEG_SAMPLER: - dataloader_cls = GSgnnAllEtypeLPJointNegDataLoader - elif config.train_negative_sampler == BUILTIN_FAST_LP_UNIFORM_NEG_SAMPLER: - dataloader_cls = FastGSgnnLinkPredictionDataLoader - elif config.train_negative_sampler == BUILTIN_FAST_LP_JOINT_NEG_SAMPLER: - dataloader_cls = FastGSgnnLPJointNegDataLoader - elif config.train_negative_sampler == BUILTIN_FAST_LP_LOCALUNIFORM_NEG_SAMPLER: - dataloader_cls = FastGSgnnLPLocalUniformNegDataLoader - elif config.train_negative_sampler == BUILTIN_FAST_LP_LOCALJOINT_NEG_SAMPLER: - dataloader_cls = FastGSgnnLPLocalJointNegDataLoader - else: - raise ValueError('Unknown negative sampler') + train_idxs = train_data.get_edge_train_set(config.train_etype) + dataloader_cls = gs.get_lp_train_sampler(config) dataloader = dataloader_cls(train_data, train_idxs, config.fanout, config.batch_size, config.num_negative_edges, node_feats=config.node_feat_name, @@ -157,19 +111,11 @@ def main(config_args): num_hard_negs=config.num_train_hard_negatives) # TODO(zhengda) let's use full-graph inference for now. - if config.eval_etypes_negative_dstnode is not None: - test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader - elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: - test_dataloader_cls = GSgnnLinkPredictionTestDataLoader - elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: - test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader - else: - raise ValueError('Unknown test negative sampler.' - 'Supported test negative samplers include ' - f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') + val_dataloader = None test_dataloader = None val_idxs = train_data.get_edge_val_set(config.eval_etype) + test_dataloader_cls = gs.get_lp_eval_sampler(config) if len(val_idxs) > 0: if config.eval_etypes_negative_dstnode is not None: val_dataloader = test_dataloader_cls(train_data, val_idxs, diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index acf31eec57..3816a867f8 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -20,12 +20,167 @@ import graphstorm as gs from graphstorm.config import get_argument_parser from graphstorm.config import GSConfig +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) from graphstorm.dataloading import GSgnnData +from graphstorm.dataloading import (GSgnnNodeDataLoader, + GSgnnEdgeDataLoader, + GSgnnMultiTaskDataLoader) from graphstorm.trainer import GSgnnMultiTaskLearningTrainer from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph from graphstorm.utils import get_lm_ntypes +def create_task_train_dataloader(task, config, train_data): + if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + train_idxs = train_data.get_node_train_set(config.target_ntype) + return GSgnnNodeDataLoader(train_data, + train_idxs, + fanout=config.fanout, + batch_size=config.batch_size, + train_task=True, + node_feats=config.node_feat_name, + label_field=config.label_field) + elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + train_idxs = train_data.get_edge_train_set(config.target_etype) + return GSgnnEdgeDataLoader(train_data, + train_idxs, + fanout=config.fanout, + batch_size=config.batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, + decoder_edge_feats=config.decoder_edge_feat, + train_task=True, + reverse_edge_types_map=config.reverse_edge_types_map, + remove_target_edge_type=config.remove_target_edge_type, + exclude_training_targets=config.exclude_training_targets) + elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + train_idxs = train_data.get_edge_train_set(config.train_etype) + dataloader_cls = gs.get_lp_train_sampler(config) + return dataloader_cls(train_data, + train_idxs, + config.fanout, + config.batch_size, + config.num_negative_edges, + node_feats=config.node_feat_name, + pos_graph_edge_feats=config.lp_edge_weight_for_loss, + train_task=True, + reverse_edge_types_map=config.reverse_edge_types_map, + exclude_training_targets=config.exclude_training_targets, + edge_dst_negative_field=config.train_etypes_negative_dstnode, + num_hard_negs=config.num_train_hard_negatives) + + return None + +def create_task_val_dataloader(task, config, train_data): + fanout = config.eval_fanout if config.use_mini_batch_infer else [] + if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + eval_ntype = config.eval_target_ntype \ + if config.eval_target_ntype is not None else config.target_ntype + val_idxs = train_data.get_node_val_set(eval_ntype) + + if len(val_idxs) > 0: + return GSgnnNodeDataLoader(train_data, + val_idxs, + fanout=fanout, + batch_size=config.eval_batch_size, + train_task=False, + node_feats=config.node_feat_name, + label_field=config.label_field, + construct_feat_ntype=config.construct_feat_ntype, + construct_feat_fanout=config.construct_feat_fanout) + elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + val_idxs = train_data.get_edge_val_set(config.target_etype) + if len(val_idxs) > 0: + return GSgnnEdgeDataLoader(train_data, + val_idxs, + fanout=fanout, + batch_size=config.eval_batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, + decoder_edge_feats=config.decoder_edge_feat, + train_task=False, + reverse_edge_types_map=config.reverse_edge_types_map, + remove_target_edge_type=config.remove_target_edge_type) + elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + val_idxs = train_data.get_edge_val_set(config.eval_etype) + dataloader_cls = gs.get_lp_eval_sampler(config) + if len(val_idxs) > 0: + if config.eval_etypes_negative_dstnode is not None: + return dataloader_cls(train_data, val_idxs, + config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, + fanout=config.eval_fanout, + fixed_test_size=config.fixed_test_size, + node_feats=config.node_feat_name, + pos_graph_edge_feats=config.lp_edge_weight_for_loss) + else: + return dataloader_cls(train_data, val_idxs, + config.eval_batch_size, + config.num_negative_edges_eval, config.eval_fanout, + fixed_test_size=config.fixed_test_size, + node_feats=config.node_feat_name, + pos_graph_edge_feats=config.lp_edge_weight_for_loss) + + return None + +def create_task_test_dataloader(task, config, train_data): + if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + eval_ntype = config.eval_target_ntype \ + if config.eval_target_ntype is not None else config.target_ntype + test_idxs = train_data.get_node_test_set(eval_ntype) + fanout = config.eval_fanout if config.use_mini_batch_infer else [] + if len(test_idxs) > 0: + return GSgnnNodeDataLoader(train_data, + test_idxs, + fanout=fanout, + batch_size=config.eval_batch_size, + train_task=False, + node_feats=config.node_feat_name, + label_field=config.label_field, + construct_feat_ntype=config.construct_feat_ntype, + construct_feat_fanout=config.construct_feat_fanout) + + elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + test_idxs = train_data.get_edge_test_set(config.target_etype) + if len(test_idxs) > 0: + return GSgnnEdgeDataLoader(train_data, + test_idxs, + fanout=fanout, + batch_size=config.eval_batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, + decoder_edge_feats=config.decoder_edge_feat, + train_task=False, + reverse_edge_types_map=config.reverse_edge_types_map, + remove_target_edge_type=config.remove_target_edge_type) + elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + test_idxs = train_data.get_edge_test_set(config.eval_etype) + dataloader_cls = gs.get_lp_eval_sampler(config) + if len(test_idxs) > 0: + if config.eval_etypes_negative_dstnode is not None: + return dataloader_cls(train_data, test_idxs, + config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, + fanout=config.eval_fanout, + fixed_test_size=config.fixed_test_size, + node_feats=config.node_feat_name, + pos_graph_edge_feats=config.lp_edge_weight_for_loss) + else: + return dataloader_cls(train_data, test_idxs, + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, + fixed_test_size=config.fixed_test_size, + node_feats=config.node_feat_name, + pos_graph_edge_feats=config.lp_edge_weight_for_loss) + return None + +def create_task_decoder(task, config): + +def create_evaluator(task, config): + def main(config_args): """ main function """ @@ -49,10 +204,11 @@ def main(config_args): test_dataloaders = [] decoders = [] for task in tasks: - train_loader = create_task_train_dataloader(task, config, train_task=True) + train_loader = create_task_train_dataloader(task, config, train_data) + val_loader = create_task_val_dataloader(task, config) + test_loader = create_task_test_dataloader(task, config) decoder = create_task_decoder(task, config) - val_loader = create_task_val_dataloader(task, config, train_task=False) - test_loader = create_task_test_dataloader(task, config, train_task=False) + evaluator = create_evaluator(task, config) train_dataloader = GSgnnMultiTaskDataLoader(train_dataloaders) val_dataloader = GSgnnMultiTaskDataLoader(val_dataloaders) From 82489ed4ed16c07e11362d6491a505426d8d98c5 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 6 May 2024 11:36:22 -0700 Subject: [PATCH 03/79] Update --- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 3816a867f8..4b55434e80 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -177,7 +177,15 @@ def create_task_test_dataloader(task, config, train_data): pos_graph_edge_feats=config.lp_edge_weight_for_loss) return None -def create_task_decoder(task, config): +def create_task_decoder(task, g, decoder_input_dim, train_task): + if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + return create_builtin_node_decoder(decoder_input_dim, task, train_task) + elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + return create_builtin_edge_decoder(g, decoder_input_dim, task, train_task) + elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + return create_builtin_lp_decoder(g, decoder_input_dim, task, train_task) + + return None, None def create_evaluator(task, config): @@ -197,17 +205,26 @@ def main(config_args): node_feat_field=config.node_feat_name, edge_feat_field=config.edge_feat_name, lm_feat_ntypes=get_lm_ntypes(config.node_lm_configs)) + model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm) + set_encoder(model, g, config, train_task) tasks = config.multi_tasks train_dataloaders = [] val_dataloaders = [] test_dataloaders = [] decoders = [] + encoder_out_dims = model.gnn_encoder.out_dims \ + if model.gnn_encoder is not None \ + else model.node_input_encoder.out_dims for task in tasks: train_loader = create_task_train_dataloader(task, config, train_data) val_loader = create_task_val_dataloader(task, config) test_loader = create_task_test_dataloader(task, config) - decoder = create_task_decoder(task, config) + train_dataloaders.append(train_loader) + val_dataloaders.append(val_loader) + test_dataloaders.append(test_loader) + decoder, loss_func = create_task_decoder(task, g, encoder_out_dims, train_task=True) + model.add_task(task.task_id, task.task_type, decoder, loss_func, task.weight) evaluator = create_evaluator(task, config) train_dataloader = GSgnnMultiTaskDataLoader(train_dataloaders) From 3efbaa1638cb9a59c74a843d31f85b55c85e20a6 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 6 May 2024 11:36:57 -0700 Subject: [PATCH 04/79] Update gsf --- python/graphstorm/gsf.py | 251 +++++++++++++++++++++++++-------------- 1 file changed, 160 insertions(+), 91 deletions(-) diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index e0d4106737..b07f576a1d 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -234,13 +234,18 @@ def create_builtin_node_gnn_model(g, config, train_task): """ return create_builtin_node_model(g, config, train_task) -def create_builtin_node_model(g, config, train_task): - """ Create a built-in model for node prediction. +# pylint: disable=unused-argument +def create_builtin_node_decoder(g, decoder_input_dim, config, train_task): + """ create builtin node decoder according to task config Parameters ---------- g: DGLGraph - The graph used in training and testing + The graph data. + Note(xiang): Make it consistent with create_builtin_edge_decoder. + Reserved for future. + decoder_input_dim: int + Input dimension size of the decoder config: GSConfig Configurations train_task : bool @@ -248,47 +253,69 @@ def create_builtin_node_model(g, config, train_task): Returns ------- - GSgnnModel : The GNN model. + decoder: The node task decoder(s) + loss_func: The loss function(s) """ - if config.training_method["name"] == "glem": - model = GLEM(config.alpha_l2norm, config.target_ntype, **config.training_method["kwargs"]) - elif config.training_method["name"] == "default": - model = GSgnnNodeModel(config.alpha_l2norm) - set_encoder(model, g, config, train_task) - + dropout = config.dropout if train_task else 0 if config.task_type == BUILTIN_TASK_NODE_CLASSIFICATION: if not isinstance(config.num_classes, dict): - model.set_decoder(EntityClassifier(model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims, - config.num_classes, - config.multilabel)) - model.set_loss_func(ClassifyLossFunc(config.multilabel, - config.multilabel_weights, - config.imbalance_class_weights)) + decoder = EntityClassifier(decoder_input_dim, + config.num_classes, + config.multilabel, + dropout=dropout) + loss_func = ClassifyLossFunc(config.multilabel, + config.multilabel_weights, + config.imbalance_class_weights) else: decoder = {} loss_func = {} for ntype in config.target_ntype: - decoder[ntype] = EntityClassifier(model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims, - config.num_classes[ntype], - config.multilabel[ntype]) + decoder[ntype] = EntityClassifier(decoder_input_dim, + config.num_classes[ntype], + config.multilabel[ntype], + dropout=dropout) loss_func[ntype] = ClassifyLossFunc(config.multilabel[ntype], config.multilabel_weights[ntype], config.imbalance_class_weights[ntype]) - - model.set_decoder(decoder) - model.set_loss_func(loss_func) - elif config.task_type == BUILTIN_TASK_NODE_REGRESSION: - model.set_decoder(EntityRegression(model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims)) - model.set_loss_func(RegressionLossFunc()) + decoder = EntityRegression(decoder_input_dim, + dropout=dropout) + loss_func = RegressionLossFunc() else: raise ValueError('unknown node task: {}'.format(config.task_type)) + + return decoder, loss_func + + +def create_builtin_node_model(g, config, train_task): + """ Create a built-in model for node prediction. + + Parameters + ---------- + g: DGLGraph + The graph used in training and testing + config: GSConfig + Configurations + train_task : bool + Whether this model is used for training. + + Returns + ------- + GSgnnModel : The GNN model. + """ + if config.training_method["name"] == "glem": + model = GLEM(config.alpha_l2norm, config.target_ntype, **config.training_method["kwargs"]) + elif config.training_method["name"] == "default": + model = GSgnnNodeModel(config.alpha_l2norm) + set_encoder(model, g, config, train_task) + + encoder_out_dims = model.gnn_encoder.out_dims \ + if model.gnn_encoder is not None \ + else model.node_input_encoder.out_dims + decoder, loss_func = create_builtin_node_decoder(g, encoder_out_dims, config) + model.set_decoder(decoder) + model.set_loss_func(loss_func) + if train_task: model.init_optimizer(lr=config.lr, sparse_optimizer_lr=config.sparse_optimizer_lr, weight_decay=config.wd_l2norm, @@ -313,28 +340,29 @@ def create_builtin_edge_gnn_model(g, config, train_task): """ return create_builtin_edge_model(g, config, train_task) -def create_builtin_edge_model(g, config, train_task): - """ Create a model for edge prediction. +def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task): + """ create builtin edge decoder according to task config Parameters ---------- g: DGLGraph - The graph used in training and testing + The graph data. + decoder_input_dim: int + Input dimension size of the decoder. config: GSConfig - Configurations + Configurations. train_task : bool Whether this model is used for training. Returns ------- - GSgnnModel : The GNN model. + decoder: The node task decoder(s) + loss_func: The loss function(s) """ - model = GSgnnEdgeModel(config.alpha_l2norm) - set_encoder(model, g, config, train_task) + dropout = config.dropout if train_task else 0 if config.task_type == BUILTIN_TASK_EDGE_CLASSIFICATION: num_classes = config.num_classes decoder_type = config.decoder_type - dropout = config.dropout if train_task else 0 # TODO(zhengda) we should support multiple target etypes target_etype = config.target_etype[0] if decoder_type == "DenseBiDecoder": @@ -342,9 +370,7 @@ def create_builtin_edge_model(g, config, train_task): assert config.num_ffn_layers_in_decoder == 0, \ "DenseBiDecoder does not support adding extra feedforward neural network layers" \ "You can increases num_basis to increase the parameter size." - decoder = DenseBiDecoder(in_units=model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims, + decoder = DenseBiDecoder(in_units=decoder_input_dim, num_classes=num_classes, multilabel=config.multilabel, num_basis=num_decoder_basis, @@ -352,9 +378,7 @@ def create_builtin_edge_model(g, config, train_task): regression=False, target_etype=target_etype) elif decoder_type == "MLPDecoder": - decoder = MLPEdgeDecoder(model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims, + decoder = MLPEdgeDecoder(decoder_input_dim, num_classes, multilabel=config.multilabel, target_etype=target_etype, @@ -373,9 +397,7 @@ def create_builtin_edge_model(g, config, train_task): for fname in decoder_edge_feat[target_etype]]) decoder = MLPEFeatEdgeDecoder( - h_dim=model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims, + h_dim=decoder_input_dim, feat_dim=feat_dim, out_dim=num_classes, multilabel=config.multilabel, @@ -384,10 +406,9 @@ def create_builtin_edge_model(g, config, train_task): num_ffn_layers=config.num_ffn_layers_in_decoder) else: assert False, f"decoder {decoder_type} is not supported." - model.set_decoder(decoder) - model.set_loss_func(ClassifyLossFunc(config.multilabel, - config.multilabel_weights, - config.imbalance_class_weights)) + loss_func = ClassifyLossFunc(config.multilabel, + config.multilabel_weights, + config.imbalance_class_weights) elif config.task_type == BUILTIN_TASK_EDGE_REGRESSION: decoder_type = config.decoder_type dropout = config.dropout if train_task else 0 @@ -395,9 +416,7 @@ def create_builtin_edge_model(g, config, train_task): target_etype = config.target_etype[0] if decoder_type == "DenseBiDecoder": num_decoder_basis = config.num_decoder_basis - decoder = DenseBiDecoder(model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims, + decoder = DenseBiDecoder(decoder_input_dim, 1, num_basis=num_decoder_basis, multilabel=False, @@ -405,9 +424,7 @@ def create_builtin_edge_model(g, config, train_task): dropout_rate=dropout, regression=True) elif decoder_type == "MLPDecoder": - decoder = MLPEdgeDecoder(model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims, + decoder = MLPEdgeDecoder(decoder_input_dim, 1, multilabel=False, target_etype=target_etype, @@ -426,9 +443,7 @@ def create_builtin_edge_model(g, config, train_task): for fname in decoder_edge_feat[target_etype]]) decoder = MLPEFeatEdgeDecoder( - h_dim=model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims, + h_dim=decoder_input_dim, feat_dim=feat_dim, out_dim=1, multilabel=False, @@ -437,10 +452,36 @@ def create_builtin_edge_model(g, config, train_task): regression=True) else: assert False, "decoder not supported" - model.set_decoder(decoder) - model.set_loss_func(RegressionLossFunc()) + loss_func = RegressionLossFunc() else: raise ValueError('unknown node task: {}'.format(config.task_type)) + return decoder, loss_func + +def create_builtin_edge_model(g, config, train_task): + """ Create a model for edge prediction. + + Parameters + ---------- + g: DGLGraph + The graph used in training and testing + config: GSConfig + Configurations + train_task : bool + Whether this model is used for training. + + Returns + ------- + GSgnnModel : The GNN model. + """ + model = GSgnnEdgeModel(config.alpha_l2norm) + set_encoder(model, g, config, train_task) + encoder_out_dims = model.gnn_encoder.out_dims \ + if model.gnn_encoder is not None \ + else model.node_input_encoder.out_dims + decoder, loss_func = create_builtin_edge_decoder(g, encoder_out_dims, config, train_task) + model.set_decoder(decoder) + model.set_loss_func(loss_func) + if train_task: model.init_optimizer(lr=config.lr, sparse_optimizer_lr=config.sparse_optimizer_lr, weight_decay=config.wd_l2norm, @@ -465,35 +506,26 @@ def create_builtin_lp_gnn_model(g, config, train_task): """ return create_builtin_lp_model(g, config, train_task) -def create_builtin_lp_model(g, config, train_task): - """ Create a model for link prediction. +# pylint: disable=unused-argument +def create_builtin_lp_decoder(g, decoder_input_dim, config, train_task): + """ create builtin link prediction decoder according to task config Parameters ---------- g: DGLGraph - The graph used in training and testing + The graph data. + decoder_input_dim: int + Input dimension size of the decoder. config: GSConfig - Configurations + Configurations. train_task : bool Whether this model is used for training. Returns ------- - GSgnnModel : The model. + decoder: The node task decoder(s) + loss_func: The loss function(s) """ - model = GSgnnLinkPredictionModel(config.alpha_l2norm, - config.lp_embed_normalizer) - set_encoder(model, g, config, train_task) - num_train_etype = len(config.train_etype) \ - if config.train_etype is not None \ - else len(g.canonical_etypes) # train_etype is None, every etype is used for training - # For backword compatibility, we add this check. - # if train etype is 1, There is no need to use DistMult - assert num_train_etype > 1 or config.lp_decoder_type == BUILTIN_LP_DOT_DECODER, \ - "If number of train etype is 1, please use dot product" - out_dims = model.gnn_encoder.out_dims \ - if model.gnn_encoder is not None \ - else model.node_input_encoder.out_dims if config.lp_decoder_type == BUILTIN_LP_DOT_DECODER: # if the training set only contains one edge type or it is specified in the arguments, # we use dot product as the score function. @@ -501,40 +533,77 @@ def create_builtin_lp_model(g, config, train_task): logging.debug('use dot product for single-etype task.') logging.debug("Using inner product objective for supervision") if config.lp_edge_weight_for_loss is None: - decoder = LinkPredictContrastiveDotDecoder(out_dims) \ + decoder = LinkPredictContrastiveDotDecoder(decoder_input_dim) \ if config.lp_loss_func == BUILTIN_LP_LOSS_CONTRASTIVELOSS else \ - LinkPredictDotDecoder(out_dims) + LinkPredictDotDecoder(decoder_input_dim) else: - decoder = LinkPredictWeightedDotDecoder(out_dims, + decoder = LinkPredictWeightedDotDecoder(decoder_input_dim, config.lp_edge_weight_for_loss) elif config.lp_decoder_type == BUILTIN_LP_DISTMULT_DECODER: if get_rank() == 0: logging.debug("Using distmult objective for supervision") if config.lp_edge_weight_for_loss is None: decoder = LinkPredictContrastiveDistMultDecoder(g.canonical_etypes, - out_dims, + decoder_input_dim, config.gamma) \ if config.lp_loss_func == BUILTIN_LP_LOSS_CONTRASTIVELOSS else \ LinkPredictDistMultDecoder(g.canonical_etypes, - out_dims, + decoder_input_dim, config.gamma) else: decoder = LinkPredictWeightedDistMultDecoder(g.canonical_etypes, - out_dims, + decoder_input_dim, config.gamma, config.lp_edge_weight_for_loss) else: raise Exception(f"Unknow link prediction decoder type {config.lp_decoder_type}") - model.set_decoder(decoder) + if config.lp_loss_func == BUILTIN_LP_LOSS_CONTRASTIVELOSS: - model.set_loss_func(LinkPredictContrastiveLossFunc(config.contrastive_loss_temperature)) + loss_func = LinkPredictContrastiveLossFunc(config.contrastive_loss_temperature) elif config.lp_loss_func == BUILTIN_LP_LOSS_CROSS_ENTROPY: if config.lp_edge_weight_for_loss is None: - model.set_loss_func(LinkPredictBCELossFunc()) + loss_func = LinkPredictBCELossFunc() else: - model.set_loss_func(WeightedLinkPredictBCELossFunc()) + loss_func = WeightedLinkPredictBCELossFunc() else: raise TypeError(f"Unknown link prediction loss function {config.lp_loss_func}") + + return decoder, loss_func + +def create_builtin_lp_model(g, config, train_task): + """ Create a model for link prediction. + + Parameters + ---------- + g: DGLGraph + The graph used in training and testing + config: GSConfig + Configurations + train_task : bool + Whether this model is used for training. + + Returns + ------- + GSgnnModel : The model. + """ + model = GSgnnLinkPredictionModel(config.alpha_l2norm, + config.lp_embed_normalizer) + set_encoder(model, g, config, train_task) + num_train_etype = len(config.train_etype) \ + if config.train_etype is not None \ + else len(g.canonical_etypes) # train_etype is None, every etype is used for training + # For backword compatibility, we add this check. + # if train etype is 1, There is no need to use DistMult + assert num_train_etype > 1 or config.lp_decoder_type == BUILTIN_LP_DOT_DECODER, \ + "If number of train etype is 1, please use dot product" + out_dims = model.gnn_encoder.out_dims \ + if model.gnn_encoder is not None \ + else model.node_input_encoder.out_dims + decoder, loss_func = create_builtin_lp_decoder(g, out_dims, config, train_task) + + model.set_decoder(decoder) + model.set_loss_func(loss_func) + if train_task: model.init_optimizer(lr=config.lr, sparse_optimizer_lr=config.sparse_optimizer_lr, weight_decay=config.wd_l2norm, From e82da16f3712315a373ab28b67b49ebb0cc9aeaf Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 6 May 2024 12:20:41 -0700 Subject: [PATCH 05/79] Update --- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 4b55434e80..727abbea42 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -29,6 +29,7 @@ from graphstorm.dataloading import (GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnMultiTaskDataLoader) +from graphstorm.model.multitask_gnn import GSgnnMultiTaskSharedEncoderModel from graphstorm.trainer import GSgnnMultiTaskLearningTrainer from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph @@ -179,11 +180,11 @@ def create_task_test_dataloader(task, config, train_data): def create_task_decoder(task, g, decoder_input_dim, train_task): if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - return create_builtin_node_decoder(decoder_input_dim, task, train_task) + return gs.create_builtin_node_decoder(decoder_input_dim, task, train_task) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - return create_builtin_edge_decoder(g, decoder_input_dim, task, train_task) + return gs.create_builtin_edge_decoder(g, decoder_input_dim, task, train_task) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - return create_builtin_lp_decoder(g, decoder_input_dim, task, train_task) + return gs.create_builtin_lp_decoder(g, decoder_input_dim, task, train_task) return None, None @@ -206,13 +207,12 @@ def main(config_args): edge_feat_field=config.edge_feat_name, lm_feat_ntypes=get_lm_ntypes(config.node_lm_configs)) model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm) - set_encoder(model, g, config, train_task) + gs.set_encoder(model, train_data.g, config, train_task=True) tasks = config.multi_tasks train_dataloaders = [] val_dataloaders = [] test_dataloaders = [] - decoders = [] encoder_out_dims = model.gnn_encoder.out_dims \ if model.gnn_encoder is not None \ else model.node_input_encoder.out_dims @@ -220,9 +220,9 @@ def main(config_args): train_loader = create_task_train_dataloader(task, config, train_data) val_loader = create_task_val_dataloader(task, config) test_loader = create_task_test_dataloader(task, config) - train_dataloaders.append(train_loader) - val_dataloaders.append(val_loader) - test_dataloaders.append(test_loader) + train_dataloaders.append((task, train_loader)) + val_dataloaders.append((task, val_loader)) + test_dataloaders.append((task, test_loader)) decoder, loss_func = create_task_decoder(task, g, encoder_out_dims, train_task=True) model.add_task(task.task_id, task.task_type, decoder, loss_func, task.weight) evaluator = create_evaluator(task, config) From 9a8f8b4cc1b162232dd4fdff81ac18086e6ceffd Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 6 May 2024 12:32:27 -0700 Subject: [PATCH 06/79] update --- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 93 ++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 727abbea42..8fde0d6732 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -29,8 +29,15 @@ from graphstorm.dataloading import (GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnMultiTaskDataLoader) + +from graphstorm.eval import (GSgnnClassificationEvaluator, + GSgnnRegressionEvaluator, + GSgnnPerEtypeMrrLPEvaluator, + GSgnnMrrLPEvaluator) from graphstorm.model.multitask_gnn import GSgnnMultiTaskSharedEncoderModel from graphstorm.trainer import GSgnnMultiTaskLearningTrainer +from graphstorm.model.utils import save_full_node_embeddings +from graphstorm.model import do_full_graph_inference from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph from graphstorm.utils import get_lm_ntypes @@ -189,6 +196,57 @@ def create_task_decoder(task, g, decoder_input_dim, train_task): return None, None def create_evaluator(task, config): + if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION]: + multilabel = config.multilabel[config.eval_target_ntype] \ + if isinstance(config.multilabel, dict) else config.multilabel + return GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + multilabel, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) + + elif task.task_type in [BUILTIN_TASK_NODE_REGRESSION]: + return GSgnnRegressionEvaluator(config.eval_frequency, + config.eval_metric, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) + elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION]: + return GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) + + elif task.task_type in [BUILTIN_TASK_EDGE_REGRESSION]: + return GSgnnRegressionEvaluator(config.eval_frequency, + config.eval_metric, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) + elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + assert len(config.eval_metric) == 1, \ + "GraphStorm doees not support computing multiple metrics at the same time." + if config.report_eval_per_type: + return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency, + major_etype=config.model_select_etype, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) + else: + return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) + return None def main(config_args): """ main function @@ -259,3 +317,38 @@ def main(config_args): if config.save_embed_path is not None: # Save node embeddings + model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm) + gs.set_encoder(model, train_data.g, config, train_task=True) + best_model_path = trainer.get_best_model_path() + # TODO(zhengda) the model path has to be in a shared filesystem. + model.restore_model(best_model_path) + model = model.to(get_device()) + # Preparing input layer for training or inference. + # The input layer can pre-compute node features in the preparing step if needed. + # For example pre-compute all BERT embeddings + model.prepare_input_encoder(train_data) + + embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, + edge_mask="train_mask", task_tracker=tracker) + + save_full_node_embeddings( + train_data.g, + config.save_embed_path, + embeddings, + node_id_mapping_file=config.node_id_mapping_file, + save_embed_format=config.save_embed_format) + +def generate_parser(): + """ Generate an argument parser + """ + parser = get_argument_parser() + return parser + +if __name__ == '__main__': + arg_parser = generate_parser() + + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) + + From 00f1211824791ba9f85b97d527c20a3c1c916151 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 6 May 2024 16:02:50 -0700 Subject: [PATCH 07/79] Update --- python/graphstorm/eval/evaluator.py | 134 +++++++++++++++++++++ python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 5 + python/graphstorm/trainer/mt_trainer.py | 9 +- 3 files changed, 143 insertions(+), 5 deletions(-) diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index cc5c5dce54..32e6c35d8a 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -1020,3 +1020,137 @@ def get_val_score_rank(self, val_score): # after compare, append the score into existing list self._val_perf_rank_list.append(val_score) return rank + +class GSgnnMultiTaskEvalInterface(): + """ Interface for multi-task evaluation + + The interface set the two abstract methods + """ + @abc.abstractmethod + def evaluate(self, val_results, test_results, total_iters): + """Evaluate validation and test sets for Prediciton tasks + + GSgnnTrainers will call this function to do evalution in their eval() fuction. + + Parameters + ---------- + val_results: dict + Validation results in a format of {task_id: validation results} + test_results: dict + Testing results in a format of {task_id: test results} + total_iters: int + The current interation number. + + Returns + ----------- + val_scores: dict + Validation scores in a format of {task_id:cores} + test_scores: dict + Test scores in a format of {task_id:cores} + """ + + @abc.abstractmethod + def compute_score(self, results, train=True): + """ Compute evaluation score for Prediciton tasks + + Parameters + ---------- + results: dict + Eval results in format of {task_id: validation results} + train: boolean + If in model training. + + Returns + ------- + Evaluation metric values: dict + Scores for each task + """ + +class GSgnnMultiTaskEvaluator(GSgnnBaseEvaluator, GSgnnMultiTaskEvalInterface): + """ Multi-task evaluator + + Parameters + ---------- + eval_frequency: int + The frequency (# of iterations) of doing evaluation. + task_evaluators: dict + Specific evaluators for different tasks. In a format of {task_id:GSgnnBaseEvaluator} + use_early_stop: bool + Set true to use early stop. + Note(xiang): Early stop not implemented. Reserved for future. + early_stop_burnin_rounds: int + Burn-in rounds before start checking for the early stop condition. + Note(xiang): Early stop not implemented. Reserved for future. + early_stop_rounds: int + The number of rounds for validation scores used to decide early stop. + Note(xiang): Early stop not implemented. Reserved for future. + early_stop_strategy: str + The early stop strategy. GraphStorm supports two strategies: + 1) consecutive_increase and 2) average_increase. + Note(xiang): Early stop not implemented. Reserved for future. + """ + # pylint: disable=unused-argument + def __init__(self, eval_frequency, task_evaluators, + use_early_stop=False, + early_stop_burnin_rounds=0, + early_stop_rounds=3, + early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): + # nodes whose embeddings are used during evaluation + # if None all nodes are used. + self._history = [] + self.tracker = None + self._best_val_score = None + self._best_test_score = None + self._best_iter = None + + self._task_evaluators = task_evaluators + assert len(self.task_evaluators) > 1, \ + "There must be evaluators for different tasks." \ + f"But get onely get {len(self.task_evaluators)}" + + self._metric_list = { + task_id: evaluator.metric_list for task_id, evaluator in self.task_evaluators.items() + } + + self._eval_frequency = eval_frequency + # TODO(xiang): Support early stop + assert use_early_stop is False, \ + "GSgnnMultiTaskEvaluator do not support early stop now." + self._do_early_stop = use_early_stop + + # add this list to store all of the performance rank of validation scores for pick top k + self._val_perf_rank_list = [] + + + # pylint: disable=unused-argument + def do_early_stop(self, val_score): + """ Decide whether to stop the training + + Note: do not support early stop for multi-task learning. + Will support it later. + + Parameters + ---------- + val_score: float + Evaluation score + """ + raise RuntimeError("GSgnnMultiTaskEvaluator.do_early_stop not implemented") + + def get_metric_comparator(self): + """ Return the comparator of the major eval metric. + + Note: not support now. + + """ + raise RuntimeError("GSgnnMultiTaskEvaluator.get_metric_comparator not implemented") + + + @property + def task_evaluators(self): + """ Task evaluators + """ + return self._task_evaluators + + @property + def val_perf_rank_list(self): + raise RuntimeError("GSgnnMultiTaskEvaluator.val_perf_rank_list not supported") \ No newline at end of file diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 8fde0d6732..842952ac26 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -303,6 +303,11 @@ def main(config_args): else: save_model_path = None + tracker = gs.create_builtin_task_tracker(config) + if gs.get_rank() == 0: + tracker.log_params(config.__dict__) + trainer.setup_task_tracker(tracker) + trainer.fit(train_loader=train_dataloader, val_loader=val_dataloader, test_loader=test_dataloader, diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 06b206866d..fcade02da8 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -140,7 +140,7 @@ def multi_task_mini_batch_predict( Returns ------- - list: prediction results of each task + dict: prediction results of each task """ dataloaders = loader.dataloaders task_infos = loader.task_infos @@ -325,7 +325,6 @@ def fit(self, train_loader, # training loop total_steps = 0 sys_tracker.check('start training') - g = data.g for epoch in range(num_epochs): model.train() epoch_start = time.time() @@ -469,17 +468,17 @@ def eval(self, model, data, val_loader, test_loader, total_steps, task_tracker=self.task_tracker) sys_tracker.check('compute embeddings') - val_scores = \ + val_results = \ multi_task_mini_batch_predict(model, emb, val_loader, self.device, return_proba) \ if val_loader is not None else None - test_scores = \ + test_results = \ multi_task_mini_batch_predict(model, emb, test_loader, self.device, return_proba) \ if test_loader is not None else None sys_tracker.check('after_test_score') val_score, test_score = self.evaluator.evaluate( - val_scores, test_scores, total_steps) + val_results, test_results, total_steps) sys_tracker.check('evaluate validation/test') model.train() From 2813d3a34ee7576386574a1b71c13a127d522257 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 6 May 2024 22:18:15 -0700 Subject: [PATCH 08/79] Update --- python/graphstorm/eval/__init__.py | 1 + python/graphstorm/eval/evaluator.py | 117 +++++++++++++++++---- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 20 +++- 3 files changed, 116 insertions(+), 22 deletions(-) diff --git a/python/graphstorm/eval/__init__.py b/python/graphstorm/eval/__init__.py index 6860effa89..2993addbe5 100644 --- a/python/graphstorm/eval/__init__.py +++ b/python/graphstorm/eval/__init__.py @@ -27,3 +27,4 @@ from .evaluator import GSgnnPerEtypeMrrLPEvaluator from .evaluator import GSgnnClassificationEvaluator from .evaluator import GSgnnRegressionEvaluator +from .evaluator import GSgnnMultiTaskEvaluator diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index 32e6c35d8a..d84069c4bc 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -1049,23 +1049,6 @@ def evaluate(self, val_results, test_results, total_iters): Test scores in a format of {task_id:cores} """ - @abc.abstractmethod - def compute_score(self, results, train=True): - """ Compute evaluation score for Prediciton tasks - - Parameters - ---------- - results: dict - Eval results in format of {task_id: validation results} - train: boolean - If in model training. - - Returns - ------- - Evaluation metric values: dict - Scores for each task - """ - class GSgnnMultiTaskEvaluator(GSgnnBaseEvaluator, GSgnnMultiTaskEvalInterface): """ Multi-task evaluator @@ -1134,7 +1117,7 @@ def do_early_stop(self, val_score): val_score: float Evaluation score """ - raise RuntimeError("GSgnnMultiTaskEvaluator.do_early_stop not implemented") + raise RuntimeError("GSgnnMultiTaskEvaluator.do_early_stop is not implemented") def get_metric_comparator(self): """ Return the comparator of the major eval metric. @@ -1142,8 +1125,23 @@ def get_metric_comparator(self): Note: not support now. """ - raise RuntimeError("GSgnnMultiTaskEvaluator.get_metric_comparator not implemented") + raise RuntimeError("GSgnnMultiTaskEvaluator.get_metric_comparator is not implemented") + # pylint: disable=unused-argument + def get_val_score_rank(self, val_score): + """ + Get the rank of the given validation score by comparing its values to the existing value + list. + + Note: not support now. + + Parameters + ---------- + val_score: dict + A dictionary whose key is the metric and the value is a score from evaluator's + validation computation. + """ + raise RuntimeError("GSgnnMultiTaskEvaluator.get_val_score_rank is not implemented") @property def task_evaluators(self): @@ -1151,6 +1149,85 @@ def task_evaluators(self): """ return self._task_evaluators + @property + def best_val_score(self): + """ Best validation score + """ + best_val_score = { + task_id: evaluator.best_val_score \ + for task_id, evaluator in self.task_evaluators.items() + } + return best_val_score + + @property + def best_test_score(self): + """ Best test score + """ + best_test_score = { + task_id: evaluator.best_test_score \ + for task_id, evaluator in self.task_evaluators.items() + } + return best_test_score + + @property + def best_iter_num(self): + """ Best iteration number + """ + best_iter = { + task_id: evaluator.best_iter_num \ + for task_id, evaluator in self.task_evaluators.items() + } + return best_iter + @property def val_perf_rank_list(self): - raise RuntimeError("GSgnnMultiTaskEvaluator.val_perf_rank_list not supported") \ No newline at end of file + raise RuntimeError("GSgnnMultiTaskEvaluator.val_perf_rank_list not supported") + + def evaluate(self, val_results, test_results, total_iters): + eval_tasks = {} + val_scores = {} + test_scores = {} + + if val_results is not None: + for task_id, val_result in val_results.itmes(): + eval_tasks[task_id] = [val_result] + + if test_results is not None: + for task_id, test_result in test_results.items(): + if task_id in eval_tasks: + eval_tasks[task_id].append(test_result) + else: + eval_tasks[task_id] = [None, test_result] + + for task_id, eval_task in eval_tasks.items(): + if len(eval_task) == 1: + # only has validation result + eval_task.append(None) + assert len(eval_task) == 2, \ + "An evaluation task is composed of two parts: " \ + f"validation and test, but get {len(eval_task)} parts" + assert task_id in self._task_evaluators, \ + f"The evaluator of {task_id} is not defined." + task_evaluator = self._task_evaluators[task_id] + + + if isinstance(task_evaluator, GSgnnPredictionEvalInterface): + val_preds, val_labels = eval_task[0] + test_preds, test_labels = eval_task[1] + val_score, test_score = task_evaluator.evaluate( + val_preds, test_preds, val_labels, test_labels, total_iters) + elif isinstance(task_evaluator, GSgnnLPRankingEvalInterface): + val_rankings = eval_task[0] + test_rankings = eval_task[1] + val_score, test_score = task_evaluator.evaluate( + val_rankings, test_rankings, total_iters) + else: + raise TypeError("Unknown evaluator") + + val_scores[task_id] = val_score + test_scores[task_id] = test_score + + self._history.append((val_scores, test_scores)) + + return val_scores, test_scores + diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 842952ac26..d5108b5698 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -16,6 +16,7 @@ GSgnn multi-task learning """ import os +import logging import graphstorm as gs from graphstorm.config import get_argument_parser @@ -33,7 +34,8 @@ from graphstorm.eval import (GSgnnClassificationEvaluator, GSgnnRegressionEvaluator, GSgnnPerEtypeMrrLPEvaluator, - GSgnnMrrLPEvaluator) + GSgnnMrrLPEvaluator, + GSgnnMultiTaskEvaluator) from graphstorm.model.multitask_gnn import GSgnnMultiTaskSharedEncoderModel from graphstorm.trainer import GSgnnMultiTaskLearningTrainer from graphstorm.model.utils import save_full_node_embeddings @@ -271,6 +273,7 @@ def main(config_args): train_dataloaders = [] val_dataloaders = [] test_dataloaders = [] + task_evaluators = {} encoder_out_dims = model.gnn_encoder.out_dims \ if model.gnn_encoder is not None \ else model.node_input_encoder.out_dims @@ -283,12 +286,25 @@ def main(config_args): test_dataloaders.append((task, test_loader)) decoder, loss_func = create_task_decoder(task, g, encoder_out_dims, train_task=True) model.add_task(task.task_id, task.task_type, decoder, loss_func, task.weight) - evaluator = create_evaluator(task, config) + if not config.no_validation: + if val_loader is None: + logging.warning("The training data do not have validation set.") + if test_loader is None: + logging.warning("The training data do not have test set.") + task_evaluators[task.task_id] = \ + create_evaluator(task, config) + train_dataloader = GSgnnMultiTaskDataLoader(train_dataloaders) val_dataloader = GSgnnMultiTaskDataLoader(val_dataloaders) test_dataloader = GSgnnMultiTaskDataLoader(test_dataloaders) + if not config.no_validation: + evaluator = GSgnnMultiTaskEvaluator(config.eval_frequency, + task_evaluators, + use_early_stop=config.use_early_stop) + trainer.setup_evaluator(evaluator) + trainer = GSgnnMultiTaskLearningTrainer(model, topk_model_to_save=config.topk_model_to_save) # Preparing input layer for training or inference. From 20d306aed98958c022bf3eab89d71e79d444b117 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 8 May 2024 18:08:32 -0700 Subject: [PATCH 09/79] update --- python/graphstorm/config/argument.py | 734 ++++++++++++++++++----- python/graphstorm/config/config.py | 100 ++- training_scripts/gsgnn_mt/ml_ncr_lp_yaml | 9 +- 3 files changed, 679 insertions(+), 164 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 37c5474f98..efcef085f4 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -55,6 +55,12 @@ from .config import (GRAPHSTORM_MODEL_ALL_LAYERS, GRAPHSTORM_MODEL_EMBED_LAYER, GRAPHSTORM_MODEL_DECODER_LAYER, GRAPHSTORM_MODEL_LAYER_OPTIONS) +from .config import get_mttask_id +from .config import (NodeClassTaskInfo, + NodeRegressionTaskInfo, + EdgeClassTaskInfo, + EdgeRegressionTaskInfo, + LinkPredictionTaskInfo) from ..utils import TORCH_MAJOR_VER, get_log_level, get_graph_name @@ -225,6 +231,351 @@ def set_attributes(self, configuration): for key, val in udf_family.items(): setattr(self, key, val) + def _parse_general_task_config(self, task_config): + """ Parse the genral task info + + Parameters + ---------- + task_config: dict + Task config + """ + assert "mask_fields" in task_config, \ + "mask_fields should be provided for each node classification task " \ + "in multi task learning" + assert "task_weight" in task_config, \ + "task_weight should be provided for each node classification task " \ + "in multi task learning" + batch_size = task_config["batch_size"] if "batch_size" in task_config else 0 + assert batch_size >= 0 # if batch size is 0, will use the global batch_size + mask_fields = task_config["mask_fields"] + assert len(mask_fields) == 3, \ + "The mask_fileds should be a list as [train-mask, validation-mask, test-mask], " \ + f"but get {mask_fields}" + task_weight = task_config["task_weight"] + assert task_weight > 0, f"task_weight should be larger than 0, but get {task_weight}" + + return batch_size, mask_fields, task_weight + + def _parse_node_classification_task(self, task_config): + """ Parse the node classification task info + + Parameters + ---------- + task_config: dict + Node classification task config + """ + task_type = BUILTIN_TASK_NODE_CLASSIFICATION + assert "target_ntype" in task_config, \ + "target_ntype should be provided for each node classification task " \ + "in multi task learning" + assert "label_field" in task_config, \ + "label_field should be provided for each node classification task " \ + "in multi task learning" + assert "num_classes" in task_config, \ + "num_classes should be provided for each node classification task " \ + "in multi task learning" + + target_ntype = task_config["target_ntype"] + label_field = task_config["label_field"] + num_classes = task_config["num_classes"] + assert num_classes > 0 + multilabel = task_config["multilabel"] \ + if "multilabel" in task_config else False + assert multilabel in [True, False] + batch_size, mask_fields, task_weight = \ + self._parse_general_task_config(task_config) + + multilabel_weights = task_config["multilabel_weights"] \ + if "multilabel_weights" in task_config else None + if multilabel_weights is not None: + multilabel_weights = self.check_multilabel_weights(multilabel, multilabel_weights, num_classes) + imbalance_class_weights = task_config["imbalance_class_weights"] \ + if "imbalance_class_weights" in task_config else None + if imbalance_class_weights is not None: + imbalance_class_weights = self.check_imbalance_class_weights( + imbalance_class_weights, + num_classes) + task_id = get_mttask_id(task_type=task_type, + ntype=target_ntype, + label=label_field) + eval_metric = task_config["eval_metric"] \ + if "eval_metric" in task_config else ["accuracy"] + eval_metric = self.check_classification_eval_metrics(eval_metric) + + return NodeClassTaskInfo(task_type=task_type, + task_id=task_id, + batch_size=batch_size,mask_fields=mask_fields, + task_weight=task_weight, + eval_metric=eval_metric, + target_ntype=target_ntype, + label_field=label_field, + num_classes=num_classes, + multilabel=multilabel, + multilabel_weights=multilabel_weights, + imbalance_class_weights=imbalance_class_weights) + + def _parse_node_regression_task(self, task_config): + """ Parse the node regression task info + + Parameters + ---------- + task_config: dict + Node regression task config + """ + task_type = BUILTIN_TASK_NODE_REGRESSION + assert "target_ntype" in task_config, \ + "target_ntype should be provided for each node regression task " \ + "in multi task learning" + assert "label_field" in task_config, \ + "label_field should be provided for each node regression task " \ + "in multi task learning" + target_ntype = task_config["target_ntype"] + label_field = task_config["label_field"] + + batch_size, mask_fields, task_weight = \ + self._parse_general_task_config(task_config) + task_id = get_mttask_id(task_type=task_type, + ntype=target_ntype, + label=label_field) + + eval_metric = task_config["eval_metric"] \ + if "eval_metric" in task_config else ["accuracy"] + eval_metric = self.check_regression_eval_metrics(eval_metric) + + return NodeRegressionTaskInfo(task_type=task_type, + task_id=task_id, + batch_size=batch_size, + mask_fields=mask_fields, + task_weight=task_weight, + eval_metric=eval_metric, + target_ntype=target_ntype, + label_field=label_field) + + def _parse_edge_classification_task(self, task_config): + """ Parse the edge classification task info + + Parameters + ---------- + task_config: dict + Edge classification task config + """ + task_type = BUILTIN_TASK_EDGE_CLASSIFICATION + assert "target_etype" in task_config, \ + "target_etype should be provided for each edge classification task " \ + "in multi task learning" + assert "label_field" in task_config, \ + "label_field should be provided for each node classification task " \ + "in multi task learning" + assert "num_classes" in task_config, \ + "num_classes should be provided for each node classification task " \ + "in multi task learning" + target_etype = task_config["target_etype"] + label_field = task_config["label_field"] + num_classes = task_config["num_classes"] + batch_size, mask_fields, task_weight = \ + self._parse_general_task_config(task_config) + + multilabel = task_config["multilabel"] \ + if "multilabel" in task_config else False + assert multilabel in [True, False] + multilabel_weights = task_config["multilabel_weights"] \ + if "multilabel_weights" in task_config else None + if multilabel_weights is not None: + multilabel_weights = self.check_multilabel_weights(multilabel, multilabel_weights, num_classes) + imbalance_class_weights = task_config["imbalance_class_weights"] \ + if "imbalance_class_weights" in task_config else None + if imbalance_class_weights is not None: + imbalance_class_weights = self.check_imbalance_class_weights( + imbalance_class_weights, + num_classes) + + decoder_type = task_config["decoder_type"] \ + if "decoder_type" in task_config else "DenseBiDecoder" + num_decoder_basis = task_config["num_decoder_basis"] \ + if "num_decoder_basis" in task_config else 2 + decoder_edge_feat = task_config["decoder_edge_feat"] \ + if "decoder_edge_feat" in task_config else None + decoder_edge_feat = self.parse_decoder_edge_feat(decoder_edge_feat) + if isinstance(decoder_edge_feat, dict): + assert len(decoder_edge_feat) == 1 and \ + list(decoder_edge_feat.keys())[0] == target_etype, \ + "In multi-task learning, we define edge regression " \ + "tasks one by one. The edge type of " \ + "decoder_edge_feat of the current regression task " \ + f"must match {target_etype}, but get {decoder_edge_feat.keys()}" + + task_id = get_mttask_id(task_type=task_type, + etype=target_etype, + label=label_field) + + eval_metric = task_config["eval_metric"] \ + if "eval_metric" in task_config else ["accuracy"] + eval_metric = self.check_classification_eval_metrics(eval_metric) + + return EdgeClassTaskInfo(task_type=task_type, + task_id=task_id, + batch_size=batch_size, + mask_fields=mask_fields, + task_weight=task_weight, + eval_metric=eval_metric, + target_etype=target_etype, + label_field=label_field, + num_classes=num_classes, + multilabel=multilabel, + multilabel_weights=multilabel_weights, + imbalance_class_weights=imbalance_class_weights, + decoder_type=decoder_type, + num_decoder_basis=num_decoder_basis, + decoder_edge_feat=decoder_edge_feat + ) + + def _parse_edge_regression_task(self, task_config): + """ Parse the edge regression task info + + Parameters + ---------- + task_config: dict + Edge regression task config + """ + task_type = BUILTIN_TASK_EDGE_REGRESSION + assert "target_etype" in task_config, \ + "target_etype should be provided for each edge regression task " \ + "in multi task learning" + assert "label_field" in task_config, \ + "label_field should be provided for each node classification task " \ + "in multi task learning" + target_etype = task_config["target_etype"] + label_field = task_config["label_field"] + + batch_size, mask_fields, task_weight = \ + self._parse_general_task_config(task_config) + + decoder_type = task_config["decoder_type"] \ + if "decoder_type" in task_config else "DenseBiDecoder" + num_decoder_basis = task_config["num_decoder_basis"] \ + if "num_decoder_basis" in task_config else 2 + decoder_edge_feat = task_config["decoder_edge_feat"] \ + if "decoder_edge_feat" in task_config else None + decoder_edge_feat = self.parse_decoder_edge_feat(decoder_edge_feat) + if isinstance(decoder_edge_feat, dict): + assert len(decoder_edge_feat) == 1 and \ + list(decoder_edge_feat.keys())[0] == target_etype, \ + "In multi-task learning, we define edge regression " \ + "tasks one by one. The edge type of " \ + "decoder_edge_feat of the current regression task " \ + f"must match {target_etype}, but get {decoder_edge_feat.keys()}" + + + task_id = get_mttask_id(task_type=task_type, + etype=target_etype, + label=label_field) + eval_metric = task_config["eval_metric"] \ + if "eval_metric" in task_config else ["accuracy"] + eval_metric = self.check_regression_eval_metrics(eval_metric) + return EdgeRegressionTaskInfo(task_type=task_type, + task_id=task_id, + batch_size=batch_size, + mask_fields=mask_fields, + task_weight=task_weight, + eval_metric=eval_metric, + target_etype=target_etype, + label_field=label_field, + decoder_type=decoder_type, + num_decoder_basis=num_decoder_basis, + decoder_edge_feat=decoder_edge_feat) + + def _parse_link_prediction_task(self, task_config): + """ Parse the link prediction task info + + Parameters + ---------- + task_config: dict + Link prediction task config + """ + task_type = BUILTIN_TASK_LINK_PREDICTION + + batch_size, mask_fields, task_weight = \ + self._parse_general_task_config(task_config) + + train_negative_sampler = task_config["train_negative_sampler"] \ + if "train_negative_sampler" in task_config \ + else BUILTIN_LP_UNIFORM_NEG_SAMPLER + eval_negative_sampler = task_config["eval_negative_sampler"] \ + if "eval_negative_sampler" in task_config \ + else BUILTIN_LP_JOINT_NEG_SAMPLER + num_negative_edges = task_config["num_negative_edges"] \ + if "num_negative_edges" in task_config \ + else 16 + assert num_negative_edges > 0, \ + "Number of negative edges must larger than 0" + num_negative_edges_eval = task_config["num_negative_edges_eval"] \ + if "num_negative_edges_eval" in task_config \ + else 1000 + assert num_negative_edges_eval > 0, \ + "Number of negative edges for evaluation must larger than 0" + + train_etype = task_config["train_etype"] \ + if "train_etype" in task_config \ + else None # None means all etypes + train_etype = self.parse_lp_etype(train_etype) + eval_etype = task_config["eval_etype"] \ + if "eval_etype" in task_config \ + else None # None means all etypes + eval_etype = self.parse_lp_etype(eval_etype) + reverse_edge_types_map = task_config["reverse_edge_types_map"] \ + if "reverse_edge_types_map" in task_config \ + else {} + reverse_edge_types_map = self.parse_reverse_edge_type_map(reverse_edge_types_map) + exclude_training_targets = task_config["exclude_training_targets"] \ + if "exclude_training_targets" in task_config \ + else True + assert exclude_training_targets in [True, False] + if exclude_training_targets is True: + assert len(reverse_edge_types_map) > 0, \ + "By default, exclude training targets is used." \ + "Reverse edge types map must be provided." + + lp_loss_func = task_config["lp_loss_func"] \ + if "lp_loss_func" in task_config \ + else BUILTIN_LP_LOSS_CROSS_ENTROPY + assert lp_loss_func in BUILTIN_LP_LOSS_FUNCTION + lp_decoder_type = task_config["lp_decoder_type"] \ + if "lp_decoder_type" in task_config \ + else BUILTIN_LP_DISTMULT_DECODER + assert lp_decoder_type in SUPPORTED_LP_DECODER, \ + f"Link prediction decoder {lp_decoder_type} not supported. " \ + f"GraphStorm only supports {SUPPORTED_LP_DECODER}" + gamma = task_config["gamma"] \ + if "gamma" in task_config \ + else 12.0 + + + report_eval_per_type = task_config["report_eval_per_type"] \ + if "report_eval_per_type" in task_config \ + else False + assert report_eval_per_type in [True, False], \ + "report_eval_per_type must be True or False" + + task_id = get_mttask_id( + task_type=task_type, + etype=train_etype if train_etype is not None else "ALL_ETYPE") + eval_metric = task_config["eval_metric"] \ + if "eval_metric" in task_config else ["accuracy"] + eval_metric = self.check_lp_eval_metrics(eval_metric) + return LinkPredictionTaskInfo(task_type=task_type, + task_id=task_id, + atch_size=batch_size, + mask_fields=mask_fields, + task_weight=task_weight, + eval_metric=eval_metric, + train_etype=train_etype, + eval_etype=eval_etype, + train_negative_sampler=train_negative_sampler, + eval_negative_sampler=eval_negative_sampler, + num_negative_edges=num_negative_edges, + num_negative_edges_eval=num_negative_edges_eval, + ) + def _parse_multi_tasks(self, multi_task_config): """ Parse multi-task configuration """ @@ -1507,6 +1858,20 @@ def check_multilabel(multilabel): return check_multilabel(self._multilabel) return False + @staticmethod + def check_multilabel_weights(multilabel, multilabel_weights, num_classes): + assert multilabel is True, "Must be a multi-label classification task." + try: + weights = multilabel_weights.split(",") + weights = [float(w) for w in weights] + except Exception: # pylint: disable=broad-except + raise RuntimeError("The weights should in following format 0.1,0.2,0.1,0.0") + for w in weights: + assert w >= 0., "multilabel weights can not be negative values" + assert len(weights) == num_classes, \ + "Each class must have an assigned weight" + return th.tensor(weights) + @property def multilabel_weights(self): """Used to specify label weight of each class in a @@ -1515,20 +1880,6 @@ def multilabel_weights(self): The weights should be in the following format 0.1,0.2,0.3,0.1,0.0 """ - - def check_multilabel_weights(multilabel, multilabel_weights, num_classes): - assert multilabel is True, "Must be a multi-label classification task." - try: - weights = multilabel_weights.split(",") - weights = [float(w) for w in weights] - except Exception: # pylint: disable=broad-except - raise RuntimeError("The weights should in following format 0.1,0.2,0.1,0.0") - for w in weights: - assert w >= 0., "multilabel weights can not be negative values" - assert len(weights) == num_classes, \ - "Each class must have an assigned weight" - return th.tensor(weights) - if hasattr(self, "_num_classes") and isinstance(self.num_classes, dict): if hasattr(self, "_multilabel_weights"): multilabel = self.multilabel @@ -1537,18 +1888,19 @@ def check_multilabel_weights(multilabel, multilabel_weights, num_classes): ntype_weights = {} for ntype in num_classes: if ntype in multilabel_weights: - ntype_weights[ntype] = check_multilabel_weights(multilabel[ntype], - multilabel_weights[ntype], - num_classes[ntype]) + ntype_weights[ntype] = self.check_multilabel_weights( + multilabel[ntype], + multilabel_weights[ntype], + num_classes[ntype]) else: ntype_weights[ntype] = None return ntype_weights return {ntype: None for ntype in self.num_classes} else: if hasattr(self, "_multilabel_weights"): - return check_multilabel_weights(self.multilabel, - self._multilabel_weights, - self.num_classes) + return self.check_multilabel_weights(self.multilabel, + self._multilabel_weights, + self.num_classes) return None @@ -1570,6 +1922,19 @@ def return_proba(self): # By default, return all the predictions return True + @staticmethod + def check_imbalance_class_weights(imbalance_class_weights, num_classes): + try: + weights = imbalance_class_weights.split(",") + weights = [float(w) for w in weights] + except Exception: # pylint: disable=broad-except + raise RuntimeError("The weights should in following format 0.1,0.2,0.3,0.1") + for w in weights: + assert w > 0., "Each weight should be larger than 0." + assert len(weights) == num_classes, \ + "Each class must have an assigned weight" + return th.tensor(weights) + @property def imbalance_class_weights(self): """ Used to specify a manual rescaling weight given to each class @@ -1579,19 +1944,6 @@ def imbalance_class_weights(self): Customer should provide the weight in following format 0.1,0.2,0.3,0.1 """ - - def check_imbalance_class_weights(imbalance_class_weights, num_classes): - try: - weights = imbalance_class_weights.split(",") - weights = [float(w) for w in weights] - except Exception: # pylint: disable=broad-except - raise RuntimeError("The weights should in following format 0.1,0.2,0.3,0.1") - for w in weights: - assert w > 0., "Each weight should be larger than 0." - assert len(weights) == num_classes, \ - "Each class must have an assigned weight" - return th.tensor(weights) - if hasattr(self, "_num_classes") and isinstance(self.num_classes, dict): if hasattr(self, "_imbalance_class_weights"): assert isinstance(self._imbalance_class_weights, dict), \ @@ -1601,7 +1953,7 @@ def check_imbalance_class_weights(imbalance_class_weights, num_classes): ntype_weights = {} for ntype in num_classes: if ntype in imbalance_class_weights: - ntype_weights[ntype] = check_imbalance_class_weights( + ntype_weights[ntype] = self.check_imbalance_class_weights( imbalance_class_weights[ntype], num_classes[ntype] ) @@ -1611,8 +1963,8 @@ def check_imbalance_class_weights(imbalance_class_weights, num_classes): return {ntype: None for ntype in self.num_classes} else: if hasattr(self, "_imbalance_class_weights"): - return check_imbalance_class_weights(self._imbalance_class_weights, - self.num_classes) + return self.check_imbalance_class_weights(self._imbalance_class_weights, + self.num_classes) return None ###classification/regression inference related #### @@ -1661,6 +2013,40 @@ def eval_target_ntype(self): return None #### edge related task variables #### + @staticmethod + def parse_reverse_edge_type_map(reverse_edge_types_map): + """ Parse the reverse edge type map. + + Parameters + ---------- + reverse_edge_types_map: list + A list one etype and reverse edge type maps + in the format of head,relation,reverse relation,tail + + Return + ------ + A map: dict + """ + if reverse_edge_types_map is None: + return {} # empty dict + assert isinstance(reverse_edge_types_map, list), \ + "Reverse edge type map should has following format: " \ + "[\"head,relation,reverse relation,tail\", " \ + "\"head,relation,reverse relation,tail\", ...]" + + reverse_map = {} + try: + for etype_info in reverse_edge_types_map: + head, rel, rev_rel, tail = etype_info.split(",") + reverse_map[(head, rel, tail)] = (tail, rev_rel, head) + except Exception: # pylint: disable=broad-except + assert False, \ + "Reverse edge type map should has following format: " \ + "[\"head,relation,reverse relation,tail\", " \ + "\"head,relation,reverse relation,tail\", ...]" \ + f"But get {reverse_edge_types_map}" + return reverse_map + @property def reverse_edge_types_map(self): """ A list of reverse edge type info. @@ -1678,26 +2064,7 @@ def reverse_edge_types_map(self): # pylint: disable=no-member if hasattr(self, "_reverse_edge_types_map"): - if self._reverse_edge_types_map is None: - return {} # empty dict - assert isinstance(self._reverse_edge_types_map, list), \ - "Reverse edge type map should has following format: " \ - "[\"head,relation,reverse relation,tail\", " \ - "\"head,relation,reverse relation,tail\", ...]" - - reverse_edge_types_map = {} - try: - for etype_info in self._reverse_edge_types_map: - head, rel, rev_rel, tail = etype_info.split(",") - reverse_edge_types_map[(head, rel, tail)] = (tail, rev_rel, head) - except Exception: # pylint: disable=broad-except - assert False, \ - "Reverse edge type map should has following format: " \ - "[\"head,relation,reverse relation,tail\", " \ - "\"head,relation,reverse relation,tail\", ...]" \ - f"But get {self._reverse_edge_types_map}" - - return reverse_edge_types_map + return self.parse_reverse_edge_type_map(self._reverse_edge_types_map) # By default return an empty dict return {} @@ -1770,6 +2137,31 @@ def num_decoder_basis(self): # By default, return 2 return 2 + @staticmethod + def parse_decoder_edge_feat(decoder_edge_feats): + """ Parse decoder edge feat + + Parameter + --------- + decoder_edge_feats: list + Edge features that will be used by the decoder. + """ + assert len(decoder_edge_feats) == 1, \ + "We only support edge classifcation or regression on one edge type" + + if ":" not in decoder_edge_feats[0]: + # global feat_name + return decoder_edge_feats[0] + + # per edge type feature + feat_name = decoder_edge_feats[0] + feat_info = feat_name.split(":") + assert len(feat_info) == 2, \ + f"Unknown format of the feature name: {feat_name}, " + \ + "must be EDGE_TYPE:FEAT_NAME" + etype = tuple(feat_info[0].split(",")) + return {etype: feat_info[1].split(",")} + @property def decoder_edge_feat(self): """ A list of edge features that can be used by a decoder to @@ -1781,28 +2173,15 @@ def decoder_edge_feat(self): (BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION), \ "Decoder edge feature only works with " \ "edge classification or regression tasks" - decoder_edge_feats = self._decoder_edge_feat - assert len(decoder_edge_feats) == 1, \ - "We only support edge classifcation or regression on one edge type" - - if ":" not in decoder_edge_feats[0]: - # global feat_name - return decoder_edge_feats[0] - - # per edge type feature - feat_name = decoder_edge_feats[0] - feat_info = feat_name.split(":") - assert len(feat_info) == 2, \ - f"Unknown format of the feature name: {feat_name}, " + \ - "must be EDGE_TYPE:FEAT_NAME" - etype = tuple(feat_info[0].split(",")) - assert etype in self.target_etype, \ - f"{etype} must in the training edge type list {self.target_etype}" - return {etype: feat_info[1].split(",")} + decoder_edge_feat = self.parse_decoder_edge_feat(self._decoder_edge_feat) + if isinstance(decoder_edge_feat, dict): + for etype in decoder_edge_feat.keys(): + assert etype in self.target_etype, \ + f"{etype} must in the training edge type list {self.target_etype}" + return decoder_edge_feat return None - ### Link Prediction specific ### @property def train_negative_sampler(self): @@ -1846,7 +2225,7 @@ def num_negative_edges_eval(self): # pylint: disable=no-member if hasattr(self, "_num_negative_edges_eval"): assert self._num_negative_edges_eval > 0, \ - "Number of negative edges must larger than 0" + "Number of negative edges for evaluation must larger than 0" return self._num_negative_edges_eval # Set default value to 1000. return 1000 @@ -2070,6 +2449,22 @@ def eval_etypes_negative_dstnode(self): # By default fixed negative is not used return None + @staticmethod + def parse_lp_etype(etypes): + """ Parse and validate the input link prediction etypes + + Parameters + ---------- + etypes: str + Edge types + """ + if etypes is None: + return None + assert isinstance(etypes, list) + assert len(etypes) > 0 + + return [tuple(etype.split(',')) for etype in etypes] + @property def train_etype(self): """ The list of canonical etypes that will be added as @@ -2079,12 +2474,7 @@ def train_etype(self): """ # pylint: disable=no-member if hasattr(self, "_train_etype"): - if self._train_etype is None: - return None - assert isinstance(self._train_etype, list) - assert len(self._train_etype) > 0 - - return [tuple(train_etype.split(',')) for train_etype in self._train_etype] + return self.parse_lp_etype(self._train_etype) # By default return None, which means use all edge types return None @@ -2097,11 +2487,8 @@ def eval_etype(self): """ # pylint: disable=no-member if hasattr(self, "_eval_etype"): - if self._eval_etype is None: - return None - assert isinstance(self._eval_etype, list) - assert len(self._eval_etype) > 0 - return [tuple(eval_etype.split(',')) for eval_etype in self._eval_etype] + return self.parse_lp_etype(self._eval_etype) + # By default return None, which means use all edge types return None @@ -2175,6 +2562,116 @@ def report_eval_per_type(self): return False + @staticmethod + def check_classification_eval_metrics(eval_metric): + """ Check the classification evaluation metrics + + Parameter + --------- + eval_metric: str or list of str + Evaluation metric(s). + + Return + ------ + list of str + Evaluation metric(s). + """ + if isinstance(eval_metric, str): + eval_metric = eval_metric.lower() + assert eval_metric in SUPPORTED_CLASSIFICATION_METRICS, \ + f"Classification evaluation metric should be " \ + f"in {SUPPORTED_CLASSIFICATION_METRICS}" \ + f"but get {eval_metric}" + eval_metrics = [eval_metric] + elif isinstance(eval_metric, list) and len(eval_metric) > 0: + eval_metrics = [] + for metric in eval_metric: + metric = metric.lower() + assert metric in SUPPORTED_CLASSIFICATION_METRICS, \ + f"Classification evaluation metric should be " \ + f"in {SUPPORTED_CLASSIFICATION_METRICS}" \ + f"but get {eval_metric}" + eval_metrics.append(metric) + else: + assert False, "Classification evaluation metric " \ + "should be a string or a list of string" + # no eval_metric + return eval_metrics + + @staticmethod + def check_regression_eval_metrics(eval_metric): + """ Check the regression evaluation metrics + + Parameter + --------- + eval_metric: str or list of str + Evaluation metric(s). + + Return + ------ + list of str + Evaluation metric(s). + """ + if isinstance(eval_metric, str): + eval_metric = eval_metric.lower() + assert eval_metric in SUPPORTED_REGRESSION_METRICS, \ + f"Regression evaluation metric should be " \ + f"in {SUPPORTED_REGRESSION_METRICS}, " \ + f"but get {eval_metric}" + eval_metrics = [eval_metric] + elif isinstance(eval_metric, list) and len(eval_metric) > 0: + eval_metrics = [] + for metric in eval_metric: + metric = metric.lower() + assert metric in SUPPORTED_REGRESSION_METRICS, \ + f"Regression evaluation metric should be " \ + f"in {SUPPORTED_REGRESSION_METRICS}" \ + f"but get {eval_metric}" + eval_metrics.append(metric) + else: + assert False, "Regression evaluation metric " \ + "should be a string or a list of string" + # no eval_metric + + return eval_metrics + + @staticmethod + def check_lp_eval_metrics(eval_metric): + """ Check the link prediction evaluation metrics + + Parameter + --------- + eval_metric: str or list of str + Evaluation metric(s). + + Return + ------ + list of str + Evaluation metric(s). + """ + if isinstance(eval_metric, str): + eval_metric = eval_metric.lower() + assert eval_metric in SUPPORTED_LINK_PREDICTION_METRICS, \ + f"Link prediction evaluation metric should be " \ + f"in {SUPPORTED_LINK_PREDICTION_METRICS}" \ + f"but get {eval_metric}" + eval_metrics = [eval_metric] + elif isinstance(eval_metric, list) and len(eval_metric) > 0: + eval_metrics = [] + for metric in eval_metric: + metric = metric.lower() + assert metric in SUPPORTED_LINK_PREDICTION_METRICS, \ + f"Link prediction evaluation metric should be " \ + f"in {SUPPORTED_LINK_PREDICTION_METRICS}" \ + f"but get {eval_metric}" + eval_metrics.append(metric) + else: + assert False, "Link prediction evaluation metric " \ + "should be a string or a list of string" + # no eval_metric + + return eval_metrics + @property def eval_metric(self): """ Evaluation metric used during evaluation @@ -2196,75 +2693,18 @@ def eval_metric(self): # check evaluation metrics if hasattr(self, "_eval_metric"): - if isinstance(self._eval_metric, str): - eval_metric = self._eval_metric.lower() - assert eval_metric in SUPPORTED_CLASSIFICATION_METRICS, \ - f"Classification evaluation metric should be " \ - f"in {SUPPORTED_CLASSIFICATION_METRICS}" \ - f"but get {self._eval_metric}" - eval_metric = [eval_metric] - elif isinstance(self._eval_metric, list) and len(self._eval_metric) > 0: - eval_metric = [] - for metric in self._eval_metric: - metric = metric.lower() - assert metric in SUPPORTED_CLASSIFICATION_METRICS, \ - f"Classification evaluation metric should be " \ - f"in {SUPPORTED_CLASSIFICATION_METRICS}" \ - f"but get {self._eval_metric}" - eval_metric.append(metric) - else: - assert False, "Classification evaluation metric " \ - "should be a string or a list of string" - # no eval_metric + eval_metric = self.check_classification_eval_metrics(self._eval_metric) else: eval_metric = ["accuracy"] elif self.task_type in [BUILTIN_TASK_NODE_REGRESSION, \ BUILTIN_TASK_EDGE_REGRESSION]: if hasattr(self, "_eval_metric"): - if isinstance(self._eval_metric, str): - eval_metric = self._eval_metric.lower() - assert eval_metric in SUPPORTED_REGRESSION_METRICS, \ - f"Regression evaluation metric should be " \ - f"in {SUPPORTED_REGRESSION_METRICS}, " \ - f"but get {self._eval_metric}" - eval_metric = [eval_metric] - elif isinstance(self._eval_metric, list) and len(self._eval_metric) > 0: - eval_metric = [] - for metric in self._eval_metric: - metric = metric.lower() - assert metric in SUPPORTED_REGRESSION_METRICS, \ - f"Regression evaluation metric should be " \ - f"in {SUPPORTED_REGRESSION_METRICS}" \ - f"but get {self._eval_metric}" - eval_metric.append(metric) - else: - assert False, "Regression evaluation metric " \ - "should be a string or a list of string" - # no eval_metric + eval_metric = self.check_regression_eval_metrics(self._eval_metric) else: eval_metric = ["rmse"] elif self.task_type == BUILTIN_TASK_LINK_PREDICTION: if hasattr(self, "_eval_metric"): - if isinstance(self._eval_metric, str): - eval_metric = self._eval_metric.lower() - assert eval_metric in SUPPORTED_LINK_PREDICTION_METRICS, \ - f"Link prediction evaluation metric should be " \ - f"in {SUPPORTED_LINK_PREDICTION_METRICS}" \ - f"but get {self._eval_metric}" - eval_metric = [eval_metric] - elif isinstance(self._eval_metric, list) and len(self._eval_metric) > 0: - eval_metric = [] - for metric in self._eval_metric: - metric = metric.lower() - assert metric in SUPPORTED_LINK_PREDICTION_METRICS, \ - f"Link prediction evaluation metric should be " \ - f"in {SUPPORTED_LINK_PREDICTION_METRICS}" \ - f"but get {self._eval_metric}" - eval_metric.append(metric) - else: - assert False, "Link prediction evaluation metric " \ - "should be a string or a list of string" - # no eval_metric + eval_metric = self.check_lp_eval_metrics(self._eval_metric) else: eval_metric = ["mrr"] else: diff --git a/python/graphstorm/config/config.py b/python/graphstorm/config/config.py index 417392071d..8b8b46d5a3 100644 --- a/python/graphstorm/config/config.py +++ b/python/graphstorm/config/config.py @@ -76,6 +76,25 @@ SUPPORTED_LP_DECODER = [BUILTIN_LP_DOT_DECODER, BUILTIN_LP_DISTMULT_DECODER] ################ Task info data classes ############################ +def get_mttask_id(task_type, ntype=None, etype=None, label=None): + task_id = [task_type] + if ntype is not None: + task_id.append(ntype) # node task + if etype is not None: + if isinstance(etype, str): + task_id.append(etype) + elif isinstance(etype, tuple): + task_id.append("_".join(etype)) + elif isinstance(etype, list): # a list of etypes + task_id.append("__".joint(["_".join(et) for et in etype])) + else: + raise TypeError("Unknown etype format: %s. Must be a string " \ + "or a tuple of strings or a list of tuples of strings.", etype) + if label is not None: + task_id.append(label) + + return "-".join(task_id) + @dataclasses.dataclass class TaskInfo: """Information of a training task in multi-task learning @@ -83,19 +102,74 @@ class TaskInfo: Parameters ---------- task_type: str - Task type - node_type: str - Node type of the task, if it is a node task - edge_type: tuple of strs - Edge type of the task, if it is a edge task - node_label_field: str - Node label field - edge_label_field: str - Edge label field + Task type. + task_id: str + Task id. Unique id for each task. + batch_size: int + Batch size of the current task. + mask_fields: list + Train/validation/test mask fields. + dataloader: + Task dataloader. + eval_metric: list + Evaluation metrics + task_weight: float + Weight of the task in final loss. """ task_type : str - node_type : str = None - edge_type : tuple = None - node_label_field : str = None - edge_label_field : str = None + task_id : str dataloader = None # dataloder + batch_size: int = 0 + mask_fields: list + task_weight: float + eval_metric : list + +@dataclasses.dataclass +class NodeClassTaskInfo(TaskInfo): + target_ntype : str + label_field : str + num_classes: str + multilabel: bool = False + multilabel_weights: str = None + imbalance_class_weights: str = None + + +@dataclasses.dataclass +class NodeRegressionTaskInfo(TaskInfo): + target_ntype : str + label_field : str + +@dataclasses.dataclass +class EdgeClassTaskInfo(TaskInfo): + target_etype : tuple + label_field : str + num_classes : str + multilabel: bool = False + multilabel_weights: str = None + imbalance_class_weights: str = None + decoder_type : str + num_decoder_basis : int + decoder_edge_feat : dict + +@dataclasses.dataclass +class EdgeRegressionTaskInfo(TaskInfo): + target_etype : tuple + label_field : str + decoder_type : str + num_decoder_basis : int + decoder_edge_feat : dict + +@dataclasses.dataclass +class LinkPredictionTaskInfo(TaskInfo): + train_etype : list + eval_etype : list + train_negative_sampler : str + eval_negative_sampler : str + num_negative_edges : int + num_negative_edges_eval : int + reverse_edge_types_map : dict + exclude_training_targets : bool + lp_loss_func : str + lp_decoder_type : str + gamma : float + report_eval_per_type : bool diff --git a/training_scripts/gsgnn_mt/ml_ncr_lp_yaml b/training_scripts/gsgnn_mt/ml_ncr_lp_yaml index 66f56da955..55b740a6c7 100644 --- a/training_scripts/gsgnn_mt/ml_ncr_lp_yaml +++ b/training_scripts/gsgnn_mt/ml_ncr_lp_yaml @@ -40,6 +40,8 @@ gsf: - "val_mask_field_nc" - "test_mask_field_nc" task_weight: 1.0 + eval_metric: + - "accuracy" - edge_classification: target_etype: - "user,rating,movie" @@ -49,7 +51,6 @@ gsf: multilabel: false num_classes: 5 num_decoder_basis: 32 - exclude_training_targets: false batch_size: 10 # will overwrite the global batch_size mask_fields: - "train_mask_field_ec" @@ -64,11 +65,11 @@ gsf: - "user,rating,movie" train_etype: - "user,rating,movie" - exclude_training_targets: false + exclude_training_targets: true reverse_edge_types_map: [] batch_size: 10 # will overwrite the global batch_size mask_fields: - "train_mask_field_lp" - - "" # empty means there is no validation mask - - "" # empty means there is no test mask + - null # empty means there is no validation mask + - null # empty means there is no test mask task_weight: 1.0 \ No newline at end of file From 9613c6aecca1a8efd9dfc7dac471b94d2b4f4568 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 8 May 2024 18:08:53 -0700 Subject: [PATCH 10/79] Update --- python/graphstorm/config/argument.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index efcef085f4..e06fe56fcb 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -574,7 +574,12 @@ def _parse_link_prediction_task(self, task_config): eval_negative_sampler=eval_negative_sampler, num_negative_edges=num_negative_edges, num_negative_edges_eval=num_negative_edges_eval, - ) + reverse_edge_types_map=reverse_edge_types_map, + exclude_training_targets=exclude_training_targets, + lp_loss_func=lp_loss_func, + lp_decoder_type=lp_decoder_type, + gamma=gamma, + report_eval_per_type=report_eval_per_type) def _parse_multi_tasks(self, multi_task_config): """ Parse multi-task configuration From 166d9d5300d2b440f0e2f4ad4d266fcc9d2cb1e5 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 8 May 2024 23:18:31 -0700 Subject: [PATCH 11/79] Update --- python/graphstorm/config/argument.py | 814 ++++++++++----------------- python/graphstorm/config/config.py | 55 +- 2 files changed, 311 insertions(+), 558 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index e06fe56fcb..2fcdb9aa96 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -56,11 +56,7 @@ from .config import (GRAPHSTORM_MODEL_ALL_LAYERS, GRAPHSTORM_MODEL_EMBED_LAYER, GRAPHSTORM_MODEL_DECODER_LAYER, GRAPHSTORM_MODEL_LAYER_OPTIONS) from .config import get_mttask_id -from .config import (NodeClassTaskInfo, - NodeRegressionTaskInfo, - EdgeClassTaskInfo, - EdgeRegressionTaskInfo, - LinkPredictionTaskInfo) +from .config import TaskInfo from ..utils import TORCH_MAJOR_VER, get_log_level, get_graph_name @@ -231,6 +227,30 @@ def set_attributes(self, configuration): for key, val in udf_family.items(): setattr(self, key, val) + def set_task_attributes(self, configuration): + """ Set graph task specific attributes + + This function is called when GSConfig is used to + store graph task specific information in multi-task learning. + + .. code:: python + + task_info = GSConfig.__new__(GSConfig) + task_info.set_task_attributes(task_config) + + target_ntype = task_info.target_ntype + + By reusing GSConfig object, we can use the same code base + for single task learning and multi-task learning. + + Parameters + ---------- + configuration: dict + Task specific config + """ + for key, val in configuration.items(): + setattr(self, key, val) + def _parse_general_task_config(self, task_config): """ Parse the genral task info @@ -245,8 +265,7 @@ def _parse_general_task_config(self, task_config): assert "task_weight" in task_config, \ "task_weight should be provided for each node classification task " \ "in multi task learning" - batch_size = task_config["batch_size"] if "batch_size" in task_config else 0 - assert batch_size >= 0 # if batch size is 0, will use the global batch_size + mask_fields = task_config["mask_fields"] assert len(mask_fields) == 3, \ "The mask_fileds should be a list as [train-mask, validation-mask, test-mask], " \ @@ -254,7 +273,7 @@ def _parse_general_task_config(self, task_config): task_weight = task_config["task_weight"] assert task_weight > 0, f"task_weight should be larger than 0, but get {task_weight}" - return batch_size, mask_fields, task_weight + return mask_fields, task_weight def _parse_node_classification_task(self, task_config): """ Parse the node classification task info @@ -265,54 +284,26 @@ def _parse_node_classification_task(self, task_config): Node classification task config """ task_type = BUILTIN_TASK_NODE_CLASSIFICATION - assert "target_ntype" in task_config, \ - "target_ntype should be provided for each node classification task " \ - "in multi task learning" - assert "label_field" in task_config, \ - "label_field should be provided for each node classification task " \ - "in multi task learning" - assert "num_classes" in task_config, \ - "num_classes should be provided for each node classification task " \ - "in multi task learning" + task_info = GSConfig.__new__(GSConfig) + task_info.set_task_attributes(task_config) + task_info.verify_node_class_arguments() - target_ntype = task_config["target_ntype"] - label_field = task_config["label_field"] - num_classes = task_config["num_classes"] - assert num_classes > 0 - multilabel = task_config["multilabel"] \ - if "multilabel" in task_config else False - assert multilabel in [True, False] - batch_size, mask_fields, task_weight = \ + mask_fields, task_weight = \ self._parse_general_task_config(task_config) + target_ntype = task_info.target_ntype + label_field = task_info.label_field - multilabel_weights = task_config["multilabel_weights"] \ - if "multilabel_weights" in task_config else None - if multilabel_weights is not None: - multilabel_weights = self.check_multilabel_weights(multilabel, multilabel_weights, num_classes) - imbalance_class_weights = task_config["imbalance_class_weights"] \ - if "imbalance_class_weights" in task_config else None - if imbalance_class_weights is not None: - imbalance_class_weights = self.check_imbalance_class_weights( - imbalance_class_weights, - num_classes) task_id = get_mttask_id(task_type=task_type, ntype=target_ntype, label=label_field) - eval_metric = task_config["eval_metric"] \ - if "eval_metric" in task_config else ["accuracy"] - eval_metric = self.check_classification_eval_metrics(eval_metric) - - return NodeClassTaskInfo(task_type=task_type, - task_id=task_id, - batch_size=batch_size,mask_fields=mask_fields, - task_weight=task_weight, - eval_metric=eval_metric, - target_ntype=target_ntype, - label_field=label_field, - num_classes=num_classes, - multilabel=multilabel, - multilabel_weights=multilabel_weights, - imbalance_class_weights=imbalance_class_weights) + setattr(task_info, "task_type", task_type) + setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "task_weight", task_weight) + setattr(task_info, "task_id", task_id) + + return TaskInfo(task_type=task_type, + task_id=task_id, + task_info=task_info) def _parse_node_regression_task(self, task_config): """ Parse the node regression task info @@ -323,33 +314,26 @@ def _parse_node_regression_task(self, task_config): Node regression task config """ task_type = BUILTIN_TASK_NODE_REGRESSION - assert "target_ntype" in task_config, \ - "target_ntype should be provided for each node regression task " \ - "in multi task learning" - assert "label_field" in task_config, \ - "label_field should be provided for each node regression task " \ - "in multi task learning" - target_ntype = task_config["target_ntype"] - label_field = task_config["label_field"] + task_info = GSConfig.__new__(GSConfig) + task_info.set_task_attributes(task_config) + task_info.verify_node_regression_arguments() - batch_size, mask_fields, task_weight = \ + mask_fields, task_weight = \ self._parse_general_task_config(task_config) + target_ntype = task_info.target_ntype + label_field = task_info.label_field + task_id = get_mttask_id(task_type=task_type, ntype=target_ntype, label=label_field) + setattr(task_info, "task_type", task_type) + setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "task_weight", task_weight) + setattr(task_info, "task_id", task_id) - eval_metric = task_config["eval_metric"] \ - if "eval_metric" in task_config else ["accuracy"] - eval_metric = self.check_regression_eval_metrics(eval_metric) - - return NodeRegressionTaskInfo(task_type=task_type, - task_id=task_id, - batch_size=batch_size, - mask_fields=mask_fields, - task_weight=task_weight, - eval_metric=eval_metric, - target_ntype=target_ntype, - label_field=label_field) + return TaskInfo(task_type=task_type, + task_id=task_id, + task_info=task_info) def _parse_edge_classification_task(self, task_config): """ Parse the edge classification task info @@ -360,74 +344,25 @@ def _parse_edge_classification_task(self, task_config): Edge classification task config """ task_type = BUILTIN_TASK_EDGE_CLASSIFICATION - assert "target_etype" in task_config, \ - "target_etype should be provided for each edge classification task " \ - "in multi task learning" - assert "label_field" in task_config, \ - "label_field should be provided for each node classification task " \ - "in multi task learning" - assert "num_classes" in task_config, \ - "num_classes should be provided for each node classification task " \ - "in multi task learning" - target_etype = task_config["target_etype"] - label_field = task_config["label_field"] - num_classes = task_config["num_classes"] - batch_size, mask_fields, task_weight = \ - self._parse_general_task_config(task_config) + task_info = GSConfig.__new__(GSConfig) + task_info.set_task_attributes(task_config) + task_info.verify_edge_class_arguments() - multilabel = task_config["multilabel"] \ - if "multilabel" in task_config else False - assert multilabel in [True, False] - multilabel_weights = task_config["multilabel_weights"] \ - if "multilabel_weights" in task_config else None - if multilabel_weights is not None: - multilabel_weights = self.check_multilabel_weights(multilabel, multilabel_weights, num_classes) - imbalance_class_weights = task_config["imbalance_class_weights"] \ - if "imbalance_class_weights" in task_config else None - if imbalance_class_weights is not None: - imbalance_class_weights = self.check_imbalance_class_weights( - imbalance_class_weights, - num_classes) - - decoder_type = task_config["decoder_type"] \ - if "decoder_type" in task_config else "DenseBiDecoder" - num_decoder_basis = task_config["num_decoder_basis"] \ - if "num_decoder_basis" in task_config else 2 - decoder_edge_feat = task_config["decoder_edge_feat"] \ - if "decoder_edge_feat" in task_config else None - decoder_edge_feat = self.parse_decoder_edge_feat(decoder_edge_feat) - if isinstance(decoder_edge_feat, dict): - assert len(decoder_edge_feat) == 1 and \ - list(decoder_edge_feat.keys())[0] == target_etype, \ - "In multi-task learning, we define edge regression " \ - "tasks one by one. The edge type of " \ - "decoder_edge_feat of the current regression task " \ - f"must match {target_etype}, but get {decoder_edge_feat.keys()}" + mask_fields, task_weight = \ + self._parse_general_task_config(task_config) + target_etype = task_info.target_etype + label_field = task_info.label_field task_id = get_mttask_id(task_type=task_type, etype=target_etype, label=label_field) - - eval_metric = task_config["eval_metric"] \ - if "eval_metric" in task_config else ["accuracy"] - eval_metric = self.check_classification_eval_metrics(eval_metric) - - return EdgeClassTaskInfo(task_type=task_type, - task_id=task_id, - batch_size=batch_size, - mask_fields=mask_fields, - task_weight=task_weight, - eval_metric=eval_metric, - target_etype=target_etype, - label_field=label_field, - num_classes=num_classes, - multilabel=multilabel, - multilabel_weights=multilabel_weights, - imbalance_class_weights=imbalance_class_weights, - decoder_type=decoder_type, - num_decoder_basis=num_decoder_basis, - decoder_edge_feat=decoder_edge_feat - ) + setattr(task_info, "task_type", task_type) + setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "task_weight", task_weight) + setattr(task_info, "task_id", task_id) + return TaskInfo(task_type=task_type, + task_id=task_id, + task_info=task_info) def _parse_edge_regression_task(self, task_config): """ Parse the edge regression task info @@ -438,51 +373,27 @@ def _parse_edge_regression_task(self, task_config): Edge regression task config """ task_type = BUILTIN_TASK_EDGE_REGRESSION - assert "target_etype" in task_config, \ - "target_etype should be provided for each edge regression task " \ - "in multi task learning" - assert "label_field" in task_config, \ - "label_field should be provided for each node classification task " \ - "in multi task learning" - target_etype = task_config["target_etype"] - label_field = task_config["label_field"] + task_info = GSConfig.__new__(GSConfig) + task_info.set_task_attributes(task_config) + task_info.verify_edge_regression_arguments() - batch_size, mask_fields, task_weight = \ + mask_fields, task_weight = \ self._parse_general_task_config(task_config) - decoder_type = task_config["decoder_type"] \ - if "decoder_type" in task_config else "DenseBiDecoder" - num_decoder_basis = task_config["num_decoder_basis"] \ - if "num_decoder_basis" in task_config else 2 - decoder_edge_feat = task_config["decoder_edge_feat"] \ - if "decoder_edge_feat" in task_config else None - decoder_edge_feat = self.parse_decoder_edge_feat(decoder_edge_feat) - if isinstance(decoder_edge_feat, dict): - assert len(decoder_edge_feat) == 1 and \ - list(decoder_edge_feat.keys())[0] == target_etype, \ - "In multi-task learning, we define edge regression " \ - "tasks one by one. The edge type of " \ - "decoder_edge_feat of the current regression task " \ - f"must match {target_etype}, but get {decoder_edge_feat.keys()}" - + target_etype = task_info.target_etype + label_field = task_info.label_field task_id = get_mttask_id(task_type=task_type, etype=target_etype, label=label_field) - eval_metric = task_config["eval_metric"] \ - if "eval_metric" in task_config else ["accuracy"] - eval_metric = self.check_regression_eval_metrics(eval_metric) - return EdgeRegressionTaskInfo(task_type=task_type, - task_id=task_id, - batch_size=batch_size, - mask_fields=mask_fields, - task_weight=task_weight, - eval_metric=eval_metric, - target_etype=target_etype, - label_field=label_field, - decoder_type=decoder_type, - num_decoder_basis=num_decoder_basis, - decoder_edge_feat=decoder_edge_feat) + + setattr(task_info, "task_type", task_type) + setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "task_weight", task_weight) + setattr(task_info, "task_id", task_id) + return TaskInfo(task_type=task_type, + task_id=task_id, + task_info=task_info) def _parse_link_prediction_task(self, task_config): """ Parse the link prediction task info @@ -493,93 +404,25 @@ def _parse_link_prediction_task(self, task_config): Link prediction task config """ task_type = BUILTIN_TASK_LINK_PREDICTION + task_info = GSConfig.__new__(GSConfig) + task_info.set_task_attributes(task_config) + task_info.verify_edge_regression_arguments() - batch_size, mask_fields, task_weight = \ + mask_fields, task_weight = \ self._parse_general_task_config(task_config) - - train_negative_sampler = task_config["train_negative_sampler"] \ - if "train_negative_sampler" in task_config \ - else BUILTIN_LP_UNIFORM_NEG_SAMPLER - eval_negative_sampler = task_config["eval_negative_sampler"] \ - if "eval_negative_sampler" in task_config \ - else BUILTIN_LP_JOINT_NEG_SAMPLER - num_negative_edges = task_config["num_negative_edges"] \ - if "num_negative_edges" in task_config \ - else 16 - assert num_negative_edges > 0, \ - "Number of negative edges must larger than 0" - num_negative_edges_eval = task_config["num_negative_edges_eval"] \ - if "num_negative_edges_eval" in task_config \ - else 1000 - assert num_negative_edges_eval > 0, \ - "Number of negative edges for evaluation must larger than 0" - - train_etype = task_config["train_etype"] \ - if "train_etype" in task_config \ - else None # None means all etypes - train_etype = self.parse_lp_etype(train_etype) - eval_etype = task_config["eval_etype"] \ - if "eval_etype" in task_config \ - else None # None means all etypes - eval_etype = self.parse_lp_etype(eval_etype) - reverse_edge_types_map = task_config["reverse_edge_types_map"] \ - if "reverse_edge_types_map" in task_config \ - else {} - reverse_edge_types_map = self.parse_reverse_edge_type_map(reverse_edge_types_map) - exclude_training_targets = task_config["exclude_training_targets"] \ - if "exclude_training_targets" in task_config \ - else True - assert exclude_training_targets in [True, False] - if exclude_training_targets is True: - assert len(reverse_edge_types_map) > 0, \ - "By default, exclude training targets is used." \ - "Reverse edge types map must be provided." - - lp_loss_func = task_config["lp_loss_func"] \ - if "lp_loss_func" in task_config \ - else BUILTIN_LP_LOSS_CROSS_ENTROPY - assert lp_loss_func in BUILTIN_LP_LOSS_FUNCTION - lp_decoder_type = task_config["lp_decoder_type"] \ - if "lp_decoder_type" in task_config \ - else BUILTIN_LP_DISTMULT_DECODER - assert lp_decoder_type in SUPPORTED_LP_DECODER, \ - f"Link prediction decoder {lp_decoder_type} not supported. " \ - f"GraphStorm only supports {SUPPORTED_LP_DECODER}" - gamma = task_config["gamma"] \ - if "gamma" in task_config \ - else 12.0 - - - report_eval_per_type = task_config["report_eval_per_type"] \ - if "report_eval_per_type" in task_config \ - else False - assert report_eval_per_type in [True, False], \ - "report_eval_per_type must be True or False" + train_etype = task_info.train_etype task_id = get_mttask_id( task_type=task_type, etype=train_etype if train_etype is not None else "ALL_ETYPE") - eval_metric = task_config["eval_metric"] \ - if "eval_metric" in task_config else ["accuracy"] - eval_metric = self.check_lp_eval_metrics(eval_metric) - return LinkPredictionTaskInfo(task_type=task_type, - task_id=task_id, - atch_size=batch_size, - mask_fields=mask_fields, - task_weight=task_weight, - eval_metric=eval_metric, - train_etype=train_etype, - eval_etype=eval_etype, - train_negative_sampler=train_negative_sampler, - eval_negative_sampler=eval_negative_sampler, - num_negative_edges=num_negative_edges, - num_negative_edges_eval=num_negative_edges_eval, - reverse_edge_types_map=reverse_edge_types_map, - exclude_training_targets=exclude_training_targets, - lp_loss_func=lp_loss_func, - lp_decoder_type=lp_decoder_type, - gamma=gamma, - report_eval_per_type=report_eval_per_type) + + setattr(task_info, "task_type", task_type) + setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "task_weight", task_weight) + setattr(task_info, "task_id", task_id) + return TaskInfo(task_type=task_type, + task_id=task_id, + task_info=task_info) def _parse_multi_tasks(self, multi_task_config): """ Parse multi-task configuration @@ -637,6 +480,72 @@ def override_arguments(self, cmd_args): # for basic attributes setattr(self, f"_{arg_key}", arg_val) + def verify_node_class_arguments(self): + """ Verify the correctness of arguments for node classification tasks. + """ + _ = self.target_ntype + _ = self.batch_size + _ = self.eval_metric + _ = self.label_field + _ = self.num_classes + _ = self.multilabel + _ = self.multilabel_weights + _ = self.imbalance_class_weights + + def verify_node_regression_arguments(self): + """ Verify the correctness of arguments for node regression tasks. + """ + _ = self.target_ntype + _ = self.batch_size + _ = self.eval_metric + _ = self.label_field + + def verify_edge_class_arguments(self): + """ Verify the correctness of arguments for edge classification tasks. + """ + _ = self.target_etype + _ = self.batch_size + _ = self.eval_metric + _ = self.label_field + _ = self.num_classes + _ = self.multilabel + _ = self.multilabel_weights + _ = self.imbalance_class_weights + _ = self.decoder_type + _ = self.num_decoder_basis + _ = self.decoder_edge_feat + + def verify_edge_regression_arguments(self): + """ Verify the correctness of arguments for edge regression tasks. + """ + _ = self.target_etype + _ = self.batch_size + _ = self.eval_metric + _ = self.label_field + _ = self.decoder_type + _ = self.num_decoder_basis + _ = self.decoder_edge_feat + + def verify_link_prediction_arguments(self): + """ Verify the correctness of arguments for link prediction tasks. + """ + _ = self.target_etype + _ = self.batch_size + _ = self.eval_metric + _ = self.train_etype + _ = self.eval_etype + _ = self.train_negative_sampler + _ = self.eval_negative_sampler + _ = self.num_negative_edges + _ = self.num_negative_edges_eval + _ = self.reverse_edge_types_map + _ = self.exclude_training_targets + _ = self.lp_loss_func + _ = self.lp_decoder_type + _ = self.gamma + _ = self.report_eval_per_type + + def verify_arguments(self, is_train): """ Verify the correctness of arguments. @@ -798,7 +707,7 @@ def handle_argument_conflicts(self): self._turn_off_gradient_checkpoint("GLEM model") # TODO(xiangsx): Add more check - ###################### Environment Info ###################### +###################### Environment Info ###################### @property def save_perf_results_path(self): """ Save performance flag @@ -1863,20 +1772,6 @@ def check_multilabel(multilabel): return check_multilabel(self._multilabel) return False - @staticmethod - def check_multilabel_weights(multilabel, multilabel_weights, num_classes): - assert multilabel is True, "Must be a multi-label classification task." - try: - weights = multilabel_weights.split(",") - weights = [float(w) for w in weights] - except Exception: # pylint: disable=broad-except - raise RuntimeError("The weights should in following format 0.1,0.2,0.1,0.0") - for w in weights: - assert w >= 0., "multilabel weights can not be negative values" - assert len(weights) == num_classes, \ - "Each class must have an assigned weight" - return th.tensor(weights) - @property def multilabel_weights(self): """Used to specify label weight of each class in a @@ -1885,6 +1780,20 @@ def multilabel_weights(self): The weights should be in the following format 0.1,0.2,0.3,0.1,0.0 """ + + def check_multilabel_weights(multilabel, multilabel_weights, num_classes): + assert multilabel is True, "Must be a multi-label classification task." + try: + weights = multilabel_weights.split(",") + weights = [float(w) for w in weights] + except Exception: # pylint: disable=broad-except + raise RuntimeError("The weights should in following format 0.1,0.2,0.1,0.0") + for w in weights: + assert w >= 0., "multilabel weights can not be negative values" + assert len(weights) == num_classes, \ + "Each class must have an assigned weight" + return th.tensor(weights) + if hasattr(self, "_num_classes") and isinstance(self.num_classes, dict): if hasattr(self, "_multilabel_weights"): multilabel = self.multilabel @@ -1893,19 +1802,18 @@ def multilabel_weights(self): ntype_weights = {} for ntype in num_classes: if ntype in multilabel_weights: - ntype_weights[ntype] = self.check_multilabel_weights( - multilabel[ntype], - multilabel_weights[ntype], - num_classes[ntype]) + ntype_weights[ntype] = check_multilabel_weights(multilabel[ntype], + multilabel_weights[ntype], + num_classes[ntype]) else: ntype_weights[ntype] = None return ntype_weights return {ntype: None for ntype in self.num_classes} else: if hasattr(self, "_multilabel_weights"): - return self.check_multilabel_weights(self.multilabel, - self._multilabel_weights, - self.num_classes) + return check_multilabel_weights(self.multilabel, + self._multilabel_weights, + self.num_classes) return None @@ -1927,19 +1835,6 @@ def return_proba(self): # By default, return all the predictions return True - @staticmethod - def check_imbalance_class_weights(imbalance_class_weights, num_classes): - try: - weights = imbalance_class_weights.split(",") - weights = [float(w) for w in weights] - except Exception: # pylint: disable=broad-except - raise RuntimeError("The weights should in following format 0.1,0.2,0.3,0.1") - for w in weights: - assert w > 0., "Each weight should be larger than 0." - assert len(weights) == num_classes, \ - "Each class must have an assigned weight" - return th.tensor(weights) - @property def imbalance_class_weights(self): """ Used to specify a manual rescaling weight given to each class @@ -1949,6 +1844,19 @@ def imbalance_class_weights(self): Customer should provide the weight in following format 0.1,0.2,0.3,0.1 """ + + def check_imbalance_class_weights(imbalance_class_weights, num_classes): + try: + weights = imbalance_class_weights.split(",") + weights = [float(w) for w in weights] + except Exception: # pylint: disable=broad-except + raise RuntimeError("The weights should in following format 0.1,0.2,0.3,0.1") + for w in weights: + assert w > 0., "Each weight should be larger than 0." + assert len(weights) == num_classes, \ + "Each class must have an assigned weight" + return th.tensor(weights) + if hasattr(self, "_num_classes") and isinstance(self.num_classes, dict): if hasattr(self, "_imbalance_class_weights"): assert isinstance(self._imbalance_class_weights, dict), \ @@ -1958,7 +1866,7 @@ def imbalance_class_weights(self): ntype_weights = {} for ntype in num_classes: if ntype in imbalance_class_weights: - ntype_weights[ntype] = self.check_imbalance_class_weights( + ntype_weights[ntype] = check_imbalance_class_weights( imbalance_class_weights[ntype], num_classes[ntype] ) @@ -1968,8 +1876,8 @@ def imbalance_class_weights(self): return {ntype: None for ntype in self.num_classes} else: if hasattr(self, "_imbalance_class_weights"): - return self.check_imbalance_class_weights(self._imbalance_class_weights, - self.num_classes) + return check_imbalance_class_weights(self._imbalance_class_weights, + self.num_classes) return None ###classification/regression inference related #### @@ -2018,40 +1926,6 @@ def eval_target_ntype(self): return None #### edge related task variables #### - @staticmethod - def parse_reverse_edge_type_map(reverse_edge_types_map): - """ Parse the reverse edge type map. - - Parameters - ---------- - reverse_edge_types_map: list - A list one etype and reverse edge type maps - in the format of head,relation,reverse relation,tail - - Return - ------ - A map: dict - """ - if reverse_edge_types_map is None: - return {} # empty dict - assert isinstance(reverse_edge_types_map, list), \ - "Reverse edge type map should has following format: " \ - "[\"head,relation,reverse relation,tail\", " \ - "\"head,relation,reverse relation,tail\", ...]" - - reverse_map = {} - try: - for etype_info in reverse_edge_types_map: - head, rel, rev_rel, tail = etype_info.split(",") - reverse_map[(head, rel, tail)] = (tail, rev_rel, head) - except Exception: # pylint: disable=broad-except - assert False, \ - "Reverse edge type map should has following format: " \ - "[\"head,relation,reverse relation,tail\", " \ - "\"head,relation,reverse relation,tail\", ...]" \ - f"But get {reverse_edge_types_map}" - return reverse_map - @property def reverse_edge_types_map(self): """ A list of reverse edge type info. @@ -2069,7 +1943,26 @@ def reverse_edge_types_map(self): # pylint: disable=no-member if hasattr(self, "_reverse_edge_types_map"): - return self.parse_reverse_edge_type_map(self._reverse_edge_types_map) + if self._reverse_edge_types_map is None: + return {} # empty dict + assert isinstance(self._reverse_edge_types_map, list), \ + "Reverse edge type map should has following format: " \ + "[\"head,relation,reverse relation,tail\", " \ + "\"head,relation,reverse relation,tail\", ...]" + + reverse_edge_types_map = {} + try: + for etype_info in self._reverse_edge_types_map: + head, rel, rev_rel, tail = etype_info.split(",") + reverse_edge_types_map[(head, rel, tail)] = (tail, rev_rel, head) + except Exception: # pylint: disable=broad-except + assert False, \ + "Reverse edge type map should has following format: " \ + "[\"head,relation,reverse relation,tail\", " \ + "\"head,relation,reverse relation,tail\", ...]" \ + f"But get {self._reverse_edge_types_map}" + + return reverse_edge_types_map # By default return an empty dict return {} @@ -2142,31 +2035,6 @@ def num_decoder_basis(self): # By default, return 2 return 2 - @staticmethod - def parse_decoder_edge_feat(decoder_edge_feats): - """ Parse decoder edge feat - - Parameter - --------- - decoder_edge_feats: list - Edge features that will be used by the decoder. - """ - assert len(decoder_edge_feats) == 1, \ - "We only support edge classifcation or regression on one edge type" - - if ":" not in decoder_edge_feats[0]: - # global feat_name - return decoder_edge_feats[0] - - # per edge type feature - feat_name = decoder_edge_feats[0] - feat_info = feat_name.split(":") - assert len(feat_info) == 2, \ - f"Unknown format of the feature name: {feat_name}, " + \ - "must be EDGE_TYPE:FEAT_NAME" - etype = tuple(feat_info[0].split(",")) - return {etype: feat_info[1].split(",")} - @property def decoder_edge_feat(self): """ A list of edge features that can be used by a decoder to @@ -2178,15 +2046,28 @@ def decoder_edge_feat(self): (BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION), \ "Decoder edge feature only works with " \ "edge classification or regression tasks" - decoder_edge_feat = self.parse_decoder_edge_feat(self._decoder_edge_feat) + decoder_edge_feats = self._decoder_edge_feat + assert len(decoder_edge_feats) == 1, \ + "We only support edge classifcation or regression on one edge type" + + if ":" not in decoder_edge_feats[0]: + # global feat_name + return decoder_edge_feats[0] + + # per edge type feature + feat_name = decoder_edge_feats[0] + feat_info = feat_name.split(":") + assert len(feat_info) == 2, \ + f"Unknown format of the feature name: {feat_name}, " + \ + "must be EDGE_TYPE:FEAT_NAME" + etype = tuple(feat_info[0].split(",")) + assert etype in self.target_etype, \ + f"{etype} must in the training edge type list {self.target_etype}" + return {etype: feat_info[1].split(",")} - if isinstance(decoder_edge_feat, dict): - for etype in decoder_edge_feat.keys(): - assert etype in self.target_etype, \ - f"{etype} must in the training edge type list {self.target_etype}" - return decoder_edge_feat return None + ### Link Prediction specific ### @property def train_negative_sampler(self): @@ -2230,7 +2111,7 @@ def num_negative_edges_eval(self): # pylint: disable=no-member if hasattr(self, "_num_negative_edges_eval"): assert self._num_negative_edges_eval > 0, \ - "Number of negative edges for evaluation must larger than 0" + "Number of negative edges must larger than 0" return self._num_negative_edges_eval # Set default value to 1000. return 1000 @@ -2454,22 +2335,6 @@ def eval_etypes_negative_dstnode(self): # By default fixed negative is not used return None - @staticmethod - def parse_lp_etype(etypes): - """ Parse and validate the input link prediction etypes - - Parameters - ---------- - etypes: str - Edge types - """ - if etypes is None: - return None - assert isinstance(etypes, list) - assert len(etypes) > 0 - - return [tuple(etype.split(',')) for etype in etypes] - @property def train_etype(self): """ The list of canonical etypes that will be added as @@ -2479,7 +2344,12 @@ def train_etype(self): """ # pylint: disable=no-member if hasattr(self, "_train_etype"): - return self.parse_lp_etype(self._train_etype) + if self._train_etype is None: + return None + assert isinstance(self._train_etype, list) + assert len(self._train_etype) > 0 + + return [tuple(train_etype.split(',')) for train_etype in self._train_etype] # By default return None, which means use all edge types return None @@ -2492,8 +2362,11 @@ def eval_etype(self): """ # pylint: disable=no-member if hasattr(self, "_eval_etype"): - return self.parse_lp_etype(self._eval_etype) - + if self._eval_etype is None: + return None + assert isinstance(self._eval_etype, list) + assert len(self._eval_etype) > 0 + return [tuple(eval_etype.split(',')) for eval_etype in self._eval_etype] # By default return None, which means use all edge types return None @@ -2567,116 +2440,6 @@ def report_eval_per_type(self): return False - @staticmethod - def check_classification_eval_metrics(eval_metric): - """ Check the classification evaluation metrics - - Parameter - --------- - eval_metric: str or list of str - Evaluation metric(s). - - Return - ------ - list of str - Evaluation metric(s). - """ - if isinstance(eval_metric, str): - eval_metric = eval_metric.lower() - assert eval_metric in SUPPORTED_CLASSIFICATION_METRICS, \ - f"Classification evaluation metric should be " \ - f"in {SUPPORTED_CLASSIFICATION_METRICS}" \ - f"but get {eval_metric}" - eval_metrics = [eval_metric] - elif isinstance(eval_metric, list) and len(eval_metric) > 0: - eval_metrics = [] - for metric in eval_metric: - metric = metric.lower() - assert metric in SUPPORTED_CLASSIFICATION_METRICS, \ - f"Classification evaluation metric should be " \ - f"in {SUPPORTED_CLASSIFICATION_METRICS}" \ - f"but get {eval_metric}" - eval_metrics.append(metric) - else: - assert False, "Classification evaluation metric " \ - "should be a string or a list of string" - # no eval_metric - return eval_metrics - - @staticmethod - def check_regression_eval_metrics(eval_metric): - """ Check the regression evaluation metrics - - Parameter - --------- - eval_metric: str or list of str - Evaluation metric(s). - - Return - ------ - list of str - Evaluation metric(s). - """ - if isinstance(eval_metric, str): - eval_metric = eval_metric.lower() - assert eval_metric in SUPPORTED_REGRESSION_METRICS, \ - f"Regression evaluation metric should be " \ - f"in {SUPPORTED_REGRESSION_METRICS}, " \ - f"but get {eval_metric}" - eval_metrics = [eval_metric] - elif isinstance(eval_metric, list) and len(eval_metric) > 0: - eval_metrics = [] - for metric in eval_metric: - metric = metric.lower() - assert metric in SUPPORTED_REGRESSION_METRICS, \ - f"Regression evaluation metric should be " \ - f"in {SUPPORTED_REGRESSION_METRICS}" \ - f"but get {eval_metric}" - eval_metrics.append(metric) - else: - assert False, "Regression evaluation metric " \ - "should be a string or a list of string" - # no eval_metric - - return eval_metrics - - @staticmethod - def check_lp_eval_metrics(eval_metric): - """ Check the link prediction evaluation metrics - - Parameter - --------- - eval_metric: str or list of str - Evaluation metric(s). - - Return - ------ - list of str - Evaluation metric(s). - """ - if isinstance(eval_metric, str): - eval_metric = eval_metric.lower() - assert eval_metric in SUPPORTED_LINK_PREDICTION_METRICS, \ - f"Link prediction evaluation metric should be " \ - f"in {SUPPORTED_LINK_PREDICTION_METRICS}" \ - f"but get {eval_metric}" - eval_metrics = [eval_metric] - elif isinstance(eval_metric, list) and len(eval_metric) > 0: - eval_metrics = [] - for metric in eval_metric: - metric = metric.lower() - assert metric in SUPPORTED_LINK_PREDICTION_METRICS, \ - f"Link prediction evaluation metric should be " \ - f"in {SUPPORTED_LINK_PREDICTION_METRICS}" \ - f"but get {eval_metric}" - eval_metrics.append(metric) - else: - assert False, "Link prediction evaluation metric " \ - "should be a string or a list of string" - # no eval_metric - - return eval_metrics - @property def eval_metric(self): """ Evaluation metric used during evaluation @@ -2698,18 +2461,75 @@ def eval_metric(self): # check evaluation metrics if hasattr(self, "_eval_metric"): - eval_metric = self.check_classification_eval_metrics(self._eval_metric) + if isinstance(self._eval_metric, str): + eval_metric = self._eval_metric.lower() + assert eval_metric in SUPPORTED_CLASSIFICATION_METRICS, \ + f"Classification evaluation metric should be " \ + f"in {SUPPORTED_CLASSIFICATION_METRICS}" \ + f"but get {self._eval_metric}" + eval_metric = [eval_metric] + elif isinstance(self._eval_metric, list) and len(self._eval_metric) > 0: + eval_metric = [] + for metric in self._eval_metric: + metric = metric.lower() + assert metric in SUPPORTED_CLASSIFICATION_METRICS, \ + f"Classification evaluation metric should be " \ + f"in {SUPPORTED_CLASSIFICATION_METRICS}" \ + f"but get {self._eval_metric}" + eval_metric.append(metric) + else: + assert False, "Classification evaluation metric " \ + "should be a string or a list of string" + # no eval_metric else: eval_metric = ["accuracy"] elif self.task_type in [BUILTIN_TASK_NODE_REGRESSION, \ BUILTIN_TASK_EDGE_REGRESSION]: if hasattr(self, "_eval_metric"): - eval_metric = self.check_regression_eval_metrics(self._eval_metric) + if isinstance(self._eval_metric, str): + eval_metric = self._eval_metric.lower() + assert eval_metric in SUPPORTED_REGRESSION_METRICS, \ + f"Regression evaluation metric should be " \ + f"in {SUPPORTED_REGRESSION_METRICS}, " \ + f"but get {self._eval_metric}" + eval_metric = [eval_metric] + elif isinstance(self._eval_metric, list) and len(self._eval_metric) > 0: + eval_metric = [] + for metric in self._eval_metric: + metric = metric.lower() + assert metric in SUPPORTED_REGRESSION_METRICS, \ + f"Regression evaluation metric should be " \ + f"in {SUPPORTED_REGRESSION_METRICS}" \ + f"but get {self._eval_metric}" + eval_metric.append(metric) + else: + assert False, "Regression evaluation metric " \ + "should be a string or a list of string" + # no eval_metric else: eval_metric = ["rmse"] elif self.task_type == BUILTIN_TASK_LINK_PREDICTION: if hasattr(self, "_eval_metric"): - eval_metric = self.check_lp_eval_metrics(self._eval_metric) + if isinstance(self._eval_metric, str): + eval_metric = self._eval_metric.lower() + assert eval_metric in SUPPORTED_LINK_PREDICTION_METRICS, \ + f"Link prediction evaluation metric should be " \ + f"in {SUPPORTED_LINK_PREDICTION_METRICS}" \ + f"but get {self._eval_metric}" + eval_metric = [eval_metric] + elif isinstance(self._eval_metric, list) and len(self._eval_metric) > 0: + eval_metric = [] + for metric in self._eval_metric: + metric = metric.lower() + assert metric in SUPPORTED_LINK_PREDICTION_METRICS, \ + f"Link prediction evaluation metric should be " \ + f"in {SUPPORTED_LINK_PREDICTION_METRICS}" \ + f"but get {self._eval_metric}" + eval_metric.append(metric) + else: + assert False, "Link prediction evaluation metric " \ + "should be a string or a list of string" + # no eval_metric else: eval_metric = ["mrr"] else: @@ -2733,20 +2553,6 @@ def model_select_etype(self): # Per edge type lp evaluation is disabled. return LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL - ###Multi task support #### - @property - def multi_tasks(self): - """ Definition of tasks in multi-task learning. - - Return: list of Tasks - """ - # pylint: disable=no-member - if hasattr(self, "_multi_tasks"): - assert len(self._multi_tasks) > 1, \ - "There must be at least two tasks for multi-task learning" - return self._multi_tasks - return None - @property def num_ffn_layers_in_input(self): """ Number of extra feedforward neural network layers in the input layer diff --git a/python/graphstorm/config/config.py b/python/graphstorm/config/config.py index 8b8b46d5a3..71cdc2b576 100644 --- a/python/graphstorm/config/config.py +++ b/python/graphstorm/config/config.py @@ -118,58 +118,5 @@ class TaskInfo: """ task_type : str task_id : str + task_config = None dataloader = None # dataloder - batch_size: int = 0 - mask_fields: list - task_weight: float - eval_metric : list - -@dataclasses.dataclass -class NodeClassTaskInfo(TaskInfo): - target_ntype : str - label_field : str - num_classes: str - multilabel: bool = False - multilabel_weights: str = None - imbalance_class_weights: str = None - - -@dataclasses.dataclass -class NodeRegressionTaskInfo(TaskInfo): - target_ntype : str - label_field : str - -@dataclasses.dataclass -class EdgeClassTaskInfo(TaskInfo): - target_etype : tuple - label_field : str - num_classes : str - multilabel: bool = False - multilabel_weights: str = None - imbalance_class_weights: str = None - decoder_type : str - num_decoder_basis : int - decoder_edge_feat : dict - -@dataclasses.dataclass -class EdgeRegressionTaskInfo(TaskInfo): - target_etype : tuple - label_field : str - decoder_type : str - num_decoder_basis : int - decoder_edge_feat : dict - -@dataclasses.dataclass -class LinkPredictionTaskInfo(TaskInfo): - train_etype : list - eval_etype : list - train_negative_sampler : str - eval_negative_sampler : str - num_negative_edges : int - num_negative_edges_eval : int - reverse_edge_types_map : dict - exclude_training_targets : bool - lp_loss_func : str - lp_decoder_type : str - gamma : float - report_eval_per_type : bool From ab62c4525ed716e03bf004379d46c0e7f8c29796 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 9 May 2024 16:42:23 -0700 Subject: [PATCH 12/79] Add unit tests --- python/graphstorm/config/argument.py | 33 +-- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 11 +- tests/unit-tests/test_config.py | 251 +++++++++++++++++++++ 3 files changed, 276 insertions(+), 19 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 2fcdb9aa96..8a902cbcd4 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -149,9 +149,9 @@ def __init__(self, cmd_args): # Load all arguments from yaml config configuration = self.load_yaml_config(cmd_args.yaml_config_file) + multi_task_config = None if 'multi_task_learning' in configuration: - # parse multi task learning config and save it into self._multi_tasks - self._parse_multi_tasks(configuration['multi_task_learning']) + multi_task_config = configuration['multi_task_learning'] del configuration['multi_task_learning'] self.set_attributes(configuration) @@ -168,6 +168,10 @@ def __init__(self, cmd_args): # We do argument check as early as possible to prevent config bugs. self.handle_argument_conflicts() + # parse multi task learning config and save it into self._multi_tasks + if multi_task_config is not None: + self._parse_multi_tasks(multi_task_config) + def set_attributes(self, configuration): """Set class attributes from 2nd level arguments in yaml config""" if 'lm_model' in configuration: @@ -296,10 +300,8 @@ def _parse_node_classification_task(self, task_config): task_id = get_mttask_id(task_type=task_type, ntype=target_ntype, label=label_field) - setattr(task_info, "task_type", task_type) setattr(task_info, "mask_fields", mask_fields) setattr(task_info, "task_weight", task_weight) - setattr(task_info, "task_id", task_id) return TaskInfo(task_type=task_type, task_id=task_id, @@ -326,10 +328,8 @@ def _parse_node_regression_task(self, task_config): task_id = get_mttask_id(task_type=task_type, ntype=target_ntype, label=label_field) - setattr(task_info, "task_type", task_type) setattr(task_info, "mask_fields", mask_fields) setattr(task_info, "task_weight", task_weight) - setattr(task_info, "task_id", task_id) return TaskInfo(task_type=task_type, task_id=task_id, @@ -356,10 +356,8 @@ def _parse_edge_classification_task(self, task_config): task_id = get_mttask_id(task_type=task_type, etype=target_etype, label=label_field) - setattr(task_info, "task_type", task_type) setattr(task_info, "mask_fields", mask_fields) setattr(task_info, "task_weight", task_weight) - setattr(task_info, "task_id", task_id) return TaskInfo(task_type=task_type, task_id=task_id, task_info=task_info) @@ -386,11 +384,8 @@ def _parse_edge_regression_task(self, task_config): task_id = get_mttask_id(task_type=task_type, etype=target_etype, label=label_field) - - setattr(task_info, "task_type", task_type) setattr(task_info, "mask_fields", mask_fields) setattr(task_info, "task_weight", task_weight) - setattr(task_info, "task_id", task_id) return TaskInfo(task_type=task_type, task_id=task_id, task_info=task_info) @@ -415,11 +410,8 @@ def _parse_link_prediction_task(self, task_config): task_id = get_mttask_id( task_type=task_type, etype=train_etype if train_etype is not None else "ALL_ETYPE") - - setattr(task_info, "task_type", task_type) setattr(task_info, "mask_fields", mask_fields) setattr(task_info, "task_weight", task_weight) - setattr(task_info, "task_id", task_id) return TaskInfo(task_type=task_type, task_id=task_id, task_info=task_info) @@ -435,6 +427,10 @@ def _parse_multi_tasks(self, multi_task_config): assert isinstance(task_config, dict) and len(task_config) == 1, \ "When defining multiple tasks for " \ "training, define one task each time." + if "batch_size" not in task_config: + # If batch_size is not set + # Use the global batch size. + task_config["batch_size"] = self.batch_size if "node_classification" in task_config: task = self._parse_node_classification_task( task_config["node_classification"]) @@ -2589,6 +2585,15 @@ def num_ffn_layers_in_decoder(self): # Set default mlp layer number between gnn layer to 0 return 0 + ################## Multi task learning ################## + @property + def multi_tasks(self): + """ Tasks in multi-task learning + """ + assert hasattr(self, "_multi_tasks"), \ + "multi_task_learning must be set in the task config" + return self._multi_tasks + def _add_initialization_args(parser): group = parser.add_argument_group(title="initialization") group.add_argument( diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index d5108b5698..13c7b66472 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -189,11 +189,11 @@ def create_task_test_dataloader(task, config, train_data): def create_task_decoder(task, g, decoder_input_dim, train_task): if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - return gs.create_builtin_node_decoder(decoder_input_dim, task, train_task) + return gs.create_builtin_node_decoder(decoder_input_dim, task.task_config, train_task) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - return gs.create_builtin_edge_decoder(g, decoder_input_dim, task, train_task) + return gs.create_builtin_edge_decoder(g, decoder_input_dim, task.task_config, train_task) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - return gs.create_builtin_lp_decoder(g, decoder_input_dim, task, train_task) + return gs.create_builtin_lp_decoder(g, decoder_input_dim, task.task_config, train_task) return None, None @@ -278,14 +278,15 @@ def main(config_args): if model.gnn_encoder is not None \ else model.node_input_encoder.out_dims for task in tasks: + task_config = task.task_config train_loader = create_task_train_dataloader(task, config, train_data) val_loader = create_task_val_dataloader(task, config) test_loader = create_task_test_dataloader(task, config) train_dataloaders.append((task, train_loader)) val_dataloaders.append((task, val_loader)) test_dataloaders.append((task, test_loader)) - decoder, loss_func = create_task_decoder(task, g, encoder_out_dims, train_task=True) - model.add_task(task.task_id, task.task_type, decoder, loss_func, task.weight) + decoder, loss_func = create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) + model.add_task(task.task_id, task.task_type, decoder, loss_func, task_config.weight) if not config.no_validation: if val_loader is None: logging.warning("The training data do not have validation set.") diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index dc9b218ac1..ddaeb080d8 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -30,6 +30,11 @@ from graphstorm.config.config import (BUILTIN_LP_LOSS_CROSS_ENTROPY, BUILTIN_LP_LOSS_LOGSIGMOID_RANKING, BUILTIN_LP_LOSS_CONTRASTIVELOSS) +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) from graphstorm.config.config import GRAPHSTORM_LP_EMB_L2_NORMALIZATION from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER @@ -1587,7 +1592,253 @@ def test_id_mapping_file(): assert config.node_id_mapping_file == part_path assert config.edge_id_mapping_file == part_path +def create_dummy_nc_config(): + return { + "target_ntype": "a", + "label_field": "label_c", + "multilabel": True, + "num_classes": 20, + "eval_metric": ["F1_score", "precision_recall", "ROC_AUC"], + "multilabel_weights": "1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2", + "imbalance_class_weights": "1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2", + "batch_size": 20, + "task_weight": 1, + "mask_fields": ["class_train_mask", "class_eval_mask", "class_test_mask"] + } + +def create_dummy_nr_config(): + return { + "target_ntype": "a", + "label_field": "label_r", + "task_weight": 0.5, + "mask_fields": ["reg_train_mask", "reg_eval_mask", "reg_test_mask"] + } + +def create_dummy_ec_config(): + return { + "target_etype": ["query,match,asin"], + "reverse_edge_types_map": [], + "label_field": "label_ec", + "multilabel": True, + "num_classes": 4, + "num_decoder_basis": 4, + "remove_target_edge_type": False, + "decoder_type": "MLPDecoder", + "decoder_edge_feat": "feat", + "eval_metric": ["precision_recall"], + "multilabel_weights": "1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2", + "imbalance_class_weights": "1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2", + "batch_size": 20, + "task_weight": 1, + "mask_fields": ["ec_train_mask", "ec_eval_mask", "ec_test_mask"] + } + +def create_dummy_er_config(): + return { + "target_etype": ["query,match-2,asin"], + "label_field": "label_er", + "eval_metric": ["mse"], + "decoder_edge_feat": ["feat1", "feat2"], + "task_weight": 1, + "mask_fields": ["er_train_mask", "er_eval_mask", "er_test_mask"] + } + +def create_dummy_lp_config(): + return { + "train_negative_sampler": BUILTIN_LP_JOINT_NEG_SAMPLER, + "num_negative_edges": 4, + "num_negative_edges_eval": 100, + "train_etype": ["query,exactmatch,asin"], + "eval_etype": ["query,exactmatch,asin"], + "exclude_training_targets": True, + "reverse_edge_types_map": ["query,exactmatch,rev-exactmatch,asin"], + "gamma": 2.0, + "lp_loss_func": BUILTIN_LP_LOSS_CROSS_ENTROPY, + "lp_embed_normalizer": GRAPHSTORM_LP_EMB_L2_NORMALIZATION, + "lp_decoder_type": BUILTIN_LP_DOT_DECODER, + "eval_metric": "MRR", + "lp_edge_weight_for_loss": ["weight"], + "task_weight": 1, + "mask_fields": ["lp_train_mask", "lp_eval_mask", "lp_test_mask"] + } + +def create_dummy_lp_config2(): + return { + "lp_loss_func": BUILTIN_LP_LOSS_CONTRASTIVELOSS, + "lp_decoder_type": BUILTIN_LP_DISTMULT_DECODER, + "task_weight": 2, + "mask_fields": ["lp2_train_mask", "lp2_eval_mask", "lp2_test_mask"] + } + +def create_multi_task_config(tmp_path, file_name): + yaml_object = create_dummpy_config_obj() + yaml_object["gsf"]["basic"] = { + "backend": "gloo", + } + yaml_object["gsf"]["hyperparam"] = { + "batch_size": 64, + "eval_batch_size": 128, + } + yaml_object["multi_task_learning"] = [ + { + BUILTIN_TASK_NODE_CLASSIFICATION : create_dummy_nc_config() + }, + { + BUILTIN_TASK_NODE_REGRESSION : create_dummy_nr_config() + }, + { + BUILTIN_TASK_EDGE_CLASSIFICATION : create_dummy_ec_config() + }, + { + BUILTIN_TASK_EDGE_REGRESSION : create_dummy_er_config() + + }, + { + BUILTIN_TASK_LINK_PREDICTION : create_dummy_lp_config() + }, + { + BUILTIN_TASK_LINK_PREDICTION : create_dummy_lp_config2() + } + ] + + with open(os.path.join(tmp_path, file_name+"_default.yaml"), "w") as f: + yaml.dump(yaml_object, f) + +def test_multi_task_config(): + with tempfile.TemporaryDirectory() as tmpdirname: + create_rgcn_config(Path(tmpdirname), 'multi_task_test') + + args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'multi_task_test_default.yaml'), local_rank=0) + config = GSConfig(args) + + assert len(config.multi_task) == 6 + nc_config = config.multi_task[0] + assert nc_config.task_type == BUILTIN_TASK_NODE_CLASSIFICATION + assert nc_config.task_weight == 1 + assert len(nc_config.mask_fields) == 3 + assert nc_config.mask_fields[0] == "class_train_mask" + assert nc_config.mask_fields[1] == "class_eval_mask" + assert nc_config.mask_fields[2] == "class_test_mask" + nc_config = nc_config.task_config + assert nc_config.target_ntype == "a" + assert nc_config.label_field == "label_c" + assert nc_config.multilabel == True + assert nc_config.num_classes == 20 + assert len(nc_config.eval_metric) == 3 + assert nc_config.eval_metric[0] == "f1_score" + assert nc_config.eval_metric[1] == "precision_recall" + assert nc_config.eval_metric[2] == "roc_auc" + assert nc_config.imbalance_class_weights.tolist() == [1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2] + assert nc_config.multilabel_weights.tolist() == [1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2] + assert nc_config.batch_size == 20 + + nr_config = config.multi_task[1] + assert nr_config.task_type == BUILTIN_TASK_NODE_REGRESSION + assert nr_config.task_weight == 0.5 + assert len(nr_config.mask_fields) == 3 + assert nr_config.mask_fields[0] == "reg_train_mask" + assert nr_config.mask_fields[1] == "reg_eval_mask" + assert nr_config.mask_fields[2] == "reg_test_mask" + nr_config = nr_config.task_config + assert nr_config.target_ntype == "a" + assert nr_config.label_field == "label_r" + assert len(nr_config.eval_metric) == 1 + assert nr_config.eval_metric[0] == "rmse" + assert nr_config.batch_size == 64 + + ec_config = config.multi_task[2] + assert ec_config.task_type == BUILTIN_TASK_EDGE_CLASSIFICATION + assert ec_config.task_weight == 1 + assert len(ec_config.mask_fields) == 3 + assert ec_config.mask_fields[0] == "ec_train_mask" + assert ec_config.mask_fields[1] == "ec_eval_mask" + assert ec_config.mask_fields[2] == "ec_test_mask" + ec_config = ec_config.task_config + assert ec_config.target_etype[0] == ("query", "match", "asin") + assert ec_config.label_field == "label_ec" + assert ec_config.multilabel == True + assert ec_config.num_classes == 4 + assert ec_config.num_decoder_basis == 4 + assert ec_config.remove_target_edge_type == False + assert ec_config.decoder_type == "MLPDecoder" + assert ec_config.decoder_edge_feat == "feat" + assert len(ec_config.eval_metric) == 1 + assert ec_config.eval_metric[0] == "precision_recall" + assert ec_config.batch_size == 20 + assert ec_config.imbalance_class_weights.tolist() == [1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2] + assert ec_config.multilabel_weights.tolist() == [1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2] + + er_config = config.multi_task[3] + assert er_config.task_type == BUILTIN_TASK_EDGE_REGRESSION + assert er_config.task_weight == 1 + assert len(er_config.mask_fields) == 3 + assert er_config.mask_fields[0] == "er_train_mask" + assert er_config.mask_fields[1] == "er_eval_mask" + assert er_config.mask_fields[2] == "er_test_mask" + er_config = er_config.task_config + assert er_config.target_etype[0] == ("query", "match-2", "asin") + assert er_config.label_field == "label_er" + assert len(er_config.eval_metric) == 1 + assert er_config.eval_metric[0] == "mse" + assert er_config.decoder_edge_feat == ["feat1", "feat2"] + assert er_config.batch_size == 64 + assert ec_config.remove_target_edge_type == True + assert ec_config.decoder_type == "DenseBiDecoder" + assert ec_config.num_decoder_basis == 2 + + lp_config = config.multi_task[4] + assert lp_config.task_type == BUILTIN_TASK_LINK_PREDICTION + assert lp_config.task_weight == 1 + assert len(lp_config.mask_fields) == 3 + assert lp_config.mask_fields[0] == "lp_train_mask" + assert lp_config.mask_fields[1] == "lp_eval_mask" + assert lp_config.mask_fields[2] == "lp_test_mask" + lp_config = lp_config.task_config + assert lp_config.train_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER + assert lp_config.num_negative_edges == 4 + assert lp_config.num_negative_edges_eval == 100 + assert len(lp_config.train_etype) == 1 + assert lp_config.train_etype[0] == ("query", "exactmatch", "asin") + assert len(lp_config.eval_etype) == 1 + assert lp_config.eval_etype[0] == ("query", "exactmatch", "asin") + assert lp_config.exclude_training_targets == True + assert len(lp_config.reverse_edge_types_map) == 1 + assert lp_config.reverse_edge_types_map[("query", "exactmatch","asin")] == \ + ("asin", "rev-exactmatch","query") + assert lp_config.gamma == 2.0 + assert lp_config.lp_loss_func == BUILTIN_LP_LOSS_CROSS_ENTROPY + assert lp_config.lp_embed_normalizer == GRAPHSTORM_LP_EMB_L2_NORMALIZATION + assert lp_config.lp_decoder_type == BUILTIN_LP_DOT_DECODER + assert len(lp_config.eval_metric) == 1 + assert lp_config.eval_metric[0] == "mrr" + assert lp_config.lp_edge_weight_for_loss == "weight" + + + lp_config = config.multi_task[5] + assert lp_config.task_type == BUILTIN_TASK_LINK_PREDICTION + assert lp_config.task_weight == 2 + assert len(lp_config.mask_fields) == 3 + assert lp_config.mask_fields[0] == "lp2_train_mask" + assert lp_config.mask_fields[1] == "lp2_eval_mask" + assert lp_config.mask_fields[2] == "lp2_test_mask" + lp_config = lp_config.task_config + assert lp_config.train_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER + assert lp_config.num_negative_edges == 16 + assert lp_config.train_etype == None + assert lp_config.eval_etype == None + check_failure(lp_config, "exclude_training_targets") + assert len(lp_config.reverse_edge_types_map) == 0 + assert lp_config.gamma == 12.0 + assert lp_config.lp_loss_func == BUILTIN_LP_LOSS_CONTRASTIVELOSS + assert lp_config.lp_embed_normalizer == GRAPHSTORM_LP_EMB_L2_NORMALIZATION + assert lp_config.lp_decoder_type == BUILTIN_LP_DISTMULT_DECODER + assert len(lp_config.eval_metric) == 1 + assert lp_config.eval_metric[0] == "mrr" + assert config.lp_edge_weight_for_loss == None + assert config.model_select_etype == LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL + if __name__ == '__main__': + test_multi_task_config() test_id_mapping_file() test_load_basic_info() test_gnn_info() From b6a72676575dd82af25bb94029c4881fc0ea0b54 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 9 May 2024 17:31:58 -0700 Subject: [PATCH 13/79] Fix bugs --- python/graphstorm/config/argument.py | 59 +++++++++++++++++----------- python/graphstorm/config/config.py | 7 ++-- tests/unit-tests/test_config.py | 30 +++++++------- 3 files changed, 55 insertions(+), 41 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 8a902cbcd4..9d4e3d6ed8 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -153,6 +153,7 @@ def __init__(self, cmd_args): if 'multi_task_learning' in configuration: multi_task_config = configuration['multi_task_learning'] del configuration['multi_task_learning'] + print(multi_task_config) self.set_attributes(configuration) # Override class attributes using command-line arguments @@ -168,6 +169,7 @@ def __init__(self, cmd_args): # We do argument check as early as possible to prevent config bugs. self.handle_argument_conflicts() + print(multi_task_config) # parse multi task learning config and save it into self._multi_tasks if multi_task_config is not None: self._parse_multi_tasks(multi_task_config) @@ -253,7 +255,7 @@ def set_task_attributes(self, configuration): Task specific config """ for key, val in configuration.items(): - setattr(self, key, val) + setattr(self, f"_{key}", val) def _parse_general_task_config(self, task_config): """ Parse the genral task info @@ -277,7 +279,8 @@ def _parse_general_task_config(self, task_config): task_weight = task_config["task_weight"] assert task_weight > 0, f"task_weight should be larger than 0, but get {task_weight}" - return mask_fields, task_weight + batch_size = self.batch_size if "batch_size" not in task_config else task_config["batch_size"] + return mask_fields, task_weight, batch_size def _parse_node_classification_task(self, task_config): """ Parse the node classification task info @@ -288,12 +291,15 @@ def _parse_node_classification_task(self, task_config): Node classification task config """ task_type = BUILTIN_TASK_NODE_CLASSIFICATION + mask_fields, task_weight, batch_size = \ + self._parse_general_task_config(task_config) + task_config["batch_size"] = batch_size + task_info = GSConfig.__new__(GSConfig) task_info.set_task_attributes(task_config) + setattr(task_info, "_task_type", task_type) task_info.verify_node_class_arguments() - mask_fields, task_weight = \ - self._parse_general_task_config(task_config) target_ntype = task_info.target_ntype label_field = task_info.label_field @@ -305,7 +311,7 @@ def _parse_node_classification_task(self, task_config): return TaskInfo(task_type=task_type, task_id=task_id, - task_info=task_info) + task_config=task_info) def _parse_node_regression_task(self, task_config): """ Parse the node regression task info @@ -316,12 +322,15 @@ def _parse_node_regression_task(self, task_config): Node regression task config """ task_type = BUILTIN_TASK_NODE_REGRESSION + mask_fields, task_weight, batch_size = \ + self._parse_general_task_config(task_config) + task_config["batch_size"] = batch_size + task_info = GSConfig.__new__(GSConfig) task_info.set_task_attributes(task_config) + setattr(task_info, "_task_type", task_type) task_info.verify_node_regression_arguments() - mask_fields, task_weight = \ - self._parse_general_task_config(task_config) target_ntype = task_info.target_ntype label_field = task_info.label_field @@ -333,7 +342,7 @@ def _parse_node_regression_task(self, task_config): return TaskInfo(task_type=task_type, task_id=task_id, - task_info=task_info) + task_config=task_info) def _parse_edge_classification_task(self, task_config): """ Parse the edge classification task info @@ -344,12 +353,15 @@ def _parse_edge_classification_task(self, task_config): Edge classification task config """ task_type = BUILTIN_TASK_EDGE_CLASSIFICATION + mask_fields, task_weight, batch_size = \ + self._parse_general_task_config(task_config) + task_config["batch_size"] = batch_size + task_info = GSConfig.__new__(GSConfig) task_info.set_task_attributes(task_config) + setattr(task_info, "_task_type", task_type) task_info.verify_edge_class_arguments() - mask_fields, task_weight = \ - self._parse_general_task_config(task_config) target_etype = task_info.target_etype label_field = task_info.label_field @@ -360,7 +372,7 @@ def _parse_edge_classification_task(self, task_config): setattr(task_info, "task_weight", task_weight) return TaskInfo(task_type=task_type, task_id=task_id, - task_info=task_info) + task_config=task_info) def _parse_edge_regression_task(self, task_config): """ Parse the edge regression task info @@ -371,13 +383,15 @@ def _parse_edge_regression_task(self, task_config): Edge regression task config """ task_type = BUILTIN_TASK_EDGE_REGRESSION + mask_fields, task_weight, batch_size = \ + self._parse_general_task_config(task_config) + task_config["batch_size"] = batch_size + task_info = GSConfig.__new__(GSConfig) task_info.set_task_attributes(task_config) + setattr(task_info, "_task_type", task_type) task_info.verify_edge_regression_arguments() - mask_fields, task_weight = \ - self._parse_general_task_config(task_config) - target_etype = task_info.target_etype label_field = task_info.label_field @@ -388,7 +402,7 @@ def _parse_edge_regression_task(self, task_config): setattr(task_info, "task_weight", task_weight) return TaskInfo(task_type=task_type, task_id=task_id, - task_info=task_info) + task_config=task_info) def _parse_link_prediction_task(self, task_config): """ Parse the link prediction task info @@ -399,14 +413,16 @@ def _parse_link_prediction_task(self, task_config): Link prediction task config """ task_type = BUILTIN_TASK_LINK_PREDICTION + mask_fields, task_weight, batch_size = \ + self._parse_general_task_config(task_config) + task_config["batch_size"] = batch_size + task_info = GSConfig.__new__(GSConfig) task_info.set_task_attributes(task_config) + setattr(task_info, "_task_type", task_type) task_info.verify_edge_regression_arguments() - mask_fields, task_weight = \ - self._parse_general_task_config(task_config) train_etype = task_info.train_etype - task_id = get_mttask_id( task_type=task_type, etype=train_etype if train_etype is not None else "ALL_ETYPE") @@ -414,7 +430,7 @@ def _parse_link_prediction_task(self, task_config): setattr(task_info, "task_weight", task_weight) return TaskInfo(task_type=task_type, task_id=task_id, - task_info=task_info) + task_config=task_info) def _parse_multi_tasks(self, multi_task_config): """ Parse multi-task configuration @@ -427,10 +443,7 @@ def _parse_multi_tasks(self, multi_task_config): assert isinstance(task_config, dict) and len(task_config) == 1, \ "When defining multiple tasks for " \ "training, define one task each time." - if "batch_size" not in task_config: - # If batch_size is not set - # Use the global batch size. - task_config["batch_size"] = self.batch_size + if "node_classification" in task_config: task = self._parse_node_classification_task( task_config["node_classification"]) diff --git a/python/graphstorm/config/config.py b/python/graphstorm/config/config.py index 71cdc2b576..81b39c6898 100644 --- a/python/graphstorm/config/config.py +++ b/python/graphstorm/config/config.py @@ -16,6 +16,7 @@ Builtin configs """ import dataclasses +import typing BUILTIN_GNN_ENCODER = ["gat", "rgat", "rgcn", "sage", "hgt", "gatv2"] BUILTIN_ENCODER = ["lm", "mlp"] + BUILTIN_GNN_ENCODER @@ -86,7 +87,7 @@ def get_mttask_id(task_type, ntype=None, etype=None, label=None): elif isinstance(etype, tuple): task_id.append("_".join(etype)) elif isinstance(etype, list): # a list of etypes - task_id.append("__".joint(["_".join(et) for et in etype])) + task_id.append("__".join(["_".join(et) for et in etype])) else: raise TypeError("Unknown etype format: %s. Must be a string " \ "or a tuple of strings or a list of tuples of strings.", etype) @@ -118,5 +119,5 @@ class TaskInfo: """ task_type : str task_id : str - task_config = None - dataloader = None # dataloder + task_config : typing.Any = None + dataloader : typing.Any = None # dataloder diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index ddaeb080d8..61bfc388e5 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -1624,10 +1624,10 @@ def create_dummy_ec_config(): "num_decoder_basis": 4, "remove_target_edge_type": False, "decoder_type": "MLPDecoder", - "decoder_edge_feat": "feat", + "decoder_edge_feat": ["feat"], "eval_metric": ["precision_recall"], - "multilabel_weights": "1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2", - "imbalance_class_weights": "1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2", + "multilabel_weights": "1,2,3,1", + "imbalance_class_weights": "1,2,3,1", "batch_size": 20, "task_weight": 1, "mask_fields": ["ec_train_mask", "ec_eval_mask", "ec_test_mask"] @@ -1638,7 +1638,7 @@ def create_dummy_er_config(): "target_etype": ["query,match-2,asin"], "label_field": "label_er", "eval_metric": ["mse"], - "decoder_edge_feat": ["feat1", "feat2"], + "decoder_edge_feat": ["query,no-match,asin:feat0,feat1"], "task_weight": 1, "mask_fields": ["er_train_mask", "er_eval_mask", "er_test_mask"] } @@ -1706,13 +1706,13 @@ def create_multi_task_config(tmp_path, file_name): def test_multi_task_config(): with tempfile.TemporaryDirectory() as tmpdirname: - create_rgcn_config(Path(tmpdirname), 'multi_task_test') + create_multi_task_config(Path(tmpdirname), 'multi_task_test') args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'multi_task_test_default.yaml'), local_rank=0) config = GSConfig(args) - assert len(config.multi_task) == 6 - nc_config = config.multi_task[0] + assert len(config.multi_tasks) == 6 + nc_config = config.multi_tasks[0] assert nc_config.task_type == BUILTIN_TASK_NODE_CLASSIFICATION assert nc_config.task_weight == 1 assert len(nc_config.mask_fields) == 3 @@ -1732,7 +1732,7 @@ def test_multi_task_config(): assert nc_config.multilabel_weights.tolist() == [1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2] assert nc_config.batch_size == 20 - nr_config = config.multi_task[1] + nr_config = config.multi_tasks[1] assert nr_config.task_type == BUILTIN_TASK_NODE_REGRESSION assert nr_config.task_weight == 0.5 assert len(nr_config.mask_fields) == 3 @@ -1746,7 +1746,7 @@ def test_multi_task_config(): assert nr_config.eval_metric[0] == "rmse" assert nr_config.batch_size == 64 - ec_config = config.multi_task[2] + ec_config = config.multi_tasks[2] assert ec_config.task_type == BUILTIN_TASK_EDGE_CLASSIFICATION assert ec_config.task_weight == 1 assert len(ec_config.mask_fields) == 3 @@ -1765,10 +1765,10 @@ def test_multi_task_config(): assert len(ec_config.eval_metric) == 1 assert ec_config.eval_metric[0] == "precision_recall" assert ec_config.batch_size == 20 - assert ec_config.imbalance_class_weights.tolist() == [1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2] - assert ec_config.multilabel_weights.tolist() == [1,2,3,1,2,1,2,3,1,2,1,2,3,1,2,1,2,3,1,2] + assert ec_config.imbalance_class_weights.tolist() == [1,2,3,1] + assert ec_config.multilabel_weights.tolist() == [1,2,3,1] - er_config = config.multi_task[3] + er_config = config.multi_tasks[3] assert er_config.task_type == BUILTIN_TASK_EDGE_REGRESSION assert er_config.task_weight == 1 assert len(er_config.mask_fields) == 3 @@ -1780,13 +1780,13 @@ def test_multi_task_config(): assert er_config.label_field == "label_er" assert len(er_config.eval_metric) == 1 assert er_config.eval_metric[0] == "mse" - assert er_config.decoder_edge_feat == ["feat1", "feat2"] + assert er_config.decoder_edge_feat == ["feat0", "feat1"] assert er_config.batch_size == 64 assert ec_config.remove_target_edge_type == True assert ec_config.decoder_type == "DenseBiDecoder" assert ec_config.num_decoder_basis == 2 - lp_config = config.multi_task[4] + lp_config = config.multi_tasks[4] assert lp_config.task_type == BUILTIN_TASK_LINK_PREDICTION assert lp_config.task_weight == 1 assert len(lp_config.mask_fields) == 3 @@ -1814,7 +1814,7 @@ def test_multi_task_config(): assert lp_config.lp_edge_weight_for_loss == "weight" - lp_config = config.multi_task[5] + lp_config = config.multi_tasks[5] assert lp_config.task_type == BUILTIN_TASK_LINK_PREDICTION assert lp_config.task_weight == 2 assert len(lp_config.mask_fields) == 3 From 71ad941e0e849c4235f40abcdf346a17b9d378ad Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 9 May 2024 22:41:39 -0700 Subject: [PATCH 14/79] Fix CI --- python/graphstorm/config/argument.py | 4 +-- python/graphstorm/trainer/mt_trainer.py | 6 ++--- tests/unit-tests/test_config.py | 34 +++++++++++++++---------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 9d4e3d6ed8..6a129bb6a2 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -153,7 +153,6 @@ def __init__(self, cmd_args): if 'multi_task_learning' in configuration: multi_task_config = configuration['multi_task_learning'] del configuration['multi_task_learning'] - print(multi_task_config) self.set_attributes(configuration) # Override class attributes using command-line arguments @@ -169,7 +168,6 @@ def __init__(self, cmd_args): # We do argument check as early as possible to prevent config bugs. self.handle_argument_conflicts() - print(multi_task_config) # parse multi task learning config and save it into self._multi_tasks if multi_task_config is not None: self._parse_multi_tasks(multi_task_config) @@ -420,7 +418,7 @@ def _parse_link_prediction_task(self, task_config): task_info = GSConfig.__new__(GSConfig) task_info.set_task_attributes(task_config) setattr(task_info, "_task_type", task_type) - task_info.verify_edge_regression_arguments() + task_info.verify_link_prediction_arguments() train_etype = task_info.train_etype task_id = get_mttask_id( diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index fcade02da8..fb8f8608dd 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -52,7 +52,7 @@ def run_node_predict_mini_batch(model, data, task_info, mini_batch, device): # TODO: we don't support edge features for now. loss = model(task_info.task_id, ((blocks, input_feats, None, input_nodes), lbl)) - return loss + return loss, task_info.task_config.task_weight def run_edge_predict_mini_batch(model, data, task_info, mini_batch, device): input_nodes, batch_graph, blocks = mini_batch @@ -90,7 +90,7 @@ def run_edge_predict_mini_batch(model, data, task_info, mini_batch, device): loss = model(task_info.task_id, ((blocks, input_feats, None, input_nodes), (batch_graph, edge_decoder_feats, lbl))) - return loss + return loss, task_info.task_config.task_weight def run_link_predict_mini_batch(model, data, task_info, mini_batch, device): input_nodes, pos_graph, neg_graph, blocks = mini_batch @@ -119,7 +119,7 @@ def run_link_predict_mini_batch(model, data, task_info, mini_batch, device): loss = model(task_info.task_id, ((blocks, input_feats, None, input_nodes), (pos_graph, neg_graph,pos_graph_feats, None))) - return loss + return loss, task_info.task_config.task_weight def multi_task_mini_batch_predict( model, emb, loader, device, return_proba=True, return_label=False): diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index 61bfc388e5..6b548397cd 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -1638,7 +1638,7 @@ def create_dummy_er_config(): "target_etype": ["query,match-2,asin"], "label_field": "label_er", "eval_metric": ["mse"], - "decoder_edge_feat": ["query,no-match,asin:feat0,feat1"], + "decoder_edge_feat": ["query,match-2,asin:feat0,feat1"], "task_weight": 1, "mask_fields": ["er_train_mask", "er_eval_mask", "er_test_mask"] } @@ -1667,7 +1667,8 @@ def create_dummy_lp_config2(): "lp_loss_func": BUILTIN_LP_LOSS_CONTRASTIVELOSS, "lp_decoder_type": BUILTIN_LP_DISTMULT_DECODER, "task_weight": 2, - "mask_fields": ["lp2_train_mask", "lp2_eval_mask", "lp2_test_mask"] + "mask_fields": ["lp2_train_mask", "lp2_eval_mask", "lp2_test_mask"], + "exclude_training_targets": False } def create_multi_task_config(tmp_path, file_name): @@ -1714,12 +1715,13 @@ def test_multi_task_config(): assert len(config.multi_tasks) == 6 nc_config = config.multi_tasks[0] assert nc_config.task_type == BUILTIN_TASK_NODE_CLASSIFICATION + assert nc_config.task_id == f"{BUILTIN_TASK_NODE_CLASSIFICATION}-a-label_c" + nc_config = nc_config.task_config assert nc_config.task_weight == 1 assert len(nc_config.mask_fields) == 3 assert nc_config.mask_fields[0] == "class_train_mask" assert nc_config.mask_fields[1] == "class_eval_mask" assert nc_config.mask_fields[2] == "class_test_mask" - nc_config = nc_config.task_config assert nc_config.target_ntype == "a" assert nc_config.label_field == "label_c" assert nc_config.multilabel == True @@ -1734,12 +1736,13 @@ def test_multi_task_config(): nr_config = config.multi_tasks[1] assert nr_config.task_type == BUILTIN_TASK_NODE_REGRESSION + assert nr_config.task_id == f"{BUILTIN_TASK_NODE_REGRESSION}-a-label_r" + nr_config = nr_config.task_config assert nr_config.task_weight == 0.5 assert len(nr_config.mask_fields) == 3 assert nr_config.mask_fields[0] == "reg_train_mask" assert nr_config.mask_fields[1] == "reg_eval_mask" assert nr_config.mask_fields[2] == "reg_test_mask" - nr_config = nr_config.task_config assert nr_config.target_ntype == "a" assert nr_config.label_field == "label_r" assert len(nr_config.eval_metric) == 1 @@ -1748,12 +1751,13 @@ def test_multi_task_config(): ec_config = config.multi_tasks[2] assert ec_config.task_type == BUILTIN_TASK_EDGE_CLASSIFICATION + assert ec_config.task_id == f"{BUILTIN_TASK_EDGE_CLASSIFICATION}-query_match_asin-label_ec" + ec_config = ec_config.task_config assert ec_config.task_weight == 1 assert len(ec_config.mask_fields) == 3 assert ec_config.mask_fields[0] == "ec_train_mask" assert ec_config.mask_fields[1] == "ec_eval_mask" assert ec_config.mask_fields[2] == "ec_test_mask" - ec_config = ec_config.task_config assert ec_config.target_etype[0] == ("query", "match", "asin") assert ec_config.label_field == "label_ec" assert ec_config.multilabel == True @@ -1770,30 +1774,33 @@ def test_multi_task_config(): er_config = config.multi_tasks[3] assert er_config.task_type == BUILTIN_TASK_EDGE_REGRESSION + assert er_config.task_id == f"{BUILTIN_TASK_EDGE_REGRESSION}-query_match-2_asin-label_er" + er_config = er_config.task_config assert er_config.task_weight == 1 assert len(er_config.mask_fields) == 3 assert er_config.mask_fields[0] == "er_train_mask" assert er_config.mask_fields[1] == "er_eval_mask" assert er_config.mask_fields[2] == "er_test_mask" - er_config = er_config.task_config assert er_config.target_etype[0] == ("query", "match-2", "asin") assert er_config.label_field == "label_er" assert len(er_config.eval_metric) == 1 assert er_config.eval_metric[0] == "mse" - assert er_config.decoder_edge_feat == ["feat0", "feat1"] + assert len(er_config.decoder_edge_feat) == 1 + assert er_config.decoder_edge_feat[("query","match-2","asin")] == ["feat0", "feat1"] assert er_config.batch_size == 64 - assert ec_config.remove_target_edge_type == True - assert ec_config.decoder_type == "DenseBiDecoder" - assert ec_config.num_decoder_basis == 2 + assert er_config.remove_target_edge_type == True + assert er_config.decoder_type == "DenseBiDecoder" + assert er_config.num_decoder_basis == 2 lp_config = config.multi_tasks[4] assert lp_config.task_type == BUILTIN_TASK_LINK_PREDICTION + assert lp_config.task_id == f"{BUILTIN_TASK_LINK_PREDICTION}-query_exactmatch_asin" + lp_config = lp_config.task_config assert lp_config.task_weight == 1 assert len(lp_config.mask_fields) == 3 assert lp_config.mask_fields[0] == "lp_train_mask" assert lp_config.mask_fields[1] == "lp_eval_mask" assert lp_config.mask_fields[2] == "lp_test_mask" - lp_config = lp_config.task_config assert lp_config.train_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER assert lp_config.num_negative_edges == 4 assert lp_config.num_negative_edges_eval == 100 @@ -1816,17 +1823,18 @@ def test_multi_task_config(): lp_config = config.multi_tasks[5] assert lp_config.task_type == BUILTIN_TASK_LINK_PREDICTION + assert lp_config.task_id == f"{BUILTIN_TASK_LINK_PREDICTION}-ALL_ETYPE" + lp_config = lp_config.task_config assert lp_config.task_weight == 2 assert len(lp_config.mask_fields) == 3 assert lp_config.mask_fields[0] == "lp2_train_mask" assert lp_config.mask_fields[1] == "lp2_eval_mask" assert lp_config.mask_fields[2] == "lp2_test_mask" - lp_config = lp_config.task_config assert lp_config.train_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER assert lp_config.num_negative_edges == 16 assert lp_config.train_etype == None assert lp_config.eval_etype == None - check_failure(lp_config, "exclude_training_targets") + assert lp_config.exclude_training_targets == False assert len(lp_config.reverse_edge_types_map) == 0 assert lp_config.gamma == 12.0 assert lp_config.lp_loss_func == BUILTIN_LP_LOSS_CONTRASTIVELOSS From 75b54356aff259a0b823bd901a00509cbb00abbe Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Sun, 12 May 2024 23:16:49 -0700 Subject: [PATCH 15/79] Update --- tests/end2end-tests/data_process/test.sh | 2 +- tests/unit-tests/data_utils.py | 122 +++++++++++------ tests/unit-tests/test_dataloading.py | 161 +++++++++++++++++++++++ 3 files changed, 244 insertions(+), 41 deletions(-) diff --git a/tests/end2end-tests/data_process/test.sh b/tests/end2end-tests/data_process/test.sh index 1bc9caeb7d..df3c010c4e 100644 --- a/tests/end2end-tests/data_process/test.sh +++ b/tests/end2end-tests/data_process/test.sh @@ -121,7 +121,7 @@ python3 $GS_HOME/tests/end2end-tests/data_process/test_multitask_data.py --graph error_and_exit $? -echo "********* Test the DistDGL graph format with multi mask support from saved config ********" +echo "********* Test the DistDGL graph format with multi mask support from saved config g********" python3 -m graphstorm.gconstruct.construct_graph --conf-file /tmp/multitask_test_data/test_multitask_data_transform_new.conf --num-processes 2 --output-dir /tmp/test_partition2 --graph-name test --add-reverse-edges error_and_exit $? diff --git a/tests/unit-tests/data_utils.py b/tests/unit-tests/data_utils.py index c6d75927f4..6cfcf195da 100644 --- a/tests/unit-tests/data_utils.py +++ b/tests/unit-tests/data_utils.py @@ -32,6 +32,13 @@ from graphstorm.model.lm_model import TOKEN_IDX, ATT_MASK_IDX, VALID_LEN from util import create_tokens +SIZE_DICT = { + 'tiny': 1e+2, + 'small': 1e+4, + 'medium': 1e+6, + 'large': 1e+8, + 'largest': 1e+10 + } def generate_mask(idx, length): mask = np.zeros(length) @@ -49,14 +56,7 @@ def generate_dummy_hetero_graph(size='tiny', gen_mask=True, add_reverse=False): :return: hg: a heterogeneous graph. """ - size_dict = { - 'tiny': 1e+2, - 'small': 1e+4, - 'medium': 1e+6, - 'large': 1e+8, - 'largest': 1e+10 - } - + size_dict = SIZE_DICT data_size = int(size_dict[size]) num_nodes_dict = { @@ -136,14 +136,7 @@ def generate_dummy_hetero_graph_multi_target_ntypes(size='tiny', gen_mask=True): :return: hg: a heterogeneous graph. """ - size_dict = { - 'tiny': 1e+2, - 'small': 1e+4, - 'medium': 1e+6, - 'large': 1e+8, - 'largest': 1e+10 - } - + size_dict = SIZE_DICT data_size = int(size_dict[size]) num_nodes_dict = { @@ -215,6 +208,56 @@ def generate_dummy_hetero_graph_multi_target_ntypes(size='tiny', gen_mask=True): return hetero_graph +def generate_dummy_hetero_graph_multi_task(size='tiny'): + """ + generate a dummy heterogeneous graph. + Parameters + ---------- + size: the size of dummy graph data, could be one of tiny, small, medium, large, and largest + + :return: + hg: a heterogeneous graph. + """ + gen_mask=True + size_dict = SIZE_DICT + # based on the graph generated for multi_target_ntypes + # we add some more tasks. + hetero_graph = generate_dummy_hetero_graph_multi_target_ntypes(size=size, gen_mask=gen_mask) + + data_size = int(size_dict[size]) + + # add extra mask for n0 + node_train_mask = generate_mask([0,1,2,3,4], data_size) + node_val_mask = generate_mask([5,6,7], data_size) + node_test_mask = generate_mask([8,9,10,11,12,13,14], data_size) + hetero_graph.nodes["n0"].data['train_mask1'] = node_train_mask + hetero_graph.nodes["n0"].data['val_mask1'] = node_val_mask + hetero_graph.nodes["n0"].data['test_mask1'] = node_test_mask + + node_train_mask = generate_mask([i for i in range(data_size//2, data_size)], data_size) + node_val_mask = generate_mask([i for i in range(data_size//4, data_size//2)], data_size) + node_test_mask = generate_mask([i for i in range(data_size//4)], data_size) + hetero_graph.nodes["n0"].data['train_mask2'] = node_train_mask + hetero_graph.nodes["n0"].data['val_mask2'] = node_val_mask + hetero_graph.nodes["n0"].data['test_mask2'] = node_test_mask + + edge_train_mask = generate_mask([0,1,2,3,4], 2 * data_size) + edge_val_mask = generate_mask([5,6,7], 2 * data_size) + edge_test_mask = generate_mask([8,9,10,11,12,13,14], 2 * data_size) + hetero_graph.edges[("n0", "r1", "n1")].data['train_mask1'] = edge_train_mask + hetero_graph.edges[("n0", "r1", "n1")].data['val_mask1'] = edge_val_mask + hetero_graph.edges[("n0", "r1", "n1")].data['test_mask1'] = edge_test_mask + + edge_train_mask = generate_mask([i for i in range(data_size, data_size * 2)], 2 * data_size) + edge_val_mask = generate_mask([i for i in range(data_size//2, data_size)], 2 * data_size) + edge_test_mask = generate_mask([i for i in range(data_size//2)], 2 * data_size) + hetero_graph.edges[("n0", "r1", "n1")].data['train_mask2'] = edge_train_mask + hetero_graph.edges[("n0", "r1", "n1")].data['val_mask2'] = edge_val_mask + hetero_graph.edges[("n0", "r1", "n1")].data['test_mask2'] = edge_test_mask + + return hetero_graph + + def generate_dummy_hetero_graph_reconstruct(size='tiny', gen_mask=True): """ generate a dummy heterogeneous graph for testing the construction of node features.. @@ -225,14 +268,7 @@ def generate_dummy_hetero_graph_reconstruct(size='tiny', gen_mask=True): :return: hg: a heterogeneous graph. """ - size_dict = { - 'tiny': 1e+2, - 'small': 1e+4, - 'medium': 1e+6, - 'large': 1e+8, - 'largest': 1e+10 - } - + size_dict = SIZE_DICT data_size = int(size_dict[size]) num_nodes_dict = { @@ -293,14 +329,7 @@ def generate_dummy_homo_graph(size='tiny', gen_mask=True): :return: hg: a homogeneous graph in one node type and one edge type. """ - size_dict = { - 'tiny': 1e+2, - 'small': 1e+4, - 'medium': 1e+6, - 'large': 1e+8, - 'largest': 1e+10 - } - + size_dict = SIZE_DICT data_size = int(size_dict[size]) num_nodes_dict = { @@ -368,14 +397,7 @@ def generate_dummy_homogeneous_failure_graph(size='tiny', gen_mask=True, type='n :return: hg: a homogeneous graph in one node type and one edge type. """ - size_dict = { - 'tiny': 1e+2, - 'small': 1e+4, - 'medium': 1e+6, - 'large': 1e+8, - 'largest': 1e+10 - } - + size_dict = SIZE_DICT data_size = int(size_dict[size]) if type == 'node': @@ -521,6 +543,26 @@ def generate_dummy_dist_graph_multi_target_ntypes(dirname, size='tiny', graph_na return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname, graph_name=graph_name) +def generate_dummy_dist_graph_multi_task(dirname, size='tiny', graph_name='dummy', gen_mask=True): + """ + Generate a dummy DGL distributed graph for multi-task testing + with the given size + + Parameters + ---------- + dirname : the directory where the graph will be partitioned and stored. + size: the size of dummy graph data, could be one of tiny, small, medium, large, and largest + graph_name: string as a name + + Returns + ------- + dist_graph: a DGL distributed graph + part_config : the path of the partition configuration file. + """ + hetero_graph = generate_dummy_hetero_graph_multi_task(size=size, gen_mask=gen_mask) + return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname, + graph_name=graph_name) + def generate_dummy_dist_graph_homogeneous_failure_graph(dirname, size='tiny', graph_name='dummy', gen_mask=True, type='node'): diff --git a/tests/unit-tests/test_dataloading.py b/tests/unit-tests/test_dataloading.py index 10bd705847..4e8124c6a6 100644 --- a/tests/unit-tests/test_dataloading.py +++ b/tests/unit-tests/test_dataloading.py @@ -30,10 +30,16 @@ generate_dummy_dist_graph, generate_dummy_dist_graph_reconstruct, generate_dummy_dist_graph_homogeneous_failure_graph, + generate_dummy_dist_graph_multi_task, create_distill_data, ) import graphstorm as gs +from graphstorm.config import (TaskInfo, + BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_LINK_PREDICTION) from graphstorm.utils import setup_device, get_device from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import GSgnnAllEtypeLinkPredictionDataLoader @@ -55,6 +61,7 @@ GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import DistillDataloaderGenerator, DistillDataManager from graphstorm.dataloading import DistributedFileSampler +from graphstorm.dataloading import GSgnnMultiTaskDataLoader from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER, BUILTIN_LP_FIXED_NEG_SAMPLER) @@ -2181,7 +2188,161 @@ def test_GSgnnTrainData_homogeneous(): # after test pass, destroy all process group th.distributed.destroy_process_group() +def test_GSgnnMultiTaskDataLoader(): + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + dist_graph, part_config = generate_dummy_dist_graph_multi_task(graph_name='dummy', + dirname=tmpdirname, add_reverse=False) + gdata = GSgnnData(part_config=part_config) + + # n0: train_mask: 2, val_mask: 2, test_maks: 4 + # train_mask1: 5, val_mask1: 3, test_mask1: 7 + # train_mask2: 50, val_mask2: 25, test_mask2: 25 + # n1: train_mask: 50, val_mask: 2, tset_mask:2 + # + # ("n0", "r1", "n1"): train_mask: 2, val_mask: 2, test_maks: 4 + # ("n0", "r1", "n1"): train_mask1: 5, val_mask1: 3, test_maks2: 7 + # ("n0", "r1", "n1"): train_mask2: 2, val_mask2: 2, test_mask2: 4 + # ("n0", "r0", "n1"): train_mask: 50, val_mask: 2, test_maks: 4 + tast_info_n0_0 = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_id='tast_info_n0_0', + task_config=None) + tast_info_n0_1 = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION, + task_id='tast_info_n0_1', + task_config=None) + tast_info_n0_2 = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_id='tast_info_n0_2', + task_config=None) + task_info_n1_0 = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_id='task_info_n1_0', + task_config=None) + task_info_edge_0 = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, + task_id='task_info_edge_0', + task_config=None) + task_info_edge_1 = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION, + task_id='task_info_edge_1', + task_config=None) + task_info_edge_2 = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION, + task_id='task_info_edge_2', + task_config=None) + task_infos = [tast_info_n0_0, tast_info_n0_1, + tast_info_n0_2, task_info_n1_0, + task_info_edge_0, task_info_edge_1, + task_info_edge_2] + + task_n0_0_dataloader = GSgnnNodeDataLoader(gdata, + gdata.get_node_train_set("n0", "train_mask"), + fanout=[10], + batch_size=2, + label_field="label", + train_task=True) + task_n0_1_dataloader = GSgnnNodeDataLoader(gdata, + gdata.get_node_train_set("n0", "train_mask1"), + fanout=[10], + batch_size=2, + label_field="label", + train_task=True) + task_n0_2_dataloader = GSgnnNodeDataLoader(gdata, + gdata.get_node_train_set("n0", "train_mask2"), + fanout=[10], + batch_size=5, + label_field="label", + train_task=True) + task_n1_0_dataloader = GSgnnNodeDataLoader(gdata, + gdata.get_node_train_set("n1", "train_mask"), + fanout=[10], + batch_size=10, + label_field="label", + train_task=True) + + task_edage_0_dataloader = GSgnnLinkPredictionDataLoader( + gdata, + gdata.get_edge_train_set(etypes=[("n0", "r1", "n1"), ("n0", "r0", "n1")]), + fanout=[10], + batch_size=10, + num_negative_edges=2, + train_task=True + ) + + task_edage_1_dataloader = GSgnnEdgeDataLoader( + gdata, + gdata.get_edge_train_set(etypes=[("n0", "r1", "n1")], + mask="train_mask1"), + fanout=[10], + batch_size=2, + num_negative_edges=2, + train_task=True + ) + + task_edage_2_dataloader = GSgnnEdgeDataLoader( + gdata, + gdata.get_edge_train_set(etypes=[("n0", "r1", "n1")], + mask="train_mask2"), + fanout=[10], + batch_size=2, + num_negative_edges=2, + train_task=True + ) + + dataloaders = [task_n0_0_dataloader, task_n0_1_dataloader, + task_n0_2_dataloader, task_n1_0_dataloader, + task_edage_0_dataloader, task_edage_1_dataloader, + task_edage_2_dataloader] + multi_dataloader = GSgnnMultiTaskDataLoader(gdata, task_infos, dataloaders) + len_n0_0 = len(task_n0_0_dataloader) + assert len_n0_0 == 1 + len_n0_1 = len(task_n0_1_dataloader) + assert len_n0_1 == 2 + len_n0_2 = len(task_n0_2_dataloader) + assert len_n0_2 == 10 + len_n1_0 = len(task_n1_0_dataloader) + assert len_n1_0 == 5 + len_edge0_0 = len(task_edage_0_dataloader) + assert len_edge0_0 == 5 + len_edge1_0 = len(task_edage_1_dataloader) + assert len_edge1_0 == 2 + len_edge2_0 = len(task_edage_2_dataloader) + assert len_edge2_0 == 1 + + assert len(multi_dataloader) == 10 + + len(multi_dataloader.dataloaders) == 7 + dataloaders = multi_dataloader.dataloaders + len(multi_dataloader.task_infos) == 7 + + dataloaders = multi_dataloader.dataloaders + task_infos = multi_dataloader.task_infos + for dataloader, task_info in zip(dataloaders, task_infos): + if task_info.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + assert isinstance(dataloader, GSgnnNodeDataLoader) + + if task_info.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION]: + assert isinstance(dataloader, GSgnnEdgeDataLoader) + if task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + assert isinstance(dataloader, GSgnnLinkPredictionDataLoader) + + for mini_batches in multi_dataloader: + assert len(mini_batches) == 7 + for task_info, mini_batch in mini_batches: + if task_info.task_id == "tast_info_n0_0": + assert task_info.dataloader == task_n0_0_dataloader + if task_info.task_id == "tast_info_n0_1": + assert task_info.dataloader == task_n0_1_dataloader + if task_info.task_id == "tast_info_n0_2": + assert task_info.dataloader == task_n0_2_dataloader + if task_info.task_id == "task_info_n1_0": + assert task_info.dataloader == task_n1_0_dataloader + if task_info.task_id == "task_info_edge_0": + assert task_info.dataloader == task_edage_0_dataloader + if task_info.task_id == "task_info_edge_1": + assert task_info.dataloader == task_edage_1_dataloader + if task_info.task_id == "task_info_edge_2": + assert task_info.dataloader == task_edage_2_dataloader + + if __name__ == '__main__': + + test_GSgnnMultiTaskDataLoader() test_GSgnnLinkPredictionPredefinedTestDataLoader(1) test_GSgnnLinkPredictionPredefinedTestDataLoader(10) test_edge_fixed_dst_negative_sample_gen_neg_pairs() From 729a5e17e05d125e83bb3fcff43faa4e6a1dcbe4 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 13 May 2024 11:05:33 -0700 Subject: [PATCH 16/79] Update --- tests/unit-tests/test_dataloading.py | 70 ++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/tests/unit-tests/test_dataloading.py b/tests/unit-tests/test_dataloading.py index 4e8124c6a6..8594fec78b 100644 --- a/tests/unit-tests/test_dataloading.py +++ b/tests/unit-tests/test_dataloading.py @@ -32,6 +32,7 @@ generate_dummy_dist_graph_homogeneous_failure_graph, generate_dummy_dist_graph_multi_task, create_distill_data, + SIZE_DICT ) import graphstorm as gs @@ -2255,7 +2256,7 @@ def test_GSgnnMultiTaskDataLoader(): label_field="label", train_task=True) - task_edage_0_dataloader = GSgnnLinkPredictionDataLoader( + task_edge_0_dataloader = GSgnnLinkPredictionDataLoader( gdata, gdata.get_edge_train_set(etypes=[("n0", "r1", "n1"), ("n0", "r0", "n1")]), fanout=[10], @@ -2264,30 +2265,28 @@ def test_GSgnnMultiTaskDataLoader(): train_task=True ) - task_edage_1_dataloader = GSgnnEdgeDataLoader( + task_edge_1_dataloader = GSgnnEdgeDataLoader( gdata, gdata.get_edge_train_set(etypes=[("n0", "r1", "n1")], mask="train_mask1"), fanout=[10], batch_size=2, - num_negative_edges=2, train_task=True ) - task_edage_2_dataloader = GSgnnEdgeDataLoader( + task_edge_2_dataloader = GSgnnEdgeDataLoader( gdata, gdata.get_edge_train_set(etypes=[("n0", "r1", "n1")], mask="train_mask2"), fanout=[10], batch_size=2, - num_negative_edges=2, train_task=True ) dataloaders = [task_n0_0_dataloader, task_n0_1_dataloader, task_n0_2_dataloader, task_n1_0_dataloader, - task_edage_0_dataloader, task_edage_1_dataloader, - task_edage_2_dataloader] + task_edge_0_dataloader, task_edge_1_dataloader, + task_edge_2_dataloader] multi_dataloader = GSgnnMultiTaskDataLoader(gdata, task_infos, dataloaders) len_n0_0 = len(task_n0_0_dataloader) assert len_n0_0 == 1 @@ -2297,11 +2296,11 @@ def test_GSgnnMultiTaskDataLoader(): assert len_n0_2 == 10 len_n1_0 = len(task_n1_0_dataloader) assert len_n1_0 == 5 - len_edge0_0 = len(task_edage_0_dataloader) + len_edge0_0 = len(task_edge_0_dataloader) assert len_edge0_0 == 5 - len_edge1_0 = len(task_edage_1_dataloader) + len_edge1_0 = len(task_edge_1_dataloader) assert len_edge1_0 == 2 - len_edge2_0 = len(task_edage_2_dataloader) + len_edge2_0 = len(task_edge_2_dataloader) assert len_edge2_0 == 1 assert len(multi_dataloader) == 10 @@ -2321,24 +2320,67 @@ def test_GSgnnMultiTaskDataLoader(): if task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: assert isinstance(dataloader, GSgnnLinkPredictionDataLoader) + num_iters = 0 for mini_batches in multi_dataloader: assert len(mini_batches) == 7 for task_info, mini_batch in mini_batches: if task_info.task_id == "tast_info_n0_0": assert task_info.dataloader == task_n0_0_dataloader + _, seeds, blocks = mini_batch + assert len(seeds) == 2 + assert len(blocks) == 1 + assert set(seeds.tolist()).issubset({0,1}) if task_info.task_id == "tast_info_n0_1": assert task_info.dataloader == task_n0_1_dataloader + _, seeds, blocks = mini_batch + assert len(seeds) == 2 + assert len(blocks) == 1 + assert set(seeds.tolist()).issubset({0,1,2,3,4}) if task_info.task_id == "tast_info_n0_2": assert task_info.dataloader == task_n0_2_dataloader + _, seeds, blocks = mini_batch + assert len(seeds) == 5 + assert len(blocks) == 1 + assert np.all(seeds.numpy() >= SIZE_DICT['tiny'] // 2) if task_info.task_id == "task_info_n1_0": assert task_info.dataloader == task_n1_0_dataloader + _, seeds, blocks = mini_batch + assert len(seeds) == 10 + assert len(blocks) == 1 + assert np.all(seeds.numpy() < SIZE_DICT['tiny'] // 2) if task_info.task_id == "task_info_edge_0": - assert task_info.dataloader == task_edage_0_dataloader + assert task_info.dataloader == task_edge_0_dataloader + _, pos_graph, _, blocks = mini_batch + assert len(batch_graph) == 1 + assert ("n0", "r1", "n1") in batch_graph + assert ("n0", "r0", "n1") in batch_graph + assert len(blocks) == 1 + eids = pos_graph.edges[("n0", "r1", "n1")].data[dgl.EID] + assert len(eids) == 2 + assert set(eids.tolist()).issubset({0,1}) + eids = pos_graph.edges[("n0", "r0", "n1")].data[dgl.EID] + assert len(eids) == 10 + assert np.any(eids.numpy() < SIZE_DICT['tiny'] // 2) if task_info.task_id == "task_info_edge_1": - assert task_info.dataloader == task_edage_1_dataloader + assert task_info.dataloader == task_edge_1_dataloader + _, batch_graph, blocks = mini_batch + assert len(batch_graph) == 1 + assert ("n0", "r1", "n1") in batch_graph + assert batch_graph.num_edges(("n0", "r1", "n1")) == 2 + assert len(blocks) == 1 + eids = batch_graph.edges[("n0", "r1", "n1")].data[dgl.EID] + assert set(eids.tolist()).issubset({0,1,2,3,4}) if task_info.task_id == "task_info_edge_2": - assert task_info.dataloader == task_edage_2_dataloader - + assert task_info.dataloader == task_edge_2_dataloader + _, batch_graph, blocks = mini_batch + assert len(batch_graph) == 1 + assert ("n0", "r1", "n1") in batch_graph + assert batch_graph.num_edges(("n0", "r1", "n1")) == 2 + assert len(blocks) == 1 + eids = batch_graph.edges[("n0", "r1", "n1")].data[dgl.EID] + assert np.any(eids.numpy() >= SIZE_DICT['tiny']) + num_iters += 1 + assert num_iters == 10 if __name__ == '__main__': From 515a64202db92472cf9dd937e84fc04cbbdc310b Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 13 May 2024 14:50:53 -0700 Subject: [PATCH 17/79] update --- python/graphstorm/dataloading/dataloading.py | 8 +- tests/unit-tests/data_utils.py | 4 +- tests/unit-tests/test_dataloading.py | 83 +++++++++++--------- 3 files changed, 56 insertions(+), 39 deletions(-) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index 28a36a75c7..9c98f72e58 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -1751,8 +1751,13 @@ def __next__(self): try: mini_batch = next(dataloader) except StopIteration: - dataloader.__iter__() + load = dataloader.__iter__() + # we assume dataloader __iter__ will return itself. + assert load is dataloader, \ + "We assume the return value of __iter__() function " \ + "of each task dataloader is itself." mini_batch = next(dataloader) + if task_info.dataloader is None: task_info.dataloader = dataloader else: @@ -1786,7 +1791,6 @@ def task_infos(self): return self._task_infos - ####################### Distillation ############################# class DistillDataManager: diff --git a/tests/unit-tests/data_utils.py b/tests/unit-tests/data_utils.py index 6cfcf195da..362eba0a3e 100644 --- a/tests/unit-tests/data_utils.py +++ b/tests/unit-tests/data_utils.py @@ -543,7 +543,7 @@ def generate_dummy_dist_graph_multi_target_ntypes(dirname, size='tiny', graph_na return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname, graph_name=graph_name) -def generate_dummy_dist_graph_multi_task(dirname, size='tiny', graph_name='dummy', gen_mask=True): +def generate_dummy_dist_graph_multi_task(dirname, size='tiny', graph_name='dummy'): """ Generate a dummy DGL distributed graph for multi-task testing with the given size @@ -559,7 +559,7 @@ def generate_dummy_dist_graph_multi_task(dirname, size='tiny', graph_name='dummy dist_graph: a DGL distributed graph part_config : the path of the partition configuration file. """ - hetero_graph = generate_dummy_hetero_graph_multi_task(size=size, gen_mask=gen_mask) + hetero_graph = generate_dummy_hetero_graph_multi_task(size=size) return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname, graph_name=graph_name) diff --git a/tests/unit-tests/test_dataloading.py b/tests/unit-tests/test_dataloading.py index 8594fec78b..8100604ddb 100644 --- a/tests/unit-tests/test_dataloading.py +++ b/tests/unit-tests/test_dataloading.py @@ -1102,7 +1102,7 @@ def check_dataloader_trim(mock_trim_data): @pytest.mark.parametrize("dataloader", [GSgnnNodeDataLoader]) @pytest.mark.parametrize("backend", ['gloo', 'nccl']) def test_np_dataloader_trim_data_device(dataloader, backend): - # initialize the torch distributed environment + # initialize the torch distributed environment th.distributed.init_process_group(backend=backend, init_method='tcp://127.0.0.1:23456', rank=0, @@ -2192,8 +2192,8 @@ def test_GSgnnTrainData_homogeneous(): def test_GSgnnMultiTaskDataLoader(): with tempfile.TemporaryDirectory() as tmpdirname: # get the test dummy distributed graph - dist_graph, part_config = generate_dummy_dist_graph_multi_task(graph_name='dummy', - dirname=tmpdirname, add_reverse=False) + _, part_config = generate_dummy_dist_graph_multi_task(graph_name='dummy', + dirname=tmpdirname) gdata = GSgnnData(part_config=part_config) # n0: train_mask: 2, val_mask: 2, test_maks: 4 @@ -2271,7 +2271,9 @@ def test_GSgnnMultiTaskDataLoader(): mask="train_mask1"), fanout=[10], batch_size=2, - train_task=True + label_field="label", + train_task=True, + remove_target_edge_type=False ) task_edge_2_dataloader = GSgnnEdgeDataLoader( @@ -2279,8 +2281,10 @@ def test_GSgnnMultiTaskDataLoader(): gdata.get_edge_train_set(etypes=[("n0", "r1", "n1")], mask="train_mask2"), fanout=[10], - batch_size=2, - train_task=True + batch_size=20, + label_field="label", + train_task=True, + remove_target_edge_type=False ) dataloaders = [task_n0_0_dataloader, task_n0_1_dataloader, @@ -2291,17 +2295,17 @@ def test_GSgnnMultiTaskDataLoader(): len_n0_0 = len(task_n0_0_dataloader) assert len_n0_0 == 1 len_n0_1 = len(task_n0_1_dataloader) - assert len_n0_1 == 2 + assert len_n0_1 == 3 len_n0_2 = len(task_n0_2_dataloader) assert len_n0_2 == 10 len_n1_0 = len(task_n1_0_dataloader) assert len_n1_0 == 5 len_edge0_0 = len(task_edge_0_dataloader) - assert len_edge0_0 == 5 + assert len_edge0_0 == 6 len_edge1_0 = len(task_edge_1_dataloader) - assert len_edge1_0 == 2 + assert len_edge1_0 == 3 len_edge2_0 = len(task_edge_2_dataloader) - assert len_edge2_0 == 1 + assert len_edge2_0 == 5 assert len(multi_dataloader) == 10 @@ -2320,70 +2324,79 @@ def test_GSgnnMultiTaskDataLoader(): if task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: assert isinstance(dataloader, GSgnnLinkPredictionDataLoader) - num_iters = 0 + iter_num = 0 + n0_1_seeds_cnt = th.tensor([4] * 5) # check whether n0_1 datalaoder iterates the whole datasets 3 times. + edge0_seeds_cnt = th.tensor([2] * 50) # check whether edge0 datalaoder iterates the whole datasets 1 and half times. for mini_batches in multi_dataloader: assert len(mini_batches) == 7 for task_info, mini_batch in mini_batches: if task_info.task_id == "tast_info_n0_0": assert task_info.dataloader == task_n0_0_dataloader _, seeds, blocks = mini_batch - assert len(seeds) == 2 + assert len(seeds["n0"]) == 2 assert len(blocks) == 1 - assert set(seeds.tolist()).issubset({0,1}) + assert set(seeds["n0"].tolist()).issubset({0,1}) if task_info.task_id == "tast_info_n0_1": assert task_info.dataloader == task_n0_1_dataloader _, seeds, blocks = mini_batch - assert len(seeds) == 2 + assert len(seeds["n0"]) == 1 if (iter_num % 3 == 2) else 2 assert len(blocks) == 1 - assert set(seeds.tolist()).issubset({0,1,2,3,4}) + assert set(seeds["n0"].tolist()).issubset({0,1,2,3,4}) + n0_1_seeds_cnt[seeds["n0"]] = n0_1_seeds_cnt[seeds["n0"]] - 1 if task_info.task_id == "tast_info_n0_2": assert task_info.dataloader == task_n0_2_dataloader _, seeds, blocks = mini_batch - assert len(seeds) == 5 + assert len(seeds["n0"]) == 5 assert len(blocks) == 1 - assert np.all(seeds.numpy() >= SIZE_DICT['tiny'] // 2) + assert np.all(seeds["n0"].numpy() >= SIZE_DICT['tiny'] // 2) if task_info.task_id == "task_info_n1_0": assert task_info.dataloader == task_n1_0_dataloader _, seeds, blocks = mini_batch - assert len(seeds) == 10 + assert len(seeds["n1"]) == 10 assert len(blocks) == 1 - assert np.all(seeds.numpy() < SIZE_DICT['tiny'] // 2) + assert np.all(seeds["n1"].numpy() < SIZE_DICT['tiny'] // 2) if task_info.task_id == "task_info_edge_0": assert task_info.dataloader == task_edge_0_dataloader _, pos_graph, _, blocks = mini_batch - assert len(batch_graph) == 1 - assert ("n0", "r1", "n1") in batch_graph - assert ("n0", "r0", "n1") in batch_graph + assert ("n0", "r1", "n1") in pos_graph.canonical_etypes or \ + ("n0", "r0", "n1") in pos_graph.canonical_etypes assert len(blocks) == 1 - eids = pos_graph.edges[("n0", "r1", "n1")].data[dgl.EID] - assert len(eids) == 2 - assert set(eids.tolist()).issubset({0,1}) + num_edges = 0 + if ("n0", "r1", "n1") in pos_graph.canonical_etypes: + eids = pos_graph.edges[("n0", "r1", "n1")].data[dgl.EID] + assert set(eids.tolist()).issubset({0,1}) + num_edges += len(eids) eids = pos_graph.edges[("n0", "r0", "n1")].data[dgl.EID] - assert len(eids) == 10 + num_edges += len(eids) + assert num_edges == 10 if iter_num != 5 else 2 assert np.any(eids.numpy() < SIZE_DICT['tiny'] // 2) + edge0_seeds_cnt[eids] = edge0_seeds_cnt[eids] - 1 if task_info.task_id == "task_info_edge_1": assert task_info.dataloader == task_edge_1_dataloader _, batch_graph, blocks = mini_batch - assert len(batch_graph) == 1 - assert ("n0", "r1", "n1") in batch_graph - assert batch_graph.num_edges(("n0", "r1", "n1")) == 2 + assert len(batch_graph.canonical_etypes) == 1 + assert ("n0", "r1", "n1") in batch_graph.canonical_etypes + assert batch_graph.num_edges(("n0", "r1", "n1")) == 1 if (iter_num % 3 == 2) else 2 assert len(blocks) == 1 eids = batch_graph.edges[("n0", "r1", "n1")].data[dgl.EID] assert set(eids.tolist()).issubset({0,1,2,3,4}) if task_info.task_id == "task_info_edge_2": assert task_info.dataloader == task_edge_2_dataloader _, batch_graph, blocks = mini_batch - assert len(batch_graph) == 1 - assert ("n0", "r1", "n1") in batch_graph - assert batch_graph.num_edges(("n0", "r1", "n1")) == 2 + assert len(batch_graph.canonical_etypes) == 1 + assert ("n0", "r1", "n1") in batch_graph.canonical_etypes + assert batch_graph.num_edges(("n0", "r1", "n1")) == 20 assert len(blocks) == 1 eids = batch_graph.edges[("n0", "r1", "n1")].data[dgl.EID] assert np.any(eids.numpy() >= SIZE_DICT['tiny']) - num_iters += 1 - assert num_iters == 10 + iter_num += 1 + assert iter_num == 9 + assert np.any(n0_1_seeds_cnt.numpy() <= 1) + assert np.any(n0_1_seeds_cnt.numpy() >= 0) + assert np.any(edge0_seeds_cnt.numpy() <= 1) + assert np.any(edge0_seeds_cnt.numpy() >= 0) if __name__ == '__main__': - test_GSgnnMultiTaskDataLoader() test_GSgnnLinkPredictionPredefinedTestDataLoader(1) test_GSgnnLinkPredictionPredefinedTestDataLoader(10) From 23c24caf5199691bfd2896dd02f481f4c428b185 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 13 May 2024 16:26:01 -0700 Subject: [PATCH 18/79] Add test --- tests/unit-tests/test_evaluator.py | 59 +++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/tests/unit-tests/test_evaluator.py b/tests/unit-tests/test_evaluator.py index 50dae032e9..4fb1060b53 100644 --- a/tests/unit-tests/test_evaluator.py +++ b/tests/unit-tests/test_evaluator.py @@ -21,9 +21,11 @@ from numpy.testing import assert_equal, assert_almost_equal import dgl -from graphstorm.eval import GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator -from graphstorm.eval import GSgnnClassificationEvaluator -from graphstorm.eval import GSgnnRegressionEvaluator +from graphstorm.eval import (GSgnnMrrLPEvaluator, + GSgnnPerEtypeMrrLPEvaluator, + GSgnnClassificationEvaluator, + GSgnnRegressionEvaluator, + GSgnnMultiTaskEvaluator) from graphstorm.eval.evaluator import early_stop_avg_increase_judge from graphstorm.eval.evaluator import early_stop_cons_increase_judge from graphstorm.config.config import EARLY_STOP_AVERAGE_INCREASE_STRATEGY @@ -185,7 +187,7 @@ def test_mrr_lp_evaluator(): test_pos_scores, test_neg_scores = test_scores lp = GSgnnMrrLPEvaluator(config.eval_frequency, use_early_stop=config.use_early_stop) - + # checke default metric list assert lp.metric_list == ['mrr'] @@ -774,7 +776,7 @@ def test_get_val_score_rank(): config.eval_metric, config.multilabel, config.use_early_stop) - + # For accuracy, the bigger the better. val_score = {"accuracy": 0.47} assert evaluator.get_val_score_rank(val_score) == 1 @@ -849,7 +851,54 @@ def test_get_val_score_rank(): assert evaluator.get_val_score_rank(val_score) == 3 +def test_multi_task_evaluator_early_stop(): + # common Dummy objects + config = Dummy({ + "multilabel": False, + "eval_frequency": 100, + "eval_metric": ["accuracy"], + "use_early_stop": False, + }) + + task_evaluators = [] + task_evaluators = xxx + + try: + GSgnnMultiTaskEvaluator(config.eval_frequency, + task_evaluators, + use_early_stop=True) + except: + + +def test_multi_task_evaluator(): + # common Dummy objects + config = Dummy({ + "multilabel": False, + "eval_frequency": 100, + "eval_metric": ["accuracy"], + "use_early_stop": False, + }) + + task_evaluators = [] + + failed = False + try: + # there is no evaluators, fail + GSgnnMultiTaskEvaluator(config.eval_frequency, + task_evaluators, + use_early_stop=False) + except: + failed = True + assert failed + + task_evaluators = xxx + mt_evaluator = GSgnnMultiTaskEvaluator(config.eval_frequency, + task_evaluators, + use_early_stop=False) + if __name__ == '__main__': + test_multi_task_evaluator_early_stop() + test_multi_task_evaluator() # test evaluators test_mrr_per_etype_lp_evaluation() test_mrr_lp_evaluator() From 3c86ab393a7ed3a49782e584b6cd9e12139966a8 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 13 May 2024 16:27:54 -0700 Subject: [PATCH 19/79] Fix --- python/graphstorm/dataloading/dataloading.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index 9c98f72e58..018c04d7d4 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -1706,9 +1706,9 @@ def __init__(self, dataset, task_infos, task_dataloaders): # check dataloaders lens = [] for task_info, dataloader in zip(task_infos, task_dataloaders): - assert isinstance(dataloader, GSgnnEdgeDataLoaderBase) or \ - isinstance(dataloader, GSgnnLinkPredictionDataLoaderBase) or \ - isinstance(dataloader, GSgnnNodeDataLoaderBase), \ + assert isinstance(dataloader, GSgnnEdgeDataLoaderBase, + GSgnnLinkPredictionDataLoaderBase, + GSgnnNodeDataLoaderBase), \ "The task data loader should be a GSgnnEdgeDataLoaderBase " \ " or a GSgnnLinkPredictionDataLoaderBase or a GSgnnNodeDataLoaderBase" num_iters = len(dataloader) @@ -1728,10 +1728,9 @@ def _reset_loader(self): """ reset the dataloaders """ for dataloader in self._dataloaders: - dataloader.__iter__() + iter(dataloader) self._num_iters = 0 - def __iter__(self): self._reset_loader() return self From e541837e6390803e46fd8e8bbf824438d1813457 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 13 May 2024 22:33:44 -0700 Subject: [PATCH 20/79] update --- python/graphstorm/dataloading/dataloading.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index 018c04d7d4..f917300bdd 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -18,7 +18,6 @@ import math import inspect import logging - import dgl import torch as th from torch.utils.data import DataLoader @@ -1687,6 +1686,7 @@ def __len__(self): return min(self.dataloader.expected_idxs, self.unlabeled_dataloader.expected_idxs) + ####################### Multi-task Dataloader #################### class GSgnnMultiTaskDataLoader: r""" DataLoader designed for multi-task learning @@ -1706,9 +1706,9 @@ def __init__(self, dataset, task_infos, task_dataloaders): # check dataloaders lens = [] for task_info, dataloader in zip(task_infos, task_dataloaders): - assert isinstance(dataloader, GSgnnEdgeDataLoaderBase, - GSgnnLinkPredictionDataLoaderBase, - GSgnnNodeDataLoaderBase), \ + assert isinstance(dataloader, (GSgnnEdgeDataLoaderBase, + GSgnnLinkPredictionDataLoaderBase, + GSgnnNodeDataLoaderBase)), \ "The task data loader should be a GSgnnEdgeDataLoaderBase " \ " or a GSgnnLinkPredictionDataLoaderBase or a GSgnnNodeDataLoaderBase" num_iters = len(dataloader) @@ -1731,6 +1731,7 @@ def _reset_loader(self): iter(dataloader) self._num_iters = 0 + def __iter__(self): self._reset_loader() return self @@ -1750,7 +1751,7 @@ def __next__(self): try: mini_batch = next(dataloader) except StopIteration: - load = dataloader.__iter__() + load = iter(dataloader) # we assume dataloader __iter__ will return itself. assert load is dataloader, \ "We assume the return value of __iter__() function " \ From 1dcc4aa20e18fd28486cc7899f5c8d859caa1b5d Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 13 May 2024 23:18:30 -0700 Subject: [PATCH 21/79] add test for evaluator --- tests/unit-tests/test_evaluator.py | 126 +++++++++++++++++++++++++---- 1 file changed, 111 insertions(+), 15 deletions(-) diff --git a/tests/unit-tests/test_evaluator.py b/tests/unit-tests/test_evaluator.py index 4fb1060b53..1667eb99ac 100644 --- a/tests/unit-tests/test_evaluator.py +++ b/tests/unit-tests/test_evaluator.py @@ -856,45 +856,141 @@ def test_multi_task_evaluator_early_stop(): config = Dummy({ "multilabel": False, "eval_frequency": 100, - "eval_metric": ["accuracy"], - "use_early_stop": False, }) + lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, + use_early_stop=False) + c_eval = GSgnnClassificationEvaluator(config.eval_frequency, + ["accuracy"], + use_early_stop=False) - task_evaluators = [] - task_evaluators = xxx - + task_evaluators = [("lp", lp), ("c_eval", c_eval)] try: GSgnnMultiTaskEvaluator(config.eval_frequency, task_evaluators, use_early_stop=True) + assert False except: + pass def test_multi_task_evaluator(): # common Dummy objects config = Dummy({ - "multilabel": False, "eval_frequency": 100, - "eval_metric": ["accuracy"], - "use_early_stop": False, }) - task_evaluators = [] - failed = False try: # there is no evaluators, fail GSgnnMultiTaskEvaluator(config.eval_frequency, - task_evaluators, + [], use_early_stop=False) except: failed = True assert failed - task_evaluators = xxx - mt_evaluator = GSgnnMultiTaskEvaluator(config.eval_frequency, - task_evaluators, - use_early_stop=False) + # Test evaluate without test set + @patch.object(GSgnnMrrLPEvaluator, 'compute_score') + @patch.object(GSgnnClassificationEvaluator, 'compute_score') + @patch.object(GSgnnRegressionEvaluator, 'compute_score') + def check_multi_task_eval(mock_lp_comput_score, mock_class_compute_score, mock_reg_compute_score): + mock_lp_comput_score.side_effect = [ + {"mrr": 0.6}, + {"mrr": 0.7}, + {"mrr": 0.65}, + {"mrr": 0.8}, + {"mrr": 0.8}, + {"mrr": 0.7} + ] + + mock_class_compute_score.side_effect = [ + {"accuracy": 0.7}, + {"accuracy": 0.65}, + {"accuracy": 0.8}, + {"accuracy": 0.7}, + {"accuracy": 0.76}, + {"accuracy": 0.8}, + ] + + mock_reg_compute_score.side_effect = [ + {"rmse": 0.7}, + {"rmse": 0.8}, + {"rmse": 0.2}, + {"rmse": 0.23}, + {"rmse": 0.3}, + {"rmse": 0.31}, + ] + + lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, + use_early_stop=False) + c_eval = GSgnnClassificationEvaluator(config.eval_frequency, + ["accuracy"], + use_early_stop=False) + r_eval = GSgnnRegressionEvaluator(config.eval_frequency, + use_early_stop=False) + + task_evaluators = [("lp", lp), ("c_eval", c_eval), + ("r_eval", r_eval)] + mt_evaluator = GSgnnMultiTaskEvaluator(config.eval_frequency, + task_evaluators, + use_early_stop=False) + assert len(mt_evaluator.task_evaluators) == 3 + + val_results = { + "lp": th.rand(10,), + "c_eval": th.rand(10,), + "r_eval": th.rand(10,), + } + test_results = { + "lp": th.rand(10,), + "c_eval": th.rand(10,), + "r_eval": th.rand(10,), + } + val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 100) + assert len(val_scores) == 3 + assert len(test_scores) == 3 + assert val_scores["lp"] == 0.7 + assert val_scores["c_eval"] == 0.7 + assert val_scores["r_eval"] == 0.7 + assert test_scores["lp"] == 0.6 + assert test_scores["c_eval"] == 0.65 + assert test_scores["r_eval"] == 0.8 + + val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 200) + assert len(val_scores) == 3 + assert len(test_scores) == 3 + assert val_scores["lp"] == 0.8 + assert val_scores["c_eval"] == 0.8 + assert val_scores["r_eval"] == 0.2 + assert test_scores["lp"] == 0.65 + assert test_scores["c_eval"] == 0.7 + assert test_scores["r_eval"] == 0.23 + + val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 300) + assert len(val_scores) == 3 + assert len(test_scores) == 3 + assert val_scores["lp"] == 0.7 + assert val_scores["c_eval"] == 0.76 + assert val_scores["r_eval"] == 0.3 + assert test_scores["lp"] == 0.8 + assert test_scores["c_eval"] == 0.8 + assert test_scores["r_eval"] == 0.31 + + best_val_score = mt_evaluator.best_val_score() + best_test_score = mt_evaluator.best_test_score() + best_iter_num = mt_evaluator.best_iter_num() + assert len(best_val_score) == 3 + assert len(best_test_score) == 3 + assert len(best_iter_num) == 3 + assert best_val_score["lp"] == 0.8 + assert best_val_score["c_eval"] == 0.8 + assert best_val_score["r_eval"] == 0.2 + assert best_test_score["lp"] == 0.65 + assert best_test_score["c_eval"] == 0.7 + assert best_test_score["r_eval"] == 0.23 + assert best_iter_num["lp"] == 200 + assert best_iter_num["c_eval"] == 200 + assert best_iter_num["r_eval"] == 300 if __name__ == '__main__': test_multi_task_evaluator_early_stop() From 1335ce1a0c0306b94d167bb700ecac075dd80976 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 13 May 2024 23:48:06 -0700 Subject: [PATCH 22/79] Update multi-task evaluator --- python/graphstorm/eval/evaluator.py | 4 +- tests/unit-tests/test_evaluator.py | 86 +++++++++++++++-------------- 2 files changed, 46 insertions(+), 44 deletions(-) diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index d84069c4bc..a593ceaef1 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -785,7 +785,6 @@ def evaluate(self, val_rankings, test_rankings, total_iters): if val_rankings is not None: val_score = self.compute_score(val_rankings) - if get_rank() == 0: for metric in self.metric_list: # be careful whether > or < it might change per metric. @@ -1189,7 +1188,7 @@ def evaluate(self, val_results, test_results, total_iters): test_scores = {} if val_results is not None: - for task_id, val_result in val_results.itmes(): + for task_id, val_result in val_results.items(): eval_tasks[task_id] = [val_result] if test_results is not None: @@ -1210,7 +1209,6 @@ def evaluate(self, val_results, test_results, total_iters): f"The evaluator of {task_id} is not defined." task_evaluator = self._task_evaluators[task_id] - if isinstance(task_evaluator, GSgnnPredictionEvalInterface): val_preds, val_labels = eval_task[0] test_preds, test_labels = eval_task[1] diff --git a/tests/unit-tests/test_evaluator.py b/tests/unit-tests/test_evaluator.py index 1667eb99ac..af8235e5ed 100644 --- a/tests/unit-tests/test_evaluator.py +++ b/tests/unit-tests/test_evaluator.py @@ -863,7 +863,8 @@ def test_multi_task_evaluator_early_stop(): ["accuracy"], use_early_stop=False) - task_evaluators = [("lp", lp), ("c_eval", c_eval)] + task_evaluators = {"lp": lp, + "c_eval": c_eval} try: GSgnnMultiTaskEvaluator(config.eval_frequency, task_evaluators, @@ -893,7 +894,7 @@ def test_multi_task_evaluator(): @patch.object(GSgnnMrrLPEvaluator, 'compute_score') @patch.object(GSgnnClassificationEvaluator, 'compute_score') @patch.object(GSgnnRegressionEvaluator, 'compute_score') - def check_multi_task_eval(mock_lp_comput_score, mock_class_compute_score, mock_reg_compute_score): + def check_multi_task_eval(mock_reg_compute_score, mock_class_compute_score, mock_lp_comput_score): mock_lp_comput_score.side_effect = [ {"mrr": 0.6}, {"mrr": 0.7}, @@ -921,16 +922,17 @@ def check_multi_task_eval(mock_lp_comput_score, mock_class_compute_score, mock_r {"rmse": 0.31}, ] - lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, - use_early_stop=False) + lp = GSgnnMrrLPEvaluator(config.eval_frequency, + use_early_stop=False) c_eval = GSgnnClassificationEvaluator(config.eval_frequency, ["accuracy"], use_early_stop=False) r_eval = GSgnnRegressionEvaluator(config.eval_frequency, use_early_stop=False) - task_evaluators = [("lp", lp), ("c_eval", c_eval), - ("r_eval", r_eval)] + task_evaluators = {"lp": lp, + "c_eval": c_eval, + "r_eval": r_eval} mt_evaluator = GSgnnMultiTaskEvaluator(config.eval_frequency, task_evaluators, use_early_stop=False) @@ -938,59 +940,61 @@ def check_multi_task_eval(mock_lp_comput_score, mock_class_compute_score, mock_r val_results = { "lp": th.rand(10,), - "c_eval": th.rand(10,), - "r_eval": th.rand(10,), + "c_eval": (th.rand(10,), th.rand(10,)), + "r_eval": (th.rand(10,), th.rand(10,)) } test_results = { "lp": th.rand(10,), - "c_eval": th.rand(10,), - "r_eval": th.rand(10,), + "c_eval": (th.rand(10,), th.rand(10,)), + "r_eval": (th.rand(10,), th.rand(10,)), } val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 100) assert len(val_scores) == 3 assert len(test_scores) == 3 - assert val_scores["lp"] == 0.7 - assert val_scores["c_eval"] == 0.7 - assert val_scores["r_eval"] == 0.7 - assert test_scores["lp"] == 0.6 - assert test_scores["c_eval"] == 0.65 - assert test_scores["r_eval"] == 0.8 + assert val_scores["lp"]["mrr"] == 0.7 + assert val_scores["c_eval"]["accuracy"] == 0.7 + assert val_scores["r_eval"]["rmse"] == 0.7 + assert test_scores["lp"]["mrr"] == 0.6 + assert test_scores["c_eval"]["accuracy"] == 0.65 + assert test_scores["r_eval"]["rmse"] == 0.8 val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 200) assert len(val_scores) == 3 assert len(test_scores) == 3 - assert val_scores["lp"] == 0.8 - assert val_scores["c_eval"] == 0.8 - assert val_scores["r_eval"] == 0.2 - assert test_scores["lp"] == 0.65 - assert test_scores["c_eval"] == 0.7 - assert test_scores["r_eval"] == 0.23 + assert val_scores["lp"]["mrr"] == 0.8 + assert val_scores["c_eval"]["accuracy"] == 0.8 + assert val_scores["r_eval"]["rmse"] == 0.2 + assert test_scores["lp"]["mrr"] == 0.65 + assert test_scores["c_eval"]["accuracy"] == 0.7 + assert test_scores["r_eval"]["rmse"] == 0.23 val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 300) assert len(val_scores) == 3 assert len(test_scores) == 3 - assert val_scores["lp"] == 0.7 - assert val_scores["c_eval"] == 0.76 - assert val_scores["r_eval"] == 0.3 - assert test_scores["lp"] == 0.8 - assert test_scores["c_eval"] == 0.8 - assert test_scores["r_eval"] == 0.31 - - best_val_score = mt_evaluator.best_val_score() - best_test_score = mt_evaluator.best_test_score() - best_iter_num = mt_evaluator.best_iter_num() + assert val_scores["lp"]["mrr"] == 0.7 + assert val_scores["c_eval"]["accuracy"] == 0.76 + assert val_scores["r_eval"]["rmse"] == 0.3 + assert test_scores["lp"]["mrr"] == 0.8 + assert test_scores["c_eval"]["accuracy"] == 0.8 + assert test_scores["r_eval"]["rmse"] == 0.31 + + best_val_score = mt_evaluator.best_val_score + best_test_score = mt_evaluator.best_test_score + best_iter_num = mt_evaluator.best_iter_num assert len(best_val_score) == 3 assert len(best_test_score) == 3 assert len(best_iter_num) == 3 - assert best_val_score["lp"] == 0.8 - assert best_val_score["c_eval"] == 0.8 - assert best_val_score["r_eval"] == 0.2 - assert best_test_score["lp"] == 0.65 - assert best_test_score["c_eval"] == 0.7 - assert best_test_score["r_eval"] == 0.23 - assert best_iter_num["lp"] == 200 - assert best_iter_num["c_eval"] == 200 - assert best_iter_num["r_eval"] == 300 + assert best_val_score["lp"]["mrr"] == 0.8 + assert best_val_score["c_eval"]["accuracy"] == 0.8 + assert best_val_score["r_eval"]["rmse"] == 0.2 + assert best_test_score["lp"]["mrr"] == 0.65 + assert best_test_score["c_eval"]["accuracy"] == 0.7 + assert best_test_score["r_eval"]["rmse"] == 0.23 + assert best_iter_num["lp"]["mrr"] == 200 + assert best_iter_num["c_eval"]["accuracy"] == 200 + assert best_iter_num["r_eval"]["rmse"] == 200 + + check_multi_task_eval() if __name__ == '__main__': test_multi_task_evaluator_early_stop() From d4a74a9af7f64dd203968f6600be12746c43c2bf Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 15 May 2024 15:24:57 -0700 Subject: [PATCH 23/79] Add movielens test data --- tests/end2end-tests/create_data.sh | 7 +++++++ .../end2end-tests/data_gen/movielens_multi_task.json | 4 ++-- tests/end2end-tests/data_gen/process_movielens.py | 10 ++++++++-- 3 files changed, 17 insertions(+), 4 deletions(-) rename training_scripts/gsgnn_mt/ml_ncr_lp.json => tests/end2end-tests/data_gen/movielens_multi_task.json (96%) diff --git a/tests/end2end-tests/create_data.sh b/tests/end2end-tests/create_data.sh index d1a6a65eb5..8b081f3d9c 100644 --- a/tests/end2end-tests/create_data.sh +++ b/tests/end2end-tests/create_data.sh @@ -165,6 +165,13 @@ python3 -m graphstorm.gconstruct.construct_graph \ --graph-name movie-lens-100k \ --add-reverse-edges +python3 -m graphstorm.gconstruct.construct_graph \ + --conf-file $GS_HOME/tests/end2end-tests/data_gen/movielens_multi_task.json \ + --num-processes 1 \ + --output-dir movielen_100k_multi_task_train_val_1p_4t \ + --graph-name movie-lens-100k \ + --add-reverse-edges + date echo 'Done' diff --git a/training_scripts/gsgnn_mt/ml_ncr_lp.json b/tests/end2end-tests/data_gen/movielens_multi_task.json similarity index 96% rename from training_scripts/gsgnn_mt/ml_ncr_lp.json rename to tests/end2end-tests/data_gen/movielens_multi_task.json index 85990db309..4d79b8b2b3 100644 --- a/training_scripts/gsgnn_mt/ml_ncr_lp.json +++ b/tests/end2end-tests/data_gen/movielens_multi_task.json @@ -37,7 +37,7 @@ "test_mask_c0"] }, { - "label_col": "label", + "label_col": "label2", "task_type": "classification", "split_pct": [0.7, 0.2, 0.1], "mask_field_names": ["train_mask_c1", @@ -56,7 +56,7 @@ "files": "/data/ml-100k/edges.parquet", "labels": [ { - "label_col": "rate", + "label_col": "rate_class", "task_type": "classification", "split_pct": [0.1, 0.1, 0.1], "mask_field_names": ["train_mask_field_c", diff --git a/tests/end2end-tests/data_gen/process_movielens.py b/tests/end2end-tests/data_gen/process_movielens.py index 9a0f90438c..66e7fe0e39 100644 --- a/tests/end2end-tests/data_gen/process_movielens.py +++ b/tests/end2end-tests/data_gen/process_movielens.py @@ -83,13 +83,19 @@ def write_data_parquet(data, data_file): user_data = {'id': user['id'], 'feat': feat, 'occupation': user['occupation']} write_data_parquet(user_data, '/data/ml-100k/users.parquet') -movie_data = {'id': ids, 'label': labels, 'title': title} +movie_data = {'id': ids, + 'title': title, + 'label': labels, + 'label2': labels } # label2 for multi-task learning test write_data_parquet(movie_data, '/data/ml-100k/movie.parquet') # process edges edges = pandas.read_csv('/data/ml-100k/u.data', delimiter='\t', header=None) # Set the rate to start from 0 to fit evaluation metrics, e.g., roc_auc or p_r -edge_data = {'src_id': edges[0], 'dst_id': edges[1], 'rate': edges[2]-1} +edge_data = {'src_id': edges[0], + 'dst_id': edges[1], + 'rate': edges[2]-1, + 'rate_class': edges[2]} # rate_class for multi-task learning test write_data_parquet(edge_data, '/data/ml-100k/edges.parquet') # generate data for homogeneous optimization test From 858b75135cf60b6a009c40cd04d622913bef2203 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 15 May 2024 15:34:12 -0700 Subject: [PATCH 24/79] update --- training_scripts/gsgnn_mt/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training_scripts/gsgnn_mt/README.md b/training_scripts/gsgnn_mt/README.md index 9a9010ed8f..0eab74980c 100644 --- a/training_scripts/gsgnn_mt/README.md +++ b/training_scripts/gsgnn_mt/README.md @@ -6,9 +6,9 @@ This folder presents example yaml files for multi-task learning with Movielens d python3 $GS_HOME/tests/end2end-tests/data_gen/process_movielens.py python3 -m graphstorm.gconstruct.construct_graph \ - --conf-file $GS_HOME/training_scripts/gsgnn_mt/ml_ncr_lp.json \ + --conf-file $GS_HOME/tests/end2end-tests/data_gen/movielens_multi_task.json \ --num-processes 1 \ - --output-dir movielen_100k_multitask_1p_4t \ + --output-dir movielen_100k_multi_task_train_val_1p_4t \ --graph-name movie-lens-100k \ --add-reverse-edges ``` From 9de160bac03ad709fa426382d16efb9992e6e0ed Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 15 May 2024 15:43:41 -0700 Subject: [PATCH 25/79] Update --- .../{ml_ncr_lp_yaml => ml_nc_ec_er_lp_yaml} | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) rename training_scripts/gsgnn_mt/{ml_ncr_lp_yaml => ml_nc_ec_er_lp_yaml} (57%) diff --git a/training_scripts/gsgnn_mt/ml_ncr_lp_yaml b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp_yaml similarity index 57% rename from training_scripts/gsgnn_mt/ml_ncr_lp_yaml rename to training_scripts/gsgnn_mt/ml_nc_ec_er_lp_yaml index 55b740a6c7..3a2f5e12d4 100644 --- a/training_scripts/gsgnn_mt/ml_ncr_lp_yaml +++ b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp_yaml @@ -9,7 +9,7 @@ gsf: model_encoder_type: rgcn fanout: "4" num_layers: 1 - hidden_size: 128 + hidden_size: 32 use_mini_batch_infer: true input: restore_model_path: null @@ -36,26 +36,49 @@ gsf: num_classes: 19 batch_size: 16 # will overwrite the global batch_size mask_fields: - - "train_mask_field_nc" - - "val_mask_field_nc" - - "test_mask_field_nc" + - "train_mask_c0" # node classification mask 0 + - "val_mask_c0" + - "test_mask_c0" + task_weight: 1.0 + eval_metric: + - "accuracy" + - node_classification: + target_ntype: "movie" + label_field: "label2" + multilabel: false + num_classes: 19 + batch_size: 16 # will overwrite the global batch_size + mask_fields: + - "train_mask_c1" # node classification mask 1 + - "val_mask_c1" + - "test_mask_c1" task_weight: 1.0 eval_metric: - "accuracy" - edge_classification: target_etype: - "user,rating,movie" - reverse_edge_types_map: - - "user,rating,rating-rev,movie" - label_field: "rate" + label_field: "rate_class" multilabel: false num_classes: 5 + num_decoder_basis: 2 + remove_target_edge_type: false + batch_size: 16 # will overwrite the global batch_size + mask_fields: + - "train_mask_field_c" # edge classification mask + - "val_mask_field_c" + - "test_mask_field_c" + task_weight: 0.5 # weight of the task + - edge_regression: + target_etype: + - "user,rating,movie" + label_field: "rate" num_decoder_basis: 32 - batch_size: 10 # will overwrite the global batch_size + remove_target_edge_type: false mask_fields: - - "train_mask_field_ec" - - "val_mask_field_ec" - - "test_mask_field_ec" + - "train_mask_field_r" # edge regression mask + - "val_mask_field_r" + - "test_mask_field_r" task_weight: 0.5 # weight of the task - link_prediction: num_negative_edges: 4 @@ -66,8 +89,9 @@ gsf: train_etype: - "user,rating,movie" exclude_training_targets: true - reverse_edge_types_map: [] - batch_size: 10 # will overwrite the global batch_size + reverse_edge_types_map: + - user,rating,rating-rev,movie + batch_size: 8 # will overwrite the global batch_size mask_fields: - "train_mask_field_lp" - null # empty means there is no validation mask From 0c630ce4b50600b7f7d980a790f919c21482a16c Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 15 May 2024 15:45:28 -0700 Subject: [PATCH 26/79] Update --- .../gsgnn_mt/{ml_nc_ec_er_lp_yaml => ml_nc_ec_er_lp.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename training_scripts/gsgnn_mt/{ml_nc_ec_er_lp_yaml => ml_nc_ec_er_lp.yaml} (100%) diff --git a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp_yaml b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml similarity index 100% rename from training_scripts/gsgnn_mt/ml_nc_ec_er_lp_yaml rename to training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml From 771adf7c41e9b9f8cf2af4d329da1e75c1ebaa29 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 15 May 2024 16:10:48 -0700 Subject: [PATCH 27/79] Add multi-task entry point --- python/graphstorm/run/gs_multi_task.py | 52 +++++++++++++++++++++++++ python/graphstorm/trainer/mt_trainer.py | 2 +- 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 python/graphstorm/run/gs_multi_task.py diff --git a/python/graphstorm/run/gs_multi_task.py b/python/graphstorm/run/gs_multi_task.py new file mode 100644 index 0000000000..03cbb02f35 --- /dev/null +++ b/python/graphstorm/run/gs_multi_task.py @@ -0,0 +1,52 @@ +""" + Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Entry point for running multi-task learning. + + Run as: + python3 -m graphstorm.run.gs_multi_task +""" +import os +import logging + +from .launch import get_argument_parser +from .launch import check_input_arguments +from .launch import submit_jobs + +def main(): + """ Main function + """ + parser = get_argument_parser() + args, exec_script_args = parser.parse_known_args() + check_input_arguments(args) + + lib_dir = os.path.abspath(os.path.dirname(__file__)) + if args.inference: + cmd = "gsgnn_mt/gsgnn_infer_mt.py" + else: + cmd = "gsgnn_mt/gsgnn_mt.py" + cmd_path = os.path.join(lib_dir, cmd) + exec_script_args = [cmd_path] + exec_script_args + + if "coo" not in args.graph_format: + args.graph_format = f"{args.graph_format},coo" + logging.debug("Automatically add COO format to graph formats for link prediction. " + \ + "New graph_format is %s", args.graph_format) + submit_jobs(args, exec_script_args) + +if __name__ == "__main__": + FMT = "%(asctime)s %(levelname)s %(message)s" + logging.basicConfig(format=FMT, level=logging.INFO) + main() diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index fb8f8608dd..d519872164 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -1,5 +1,5 @@ """ - Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. From 0d9f30d48733185d5e6b07c9a9e139dc9a8b87b5 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 15 May 2024 16:25:59 -0700 Subject: [PATCH 28/79] Update --- ...ulti_task.py => gs_multi_task_learning.py} | 2 +- .../end2end-tests/graphstorm-mt/mgpu_test.sh | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) rename python/graphstorm/run/{gs_multi_task.py => gs_multi_task_learning.py} (95%) create mode 100644 tests/end2end-tests/graphstorm-mt/mgpu_test.sh diff --git a/python/graphstorm/run/gs_multi_task.py b/python/graphstorm/run/gs_multi_task_learning.py similarity index 95% rename from python/graphstorm/run/gs_multi_task.py rename to python/graphstorm/run/gs_multi_task_learning.py index 03cbb02f35..93ed3c6edd 100644 --- a/python/graphstorm/run/gs_multi_task.py +++ b/python/graphstorm/run/gs_multi_task_learning.py @@ -16,7 +16,7 @@ Entry point for running multi-task learning. Run as: - python3 -m graphstorm.run.gs_multi_task + python3 -m graphstorm.run.gs_multi_task_learning """ import os import logging diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh new file mode 100644 index 0000000000..dfbf2b18ee --- /dev/null +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +service ssh restart + +DGL_HOME=/root/dgl +GS_HOME=$(pwd) +NUM_TRAINERS=4 +NUM_INFO_TRAINERS=2 +export PYTHONPATH=$GS_HOME/python/ +cd $GS_HOME/training_scripts/gsgnn_mt +echo "127.0.0.1" > ip_list.txt +cd $GS_HOME/inference_scripts/lp_infer +echo "127.0.0.1" > ip_list.txt + +error_and_exit () { + # check exec status of launch.py + status=$1 + echo $status + + if test $status -ne 0 + then + exit -1 + fi +} + +df /dev/shm -h + + +echo "**************dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, multi-task, save model" +python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True \ No newline at end of file From 6ca1d65b46f2e86118d2a8278e4f2aeb8b888b72 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 15 May 2024 22:21:38 -0700 Subject: [PATCH 29/79] Fix some bugs --- python/graphstorm/config/argument.py | 6 +- python/graphstorm/trainer/__init__.py | 1 + tests/unit-tests/test_config.py | 2 +- training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml | 127 +++++++++--------- 4 files changed, 69 insertions(+), 67 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 0b380cc814..d86a1b5d05 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -150,9 +150,9 @@ def __init__(self, cmd_args): configuration = self.load_yaml_config(cmd_args.yaml_config_file) multi_task_config = None - if 'multi_task_learning' in configuration: - multi_task_config = configuration['multi_task_learning'] - del configuration['multi_task_learning'] + if 'multi_task_learning' in configuration['gsf']: + multi_task_config = configuration['gsf']['multi_task_learning'] + del configuration['gsf']['multi_task_learning'] self.set_attributes(configuration) # Override class attributes using command-line arguments diff --git a/python/graphstorm/trainer/__init__.py b/python/graphstorm/trainer/__init__.py index b5c0c3ad6f..7dfeba4948 100644 --- a/python/graphstorm/trainer/__init__.py +++ b/python/graphstorm/trainer/__init__.py @@ -20,3 +20,4 @@ from .ep_trainer import GSgnnEdgePredictionTrainer from .gsgnn_trainer import GSgnnTrainer from .glem_np_trainer import GLEMNodePredictionTrainer +from .mt_trainer import GSgnnMultiTaskLearningTrainer diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index aafc0f8a1e..76e7ba90f0 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -1680,7 +1680,7 @@ def create_multi_task_config(tmp_path, file_name): "batch_size": 64, "eval_batch_size": 128, } - yaml_object["multi_task_learning"] = [ + yaml_object['gsf']["multi_task_learning"] = [ { BUILTIN_TASK_NODE_CLASSIFICATION : create_dummy_nc_config() }, diff --git a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml index 3a2f5e12d4..d676f73449 100644 --- a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml +++ b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml @@ -5,6 +5,7 @@ gsf: backend: gloo verbose: false save_perf_results_path: null + batch_size: 32 gnn: model_encoder_type: rgcn fanout: "4" @@ -30,70 +31,70 @@ gsf: use_node_embeddings: false multi_task_learning: - node_classification: - target_ntype: "movie" - label_field: "label" - multilabel: false - num_classes: 19 - batch_size: 16 # will overwrite the global batch_size - mask_fields: - - "train_mask_c0" # node classification mask 0 - - "val_mask_c0" - - "test_mask_c0" - task_weight: 1.0 - eval_metric: - - "accuracy" + target_ntype: "movie" + label_field: "label" + multilabel: false + num_classes: 19 + batch_size: 16 # will overwrite the global batch_size + mask_fields: + - "train_mask_c0" # node classification mask 0 + - "val_mask_c0" + - "test_mask_c0" + task_weight: 1.0 + eval_metric: + - "accuracy" - node_classification: - target_ntype: "movie" - label_field: "label2" - multilabel: false - num_classes: 19 - batch_size: 16 # will overwrite the global batch_size - mask_fields: - - "train_mask_c1" # node classification mask 1 - - "val_mask_c1" - - "test_mask_c1" - task_weight: 1.0 - eval_metric: - - "accuracy" + target_ntype: "movie" + label_field: "label2" + multilabel: false + num_classes: 19 + batch_size: 16 # will overwrite the global batch_size + mask_fields: + - "train_mask_c1" # node classification mask 1 + - "val_mask_c1" + - "test_mask_c1" + task_weight: 1.0 + eval_metric: + - "accuracy" - edge_classification: - target_etype: - - "user,rating,movie" - label_field: "rate_class" - multilabel: false - num_classes: 5 - num_decoder_basis: 2 - remove_target_edge_type: false - batch_size: 16 # will overwrite the global batch_size - mask_fields: - - "train_mask_field_c" # edge classification mask - - "val_mask_field_c" - - "test_mask_field_c" - task_weight: 0.5 # weight of the task + target_etype: + - "user,rating,movie" + label_field: "rate_class" + multilabel: false + num_classes: 5 + num_decoder_basis: 2 + remove_target_edge_type: false + batch_size: 16 # will overwrite the global batch_size + mask_fields: + - "train_mask_field_c" # edge classification mask + - "val_mask_field_c" + - "test_mask_field_c" + task_weight: 0.5 # weight of the task - edge_regression: - target_etype: - - "user,rating,movie" - label_field: "rate" - num_decoder_basis: 32 - remove_target_edge_type: false - mask_fields: - - "train_mask_field_r" # edge regression mask - - "val_mask_field_r" - - "test_mask_field_r" - task_weight: 0.5 # weight of the task + target_etype: + - "user,rating,movie" + label_field: "rate" + num_decoder_basis: 32 + remove_target_edge_type: false + mask_fields: + - "train_mask_field_r" # edge regression mask + - "val_mask_field_r" + - "test_mask_field_r" + task_weight: 0.5 # weight of the task - link_prediction: - num_negative_edges: 4 - num_negative_edges_eval: 100 - train_negative_sampler: joint - eval_etype: - - "user,rating,movie" - train_etype: - - "user,rating,movie" - exclude_training_targets: true - reverse_edge_types_map: - - user,rating,rating-rev,movie - batch_size: 8 # will overwrite the global batch_size - mask_fields: - - "train_mask_field_lp" - - null # empty means there is no validation mask - - null # empty means there is no test mask - task_weight: 1.0 \ No newline at end of file + num_negative_edges: 4 + num_negative_edges_eval: 100 + train_negative_sampler: joint + eval_etype: + - "user,rating,movie" + train_etype: + - "user,rating,movie" + exclude_training_targets: true + reverse_edge_types_map: + - user,rating,rating-rev,movie + batch_size: 8 # will overwrite the global batch_size + mask_fields: + - "train_mask_field_lp" + - null # empty means there is no validation mask + - null # empty means there is no test mask + task_weight: 1.0 \ No newline at end of file From b26566036342fb65d0081e1f95426408034cda4d Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 16 May 2024 00:11:19 -0700 Subject: [PATCH 30/79] Update --- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 167 ++++++++++-------- training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml | 9 + 2 files changed, 103 insertions(+), 73 deletions(-) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 13c7b66472..64973795dc 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -44,147 +44,168 @@ from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph from graphstorm.utils import get_lm_ntypes -def create_task_train_dataloader(task, config, train_data): +def create_task_train_dataloader(task, config, task_config, train_data): + """ + """ + fanout = config.fanout + # All tasks share the same input encoder, so the node feats must be same. + node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: train_idxs = train_data.get_node_train_set(config.target_ntype) + # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, train_idxs, - fanout=config.fanout, - batch_size=config.batch_size, + fanout=fanout, + batch_size=task_config.batch_size, train_task=True, - node_feats=config.node_feat_name, - label_field=config.label_field) + node_feats=node_feats, + label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: train_idxs = train_data.get_edge_train_set(config.target_etype) + # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, train_idxs, - fanout=config.fanout, - batch_size=config.batch_size, - node_feats=config.node_feat_name, - label_field=config.label_field, - decoder_edge_feats=config.decoder_edge_feat, + fanout=fanout, + batch_size=task_config.batch_size, + node_feats=node_feats, + label_field=task_config.label_field, + decoder_edge_feats=task_config.decoder_edge_feat, train_task=True, - reverse_edge_types_map=config.reverse_edge_types_map, - remove_target_edge_type=config.remove_target_edge_type, - exclude_training_targets=config.exclude_training_targets) + reverse_edge_types_map=task_config.reverse_edge_types_map, + remove_target_edge_type=task_config.remove_target_edge_type, + exclude_training_targets=task_config.exclude_training_targets) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: train_idxs = train_data.get_edge_train_set(config.train_etype) dataloader_cls = gs.get_lp_train_sampler(config) return dataloader_cls(train_data, train_idxs, - config.fanout, - config.batch_size, - config.num_negative_edges, - node_feats=config.node_feat_name, - pos_graph_edge_feats=config.lp_edge_weight_for_loss, + fanout, + task_config.batch_size, + task_config.num_negative_edges, + node_feats=node_feats, + pos_graph_edge_feats=task_config.lp_edge_weight_for_loss, train_task=True, - reverse_edge_types_map=config.reverse_edge_types_map, - exclude_training_targets=config.exclude_training_targets, - edge_dst_negative_field=config.train_etypes_negative_dstnode, - num_hard_negs=config.num_train_hard_negatives) + reverse_edge_types_map=task_config.reverse_edge_types_map, + exclude_training_targets=task_config.exclude_training_targets, + edge_dst_negative_field=task_config.train_etypes_negative_dstnode, + num_hard_negs=task_config.num_train_hard_negatives) return None -def create_task_val_dataloader(task, config, train_data): - fanout = config.eval_fanout if config.use_mini_batch_infer else [] +def create_task_val_dataloader(task, config, task_config, train_data): + """ + """ + # All tasks share the same input encoder, so the node feats must be same. + node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: eval_ntype = config.eval_target_ntype \ if config.eval_target_ntype is not None else config.target_ntype val_idxs = train_data.get_node_val_set(eval_ntype) - + fanout = config.eval_fanout if config.use_mini_batch_infer else [] if len(val_idxs) > 0: + # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, val_idxs, fanout=fanout, - batch_size=config.eval_batch_size, + batch_size=task_config.eval_batch_size, train_task=False, - node_feats=config.node_feat_name, - label_field=config.label_field, - construct_feat_ntype=config.construct_feat_ntype, - construct_feat_fanout=config.construct_feat_fanout) + node_feats=node_feats, + label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: val_idxs = train_data.get_edge_val_set(config.target_etype) + fanout = config.eval_fanout if config.use_mini_batch_infer else [] if len(val_idxs) > 0: + # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, val_idxs, fanout=fanout, - batch_size=config.eval_batch_size, - node_feats=config.node_feat_name, - label_field=config.label_field, - decoder_edge_feats=config.decoder_edge_feat, + batch_size=task_config.eval_batch_size, + node_feats=node_feats, + label_field=task_config.label_field, + decoder_edge_feats=task_config.decoder_edge_feat, train_task=False, - reverse_edge_types_map=config.reverse_edge_types_map, - remove_target_edge_type=config.remove_target_edge_type) + reverse_edge_types_map=task_config.reverse_edge_types_map, + remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: val_idxs = train_data.get_edge_val_set(config.eval_etype) dataloader_cls = gs.get_lp_eval_sampler(config) if len(val_idxs) > 0: + # TODO(xiangsx): Support construct feat if config.eval_etypes_negative_dstnode is not None: return dataloader_cls(train_data, val_idxs, - config.eval_batch_size, - fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, - fanout=config.eval_fanout, - fixed_test_size=config.fixed_test_size, - node_feats=config.node_feat_name, - pos_graph_edge_feats=config.lp_edge_weight_for_loss) + task_config.eval_batch_size, + fixed_edge_dst_negative_field=task_config.eval_etypes_negative_dstnode, + fanout=task_config.eval_fanout, + fixed_test_size=task_config.fixed_test_size, + node_feats=node_feats, + pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) else: return dataloader_cls(train_data, val_idxs, - config.eval_batch_size, - config.num_negative_edges_eval, config.eval_fanout, - fixed_test_size=config.fixed_test_size, - node_feats=config.node_feat_name, - pos_graph_edge_feats=config.lp_edge_weight_for_loss) + task_config.eval_batch_size, + task_config.num_negative_edges_eval, + fanout=task_config.eval_fanout, + fixed_test_size=task_config.fixed_test_size, + node_feats=node_feats, + pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) return None -def create_task_test_dataloader(task, config, train_data): +def create_task_test_dataloader(task, config, task_config, train_data): + """ + """ + # All tasks share the same input encoder, so the node feats must be same. + node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: eval_ntype = config.eval_target_ntype \ if config.eval_target_ntype is not None else config.target_ntype test_idxs = train_data.get_node_test_set(eval_ntype) fanout = config.eval_fanout if config.use_mini_batch_infer else [] if len(test_idxs) > 0: + # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, test_idxs, fanout=fanout, - batch_size=config.eval_batch_size, + batch_size=task_config.eval_batch_size, train_task=False, - node_feats=config.node_feat_name, - label_field=config.label_field, - construct_feat_ntype=config.construct_feat_ntype, - construct_feat_fanout=config.construct_feat_fanout) + node_feats=node_feats, + label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: test_idxs = train_data.get_edge_test_set(config.target_etype) + fanout = config.eval_fanout if config.use_mini_batch_infer else [] if len(test_idxs) > 0: + # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, test_idxs, fanout=fanout, - batch_size=config.eval_batch_size, - node_feats=config.node_feat_name, - label_field=config.label_field, - decoder_edge_feats=config.decoder_edge_feat, + batch_size=task_config.eval_batch_size, + node_feats=node_feats, + label_field=task_config.label_field, + decoder_edge_feats=task_config.decoder_edge_feat, train_task=False, - reverse_edge_types_map=config.reverse_edge_types_map, - remove_target_edge_type=config.remove_target_edge_type) + reverse_edge_types_map=task_config.reverse_edge_types_map, + remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: test_idxs = train_data.get_edge_test_set(config.eval_etype) dataloader_cls = gs.get_lp_eval_sampler(config) if len(test_idxs) > 0: + # TODO(xiangsx): Support construct feat if config.eval_etypes_negative_dstnode is not None: return dataloader_cls(train_data, test_idxs, - config.eval_batch_size, - fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, - fanout=config.eval_fanout, - fixed_test_size=config.fixed_test_size, - node_feats=config.node_feat_name, - pos_graph_edge_feats=config.lp_edge_weight_for_loss) + task_config.eval_batch_size, + fixed_edge_dst_negative_field=task_config.eval_etypes_negative_dstnode, + fanout=task_config.eval_fanout, + fixed_test_size=task_config.fixed_test_size, + node_feats=node_feats, + pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) else: return dataloader_cls(train_data, test_idxs, - config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, - fixed_test_size=config.fixed_test_size, - node_feats=config.node_feat_name, - pos_graph_edge_feats=config.lp_edge_weight_for_loss) + task_config.eval_batch_size, + task_config.num_negative_edges_eval, + task_config.eval_fanout, + fixed_test_size=task_config.fixed_test_size, + node_feats=node_feats, + pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) return None def create_task_decoder(task, g, decoder_input_dim, train_task): @@ -267,7 +288,7 @@ def main(config_args): edge_feat_field=config.edge_feat_name, lm_feat_ntypes=get_lm_ntypes(config.node_lm_configs)) model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm) - gs.set_encoder(model, train_data.g, config, train_task=True) + gs.gsf.set_encoder(model, train_data.g, config, train_task=True) tasks = config.multi_tasks train_dataloaders = [] @@ -279,9 +300,9 @@ def main(config_args): else model.node_input_encoder.out_dims for task in tasks: task_config = task.task_config - train_loader = create_task_train_dataloader(task, config, train_data) - val_loader = create_task_val_dataloader(task, config) - test_loader = create_task_test_dataloader(task, config) + train_loader = create_task_train_dataloader(task, config, task_config, train_data) + val_loader = create_task_val_dataloader(task, config, task_config, train_data) + test_loader = create_task_test_dataloader(task, config, task_config, train_data) train_dataloaders.append((task, train_loader)) val_dataloaders.append((task, val_loader)) test_dataloaders.append((task, test_loader)) diff --git a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml index d676f73449..017eb4ab3b 100644 --- a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml +++ b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml @@ -1,5 +1,14 @@ --- version: 1.0 +lm_model: + node_lm_models: + - + lm_type: bert + model_name: "bert-base-uncased" + gradient_checkpoint: true + node_types: + - movie + - user gsf: basic: backend: gloo From 3b6f8bfd11fd075d83c8f8946b7e604c4963ee51 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 16 May 2024 00:19:13 -0700 Subject: [PATCH 31/79] Update --- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 47 +++++++++++++--------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 64973795dc..c3e0fbcaaa 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -47,11 +47,12 @@ def create_task_train_dataloader(task, config, task_config, train_data): """ """ + # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.fanout # All tasks share the same input encoder, so the node feats must be same. node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - train_idxs = train_data.get_node_train_set(config.target_ntype) + train_idxs = train_data.get_node_train_set(task_config.target_ntype) # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, train_idxs, @@ -61,7 +62,7 @@ def create_task_train_dataloader(task, config, task_config, train_data): node_feats=node_feats, label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - train_idxs = train_data.get_edge_train_set(config.target_etype) + train_idxs = train_data.get_edge_train_set(task_config.target_etype) # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, train_idxs, @@ -75,8 +76,8 @@ def create_task_train_dataloader(task, config, task_config, train_data): remove_target_edge_type=task_config.remove_target_edge_type, exclude_training_targets=task_config.exclude_training_targets) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - train_idxs = train_data.get_edge_train_set(config.train_etype) - dataloader_cls = gs.get_lp_train_sampler(config) + train_idxs = train_data.get_edge_train_set(task_config.train_etype) + dataloader_cls = gs.get_lp_train_sampler(task_config) return dataloader_cls(train_data, train_idxs, fanout, @@ -98,10 +99,12 @@ def create_task_val_dataloader(task, config, task_config, train_data): # All tasks share the same input encoder, so the node feats must be same. node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - eval_ntype = config.eval_target_ntype \ - if config.eval_target_ntype is not None else config.target_ntype + eval_ntype = task_config.eval_target_ntype \ + if task_config.eval_target_ntype is not None \ + else task_config.target_ntype val_idxs = train_data.get_node_val_set(eval_ntype) - fanout = config.eval_fanout if config.use_mini_batch_infer else [] + # All tasks share the same GNN model, so the fanout should be the global fanout + fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(val_idxs) > 0: # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, @@ -112,8 +115,9 @@ def create_task_val_dataloader(task, config, task_config, train_data): node_feats=node_feats, label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - val_idxs = train_data.get_edge_val_set(config.target_etype) - fanout = config.eval_fanout if config.use_mini_batch_infer else [] + val_idxs = train_data.get_edge_val_set(task_config.target_etype) + # All tasks share the same GNN model, so the fanout should be the global fanout + fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(val_idxs) > 0: # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, @@ -127,11 +131,11 @@ def create_task_val_dataloader(task, config, task_config, train_data): reverse_edge_types_map=task_config.reverse_edge_types_map, remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - val_idxs = train_data.get_edge_val_set(config.eval_etype) - dataloader_cls = gs.get_lp_eval_sampler(config) + val_idxs = train_data.get_edge_val_set(task_config.eval_etype) + dataloader_cls = gs.get_lp_eval_sampler(task_config) if len(val_idxs) > 0: # TODO(xiangsx): Support construct feat - if config.eval_etypes_negative_dstnode is not None: + if task_config.eval_etypes_negative_dstnode is not None: return dataloader_cls(train_data, val_idxs, task_config.eval_batch_size, fixed_edge_dst_negative_field=task_config.eval_etypes_negative_dstnode, @@ -156,10 +160,12 @@ def create_task_test_dataloader(task, config, task_config, train_data): # All tasks share the same input encoder, so the node feats must be same. node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - eval_ntype = config.eval_target_ntype \ - if config.eval_target_ntype is not None else config.target_ntype + eval_ntype = task_config.eval_target_ntype \ + if task_config.eval_target_ntype is not None \ + else task_config.target_ntype test_idxs = train_data.get_node_test_set(eval_ntype) - fanout = config.eval_fanout if config.use_mini_batch_infer else [] + # All tasks share the same GNN model, so the fanout should be the global fanout + fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, @@ -171,8 +177,9 @@ def create_task_test_dataloader(task, config, task_config, train_data): label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - test_idxs = train_data.get_edge_test_set(config.target_etype) - fanout = config.eval_fanout if config.use_mini_batch_infer else [] + test_idxs = train_data.get_edge_test_set(task_config.target_etype) + # All tasks share the same GNN model, so the fanout should be the global fanout + fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, @@ -186,11 +193,11 @@ def create_task_test_dataloader(task, config, task_config, train_data): reverse_edge_types_map=task_config.reverse_edge_types_map, remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - test_idxs = train_data.get_edge_test_set(config.eval_etype) - dataloader_cls = gs.get_lp_eval_sampler(config) + test_idxs = train_data.get_edge_test_set(task_config.eval_etype) + dataloader_cls = gs.get_lp_eval_sampler(task_config) if len(test_idxs) > 0: # TODO(xiangsx): Support construct feat - if config.eval_etypes_negative_dstnode is not None: + if task_config.eval_etypes_negative_dstnode is not None: return dataloader_cls(train_data, test_idxs, task_config.eval_batch_size, fixed_edge_dst_negative_field=task_config.eval_etypes_negative_dstnode, From 3775b945215b03da587084b3c035df176e1ad3a2 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 16 May 2024 10:16:52 -0700 Subject: [PATCH 32/79] Update --- python/graphstorm/config/argument.py | 20 ++++++++--- python/graphstorm/dataloading/dataloading.py | 1 + python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 18 +++++----- tests/unit-tests/test_config.py | 36 ++++++++++---------- 4 files changed, 43 insertions(+), 32 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index d86a1b5d05..555d8a9380 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -305,7 +305,9 @@ def _parse_node_classification_task(self, task_config): task_id = get_mttask_id(task_type=task_type, ntype=target_ntype, label=label_field) - setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "train_mask", mask_fields[0]) + setattr(task_info, "val_mask", mask_fields[1]) + setattr(task_info, "test_mask", mask_fields[2]) setattr(task_info, "task_weight", task_weight) return TaskInfo(task_type=task_type, @@ -336,7 +338,9 @@ def _parse_node_regression_task(self, task_config): task_id = get_mttask_id(task_type=task_type, ntype=target_ntype, label=label_field) - setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "train_mask", mask_fields[0]) + setattr(task_info, "val_mask", mask_fields[1]) + setattr(task_info, "test_mask", mask_fields[2]) setattr(task_info, "task_weight", task_weight) return TaskInfo(task_type=task_type, @@ -367,7 +371,9 @@ def _parse_edge_classification_task(self, task_config): task_id = get_mttask_id(task_type=task_type, etype=target_etype, label=label_field) - setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "train_mask", mask_fields[0]) + setattr(task_info, "val_mask", mask_fields[1]) + setattr(task_info, "test_mask", mask_fields[2]) setattr(task_info, "task_weight", task_weight) return TaskInfo(task_type=task_type, task_id=task_id, @@ -397,7 +403,9 @@ def _parse_edge_regression_task(self, task_config): task_id = get_mttask_id(task_type=task_type, etype=target_etype, label=label_field) - setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "train_mask", mask_fields[0]) + setattr(task_info, "val_mask", mask_fields[1]) + setattr(task_info, "test_mask", mask_fields[2]) setattr(task_info, "task_weight", task_weight) return TaskInfo(task_type=task_type, task_id=task_id, @@ -425,7 +433,9 @@ def _parse_link_prediction_task(self, task_config): task_id = get_mttask_id( task_type=task_type, etype=train_etype if train_etype is not None else "ALL_ETYPE") - setattr(task_info, "mask_fields", mask_fields) + setattr(task_info, "train_mask", mask_fields[0]) + setattr(task_info, "val_mask", mask_fields[1]) + setattr(task_info, "test_mask", mask_fields[2]) setattr(task_info, "task_weight", task_weight) return TaskInfo(task_type=task_type, task_id=task_id, diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index c9d45b28da..ae78e155d2 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -1595,6 +1595,7 @@ def _prepare_dataloader(self, dataset, target_idx, fanout, batch_size, if len(construct_feat_ntype) > 0: sampler = MultiLayerNeighborSamplerForReconstruct(sampler, dataset, construct_feat_ntype, construct_feat_fanout) + print(target_idx) loader = dgl.dataloading.DistNodeDataLoader(g, target_idx, sampler, batch_size=batch_size, shuffle=train_task) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index c3e0fbcaaa..bfb7a22233 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -52,7 +52,7 @@ def create_task_train_dataloader(task, config, task_config, train_data): # All tasks share the same input encoder, so the node feats must be same. node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - train_idxs = train_data.get_node_train_set(task_config.target_ntype) + train_idxs = train_data.get_node_train_set(task_config.target_ntype, mask=task_config.train_mask) # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, train_idxs, @@ -62,7 +62,7 @@ def create_task_train_dataloader(task, config, task_config, train_data): node_feats=node_feats, label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - train_idxs = train_data.get_edge_train_set(task_config.target_etype) + train_idxs = train_data.get_edge_train_set(task_config.target_etype, mask=task_config.train_mask) # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, train_idxs, @@ -76,7 +76,7 @@ def create_task_train_dataloader(task, config, task_config, train_data): remove_target_edge_type=task_config.remove_target_edge_type, exclude_training_targets=task_config.exclude_training_targets) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - train_idxs = train_data.get_edge_train_set(task_config.train_etype) + train_idxs = train_data.get_edge_train_set(task_config.train_etype, mask=task_config.train_mask) dataloader_cls = gs.get_lp_train_sampler(task_config) return dataloader_cls(train_data, train_idxs, @@ -102,7 +102,7 @@ def create_task_val_dataloader(task, config, task_config, train_data): eval_ntype = task_config.eval_target_ntype \ if task_config.eval_target_ntype is not None \ else task_config.target_ntype - val_idxs = train_data.get_node_val_set(eval_ntype) + val_idxs = train_data.get_node_val_set(eval_ntype, mask=task_config.val_mask) # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(val_idxs) > 0: @@ -115,7 +115,7 @@ def create_task_val_dataloader(task, config, task_config, train_data): node_feats=node_feats, label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - val_idxs = train_data.get_edge_val_set(task_config.target_etype) + val_idxs = train_data.get_edge_val_set(task_config.target_etype, mask=task_config.val_mask) # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(val_idxs) > 0: @@ -131,7 +131,7 @@ def create_task_val_dataloader(task, config, task_config, train_data): reverse_edge_types_map=task_config.reverse_edge_types_map, remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - val_idxs = train_data.get_edge_val_set(task_config.eval_etype) + val_idxs = train_data.get_edge_val_set(task_config.eval_etype, mask=task_config.val_mask) dataloader_cls = gs.get_lp_eval_sampler(task_config) if len(val_idxs) > 0: # TODO(xiangsx): Support construct feat @@ -163,7 +163,7 @@ def create_task_test_dataloader(task, config, task_config, train_data): eval_ntype = task_config.eval_target_ntype \ if task_config.eval_target_ntype is not None \ else task_config.target_ntype - test_idxs = train_data.get_node_test_set(eval_ntype) + test_idxs = train_data.get_node_test_set(eval_ntype, mask=task_config.val_mask) # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: @@ -177,7 +177,7 @@ def create_task_test_dataloader(task, config, task_config, train_data): label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - test_idxs = train_data.get_edge_test_set(task_config.target_etype) + test_idxs = train_data.get_edge_test_set(task_config.target_etype, mask=task_config.val_mask) # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: @@ -193,7 +193,7 @@ def create_task_test_dataloader(task, config, task_config, train_data): reverse_edge_types_map=task_config.reverse_edge_types_map, remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - test_idxs = train_data.get_edge_test_set(task_config.eval_etype) + test_idxs = train_data.get_edge_test_set(task_config.eval_etype, mask=task_config.val_mask) dataloader_cls = gs.get_lp_eval_sampler(task_config) if len(test_idxs) > 0: # TODO(xiangsx): Support construct feat diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index 76e7ba90f0..0b354d2440 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -1719,9 +1719,9 @@ def test_multi_task_config(): nc_config = nc_config.task_config assert nc_config.task_weight == 1 assert len(nc_config.mask_fields) == 3 - assert nc_config.mask_fields[0] == "class_train_mask" - assert nc_config.mask_fields[1] == "class_eval_mask" - assert nc_config.mask_fields[2] == "class_test_mask" + assert nc_config.train_mask == "class_train_mask" + assert nc_config.val_mask == "class_eval_mask" + assert nc_config.test_mask == "class_test_mask" assert nc_config.target_ntype == "a" assert nc_config.label_field == "label_c" assert nc_config.multilabel == True @@ -1740,9 +1740,9 @@ def test_multi_task_config(): nr_config = nr_config.task_config assert nr_config.task_weight == 0.5 assert len(nr_config.mask_fields) == 3 - assert nr_config.mask_fields[0] == "reg_train_mask" - assert nr_config.mask_fields[1] == "reg_eval_mask" - assert nr_config.mask_fields[2] == "reg_test_mask" + assert nr_config.train_mask == "reg_train_mask" + assert nr_config.val_mask == "reg_eval_mask" + assert nr_config.test_mask == "reg_test_mask" assert nr_config.target_ntype == "a" assert nr_config.label_field == "label_r" assert len(nr_config.eval_metric) == 1 @@ -1755,9 +1755,9 @@ def test_multi_task_config(): ec_config = ec_config.task_config assert ec_config.task_weight == 1 assert len(ec_config.mask_fields) == 3 - assert ec_config.mask_fields[0] == "ec_train_mask" - assert ec_config.mask_fields[1] == "ec_eval_mask" - assert ec_config.mask_fields[2] == "ec_test_mask" + assert ec_config.train_mask == "ec_train_mask" + assert ec_config.val_mask == "ec_eval_mask" + assert ec_config.test_mask == "ec_test_mask" assert ec_config.target_etype[0] == ("query", "match", "asin") assert ec_config.label_field == "label_ec" assert ec_config.multilabel == True @@ -1778,9 +1778,9 @@ def test_multi_task_config(): er_config = er_config.task_config assert er_config.task_weight == 1 assert len(er_config.mask_fields) == 3 - assert er_config.mask_fields[0] == "er_train_mask" - assert er_config.mask_fields[1] == "er_eval_mask" - assert er_config.mask_fields[2] == "er_test_mask" + assert er_config.train_mask == "er_train_mask" + assert er_config.val_mask == "er_eval_mask" + assert er_config.test_mask == "er_test_mask" assert er_config.target_etype[0] == ("query", "match-2", "asin") assert er_config.label_field == "label_er" assert len(er_config.eval_metric) == 1 @@ -1798,9 +1798,9 @@ def test_multi_task_config(): lp_config = lp_config.task_config assert lp_config.task_weight == 1 assert len(lp_config.mask_fields) == 3 - assert lp_config.mask_fields[0] == "lp_train_mask" - assert lp_config.mask_fields[1] == "lp_eval_mask" - assert lp_config.mask_fields[2] == "lp_test_mask" + assert lp_config.train_mask == "lp_train_mask" + assert lp_config.val_mask == "lp_eval_mask" + assert lp_config.test_mask == "lp_test_mask" assert lp_config.train_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER assert lp_config.num_negative_edges == 4 assert lp_config.num_negative_edges_eval == 100 @@ -1827,9 +1827,9 @@ def test_multi_task_config(): lp_config = lp_config.task_config assert lp_config.task_weight == 2 assert len(lp_config.mask_fields) == 3 - assert lp_config.mask_fields[0] == "lp2_train_mask" - assert lp_config.mask_fields[1] == "lp2_eval_mask" - assert lp_config.mask_fields[2] == "lp2_test_mask" + assert lp_config.train_mask == "lp2_train_mask" + assert lp_config.val_mask == "lp2_eval_mask" + assert lp_config.test_mask == "lp2_test_mask" assert lp_config.train_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER assert lp_config.num_negative_edges == 16 assert lp_config.train_etype == None From 5c41c8a93c980c28adaad4db806de5fc9c5a516e Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 16 May 2024 11:34:46 -0700 Subject: [PATCH 33/79] Fix some bugs --- python/graphstorm/__init__.py | 4 ++ python/graphstorm/dataloading/dataloading.py | 26 +++++++++-- python/graphstorm/model/__init__.py | 2 + python/graphstorm/model/multitask_gnn.py | 16 ++++--- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 48 ++++++++++++-------- python/graphstorm/trainer/mt_trainer.py | 6 ++- 6 files changed, 68 insertions(+), 34 deletions(-) diff --git a/python/graphstorm/__init__.py b/python/graphstorm/__init__.py index 721e2fb366..fd310ca5ae 100644 --- a/python/graphstorm/__init__.py +++ b/python/graphstorm/__init__.py @@ -28,5 +28,9 @@ from .gsf import create_builtin_lp_model from .gsf import create_builtin_edge_model from .gsf import create_builtin_node_model + +from .gsf import (create_builtin_node_decoder, + create_builtin_edge_decoder, + create_builtin_lp_decoder) from .gsf import (get_builtin_lp_train_dataloader_class, get_builtin_lp_eval_dataloader_class) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index ae78e155d2..efb65566c7 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -1595,7 +1595,6 @@ def _prepare_dataloader(self, dataset, target_idx, fanout, batch_size, if len(construct_feat_ntype) > 0: sampler = MultiLayerNeighborSamplerForReconstruct(sampler, dataset, construct_feat_ntype, construct_feat_fanout) - print(target_idx) loader = dgl.dataloading.DistNodeDataLoader(g, target_idx, sampler, batch_size=batch_size, shuffle=train_task) @@ -1707,12 +1706,15 @@ def __init__(self, dataset, task_infos, task_dataloaders): # check dataloaders lens = [] for task_info, dataloader in zip(task_infos, task_dataloaders): + # For evaluation and testing, we allow some of the val_dataloaders or test_dataloaders + # are empty (None). assert isinstance(dataloader, (GSgnnEdgeDataLoaderBase, GSgnnLinkPredictionDataLoaderBase, - GSgnnNodeDataLoaderBase)), \ + GSgnnNodeDataLoaderBase)) or dataloader is None, \ "The task data loader should be an instance of GSgnnEdgeDataLoaderBase, " \ - "GSgnnLinkPredictionDataLoaderBase or GSgnnNodeDataLoaderBase" - num_iters = len(dataloader) + "GSgnnLinkPredictionDataLoaderBase or GSgnnNodeDataLoaderBase" \ + f"But get {type(dataloader)}" + num_iters = len(dataloader) if dataloader is not None else 0 lens.append(num_iters) logging.debug("Task %s has number of iterations of %d", task_info, num_iters) @@ -1729,7 +1731,8 @@ def _reset_loader(self): """ reset the dataloaders """ for dataloader in self._dataloaders: - iter(dataloader) + if dataloader is not None: + iter(dataloader) self._num_iters = 0 def __iter__(self): @@ -1748,6 +1751,19 @@ def __next__(self): # call __next__ of each dataloader mini_batches = [] for task_info, dataloader in zip(self._task_infos, self._dataloaders): + if dataloader is None: + # The dataloader is None + logging.warning("The dataloader of %s is None. " + "Please check whether the coresponding " + "train/val/test mask(s) are missing." + "If you are calling iter(mt_dataloader) for validation " + "or testing, we suggest you to use " + "mt_dataloader.dataloaders to get task specific " + "dataloaders and call the corresponding evaluators " + "task by task", task_info.task_id) + mini_batches.append((task_info, None)) + continue + try: mini_batch = next(dataloader) except StopIteration: diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 18a741e200..9fd9e9d9d4 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -35,6 +35,8 @@ GSgnnLinkPredictionModelBase, GSgnnLinkPredictionModelInterface, run_lp_mini_batch_predict) +from multitask_gnn import (GSgnnMultiTaskModelInterface, + GSgnnMultiTaskSharedEncoderModel) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 5568143464..0bef7db272 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -17,16 +17,14 @@ """ import abc import logging -import time -import torch as th -import dgl +from torch import nn from ..config import (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, BUILTIN_TASK_LINK_PREDICTION) -from .gnn import GSgnnModel, GSgnnModelBase +from .gnn import GSgnnModel class GSgnnMultiTaskModelInterface: @@ -89,6 +87,7 @@ def __init__(self, alpha_l2norm): super(GSgnnMultiTaskSharedEncoderModel, self).__init__() self._alpha_l2norm = alpha_l2norm self._task_pool = {} + self._task_decoders = nn.ModuleDict() def add_task(self, task_id, task_type, decoder, loss_func, weight): @@ -97,7 +96,8 @@ def add_task(self, task_id, task_type, assert task_id not in self._task_pool, \ f"Task {task_id} already exists" logging.info("Setup task %s", task_id) - self._task_pool[task_id] = (task_type, decoder, loss_func, weight) + self._task_pool[task_id] = (task_type, loss_func, weight) + self._task_decoders[task_id] = decoder @property def alpha_l2norm(self): @@ -132,7 +132,8 @@ def forward(self, task_id, mini_batch): # Call emb normalization. encode_embs = self.normalize_node_embs(encode_embs) - task_type, decoder, loss_func, weight = self.task_pool[task_id] + task_type, loss_func, weight = self.task_pool[task_id] + decoder = self._task_decoders[task_id] if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: labels = decoder_data @@ -195,7 +196,8 @@ def predict(self, task_id, mini_batch, return_proba=False): # Call emb normalization. encode_embs = self.normalize_node_embs(encode_embs) - task_type, decoder, _, _ = self.task_pool[task_id] + task_type, _, _ = self.task_pool[task_id] + decoder = self._task_decoders[task_id] if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: assert len(encode_embs) == 1, \ diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index bfb7a22233..a222145717 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -51,6 +51,8 @@ def create_task_train_dataloader(task, config, task_config, train_data): fanout = config.fanout # All tasks share the same input encoder, so the node feats must be same. node_feats = config.node_feat_name + + logging.info("Create dataloader for %s", task.task_id) if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: train_idxs = train_data.get_node_train_set(task_config.target_ntype, mask=task_config.train_mask) # TODO(xiangsx): Support construct feat @@ -73,11 +75,10 @@ def create_task_train_dataloader(task, config, task_config, train_data): decoder_edge_feats=task_config.decoder_edge_feat, train_task=True, reverse_edge_types_map=task_config.reverse_edge_types_map, - remove_target_edge_type=task_config.remove_target_edge_type, - exclude_training_targets=task_config.exclude_training_targets) + remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: train_idxs = train_data.get_edge_train_set(task_config.train_etype, mask=task_config.train_mask) - dataloader_cls = gs.get_lp_train_sampler(task_config) + dataloader_cls = gs.get_builtin_lp_train_dataloader_class(task_config) return dataloader_cls(train_data, train_idxs, fanout, @@ -96,6 +97,9 @@ def create_task_train_dataloader(task, config, task_config, train_data): def create_task_val_dataloader(task, config, task_config, train_data): """ """ + if task_config.val_mask is None: + # There is no validation mask + return None # All tasks share the same input encoder, so the node feats must be same. node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: @@ -132,7 +136,7 @@ def create_task_val_dataloader(task, config, task_config, train_data): remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: val_idxs = train_data.get_edge_val_set(task_config.eval_etype, mask=task_config.val_mask) - dataloader_cls = gs.get_lp_eval_sampler(task_config) + dataloader_cls = gs.get_builtin_lp_eval_dataloader_class(task_config) if len(val_idxs) > 0: # TODO(xiangsx): Support construct feat if task_config.eval_etypes_negative_dstnode is not None: @@ -157,13 +161,16 @@ def create_task_val_dataloader(task, config, task_config, train_data): def create_task_test_dataloader(task, config, task_config, train_data): """ """ + if task_config.test_mask is None: + # There is no validation mask + return None # All tasks share the same input encoder, so the node feats must be same. node_feats = config.node_feat_name if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: eval_ntype = task_config.eval_target_ntype \ if task_config.eval_target_ntype is not None \ else task_config.target_ntype - test_idxs = train_data.get_node_test_set(eval_ntype, mask=task_config.val_mask) + test_idxs = train_data.get_node_test_set(eval_ntype, mask=task_config.test_mask) # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: @@ -177,7 +184,7 @@ def create_task_test_dataloader(task, config, task_config, train_data): label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - test_idxs = train_data.get_edge_test_set(task_config.target_etype, mask=task_config.val_mask) + test_idxs = train_data.get_edge_test_set(task_config.target_etype, mask=task_config.test_mask) # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: @@ -194,7 +201,7 @@ def create_task_test_dataloader(task, config, task_config, train_data): remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: test_idxs = train_data.get_edge_test_set(task_config.eval_etype, mask=task_config.val_mask) - dataloader_cls = gs.get_lp_eval_sampler(task_config) + dataloader_cls = gs.get_builtin_lp_eval_dataloader_class(task_config) if len(test_idxs) > 0: # TODO(xiangsx): Support construct feat if task_config.eval_etypes_negative_dstnode is not None: @@ -217,7 +224,7 @@ def create_task_test_dataloader(task, config, task_config, train_data): def create_task_decoder(task, g, decoder_input_dim, train_task): if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - return gs.create_builtin_node_decoder(decoder_input_dim, task.task_config, train_task) + return gs.create_builtin_node_decoder(g, decoder_input_dim, task.task_config, train_task) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: return gs.create_builtin_edge_decoder(g, decoder_input_dim, task.task_config, train_task) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: @@ -236,7 +243,6 @@ def create_evaluator(task, config): config.early_stop_burnin_rounds, config.early_stop_rounds, config.early_stop_strategy) - elif task.task_type in [BUILTIN_TASK_NODE_REGRESSION]: return GSgnnRegressionEvaluator(config.eval_frequency, config.eval_metric, @@ -310,32 +316,34 @@ def main(config_args): train_loader = create_task_train_dataloader(task, config, task_config, train_data) val_loader = create_task_val_dataloader(task, config, task_config, train_data) test_loader = create_task_test_dataloader(task, config, task_config, train_data) - train_dataloaders.append((task, train_loader)) - val_dataloaders.append((task, val_loader)) - test_dataloaders.append((task, test_loader)) + train_dataloaders.append(train_loader) + val_dataloaders.append(val_loader) + test_dataloaders.append(test_loader) decoder, loss_func = create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) - model.add_task(task.task_id, task.task_type, decoder, loss_func, task_config.weight) + model.add_task(task.task_id, task.task_type, decoder, loss_func, task_config.task_weight) if not config.no_validation: if val_loader is None: logging.warning("The training data do not have validation set.") if test_loader is None: logging.warning("The training data do not have test set.") task_evaluators[task.task_id] = \ - create_evaluator(task, config) - + create_evaluator(task, task_config) - train_dataloader = GSgnnMultiTaskDataLoader(train_dataloaders) - val_dataloader = GSgnnMultiTaskDataLoader(val_dataloaders) - test_dataloader = GSgnnMultiTaskDataLoader(test_dataloaders) + train_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, train_dataloaders) + val_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, val_dataloaders) + test_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, test_dataloaders) + model.init_optimizer(lr=config.lr, + sparse_optimizer_lr=config.sparse_optimizer_lr, + weight_decay=config.wd_l2norm, + lm_lr=config.lm_tune_lr) + trainer = GSgnnMultiTaskLearningTrainer(model, topk_model_to_save=config.topk_model_to_save) if not config.no_validation: evaluator = GSgnnMultiTaskEvaluator(config.eval_frequency, task_evaluators, use_early_stop=config.use_early_stop) trainer.setup_evaluator(evaluator) - trainer = GSgnnMultiTaskLearningTrainer(model, topk_model_to_save=config.topk_model_to_save) - # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. # For example pre-compute all BERT embeddings diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index d519872164..0c14716bca 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -29,7 +29,9 @@ BUILTIN_TASK_EDGE_REGRESSION, BUILTIN_TASK_LINK_PREDICTION) from ..model import (do_full_graph_inference, - do_mini_batch_inference,GSgnnModelBase, GSgnnModel) + do_mini_batch_inference, + GSgnnModelBase, GSgnnModel, + GSgnnMultiTaskModelInterface) from .gsgnn_trainer import GSgnnTrainer from ..model import (run_node_mini_batch_predict, run_edge_mini_batch_predict, @@ -202,7 +204,7 @@ class GSgnnMultiTaskLearningTrainer(GSgnnTrainer): """ def __init__(self, model, topk_model_to_save=1): super(GSgnnMultiTaskLearningTrainer, self).__init__(model, topk_model_to_save) - assert isinstance(model) and isinstance(model, GSgnnModelBase), \ + assert isinstance(model, GSgnnMultiTaskModelInterface) and isinstance(model, GSgnnModelBase), \ "The input model is not a GSgnnModel model. Please implement GSgnnModelBase." def _run_mini_batch(self, data, model, task_info, mini_batch, device): From 25f48175df8b1aa1c1919682fb774700afca9be0 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 16 May 2024 14:22:29 -0700 Subject: [PATCH 34/79] Fix bugs --- python/graphstorm/model/__init__.py | 4 ++-- python/graphstorm/model/multitask_gnn.py | 24 +++++++++---------- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 4 ++++ python/graphstorm/trainer/mt_trainer.py | 15 ++++++++---- training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml | 14 ++++------- 5 files changed, 32 insertions(+), 29 deletions(-) diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 9fd9e9d9d4..5fffded442 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -35,8 +35,8 @@ GSgnnLinkPredictionModelBase, GSgnnLinkPredictionModelInterface, run_lp_mini_batch_predict) -from multitask_gnn import (GSgnnMultiTaskModelInterface, - GSgnnMultiTaskSharedEncoderModel) +from .multitask_gnn import (GSgnnMultiTaskModelInterface, + GSgnnMultiTaskSharedEncoderModel) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 0bef7db272..d4d9c8dbad 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -87,7 +87,7 @@ def __init__(self, alpha_l2norm): super(GSgnnMultiTaskSharedEncoderModel, self).__init__() self._alpha_l2norm = alpha_l2norm self._task_pool = {} - self._task_decoders = nn.ModuleDict() + self._decoder = nn.ModuleDict() def add_task(self, task_id, task_type, decoder, loss_func, weight): @@ -97,7 +97,7 @@ def add_task(self, task_id, task_type, f"Task {task_id} already exists" logging.info("Setup task %s", task_id) self._task_pool[task_id] = (task_type, loss_func, weight) - self._task_decoders[task_id] = decoder + self._decoder[task_id] = decoder @property def alpha_l2norm(self): @@ -133,7 +133,7 @@ def forward(self, task_id, mini_batch): encode_embs = self.normalize_node_embs(encode_embs) task_type, loss_func, weight = self.task_pool[task_id] - decoder = self._task_decoders[task_id] + task_decoder = self.decoder[task_id] if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: labels = decoder_data @@ -147,7 +147,7 @@ def forward(self, task_id, mini_batch): assert target_ntype in labels, f"Node type {target_ntype} not in labels" emb = encode_embs[target_ntype] ntype_labels = labels[target_ntype] - ntype_logits = decoder(emb) + ntype_logits = task_decoder(emb) pred_loss = loss_func(ntype_logits, ntype_labels) return pred_loss, weight @@ -158,15 +158,15 @@ def forward(self, task_id, mini_batch): "on one edge type for a single edge task." pred_loss = 0 target_etype = list(labels.keys())[0] - logits = decoder(target_edges, encode_embs, target_edge_feats) + logits = task_decoder(target_edges, encode_embs, target_edge_feats) pred_loss = loss_func(logits, labels[target_etype]) return pred_loss, weight elif task_type == BUILTIN_TASK_LINK_PREDICTION: pos_graph, neg_graph, pos_edge_feats, neg_edge_feats = decoder_data - pos_score = decoder(pos_graph, encode_embs, pos_edge_feats) - neg_score = decoder(neg_graph, encode_embs, neg_edge_feats) + pos_score = task_decoder(pos_graph, encode_embs, pos_edge_feats) + neg_score = task_decoder(neg_graph, encode_embs, neg_edge_feats) assert pos_score.keys() == neg_score.keys(), \ "Positive scores and Negative scores must have edges of same" \ f"edge types, but get {pos_score.keys()} and {neg_score.keys()}" @@ -197,7 +197,7 @@ def predict(self, task_id, mini_batch, return_proba=False): encode_embs = self.normalize_node_embs(encode_embs) task_type, _, _ = self.task_pool[task_id] - decoder = self._task_decoders[task_id] + task_decoder = self.decoder[task_id] if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: assert len(encode_embs) == 1, \ @@ -206,15 +206,15 @@ def predict(self, task_id, mini_batch, return_proba=False): target_ntype = list(encode_embs.keys())[0] predicts = {} if return_proba: - predicts[target_ntype] = decoder.predict_proba(encode_embs[target_ntype]) + predicts[target_ntype] = task_decoder.predict_proba(encode_embs[target_ntype]) else: - predicts[target_ntype] = decoder.predict(encode_embs[target_ntype]) + predicts[target_ntype] = task_decoder.predict(encode_embs[target_ntype]) return predicts elif task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: target_edges, target_edge_feats, _ = decoder_data if return_proba: - return decoder.predict_proba(target_edges, encode_embs, target_edge_feats) - return decoder.predict(target_edges, encode_embs, target_edge_feats) + return task_decoder.predict_proba(target_edges, encode_embs, target_edge_feats) + return task_decoder.predict(target_edges, encode_embs, target_edge_feats) elif task_type == BUILTIN_TASK_LINK_PREDICTION: logging.warning("Prediction for link prediction is not implemented") return None diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index a222145717..ff0eaebce5 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -343,6 +343,10 @@ def main(config_args): task_evaluators, use_early_stop=config.use_early_stop) trainer.setup_evaluator(evaluator) + if config.restore_model_path is not None: + trainer.restore_model(model_path=config.restore_model_path, + model_layer_to_load=config.restore_model_layers) + trainer.setup_device(device=get_device()) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 0c14716bca..2725cf1065 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -65,6 +65,7 @@ def run_edge_predict_mini_batch(model, data, task_info, mini_batch, device): input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) if task_info.dataloader.decoder_edge_feat_fields is not None: + print(task_info.dataloader.decoder_edge_feat_fields) input_edges = {etype: batch_graph.edges[etype].data[dgl.EID] \ for etype in batch_graph.canonical_etypes} edge_decoder_feats = \ @@ -214,10 +215,14 @@ def _run_mini_batch(self, data, model, task_info, mini_batch, device): ---------- data: GSgnnData Graph data + model: GSgnnModel + Model task_info: TaskInfo - task meta information + Task meta information mini_batch: tuple - mini-batch info + Mini-batch info + device: torch.device + Device Return ------ @@ -342,13 +347,13 @@ def fit(self, train_loader, losses = [] for (task_info, mini_batch) in task_mini_batches: - loss, weight = self._run_mini_batch(data, task_info, mini_batch) + loss, weight = self._run_mini_batch(data, model, task_info, mini_batch, device) losses.append((loss, weight)) reg_loss = th.tensor(0.).to(device) - for d_para in model.get_dense_params(): + for d_para in model.module.get_dense_params(): reg_loss += d_para.square().sum() - alpha_l2norm = model.alpha_l2norm + alpha_l2norm = model.module.alpha_l2norm mt_loss = reg_loss * alpha_l2norm mt_loss += loss * weight diff --git a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml index 017eb4ab3b..45d736a2a3 100644 --- a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml +++ b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml @@ -1,20 +1,14 @@ --- version: 1.0 -lm_model: - node_lm_models: - - - lm_type: bert - model_name: "bert-base-uncased" - gradient_checkpoint: true - node_types: - - movie - - user gsf: basic: backend: gloo verbose: false save_perf_results_path: null batch_size: 32 + node_feat_name: + - user:feat + - movie:title gnn: model_encoder_type: rgcn fanout: "4" @@ -70,7 +64,7 @@ gsf: - "user,rating,movie" label_field: "rate_class" multilabel: false - num_classes: 5 + num_classes: 6 num_decoder_basis: 2 remove_target_edge_type: false batch_size: 16 # will overwrite the global batch_size From 216c6105782e169a689961a8cec531ee642d9479 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 17 May 2024 13:17:39 -0700 Subject: [PATCH 35/79] Update --- python/graphstorm/dataloading/dataloading.py | 11 ++ python/graphstorm/eval/evaluator.py | 22 ++++ python/graphstorm/model/multitask_gnn.py | 14 ++- python/graphstorm/model/node_gnn.py | 13 ++- python/graphstorm/trainer/mt_trainer.py | 109 ++++++++++++------ training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml | 4 +- 6 files changed, 128 insertions(+), 45 deletions(-) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index efb65566c7..6a123caa67 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -1806,6 +1806,17 @@ def task_infos(self): # useful for conducting validation scores and test scores. return self._task_infos + @property + def fanout(self): + """ The fanout of each GNN layers of each dataloader + + Returns + ------- + list or a dict of list : the fanouts for each GNN layer. + """ + fanouts = [dataloader.fanout if dataloader is not None else None for dataloader in self.dataloaders] + return fanouts + ####################### Distillation ############################# diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index e9e545fa1d..3f09c03058 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -1174,6 +1174,28 @@ def best_test_score(self): } return best_test_score + @property + def last_val_score(self): + """ Last validation score + """ + last_val_score = { + task_id: evaluator.last_val_score \ + for task_id, evaluator in self.task_evaluators.items() + } + return last_val_score + + @property + def last_test_score(self): + """ Last test score + """ + last_test_score = { + task_id: evaluator.last_test_score \ + for task_id, evaluator in self.task_evaluators.items() + } + return last_test_score + + + @property def best_iter_num(self): """ Best iteration number diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index d4d9c8dbad..231aeab13b 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -111,6 +111,12 @@ def task_pool(self): """ return self._task_pool + @property + def task_decoders(self): + """ Get task decoders + """ + return self._decoder + # pylint: disable=unused-argument def forward(self, task_id, mini_batch): """ The forward function for multi-task learning @@ -132,7 +138,7 @@ def forward(self, task_id, mini_batch): # Call emb normalization. encode_embs = self.normalize_node_embs(encode_embs) - task_type, loss_func, weight = self.task_pool[task_id] + task_type, loss_func, _ = self.task_pool[task_id] task_decoder = self.decoder[task_id] if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: @@ -150,7 +156,7 @@ def forward(self, task_id, mini_batch): ntype_logits = task_decoder(emb) pred_loss = loss_func(ntype_logits, ntype_labels) - return pred_loss, weight + return pred_loss elif task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: target_edges, target_edge_feats, labels = decoder_data assert len(labels) == 1, \ @@ -161,7 +167,7 @@ def forward(self, task_id, mini_batch): logits = task_decoder(target_edges, encode_embs, target_edge_feats) pred_loss = loss_func(logits, labels[target_etype]) - return pred_loss, weight + return pred_loss elif task_type == BUILTIN_TASK_LINK_PREDICTION: pos_graph, neg_graph, pos_edge_feats, neg_edge_feats = decoder_data @@ -171,7 +177,7 @@ def forward(self, task_id, mini_batch): "Positive scores and Negative scores must have edges of same" \ f"edge types, but get {pos_score.keys()} and {neg_score.keys()}" pred_loss = loss_func(pos_score, neg_score) - return pred_loss, weight + return pred_loss else: raise TypeError("Unknow task type %s", task_type) diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index 5e30f8f7fe..04c63107f5 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -327,6 +327,9 @@ def run_node_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): """ Perform mini-batch prediction. + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. + Parameters ---------- decoder : GSNodeDecoder @@ -360,15 +363,15 @@ def run_node_mini_batch_predict(decoder, emb, loader, device, labels = {} # TODO(zhengda) I need to check if the data loader only returns target nodes. with th.no_grad(): - for input_nodes, seeds, _ in loader: - for ntype, in_nodes in input_nodes.items(): + for _, seeds, _ in loader: + for ntype, seed_nodes in seeds.items(): if isinstance(decoder, th.nn.ModuleDict): assert ntype in decoder, f"Node type {ntype} not in decoder" decoder = decoder[ntype] if return_proba: - pred = decoder.predict_proba(emb[ntype][in_nodes].to(device)) + pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device)) else: - pred = decoder.predict(emb[ntype][in_nodes].to(device)) + pred = decoder.predict(emb[ntype][seed_nodes].to(device)) if ntype in preds: preds[ntype].append(pred.cpu()) else: @@ -379,7 +382,7 @@ def run_node_mini_batch_predict(decoder, emb, loader, device, if ntype in labels: labels[ntype].append(lbl[ntype]) else: - labels[ntype] = lbl[ntype] + labels[ntype] = [lbl[ntype]] for ntype, ntype_pred in preds.items(): preds[ntype] = th.cat(ntype_pred) diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 2725cf1065..e4e0d94a43 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -65,7 +65,6 @@ def run_edge_predict_mini_batch(model, data, task_info, mini_batch, device): input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) if task_info.dataloader.decoder_edge_feat_fields is not None: - print(task_info.dataloader.decoder_edge_feat_fields) input_edges = {etype: batch_graph.edges[etype].data[dgl.EID] \ for etype in batch_graph.canonical_etypes} edge_decoder_feats = \ @@ -147,39 +146,62 @@ def multi_task_mini_batch_predict( """ dataloaders = loader.dataloaders task_infos = loader.task_infos - task_pool = model.task_pool + task_decoders = model.task_decoders res = {} with th.no_grad(): for dataloader, task_info in zip(dataloaders, task_infos): if task_info.task_type in \ [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - task_type, decoder, _, _ = task_pool[task_info.task_id] - assert task_info.task_type == task_type - preds, labels = \ - run_node_mini_batch_predict(decoder, - emb, - dataloader, - device, - return_proba, - return_label) - res[task_info.task_id] = (preds, labels) + if dataloader is None: + # In cases when there is no validation or test set. + # set pred and labels to None + res[task_info.task_id] = (None, None) + else: + decoder = task_decoders[task_info.task_id] + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + dataloader, + device, + return_proba, + return_label) + assert len(labels) == 1, \ + "In multi-task learning, for each training task, " \ + "we only support prediction on one node type." \ + "For multiple node types, please treat them as " \ + "different training tasks." + ntype = list(labels.keys())[0] + res[task_info.task_id] = (preds[ntype], labels[ntype]) elif task_info.task_type in \ [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - task_type, decoder, _, _ = task_pool[task_info.task_id] - assert task_info.task_type == task_type - preds, labels = \ - run_edge_mini_batch_predict(decoder, - emb, - loader, - device, - return_proba, - return_label) - res[task_info.task_id] = (preds, labels) + if dataloader is None: + # In cases when there is no validation or test set. + # set pred and labels to None + res[task_info.task_id] = (None, None) + else: + decoder = task_decoders[task_info.task_id] + preds, labels = \ + run_edge_mini_batch_predict(decoder, + emb, + dataloader, + device, + return_proba, + return_label) + assert len(labels) == 1, \ + "In multi-task learning, for each training task, " \ + "we only support prediction on one edge type." \ + "For multiple edge types, please treat them as " \ + "different training tasks." + etype = list(labels.keys())[0] + res[task_info.task_id] = (preds[etype], labels[etype]) elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: - task_type, decoder, _, _ = task_pool[task_info.task_id] - assert task_info.task_type == task_type - ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device) - res[task_info.task_id] = ranking + if dataloader is None: + # In cases when there is no validation or test set. + res[task_info.task_id] = None + else: + decoder = task_decoders[task_info.task_id] + ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device) + res[task_info.task_id] = ranking else: raise TypeError("Unknown task %s", task_info) @@ -309,14 +331,13 @@ def fit(self, train_loader, "Only GSgnnModel supports full-graph inference." # with freeze_input_layer_epochs is 0, computation graph will not be changed. - static_graph = freeze_input_layer_epochs == 0 on_cpu = self.device == th.device('cpu') if is_distributed(): model = DistributedDataParallel(self._model, device_ids=None if on_cpu else [self.device], output_device=None if on_cpu else self.device, find_unused_parameters=True, - static_graph=static_graph) + static_graph=False) else: model = self._model device = model.device @@ -356,7 +377,8 @@ def fit(self, train_loader, alpha_l2norm = model.module.alpha_l2norm mt_loss = reg_loss * alpha_l2norm - mt_loss += loss * weight + for loss, weight in losses: + mt_loss += loss * weight rt_profiler.record('train_forward') self.optimizer.zero_grad() loss.backward() @@ -371,7 +393,7 @@ def fit(self, train_loader, if i % 20 == 0 and get_rank() == 0: rt_profiler.print_stats() logging.info("Epoch %05d | Batch %03d | Train Loss: %.4f | Time: %.4f", - epoch, i, loss.item(), time.time() - batch_tic) + epoch, i, mt_loss.item(), time.time() - batch_tic) val_score = None if self.evaluator is not None and \ @@ -465,22 +487,41 @@ def eval(self, model, data, val_loader, test_loader, total_steps, sys_tracker.check('before prediction') model.eval() + # All the tasks share the same GNN encoder so the fanouts are same + # for different tasks. + fanout = None + for task_fanout in val_loader.fanout: + if task_fanout is not None: + fanout = task_fanout + break + assert fanout is not None, \ + "There is no validation dataloader. eval() function should not be called" if use_mini_batch_infer: emb = do_mini_batch_inference(model, data, - fanout=val_loader.fanout, + fanout=fanout, task_tracker=self.task_tracker) else: emb = do_full_graph_inference(model, data, - fanout=val_loader.fanout, + fanout=fanout, task_tracker=self.task_tracker) sys_tracker.check('compute embeddings') val_results = \ - multi_task_mini_batch_predict(model, emb, val_loader, self.device, return_proba) \ + multi_task_mini_batch_predict(model, + emb=emb, + loader=val_loader, + device=self.device, + return_proba=return_proba, + return_label=True) \ if val_loader is not None else None test_results = \ - multi_task_mini_batch_predict(model, emb, test_loader, self.device, return_proba) \ + multi_task_mini_batch_predict(model, + emb=emb, + loader=test_loader, + device=self.device, + return_proba=return_proba, + return_label=True) \ if test_loader is not None else None sys_tracker.check('after_test_score') diff --git a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml index 45d736a2a3..fcf208c0f8 100644 --- a/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml +++ b/training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml @@ -67,7 +67,7 @@ gsf: num_classes: 6 num_decoder_basis: 2 remove_target_edge_type: false - batch_size: 16 # will overwrite the global batch_size + batch_size: 64 # will overwrite the global batch_size mask_fields: - "train_mask_field_c" # edge classification mask - "val_mask_field_c" @@ -95,7 +95,7 @@ gsf: exclude_training_targets: true reverse_edge_types_map: - user,rating,rating-rev,movie - batch_size: 8 # will overwrite the global batch_size + batch_size: 128 # will overwrite the global batch_size mask_fields: - "train_mask_field_lp" - null # empty means there is no validation mask From 017c00f18e5cf916c387423a37e69e0612ed6fd9 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 17 May 2024 13:25:44 -0700 Subject: [PATCH 36/79] Update --- python/graphstorm/eval/evaluator.py | 22 ---------------------- python/graphstorm/trainer/mt_trainer.py | 2 -- 2 files changed, 24 deletions(-) diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index 3f09c03058..e9e545fa1d 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -1174,28 +1174,6 @@ def best_test_score(self): } return best_test_score - @property - def last_val_score(self): - """ Last validation score - """ - last_val_score = { - task_id: evaluator.last_val_score \ - for task_id, evaluator in self.task_evaluators.items() - } - return last_val_score - - @property - def last_test_score(self): - """ Last test score - """ - last_test_score = { - task_id: evaluator.last_test_score \ - for task_id, evaluator in self.task_evaluators.items() - } - return last_test_score - - - @property def best_iter_num(self): """ Best iteration number diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index e4e0d94a43..0f09994e49 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -447,8 +447,6 @@ def fit(self, train_loader, # final evaluation output = {'best_test_score': self.evaluator.best_test_score, 'best_val_score':self.evaluator.best_val_score, - 'last_test_score': self.evaluator.last_test_score, - 'last_val_score':self.evaluator.last_val_score, 'peak_GPU_mem_alloc_MB': th.cuda.max_memory_allocated(device) / 1024 / 1024, 'peak_RAM_mem_alloc_MB': \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, From 3b592c476a4ab01bc88321248ac7513c35ddf49b Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 15:59:00 -0700 Subject: [PATCH 37/79] clean up duplicated code --- tests/unit-tests/test_evaluator.py | 148 ----------------------------- 1 file changed, 148 deletions(-) diff --git a/tests/unit-tests/test_evaluator.py b/tests/unit-tests/test_evaluator.py index 25d66a33e1..e78bcfacea 100644 --- a/tests/unit-tests/test_evaluator.py +++ b/tests/unit-tests/test_evaluator.py @@ -995,155 +995,7 @@ def check_multi_task_eval(mock_reg_compute_score, mock_class_compute_score, mock check_multi_task_eval() - -def test_multi_task_evaluator_early_stop(): - # common Dummy objects - config = Dummy({ - "multilabel": False, - "eval_frequency": 100, - }) - lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, - use_early_stop=False) - c_eval = GSgnnClassificationEvaluator(config.eval_frequency, - ["accuracy"], - use_early_stop=False) - - task_evaluators = {"lp": lp, - "c_eval": c_eval} - try: - GSgnnMultiTaskEvaluator(config.eval_frequency, - task_evaluators, - use_early_stop=True) - assert False - except: - pass - - -def test_multi_task_evaluator(): - # common Dummy objects - config = Dummy({ - "eval_frequency": 100, - }) - - failed = False - try: - # there is no evaluators, fail - GSgnnMultiTaskEvaluator(config.eval_frequency, - [], - use_early_stop=False) - except: - failed = True - assert failed - - # Test evaluate without test set - @patch.object(GSgnnMrrLPEvaluator, 'compute_score') - @patch.object(GSgnnClassificationEvaluator, 'compute_score') - @patch.object(GSgnnRegressionEvaluator, 'compute_score') - def check_multi_task_eval(mock_reg_compute_score, mock_class_compute_score, mock_lp_comput_score): - mock_lp_comput_score.side_effect = [ - {"mrr": 0.6}, - {"mrr": 0.7}, - {"mrr": 0.65}, - {"mrr": 0.8}, - {"mrr": 0.8}, - {"mrr": 0.7} - ] - - mock_class_compute_score.side_effect = [ - {"accuracy": 0.7}, - {"accuracy": 0.65}, - {"accuracy": 0.8}, - {"accuracy": 0.7}, - {"accuracy": 0.76}, - {"accuracy": 0.8}, - ] - - mock_reg_compute_score.side_effect = [ - {"rmse": 0.7}, - {"rmse": 0.8}, - {"rmse": 0.2}, - {"rmse": 0.23}, - {"rmse": 0.3}, - {"rmse": 0.31}, - ] - - lp = GSgnnMrrLPEvaluator(config.eval_frequency, - use_early_stop=False) - c_eval = GSgnnClassificationEvaluator(config.eval_frequency, - ["accuracy"], - use_early_stop=False) - r_eval = GSgnnRegressionEvaluator(config.eval_frequency, - use_early_stop=False) - - task_evaluators = {"lp": lp, - "c_eval": c_eval, - "r_eval": r_eval} - mt_evaluator = GSgnnMultiTaskEvaluator(config.eval_frequency, - task_evaluators, - use_early_stop=False) - assert len(mt_evaluator.task_evaluators) == 3 - - val_results = { - "lp": th.rand(10,), - "c_eval": (th.rand(10,), th.rand(10,)), - "r_eval": (th.rand(10,), th.rand(10,)) - } - test_results = { - "lp": th.rand(10,), - "c_eval": (th.rand(10,), th.rand(10,)), - "r_eval": (th.rand(10,), th.rand(10,)), - } - val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 100) - assert len(val_scores) == 3 - assert len(test_scores) == 3 - assert val_scores["lp"]["mrr"] == 0.7 - assert val_scores["c_eval"]["accuracy"] == 0.7 - assert val_scores["r_eval"]["rmse"] == 0.7 - assert test_scores["lp"]["mrr"] == 0.6 - assert test_scores["c_eval"]["accuracy"] == 0.65 - assert test_scores["r_eval"]["rmse"] == 0.8 - - val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 200) - assert len(val_scores) == 3 - assert len(test_scores) == 3 - assert val_scores["lp"]["mrr"] == 0.8 - assert val_scores["c_eval"]["accuracy"] == 0.8 - assert val_scores["r_eval"]["rmse"] == 0.2 - assert test_scores["lp"]["mrr"] == 0.65 - assert test_scores["c_eval"]["accuracy"] == 0.7 - assert test_scores["r_eval"]["rmse"] == 0.23 - - val_scores, test_scores = mt_evaluator.evaluate(val_results, test_results, 300) - assert len(val_scores) == 3 - assert len(test_scores) == 3 - assert val_scores["lp"]["mrr"] == 0.7 - assert val_scores["c_eval"]["accuracy"] == 0.76 - assert val_scores["r_eval"]["rmse"] == 0.3 - assert test_scores["lp"]["mrr"] == 0.8 - assert test_scores["c_eval"]["accuracy"] == 0.8 - assert test_scores["r_eval"]["rmse"] == 0.31 - - best_val_score = mt_evaluator.best_val_score - best_test_score = mt_evaluator.best_test_score - best_iter_num = mt_evaluator.best_iter_num - assert len(best_val_score) == 3 - assert len(best_test_score) == 3 - assert len(best_iter_num) == 3 - assert best_val_score["lp"]["mrr"] == 0.8 - assert best_val_score["c_eval"]["accuracy"] == 0.8 - assert best_val_score["r_eval"]["rmse"] == 0.2 - assert best_test_score["lp"]["mrr"] == 0.65 - assert best_test_score["c_eval"]["accuracy"] == 0.7 - assert best_test_score["r_eval"]["rmse"] == 0.23 - assert best_iter_num["lp"]["mrr"] == 200 - assert best_iter_num["c_eval"]["accuracy"] == 200 - assert best_iter_num["r_eval"]["rmse"] == 200 - - check_multi_task_eval() - if __name__ == '__main__': - test_multi_task_evaluator_early_stop() - test_multi_task_evaluator() # test evaluators test_multi_task_evaluator_early_stop() test_multi_task_evaluator() From d7e0405d203c7a2130cc3c294cd3744dd3a61e2c Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:04:58 -0700 Subject: [PATCH 38/79] Update --- tests/unit-tests/test_evaluator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit-tests/test_evaluator.py b/tests/unit-tests/test_evaluator.py index e78bcfacea..947d2e588f 100644 --- a/tests/unit-tests/test_evaluator.py +++ b/tests/unit-tests/test_evaluator.py @@ -995,6 +995,7 @@ def check_multi_task_eval(mock_reg_compute_score, mock_class_compute_score, mock check_multi_task_eval() + if __name__ == '__main__': # test evaluators test_multi_task_evaluator_early_stop() From 08e3fe6b7db79dcf14da8047fa886a5514d16494 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:06:52 -0700 Subject: [PATCH 39/79] update init --- python/graphstorm/model/__init__.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 08a8391f56..18a741e200 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -24,12 +24,17 @@ from .gnn import do_full_graph_inference from .gnn import do_mini_batch_inference from .node_gnn import GSgnnNodeModel, GSgnnNodeModelBase, GSgnnNodeModelInterface -from .node_gnn import node_mini_batch_gnn_predict, node_mini_batch_predict +from .node_gnn import (node_mini_batch_gnn_predict, + node_mini_batch_predict, + run_node_mini_batch_predict) from .edge_gnn import GSgnnEdgeModel, GSgnnEdgeModelBase, GSgnnEdgeModelInterface -from .edge_gnn import edge_mini_batch_gnn_predict, edge_mini_batch_predict +from .edge_gnn import (edge_mini_batch_gnn_predict, + edge_mini_batch_predict, + run_edge_mini_batch_predict) from .lp_gnn import (GSgnnLinkPredictionModel, GSgnnLinkPredictionModelBase, - GSgnnLinkPredictionModelInterface) + GSgnnLinkPredictionModelInterface, + run_lp_mini_batch_predict) from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv From 4945c2c22e31225ce252699014b41841b4453bb2 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:08:09 -0700 Subject: [PATCH 40/79] update ep_gnn.py --- python/graphstorm/model/edge_gnn.py | 39 ++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 536e61f311..0a4b6a38e5 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -311,6 +311,44 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= model.eval() decoder = model.decoder device = model.device + + preds, labels = run_edge_mini_batch_predict(decoder, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_edge_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch prediction using edge decoder + + This function usually follows full-grain GNN embedding inference. After having + the GNN embeddings, we need to perform mini-batch computation to make predictions + on the GNN embeddings. + + Parameters + ---------- + decoder : GSEdgeDecoder + The GraphStorm edge decoder + emb : dict of Tensor + The GNN embeddings + loader : GSgnnEdgeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba: bool + Whether to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels + + Returns + ------- + dict of Tensor : GNN prediction results. Return all the results when return_proba is true + otherwise return the maximum result. + dict of Tensor : labels if return_labels is True + """ data = loader.data g = data.g preds = {} @@ -379,7 +417,6 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= append_to_dict(lbl, labels) barrier() - model.train() for target_etype, pred in preds.items(): preds[target_etype] = th.cat(pred) if return_label: From d0b37b4d3891c2ab1e8f2ea3e9c9cbd896d56ce0 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:09:06 -0700 Subject: [PATCH 41/79] update lp_gnn.py --- python/graphstorm/model/lp_gnn.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index 91c2c3317c..ebf8449a43 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -154,6 +154,30 @@ def lp_mini_batch_predict(model, emb, loader, device): Rankings of positive scores in format of {etype: ranking} """ decoder = model.decoder + return run_lp_mini_batch_predict(decoder, + emb, + loader, + device) + +def run_lp_mini_batch_predict(decoder, emb, loader, device): + """ Perform mini-batch link prediction. + + Parameters + ---------- + decoder : LinkPredictNoParamDecoder or LinkPredictLearnableDecoder + The GraphStorm link prediction decoder model + emb : dict of Tensor + The GNN embeddings + loader : GSgnnEdgeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute test scores + + Returns + ------- + rankings: dict of tensors + Rankings of positive scores in format of {etype: ranking} + """ with th.no_grad(): ranking = {} for pos_neg_tuple, neg_sample_type in loader: From 46da6cabea66a834de0d8c1a746a3710608e10fd Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:12:42 -0700 Subject: [PATCH 42/79] update --- python/graphstorm/model/edge_gnn.py | 5 +++- python/graphstorm/model/lp_gnn.py | 11 ++++++-- python/graphstorm/model/node_gnn.py | 43 +++++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 0a4b6a38e5..7525e2428f 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -322,12 +322,15 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= def run_edge_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): - """ Perform mini-batch prediction using edge decoder + """ Perform mini-batch prediction with the given decoder. This function usually follows full-grain GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. + Parameters ---------- decoder : GSEdgeDecoder diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index ebf8449a43..1e08443755 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -133,7 +133,7 @@ def forward(self, blocks, pos_graph, def lp_mini_batch_predict(model, emb, loader, device): """ Perform mini-batch prediction. - This function follows full-grain GNN embedding inference. + This function follows full-graph GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. @@ -160,7 +160,14 @@ def lp_mini_batch_predict(model, emb, loader, device): device) def run_lp_mini_batch_predict(decoder, emb, loader, device): - """ Perform mini-batch link prediction. + """ Perform mini-batch link prediction with the given decoder. + + This function follows full-graph GNN embedding inference. + After having the GNN embeddings, we need to perform mini-batch + computation to make predictions on the GNN embeddings. + + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. Parameters ---------- diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index fca05ada24..442582bf4f 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -311,6 +311,47 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= Labels if return_labels is True """ device = model.device + decoder = model.decoder + model.eval() + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + loader, + device, + return_proba, + return_label) + model.train() + return preds, labels + +def run_node_mini_batch_predict(decoder, emb, loader, device, + return_proba=True, return_label=False): + """ Perform mini-batch prediction with the given decoder. + + Note: caller should call model.eval() before calling this function + and call model.train() after when doing training. + + Parameters + ---------- + decoder : GSNodeDecoder + The GraphStorm node decoder + emb : dict of Tensor + The GNN embeddings + loader : GSgnnNodeDataLoader + The GraphStorm dataloader + device: th.device + Device used to compute prediction result + return_proba : bool + Whether or not to return all the predictions or the maximum prediction + return_label : bool + Whether or not to return labels. + + Returns + ------- + dict of Tensor : + Prediction results. + dict of Tensor : + Labels if return_labels is True + """ data = loader.data if return_label: @@ -321,7 +362,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= preds = {} labels = {} # TODO(zhengda) I need to check if the data loader only returns target nodes. - model.eval() with th.no_grad(): for _, seeds, _ in loader: # seeds are target nodes for ntype, seed_nodes in seeds.items(): @@ -345,7 +385,6 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= labels[ntype].append(lbl[ntype]) else: labels[ntype] = [lbl[ntype]] - model.train() for ntype, ntype_pred in preds.items(): preds[ntype] = th.cat(ntype_pred) From 1dc4dcc8957f602086ae5cd4dffab50ef125c172 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:14:10 -0700 Subject: [PATCH 43/79] update --- python/graphstorm/eval/evaluator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index e9e545fa1d..3dc30d8efb 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -785,6 +785,7 @@ def evaluate(self, val_rankings, test_rankings, total_iters): if val_rankings is not None: val_score = self.compute_score(val_rankings) + if get_rank() == 0: for metric in self.metric_list: # be careful whether > or < it might change per metric. From 9b1942748f194a9064d8e5b617041d8927483d1d Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 16:31:08 -0700 Subject: [PATCH 44/79] update --- python/graphstorm/model/edge_gnn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 0a4b6a38e5..eb48c0d4f6 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -313,6 +313,7 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= device = model.device preds, labels = run_edge_mini_batch_predict(decoder, + emb, loader, device, return_proba, From 28ec04cd4bb7f5bec72aaa923d602c43533adc80 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 17:07:02 -0700 Subject: [PATCH 45/79] Add unitests --- python/graphstorm/model/edge_gnn.py | 1 + python/graphstorm/model/node_gnn.py | 9 ++- tests/unit-tests/test_gnn.py | 95 ++++++++++++++++++++++++++--- 3 files changed, 92 insertions(+), 13 deletions(-) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index 7525e2428f..a2e023ba81 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -313,6 +313,7 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= device = model.device preds, labels = run_edge_mini_batch_predict(decoder, + emb, loader, device, return_proba, diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index 442582bf4f..c432f07f5d 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -365,11 +365,10 @@ def run_node_mini_batch_predict(decoder, emb, loader, device, with th.no_grad(): for _, seeds, _ in loader: # seeds are target nodes for ntype, seed_nodes in seeds.items(): - if isinstance(model.decoder, th.nn.ModuleDict): - assert ntype in model.decoder, f"Node type {ntype} not in decoder" - decoder = model.decoder[ntype] - else: - decoder = model.decoder + if isinstance(decoder, th.nn.ModuleDict): + assert ntype in decoder, f"Node type {ntype} not in decoder" + decoder = decoder[ntype] + if return_proba: pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device)) else: diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 399015a2ae..d2746fd264 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -60,9 +60,13 @@ from graphstorm import get_node_feat_size from graphstorm.gsf import get_rel_names_for_reconstruct from graphstorm.model import do_full_graph_inference, do_mini_batch_inference -from graphstorm.model.node_gnn import node_mini_batch_predict, node_mini_batch_gnn_predict +from graphstorm.model.node_gnn import (node_mini_batch_predict, + run_node_mini_batch_predict, + node_mini_batch_gnn_predict) from graphstorm.model.node_gnn import GSgnnNodeModelInterface -from graphstorm.model.edge_gnn import edge_mini_batch_predict, edge_mini_batch_gnn_predict +from graphstorm.model.edge_gnn import (edge_mini_batch_predict, + run_edge_mini_batch_predict, + edge_mini_batch_gnn_predict) from graphstorm.model.gnn_with_reconstruct import construct_node_feat, get_input_embeds_combined from graphstorm.model.utils import load_model, save_model @@ -279,9 +283,13 @@ def require_cache_embed(self): pred2_gnn_pred, _, labels2_gnn_pred, = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) # Call last layer mini-batch inference with the GNN dataloader pred2_pred, labels2_pred = node_mini_batch_predict(model, embs, dataloader2, return_label=True) + + pred2_d_pred, labels2_d_pred = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + if isinstance(pred1,dict): assert len(pred1) == len(pred2_gnn_pred) and len(labels1) == len(labels2_gnn_pred) assert len(pred1) == len(pred2_pred) and len(labels1) == len(labels2_pred) + assert len(pred1) == len(pred2_d_pred) and len(labels1) == len(labels2_gnn_pred) for ntype in pred1: assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2_gnn_pred[ntype][0:len(pred2_gnn_pred)].numpy(), decimal=5) @@ -289,6 +297,9 @@ def require_cache_embed(self): assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2_pred[ntype][0:len(pred2_pred)].numpy(), decimal=5) assert_equal(labels1[ntype].numpy(), labels2_pred[ntype].numpy()) + assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), + pred2_d_pred[ntype][0:len(pred2_d_pred)].numpy()) + assert_equal(labels1[ntype].numpy(), labels2_d_pred[ntype].numpy()) else: assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2_gnn_pred[0:len(pred2_gnn_pred)].numpy(), decimal=5) @@ -296,24 +307,42 @@ def require_cache_embed(self): assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2_pred[0:len(pred2_pred)].numpy(), decimal=5) assert_equal(labels1.numpy(), labels2_pred.numpy()) + assert_almost_equal(pred1[0:len(pred1)].numpy(), + labels2_d_pred[0:len(labels2_d_pred)].numpy()) + assert_equal(labels1.numpy(), labels2_d_pred.numpy()) # Test the return_proba argument. pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3_d, labels3_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred4, labels4 = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred4_d, labels4_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) and len(labels3) == len(labels4) + assert len(pred3) == len(pred3_d) and len(labels3) == len(labels3_d) + assert len(pred4) == len(pred4_d) and len(labels4) == len(labels4_d) for key in pred3: assert pred3[key].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[key])) + assert pred3_d[key].dim() == 2 + assert(th.is_floating_point(pred3_d[key])) assert(pred4[key].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[key])) assert(th.equal(pred3[key].argmax(dim=1), pred4[key])) + assert(pred4_d[key].dim() == 1) + assert(is_int(pred4_d[key])) + assert(th.equal(pred3[key].argmax(dim=1), pred4_d[key])) else: assert pred3.dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3)) + assert pred3_d.dim() == 2 + assert(th.is_floating_point(pred3_d)) assert(pred4.dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4)) assert(th.equal(pred3.argmax(dim=1), pred4)) + assert(labels4_d.dim() == 1) + assert(is_int(labels4_d)) + assert(th.equal(pred3.argmax(dim=1), labels4_d)) def check_node_prediction_with_reconstruct(model, data, construct_feat_ntype, train_ntypes, node_feat_field=None): """ Check whether full graph inference and mini batch inference generate the same @@ -416,32 +445,51 @@ def check_mlp_node_prediction(model, data): batch_size=10, label_field='label', node_feats='feat', train_task=False) pred2, _, labels2 = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) + pred1_d, labels1_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) if isinstance(pred1, dict): assert len(pred1) == len(pred2) and len(labels1) == len(labels2) + assert len(pred1) == len(pred1_d) and len(labels1) == len(labels1_d) for ntype in pred1: assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2[ntype][0:len(pred2)].numpy(), decimal=5) assert_equal(labels1[ntype].numpy(), labels2[ntype].numpy()) + assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred1_d[ntype][0:len(pred1_d)].numpy()) + assert_equal(labels1[ntype].numpy(), labels1_d[ntype].numpy()) else: assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2[0:len(pred2)].numpy(), decimal=5) assert_equal(labels1.numpy(), labels2.numpy()) + assert_almost_equal(pred1[0:len(pred1)].numpy(), pred1_d[0:len(pred1_d)].numpy()) + assert_equal(labels1.numpy(), labels1_d.numpy()) # Test the return_proba argument. - pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) - pred4, labels4 = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred3, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred4, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + pred3_d, _ = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred4_d, _ = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) + assert len(pred3) == len(pred3_d) + assert len(pred4) == len(pred4_d) for ntype in pred3: assert pred3[ntype].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[ntype])) + assert pred3_d[ntype].dim() == 2 + assert(th.is_floating_point(pred3_d[ntype])) assert(pred4[ntype].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[ntype])) assert(th.equal(pred3[ntype].argmax(dim=1), pred4[ntype])) + assert(is_int(pred4_d[ntype])) + assert(th.equal(pred3[ntype].argmax(dim=1), pred4_d[ntype])) else: assert pred3.dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3)) + assert pred3_d.dim() == 2 + assert(th.is_floating_point(pred3_d)) assert(pred4.dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4)) assert(th.equal(pred3.argmax(dim=1), pred4)) + assert(pred4_d.dim() == 1) + assert(is_int(pred4_d)) + assert(th.equal(pred3.argmax(dim=1), pred4_d)) @pytest.mark.parametrize("norm", [None, 'batch', 'layer']) def test_rgcn_node_prediction(norm): @@ -752,15 +800,31 @@ def check_edge_prediction(model, data): pred2[("n0", "r1", "n1")][0:len(pred2[("n0", "r1", "n1")])].numpy(), decimal=5) assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels2[("n0", "r1", "n1")].numpy()) + pred1_d, labels1_d = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + assert_almost_equal(pred1[("n0", "r1", "n1")][0:len(pred1[("n0", "r1", "n1")])].numpy(), + pred1_d[("n0", "r1", "n1")][0:len(pred1_d[("n0", "r1", "n1")])].numpy()) + assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels1_d[("n0", "r1", "n1")].numpy()) + + # Test the return_proba argument. - pred3, labels3 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) assert(th.is_floating_point(pred3[("n0", "r1", "n1")])) assert pred3[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true - pred4, labels4 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + + pred3_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + assert(th.is_floating_point(pred3_d[("n0", "r1", "n1")])) + assert pred3_d[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true + + pred4, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[("n0", "r1", "n1")])) assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4[("n0", "r1", "n1")])) + pred4_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False + assert(is_int(pred4_d[("n0", "r1", "n1")])) + assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4_d[("n0", "r1", "n1")])) + def check_mlp_edge_prediction(model, data): """ Check whether full graph inference and mini batch inference generate the same prediction result for GSgnnEdgeModel without GNN layers. @@ -793,15 +857,30 @@ def check_mlp_edge_prediction(model, data): pred2[("n0", "r1", "n1")][0:len(pred2[("n0", "r1", "n1")])].numpy(), decimal=5) assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels2[("n0", "r1", "n1")].numpy()) + pred1_d, labels1_d = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + assert_almost_equal(pred1[("n0", "r1", "n1")][0:len(pred1[("n0", "r1", "n1")])].numpy(), + pred1_d[("n0", "r1", "n1")][0:len(pred1_d[("n0", "r1", "n1")])].numpy()) + assert_equal(labels1[("n0", "r1", "n1")].numpy(), labels1_d[("n0", "r1", "n1")].numpy()) + # Test the return_proba argument. - pred3, labels3 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) + pred3, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) assert pred3[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true assert(th.is_floating_point(pred3[("n0", "r1", "n1")])) - pred4, labels4 = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) + + pred3_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + assert(th.is_floating_point(pred3_d[("n0", "r1", "n1")])) + assert pred3_d[("n0", "r1", "n1")].dim() == 2 # returns all predictions (2D tensor) when return_proba is true + + pred4, _ = edge_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False assert(is_int(pred4[("n0", "r1", "n1")])) assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4[("n0", "r1", "n1")])) + pred4_d, _ = run_edge_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + assert(pred4[("n0", "r1", "n1")].dim() == 1) # returns maximum prediction (1D tensor) when return_proba is False + assert(is_int(pred4_d[("n0", "r1", "n1")])) + assert(th.equal(pred3[("n0", "r1", "n1")].argmax(dim=1), pred4_d[("n0", "r1", "n1")])) + @pytest.mark.parametrize("num_ffn_layers", [0, 2]) def test_rgcn_edge_prediction(num_ffn_layers): """ Test edge prediction logic correctness with a edge prediction model From 2675b6dfa190f6a330f9888217ba962518c59cc4 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 17:12:37 -0700 Subject: [PATCH 46/79] Update docstr --- python/graphstorm/model/node_gnn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index c432f07f5d..526be5b4b6 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -332,8 +332,9 @@ def run_node_mini_batch_predict(decoder, emb, loader, device, Parameters ---------- - decoder : GSNodeDecoder - The GraphStorm node decoder + decoder : GSNodeDecoder or th.nn.ModuleDict + The GraphStorm node decoder. + It can be a GSNodeDecoder or a dict of GSNodeDecoders emb : dict of Tensor The GNN embeddings loader : GSgnnNodeDataLoader From 775d4b28586e3a45c0d264eb3f0c11c56bd311e8 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 17:13:58 -0700 Subject: [PATCH 47/79] Update --- python/graphstorm/model/node_gnn.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index 3e717b3304..526be5b4b6 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -325,15 +325,16 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= def run_node_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): - """ Perform mini-batch prediction. + """ Perform mini-batch prediction with the given decoder. Note: caller should call model.eval() before calling this function and call model.train() after when doing training. Parameters ---------- - decoder : GSNodeDecoder - The GraphStorm node decoder + decoder : GSNodeDecoder or th.nn.ModuleDict + The GraphStorm node decoder. + It can be a GSNodeDecoder or a dict of GSNodeDecoders emb : dict of Tensor The GNN embeddings loader : GSgnnNodeDataLoader @@ -365,11 +366,10 @@ def run_node_mini_batch_predict(decoder, emb, loader, device, with th.no_grad(): for _, seeds, _ in loader: # seeds are target nodes for ntype, seed_nodes in seeds.items(): - if isinstance(model.decoder, th.nn.ModuleDict): - assert ntype in model.decoder, f"Node type {ntype} not in decoder" - decoder = model.decoder[ntype] - else: - decoder = model.decoder + if isinstance(decoder, th.nn.ModuleDict): + assert ntype in decoder, f"Node type {ntype} not in decoder" + decoder = decoder[ntype] + if return_proba: pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device)) else: From 44c5357940bdbe3cb2523e43ce32c26083b37d25 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 20 May 2024 17:36:37 -0700 Subject: [PATCH 48/79] update --- tests/unit-tests/test_gnn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index d2746fd264..4ac8a479bc 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -313,10 +313,10 @@ def require_cache_embed(self): # Test the return_proba argument. pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) - pred3_d, labels3_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred3_d, labels3_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) pred4, labels4 = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) - pred4_d, labels4_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + pred4_d, labels4_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) and len(labels3) == len(labels4) assert len(pred3) == len(pred3_d) and len(labels3) == len(labels3_d) @@ -445,7 +445,7 @@ def check_mlp_node_prediction(model, data): batch_size=10, label_field='label', node_feats='feat', train_task=False) pred2, _, labels2 = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) - pred1_d, labels1_d = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) + pred1_d, labels1_d = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_label=True) if isinstance(pred1, dict): assert len(pred1) == len(pred2) and len(labels1) == len(labels2) assert len(pred1) == len(pred1_d) and len(labels1) == len(labels1_d) @@ -463,8 +463,8 @@ def check_mlp_node_prediction(model, data): # Test the return_proba argument. pred3, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True) pred4, _ = node_mini_batch_predict(model, embs, dataloader1, return_proba=False, return_label=True) - pred3_d, _ = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) - pred4_d, _ = node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) + pred3_d, _ = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=True, return_label=True) + pred4_d, _ = run_node_mini_batch_predict(model.decoder, embs, dataloader1, model.device, return_proba=False, return_label=True) if isinstance(pred3, dict): assert len(pred3) == len(pred4) assert len(pred3) == len(pred3_d) From c719891ed8b1baec4e0c99758a4854bff5dfdca6 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Tue, 21 May 2024 00:30:06 -0700 Subject: [PATCH 49/79] Add test --- python/graphstorm/model/__init__.py | 1 + python/graphstorm/model/multitask_gnn.py | 92 ++++++++++ python/graphstorm/trainer/mt_trainer.py | 119 +++---------- tests/unit-tests/test_trainer.py | 204 ++++++++++++++++++++++- 4 files changed, 314 insertions(+), 102 deletions(-) diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 5fffded442..34cce011ce 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -37,6 +37,7 @@ run_lp_mini_batch_predict) from .multitask_gnn import (GSgnnMultiTaskModelInterface, GSgnnMultiTaskSharedEncoderModel) +from .multitask_gnn import multi_task_mini_batch_predict from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 231aeab13b..112743ed73 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -27,6 +27,11 @@ from .gnn import GSgnnModel +from .node_gnn import run_node_mini_batch_predict +from .edge_gnn import run_edge_mini_batch_predict +from .lp_gnn import run_lp_mini_batch_predict + + class GSgnnMultiTaskModelInterface: """ The interface for GraphStorm multi-task learning. @@ -226,3 +231,90 @@ def predict(self, task_id, mini_batch, return_proba=False): return None else: raise TypeError("Unknow task type %s", task_type) + +def multi_task_mini_batch_predict( + model, emb, loader, device, return_proba=True, return_label=False): + """ conduct mini batch prediction on multiple tasks + + Parameters + ---------- + model: GSgnnMultiTaskModelInterface, GSgnnModel + Multi-task learning model + emb : dict of Tensor + The GNN embeddings + loader: GSgnnMultiTaskDataLoader + The mini-batch dataloader. + device: th.device + Device used to compute test scores. + return_proba: bool + Whether to return all the predictions or the maximum prediction. + return_label : bool + Whether or not to return labels. + + Returns + ------- + dict: prediction results of each task + """ + dataloaders = loader.dataloaders + task_infos = loader.task_infos + task_decoders = model.task_decoders + res = {} + with th.no_grad(): + for dataloader, task_info in zip(dataloaders, task_infos): + if task_info.task_type in \ + [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + if dataloader is None: + # In cases when there is no validation or test set. + # set pred and labels to None + res[task_info.task_id] = (None, None) + else: + decoder = task_decoders[task_info.task_id] + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + dataloader, + device, + return_proba, + return_label) + assert len(labels) == 1, \ + "In multi-task learning, for each training task, " \ + "we only support prediction on one node type." \ + "For multiple node types, please treat them as " \ + "different training tasks." + ntype = list(labels.keys())[0] + res[task_info.task_id] = (preds[ntype], labels[ntype]) + elif task_info.task_type in \ + [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + if dataloader is None: + # In cases when there is no validation or test set. + # set pred and labels to None + res[task_info.task_id] = (None, None) + else: + decoder = task_decoders[task_info.task_id] + preds, labels = \ + run_edge_mini_batch_predict(decoder, + emb, + dataloader, + device, + return_proba, + return_label) + assert len(labels) == 1, \ + "In multi-task learning, for each training task, " \ + "we only support prediction on one edge type." \ + "For multiple edge types, please treat them as " \ + "different training tasks." + etype = list(labels.keys())[0] + res[task_info.task_id] = (preds[etype], labels[etype]) + elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: + if dataloader is None: + # In cases when there is no validation or test set. + res[task_info.task_id] = None + else: + decoder = task_decoders[task_info.task_id] + ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device) + res[task_info.task_id] = ranking + else: + raise TypeError("Unknown task %s", task_info) + + return res + diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 0f09994e49..05c75f8b81 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -31,21 +31,22 @@ from ..model import (do_full_graph_inference, do_mini_batch_inference, GSgnnModelBase, GSgnnModel, - GSgnnMultiTaskModelInterface) + GSgnnMultiTaskModelInterface, + multi_task_mini_batch_predict) from .gsgnn_trainer import GSgnnTrainer -from ..model import (run_node_mini_batch_predict, - run_edge_mini_batch_predict, - run_lp_mini_batch_predict) from ..utils import sys_tracker, rt_profiler, print_mem, get_rank -from ..utils import barrier, is_distributed, get_backend +from ..utils import barrier, is_distributed -def run_node_predict_mini_batch(model, data, task_info, mini_batch, device): +def run_node_mini_batch(model, data, task_info, mini_batch, device): + """ Run node mini_batch forward + """ g = data.g input_nodes, seeds, blocks = mini_batch if not isinstance(input_nodes, dict): assert len(g.ntypes) == 1 input_nodes = {g.ntypes[0]: input_nodes} + nfeat_fields = task_info.dataloader.node_feat_fields label_field = task_info.dataloader.label_field input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) @@ -56,11 +57,12 @@ def run_node_predict_mini_batch(model, data, task_info, mini_batch, device): return loss, task_info.task_config.task_weight -def run_edge_predict_mini_batch(model, data, task_info, mini_batch, device): +def run_edge_mini_batch(model, data, task_info, mini_batch, device): input_nodes, batch_graph, blocks = mini_batch if not isinstance(input_nodes, dict): assert len(batch_graph.ntypes) == 1 input_nodes = {batch_graph.ntypes[0]: input_nodes} + nfeat_fields = task_info.dataloader.node_feat_fields input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) @@ -123,89 +125,6 @@ def run_link_predict_mini_batch(model, data, task_info, mini_batch, device): (pos_graph, neg_graph,pos_graph_feats, None))) return loss, task_info.task_config.task_weight -def multi_task_mini_batch_predict( - model, emb, loader, device, return_proba=True, return_label=False): - """ conduct mini batch prediction on multiple tasks - - Parameters - ---------- - model: GSgnnMultiTaskModelInterface, GSgnnModel - Multi-task learning model - emb : dict of Tensor - The GNN embeddings - loader: GSgnnMultiTaskDataLoader - The mini-batch dataloader. - device: th.device - Device used to compute test scores. - return_proba: bool - Whether to return all the predictions or the maximum prediction. - - Returns - ------- - dict: prediction results of each task - """ - dataloaders = loader.dataloaders - task_infos = loader.task_infos - task_decoders = model.task_decoders - res = {} - with th.no_grad(): - for dataloader, task_info in zip(dataloaders, task_infos): - if task_info.task_type in \ - [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - if dataloader is None: - # In cases when there is no validation or test set. - # set pred and labels to None - res[task_info.task_id] = (None, None) - else: - decoder = task_decoders[task_info.task_id] - preds, labels = \ - run_node_mini_batch_predict(decoder, - emb, - dataloader, - device, - return_proba, - return_label) - assert len(labels) == 1, \ - "In multi-task learning, for each training task, " \ - "we only support prediction on one node type." \ - "For multiple node types, please treat them as " \ - "different training tasks." - ntype = list(labels.keys())[0] - res[task_info.task_id] = (preds[ntype], labels[ntype]) - elif task_info.task_type in \ - [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - if dataloader is None: - # In cases when there is no validation or test set. - # set pred and labels to None - res[task_info.task_id] = (None, None) - else: - decoder = task_decoders[task_info.task_id] - preds, labels = \ - run_edge_mini_batch_predict(decoder, - emb, - dataloader, - device, - return_proba, - return_label) - assert len(labels) == 1, \ - "In multi-task learning, for each training task, " \ - "we only support prediction on one edge type." \ - "For multiple edge types, please treat them as " \ - "different training tasks." - etype = list(labels.keys())[0] - res[task_info.task_id] = (preds[etype], labels[etype]) - elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: - if dataloader is None: - # In cases when there is no validation or test set. - res[task_info.task_id] = None - else: - decoder = task_decoders[task_info.task_id] - ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device) - res[task_info.task_id] = ranking - else: - raise TypeError("Unknown task %s", task_info) - - return res class GSgnnMultiTaskLearningTrainer(GSgnnTrainer): r""" A trainer for multi-task learning @@ -252,18 +171,18 @@ def _run_mini_batch(self, data, model, task_info, mini_batch, device): """ if task_info.task_type in \ [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - return run_node_predict_mini_batch(model, - data, - task_info, - mini_batch, - device) + return run_node_mini_batch(model, + data, + task_info, + mini_batch, + device) elif task_info.task_type in \ [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - return run_edge_predict_mini_batch(model, - data, - task_info, - mini_batch, - device) + return run_edge_mini_batch(model, + data, + task_info, + mini_batch, + device) elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: return run_link_predict_mini_batch(model, data, diff --git a/tests/unit-tests/test_trainer.py b/tests/unit-tests/test_trainer.py index c6ca1748cf..a84089b7a6 100644 --- a/tests/unit-tests/test_trainer.py +++ b/tests/unit-tests/test_trainer.py @@ -22,11 +22,24 @@ from argparse import Namespace import torch as th -from graphstorm.config import GSConfig +from graphstorm.config import (GSConfig, TaskInfo) +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) +from graphstorm.dataloading import GSgnnData from graphstorm.tracker import GSSageMakerTaskTracker from graphstorm import create_builtin_node_gnn_model from graphstorm.trainer import GSgnnTrainer from graphstorm.eval import GSgnnClassificationEvaluator +from graphstorm.utils import setup_device, get_device +from graphstorm.trainer.mt_trainer import (run_node_mini_batch, + run_edge_mini_batch, + run_link_predict_mini_batch) +from graphstorm.dataloading import (GSgnnNodeDataLoader, + GSgnnEdgeDataLoader, + GSgnnLinkPredictionDataLoader) +from graphstorm.model import GSgnnMultiTaskModelInterface, GSgnnModel +from numpy.testing import assert_equal from data_utils import generate_dummy_dist_graph @@ -86,7 +99,7 @@ def test_trainer_setup_evaluator(): # case 2: evaluator has no task_tracker by default assert evaluator.task_tracker is None - + # case 3: when setup an evaluator that has no task_tracker and train has no task tracker # eitehr, create a new task_tracker and set it to the evaluator. trainer.setup_evaluator(evaluator) @@ -115,6 +128,193 @@ def test_trainer_setup_evaluator(): th.distributed.destroy_process_group() dgl.distributed.kvstore.close_kvstore() +class DummyGSgnnMultiTaskSharedEncoderModel(GSgnnModel, GSgnnMultiTaskModelInterface): + """ Dummy GSgnnMultiTaskSharedEncoderModel for testing + """ + def __init__(self, task_id, task_type, input_nodes, labels, node_feats, expected_loss): + self.task_id = task_id + self.task_type = task_type + self.input_nodes = input_nodes + self.labels = labels + self.node_feats = node_feats + self.expected_loss = expected_loss + + def forward(self, task_id, mini_batch): + assert task_id == self.task_id + assert len(mini_batch) == 2 + encoder_data, decoder_data = mini_batch + + if self.task_type == BUILTIN_TASK_NODE_CLASSIFICATION: + assert len(encoder_data) == 4 + blocks, node_feats, _, input_nodes = encoder_data + lbl = decoder_data + assert blocks is None + assert_equal(lbl.numpy(), self.labels.numpy()) + for ntype, idx in input_nodes.items(): + assert_equal(idx.numpy(), self.input_nodes[ntype].numpy()) + + for ntype, feats in node_feats.items(): + assert_equal(feats.numpy(), self.node_feats[ntype].numpy()) + + return self.expected_loss + if self.task_type == BUILTIN_TASK_EDGE_REGRESSION: + assert len(encoder_data) == 4 + blocks, node_feats, _, input_nodes = encoder_data + assert blocks is None + for ntype, idx in input_nodes.items(): + assert_equal(idx.numpy(), self.input_nodes[ntype].numpy()) + + for ntype, feats in node_feats.items(): + assert_equal(feats.numpy(), self.node_feats[ntype].numpy()) + assert len(decoder_data) == 3 + batch_graph, edge_decoder_feats, lbl = decoder_data + assert batch_graph is None + assert edge_decoder_feats is None + assert_equal(lbl.numpy(), self.labels.numpy()) + + return self.expected_loss + if self.task_type == BUILTIN_TASK_LINK_PREDICTION: + assert len(encoder_data) == 4 + blocks, node_feats, _, input_nodes = encoder_data + assert blocks is None + for ntype, idx in input_nodes.items(): + assert_equal(idx.numpy(), self.input_nodes[ntype].numpy()) + + for ntype, feats in node_feats.items(): + assert_equal(feats.numpy(), self.node_feats[ntype].numpy()) + + pos_graph, neg_graph, pos_graph_feats, _ = decoder_data + assert pos_graph is None + assert neg_graph is None + assert pos_graph_feats is None + + return self.expected_loss + + assert False + + def predict(self, task_id, mini_batch, return_proba=False): + pass + +def test_mtask_run_node_mini_batch(): + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) + np_data = GSgnnData(part_config=part_config) + + setup_device(0) + device = get_device() + # Without shuffling, the seed nodes should have the same order as the target nodes. + target_idx = {'n1': th.arange(np_data.g.number_of_nodes('n1'))} + task_id = "test_node_prediction" + dataloader = GSgnnNodeDataLoader(np_data, target_idx, [10], 10, + label_field='label', + node_feats='feat', + edge_feats='feat', + train_task=False) + task_config = GSConfig.__new__(GSConfig) + expected_loss = th.rand(np_data.g.number_of_nodes('n1')) + setattr(task_config, "task_weight", 0.75) + task_info = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_id=task_id, + task_config=task_config, + dataloader=dataloader) + node_feats = np_data.get_node_feats(target_idx, 'feat', device=device) + labels = np_data.get_node_feats(target_idx, 'label', device=device) + mini_batch = (target_idx, target_idx, None) + model = DummyGSgnnMultiTaskSharedEncoderModel(task_id=task_id, + task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + input_nodes=target_idx, + labels=labels, + node_feast=node_feats, + expected_loss=expected_loss) + loss, weight = run_node_mini_batch(model, np_data, task_info, mini_batch, device) + assert assert_equal(loss.numpy(), expected_loss.numpy()) + assert weight == 0.75 + +def test_mtask_run_edge_mini_batch(): + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) + ep_data = GSgnnData(part_config=part_config) + + setup_device(0) + device = get_device() + + target_idx = {('n0', 'r1', 'n1'): th.arange(ep_data.g.number_of_edges('r1'))} + task_id = "test_edge_prediction" + dataloader = GSgnnEdgeDataLoader(ep_data, target_idx, [10], 10, + node_feats='feat', + edge_feats='feat', + label_field='label', + train_task=True, remove_target_edge_type=False) + + task_config = GSConfig.__new__(GSConfig) + expected_loss = th.rand(ep_data.g.number_of_edges('r1')) + setattr(task_config, "task_weight", 0.71) + task_info = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, + task_id=task_id, + task_config=task_config, + dataloader=dataloader) + input_nodes = { + "n0": th.arange(10), + "n1": th.arange(20), + } + node_feats = ep_data.get_node_feats(input_nodes, 'feat', device=device) + labels = ep_data.get_node_feats(target_idx, 'label', device=device) + mini_batch = (input_nodes, None, None) + model = DummyGSgnnMultiTaskSharedEncoderModel(task_id, + task_type=BUILTIN_TASK_EDGE_REGRESSION, + labels=labels, + node_feast=node_feats, + expected_loss=expected_loss) + + + loss, weight = run_edge_mini_batch(model, ep_data, task_info, mini_batch, device) + assert assert_equal(loss.numpy(), expected_loss.numpy()) + assert weight == 0.71 + +def test_mtask_run_lp_mini_batch(): + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) + ep_data = GSgnnData(part_config=part_config) + + setup_device(0) + device = get_device() + + target_idx = {('n0', 'r1', 'n1'): th.arange(ep_data.g.number_of_edges('r1'))} + task_id = "test_link_prediction" + dataloader = GSgnnLinkPredictionDataLoader(ep_data, target_idx, + [10], 10, + num_negative_edges=2, + train_task=False) + task_config = GSConfig.__new__(GSConfig) + expected_loss = th.rand(ep_data.g.number_of_edges('r1')) + setattr(task_config, "task_weight", 0.72) + task_info = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, + task_id=task_id, + task_config=task_config, + dataloader=dataloader) + input_nodes = { + "n0": th.arange(10), + "n1": th.arange(20), + } + node_feats = ep_data.get_node_feats(input_nodes, 'feat', device=device) + + mini_batch = (input_nodes, None, None, None) + model = DummyGSgnnMultiTaskSharedEncoderModel(task_id, + task_type=BUILTIN_TASK_LINK_PREDICTION, + labels=None, + node_feast=node_feats, + expected_loss=expected_loss) + + loss, weight = run_link_predict_mini_batch(model, ep_data, task_info, mini_batch, device) + assert assert_equal(loss.numpy(), expected_loss.numpy()) + assert weight == 0.72 if __name__ == '__main__': test_trainer_setup_evaluator() + + test_mtask_run_node_mini_batch() + test_mtask_run_edge_mini_batch() + test_mtask_run_lp_mini_batch() From 958e2b9638a297b6259a86913afaa48fdbb63e3e Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Tue, 21 May 2024 11:25:49 -0700 Subject: [PATCH 50/79] Update --- python/graphstorm/model/edge_gnn.py | 6 +++--- python/graphstorm/model/lp_gnn.py | 5 ++++- python/graphstorm/model/node_gnn.py | 4 ++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/graphstorm/model/edge_gnn.py b/python/graphstorm/model/edge_gnn.py index a2e023ba81..75eea6791c 100644 --- a/python/graphstorm/model/edge_gnn.py +++ b/python/graphstorm/model/edge_gnn.py @@ -323,13 +323,13 @@ def edge_mini_batch_predict(model, emb, loader, return_proba=True, return_label= def run_edge_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): - """ Perform mini-batch prediction with the given decoder. + """ Perform mini-batch edge prediction with the given decoder. - This function usually follows full-grain GNN embedding inference. After having + This function usually follows full-graph GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. - Note: caller should call model.eval() before calling this function + Note: callers should call model.eval() before calling this function and call model.train() after when doing training. Parameters diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index 1e08443755..2ba2f9322d 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -137,6 +137,9 @@ def lp_mini_batch_predict(model, emb, loader, device): After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. + Note: callers should call model.eval() before calling this function + and call model.train() after when doing training. + Parameters ---------- model : GSgnnModel @@ -166,7 +169,7 @@ def run_lp_mini_batch_predict(decoder, emb, loader, device): After having the GNN embeddings, we need to perform mini-batch computation to make predictions on the GNN embeddings. - Note: caller should call model.eval() before calling this function + Note: callers should call model.eval() before calling this function and call model.train() after when doing training. Parameters diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index 526be5b4b6..f1c79cd968 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -325,9 +325,9 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= def run_node_mini_batch_predict(decoder, emb, loader, device, return_proba=True, return_label=False): - """ Perform mini-batch prediction with the given decoder. + """ Perform mini-batch node prediction with the given decoder. - Note: caller should call model.eval() before calling this function + Note: callers should call model.eval() before calling this function and call model.train() after when doing training. Parameters From 04ec3e943be91f90a78fabd300ef6c6ddf5b9f33 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 23 May 2024 15:51:34 -0700 Subject: [PATCH 51/79] Add test for multitask_gnn.py --- python/graphstorm/model/multitask_gnn.py | 45 +- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 2 +- python/graphstorm/trainer/mt_trainer.py | 4 +- tests/unit-tests/test_gnn.py | 574 ++++++++++++++++++++- 4 files changed, 605 insertions(+), 20 deletions(-) diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 112743ed73..f762886e3a 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -17,6 +17,7 @@ """ import abc import logging +import torch as th from torch import nn from ..config import (BUILTIN_TASK_NODE_CLASSIFICATION, @@ -95,13 +96,27 @@ def __init__(self, alpha_l2norm): self._decoder = nn.ModuleDict() def add_task(self, task_id, task_type, - decoder, loss_func, weight): + decoder, loss_func): """ Add a task into the multi-task pool + + Parameters + ---------- + task_id: str + Task ID. + task_type: str + Task type. + decoder: GSNodeDecoder or + GSEdgeDecoder or + LinkPredictNoParamDecoder or + LinkPredictLearnableDecoder + Task decoder. + loss_func: func + Loss function. """ assert task_id not in self._task_pool, \ f"Task {task_id} already exists" logging.info("Setup task %s", task_id) - self._task_pool[task_id] = (task_type, loss_func, weight) + self._task_pool[task_id] = (task_type, loss_func) self._decoder[task_id] = decoder @property @@ -143,7 +158,7 @@ def forward(self, task_id, mini_batch): # Call emb normalization. encode_embs = self.normalize_node_embs(encode_embs) - task_type, loss_func, _ = self.task_pool[task_id] + task_type, loss_func = self.task_pool[task_id] task_decoder = self.decoder[task_id] if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: @@ -163,13 +178,13 @@ def forward(self, task_id, mini_batch): return pred_loss elif task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - target_edges, target_edge_feats, labels = decoder_data + batch_graph, target_edge_feats, labels = decoder_data assert len(labels) == 1, \ "In multi-task learning, only support do prediction " \ "on one edge type for a single edge task." pred_loss = 0 target_etype = list(labels.keys())[0] - logits = task_decoder(target_edges, encode_embs, target_edge_feats) + logits = task_decoder(batch_graph, encode_embs, target_edge_feats) pred_loss = loss_func(logits, labels[target_etype]) return pred_loss @@ -207,7 +222,7 @@ def predict(self, task_id, mini_batch, return_proba=False): # Call emb normalization. encode_embs = self.normalize_node_embs(encode_embs) - task_type, _, _ = self.task_pool[task_id] + task_type, _ = self.task_pool[task_id] task_decoder = self.decoder[task_id] if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: @@ -222,10 +237,10 @@ def predict(self, task_id, mini_batch, return_proba=False): predicts[target_ntype] = task_decoder.predict(encode_embs[target_ntype]) return predicts elif task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - target_edges, target_edge_feats, _ = decoder_data + batch_graph, target_edge_feats, _ = decoder_data if return_proba: - return task_decoder.predict_proba(target_edges, encode_embs, target_edge_feats) - return task_decoder.predict(target_edges, encode_embs, target_edge_feats) + return task_decoder.predict_proba(batch_graph, encode_embs, target_edge_feats) + return task_decoder.predict(batch_graph, encode_embs, target_edge_feats) elif task_type == BUILTIN_TASK_LINK_PREDICTION: logging.warning("Prediction for link prediction is not implemented") return None @@ -276,13 +291,13 @@ def multi_task_mini_batch_predict( device, return_proba, return_label) - assert len(labels) == 1, \ + assert labels is None or len(labels) == 1, \ "In multi-task learning, for each training task, " \ "we only support prediction on one node type." \ "For multiple node types, please treat them as " \ "different training tasks." - ntype = list(labels.keys())[0] - res[task_info.task_id] = (preds[ntype], labels[ntype]) + ntype = list(preds.keys())[0] + res[task_info.task_id] = (preds[ntype], labels[ntype] if labels is not None else None) elif task_info.task_type in \ [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: if dataloader is None: @@ -298,13 +313,13 @@ def multi_task_mini_batch_predict( device, return_proba, return_label) - assert len(labels) == 1, \ + assert labels is None or len(labels) == 1, \ "In multi-task learning, for each training task, " \ "we only support prediction on one edge type." \ "For multiple edge types, please treat them as " \ "different training tasks." - etype = list(labels.keys())[0] - res[task_info.task_id] = (preds[etype], labels[etype]) + etype = list(preds.keys())[0] + res[task_info.task_id] = (preds[etype], labels[etype] if labels is not None else None) elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: if dataloader is None: # In cases when there is no validation or test set. diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index ff0eaebce5..74d37be20f 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -320,7 +320,7 @@ def main(config_args): val_dataloaders.append(val_loader) test_dataloaders.append(test_loader) decoder, loss_func = create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) - model.add_task(task.task_id, task.task_type, decoder, loss_func, task_config.task_weight) + model.add_task(task.task_id, task.task_type, decoder, loss_func) if not config.no_validation: if val_loader is None: logging.warning("The training data do not have validation set.") diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 05c75f8b81..8df2091e15 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -121,8 +121,8 @@ def run_link_predict_mini_batch(model, data, task_info, mini_batch, device): # TODO: we don't support edge features for now. loss = model(task_info.task_id, - ((blocks, input_feats, None, input_nodes), - (pos_graph, neg_graph,pos_graph_feats, None))) + ((blocks, input_feats, None, input_nodes), + (pos_graph, neg_graph,pos_graph_feats, None))) return loss, task_info.task_config.task_weight diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 4ac8a479bc..526f9c2f01 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -32,8 +32,13 @@ import dgl -from graphstorm.config import GSConfig +from graphstorm.config import GSConfig, TaskInfo from graphstorm.config import BUILTIN_LP_DOT_DECODER +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) from graphstorm.model import GSNodeEncoderInputLayer, RelationalGCNEncoder from graphstorm.model import GSgnnNodeModel, GSgnnEdgeModel from graphstorm.model import GSLMNodeEncoderInputLayer, GSPureLMNodeInputLayer @@ -53,7 +58,7 @@ LinkPredictWeightedDistMultDecoder) from graphstorm.model.node_decoder import EntityRegression, EntityClassifier from graphstorm.dataloading import GSgnnData -from graphstorm.dataloading import GSgnnNodeDataLoader, GSgnnEdgeDataLoader +from graphstorm.dataloading import GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnMultiTaskDataLoader from graphstorm.dataloading.dataset import prepare_batch_input from graphstorm import create_builtin_edge_gnn_model, create_builtin_node_gnn_model from graphstorm import create_builtin_lp_gnn_model @@ -67,8 +72,13 @@ from graphstorm.model.edge_gnn import (edge_mini_batch_predict, run_edge_mini_batch_predict, edge_mini_batch_gnn_predict) +from graphstorm.model.multitask_gnn import multi_task_mini_batch_predict from graphstorm.model.gnn_with_reconstruct import construct_node_feat, get_input_embeds_combined from graphstorm.model.utils import load_model, save_model +from graphstorm.model import GSgnnMultiTaskSharedEncoderModel +from graphstorm.dataloading import (GSgnnEdgeDataLoaderBase, + GSgnnLinkPredictionDataLoaderBase, + GSgnnNodeDataLoaderBase) from data_utils import generate_dummy_dist_graph, generate_dummy_dist_graph_multi_target_ntypes from data_utils import generate_dummy_dist_graph_reconstruct @@ -1714,7 +1724,567 @@ def check_predict(mock_get_labels, return_dict): th.distributed.destroy_process_group() +class DummyNCDecoder(nn.Module): + + def forward(self, inputs): + return inputs + + def predict(self, inputs): + return inputs + + def predict_proba(self, inputs): + return inputs * 2 + +class DummyNRDecoder(nn.Module): + + def forward(self, inputs): + return inputs + + def predict(self, inputs): + return inputs + + def predict_proba(self, inputs): + return inputs * 2 + +class DummyECDecoder(nn.Module): + + def forward(self, g, h, e_h): + return h["n0"] + + def predict(self, g, h, e_h): + return h["n0"] + + def predict_proba(self, g, h, e_h): + return h["n0"] * 2 + +class DummyERDecoder(nn.Module): + + def forward(self, g, h, e_h): + return h["n0"] * 2 + + def predict(self, g, h, e_h): + return h["n0"] + + def predict_proba(self, g, h, e_h): + return h["n0"] * 2 + +class DummyLPDecoder(nn.Module): + + def forward(self, g, h, e_h=None): + return h + + +def test_multi_task_forward(): + mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) + + def pred_los_func(logits, labels): + return logits - labels + def pred_lp_loss_func(pos_score, neg_score): + return pos_score["n0"] + neg_score["n0"] + mt_model.add_task("nc_task", + BUILTIN_TASK_NODE_CLASSIFICATION, + DummyNCDecoder(), + pred_los_func) + + mt_model.add_task("nr_task", + BUILTIN_TASK_NODE_REGRESSION, + DummyNRDecoder(), + pred_los_func) + + mt_model.add_task("ec_task", + BUILTIN_TASK_EDGE_CLASSIFICATION, + DummyECDecoder(), + pred_los_func) + + mt_model.add_task("er_task", + BUILTIN_TASK_EDGE_REGRESSION, + DummyERDecoder(), + pred_los_func) + + mt_model.add_task("lp_task", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + pred_lp_loss_func) + + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'comput_input_embed') + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'compute_embed_step') + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'normalize_node_embs') + def check_forward(mock_normalize_node_embs, + mock_compute_emb, + mock_input_embed): + + def normalize_size_effect_func(embs): + return embs + + def compute_side_effect_func(blocks, node_feats, input_nodes): + return input_nodes + + def input_embed_side_effect_func(input_nodes, node_feats): + return input_nodes + + mock_normalize_node_embs.side_effect = normalize_size_effect_func + mock_compute_emb.side_effect = compute_side_effect_func + mock_input_embed.side_effect = input_embed_side_effect_func + + ### blocks is None (no GNN setting) + # NC task + task_id = "nc_task" + blocks = None + input_nodes = {"n0": th.randint(5, (10,))} + labels = {"n0": th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) + + # NR task + task_id = "nr_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + labels = {"n0": th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) + + # EC task + task_id = "ec_task" + blocks = None + input_nodes = {"n0": th.randint(5, (10,))} + labels = {("n0", "r1", "n1"): th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels[("n0", "r1", "n1")]).numpy()) + + # ER task + task_id = "er_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + labels = {("n0", "r1", "n1"): th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]*2-labels[("n0", "r1", "n1")]).numpy()) + + # LP task + task_id = "lp_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]*2).numpy()) + + ### blocks is a list (GNN setting) + # NC task + task_id = "nc_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.randint(5, (10,))} + labels = {"n0": th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) + + # NR task + task_id = "nr_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + labels = {"n0": th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) + + # EC task + task_id = "ec_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.randint(5, (10,))} + labels = {("n0", "r1", "n1"): th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels[("n0", "r1", "n1")]).numpy()) + + # ER task + task_id = "er_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + labels = {("n0", "r1", "n1"): th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]*2-labels[("n0", "r1", "n1")]).numpy()) + + # LP task + task_id = "lp_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]*2).numpy()) + + + check_forward() + +def test_multi_task_predict(): + mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) + + def pred_los_func(logits, labels): + return logits - labels + def pred_lp_loss_func(pos_score, neg_score): + return pos_score["n0"] + neg_score["n0"] + mt_model.add_task("nc_task", + BUILTIN_TASK_NODE_CLASSIFICATION, + DummyNCDecoder(), + pred_los_func) + + mt_model.add_task("nr_task", + BUILTIN_TASK_NODE_REGRESSION, + DummyNRDecoder(), + pred_los_func) + + mt_model.add_task("ec_task", + BUILTIN_TASK_EDGE_CLASSIFICATION, + DummyECDecoder(), + pred_los_func) + + mt_model.add_task("er_task", + BUILTIN_TASK_EDGE_REGRESSION, + DummyERDecoder(), + pred_los_func) + + mt_model.add_task("lp_task", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + pred_lp_loss_func) + + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'comput_input_embed') + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'compute_embed_step') + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'normalize_node_embs') + def check_forward(mock_normalize_node_embs, + mock_compute_emb, + mock_input_embed): + + def normalize_size_effect_func(embs): + return embs + + def compute_side_effect_func(blocks, node_feats, input_nodes): + return input_nodes + + def input_embed_side_effect_func(input_nodes, node_feats): + return input_nodes + + mock_normalize_node_embs.side_effect = normalize_size_effect_func + mock_compute_emb.side_effect = compute_side_effect_func + mock_input_embed.side_effect = input_embed_side_effect_func + + ### blocks is None (no GNN setting) + # NC task + task_id = "nc_task" + blocks = None + input_nodes = {"n0": th.randint(5, (10,))} + labels = {"n0": th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + pred = mt_model.predict(task_id, mini_batch) + assert_equal(pred["n0"].numpy(), (input_nodes["n0"]).numpy()) + + # NR task + task_id = "nr_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + labels = {"n0": th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + pred = mt_model.predict(task_id, mini_batch) + assert_equal(pred["n0"].numpy(), (input_nodes["n0"]).numpy()) + + # EC task + task_id = "ec_task" + blocks = None + input_nodes = {"n0": th.randint(5, (10,))} + labels = {("n0", "r1", "n1"): th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + pred = mt_model.predict(task_id, mini_batch) + assert_equal(pred.numpy(), (input_nodes["n0"]).numpy()) + + # ER task + task_id = "er_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + labels = {("n0", "r1", "n1"): th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + pred = mt_model.predict(task_id, mini_batch) + assert_equal(pred.numpy(), (input_nodes["n0"]).numpy()) + + # LP task + task_id = "lp_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) + pred = mt_model.predict(task_id, mini_batch) + assert pred is None + + ### blocks is a list (GNN setting) and call return_proba=True + # NC task + task_id = "nc_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.randint(5, (10,))} + labels = {"n0": th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert_equal(pred["n0"].numpy(), (input_nodes["n0"]*2).numpy()) + + # NR task + task_id = "nr_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + labels = {"n0": th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert_equal(pred["n0"].numpy(), (input_nodes["n0"]*2).numpy()) + + # EC task + task_id = "ec_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.randint(5, (10,))} + labels = {("n0", "r1", "n1"): th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert_equal(pred.numpy(), (input_nodes["n0"]*2).numpy()) + + # ER task + task_id = "er_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + labels = {("n0", "r1", "n1"): th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert_equal(pred.numpy(), (input_nodes["n0"]*2).numpy()) + + # LP task + task_id = "lp_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert pred is None + + check_forward() + +class DummyGSgnnNodeDataLoader(GSgnnNodeDataLoaderBase): + def __init__(self): + pass # do nothing + + def __len__(self): + return 10 + + def __iter__(self): + return self + +class DummyGSgnnEdgeDataLoader(GSgnnEdgeDataLoaderBase): + def __init__(self): + pass # do nothing + + def __len__(self): + return 10 + + def __iter__(self): + return self + +class DummyGSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase): + def __init__(self): + pass # do nothing + + def __len__(self): + return 10 + + def __iter__(self): + return self + +def test_multi_task_mini_batch_predict(): + mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) + + def pred_los_func(logits, labels): + return logits - labels + def pred_lp_loss_func(pos_score, neg_score): + return pos_score["n0"] + neg_score["n0"] + mt_model.add_task("nc_task", + BUILTIN_TASK_NODE_CLASSIFICATION, + DummyNCDecoder(), + pred_los_func) + + mt_model.add_task("nr_task", + BUILTIN_TASK_NODE_REGRESSION, + DummyNRDecoder(), + pred_los_func) + + mt_model.add_task("ec_task", + BUILTIN_TASK_EDGE_CLASSIFICATION, + DummyECDecoder(), + pred_los_func) + + mt_model.add_task("er_task", + BUILTIN_TASK_EDGE_REGRESSION, + DummyERDecoder(), + pred_los_func) + + mt_model.add_task("lp_task", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + pred_lp_loss_func) + + tast_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_id='nc_task', + task_config=None) + nc_dataloader = DummyGSgnnNodeDataLoader() + tast_info_nr = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION, + task_id='nr_task', + task_config=None) + nr_dataloader = DummyGSgnnNodeDataLoader() + tast_info_ec = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION, + task_id='ec_task', + task_config=None) + ec_dataloader = DummyGSgnnEdgeDataLoader() + tast_info_er = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, + task_id='er_task', + task_config=None) + er_dataloader = DummyGSgnnEdgeDataLoader() + tast_info_lp = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, + task_id='lp_task', + task_config=None) + lp_dataloader = DummyGSgnnLinkPredictionDataLoader() + task_infos = [tast_info_nc, tast_info_nr, tast_info_ec, tast_info_er, tast_info_lp] + dataloaders = [nc_dataloader, nr_dataloader, ec_dataloader, er_dataloader, lp_dataloader] + + node_pred = {"n0": th.arange(10)} + node_prob = {"n0": th.arange(10)/10} + node_label = {"n0": th.arange(10)} + edge_pred = {("n0", "r0", "n1"): th.arange(5)} + edge_prob = {("n0", "r0", "n1"): th.arange(5)/10} + edge_label = {("n0", "r0", "n1"): th.arange(5)} + lp_pred = {("n0", "r0", "n1"): th.arange(5)/10, + ("n0", "r0", "n2"): th.arange(5)/20} + + def run_node_mini_batch_predict_side_func(decoder, emb, loader, device, return_prob, return_label): + pred = node_pred + label = None + if return_prob: + pred = node_prob + if return_label: + label = node_label + + return pred, label + + def run_edge_mini_batch_predict_side_func(decoder, emb, loader, device, return_prob, return_label): + pred = edge_pred + label = None + if return_prob: + pred = edge_prob + if return_label: + label = edge_label + + return pred, label + + def run_lpmini_batch_predict_side_func(decoder, emb, loader, device): + return lp_pred + + @patch("graphstorm.model.multitask_gnn.run_node_mini_batch_predict", side_effect = run_node_mini_batch_predict_side_func) + @patch("graphstorm.model.multitask_gnn.run_edge_mini_batch_predict", side_effect = run_edge_mini_batch_predict_side_func) + @patch("graphstorm.model.multitask_gnn.run_lp_mini_batch_predict", side_effect = run_lpmini_batch_predict_side_func) + def check_forward(mock_run_lp_mini_batch_predict, + mock_run_edge_mini_batch_predict, + mock_run_node_mini_batch_predict): + + mt_dataloader = GSgnnMultiTaskDataLoader(None, task_infos, dataloaders) + res = multi_task_mini_batch_predict(mt_model, + None, + mt_dataloader, + device=th.device('cpu'), + return_proba=False, + return_label=False) + assert len(res["nc_task"]) == 2 + assert_equal(res["nc_task"][0].numpy(), node_pred["n0"].numpy()) + assert res["nc_task"][1] is None + assert len(res["nr_task"]) == 2 + assert_equal(res["nr_task"][0].numpy(), node_pred["n0"].numpy()) + assert res["nr_task"][1] is None + assert len(res["ec_task"]) == 2 + assert_equal(res["ec_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert res["ec_task"][1] is None + assert len(res["er_task"]) == 2 + assert_equal(res["er_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert res["er_task"][1] is None + assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy()) + + res = multi_task_mini_batch_predict(mt_model, + None, + mt_dataloader, + device=th.device('cpu'), + return_proba=True, + return_label=False) + assert len(res["nc_task"]) == 2 + assert_equal(res["nc_task"][0].numpy(), node_prob["n0"].numpy()) + assert res["nc_task"][1] is None + assert len(res["nr_task"]) == 2 + assert_equal(res["nr_task"][0].numpy(), node_prob["n0"].numpy()) + assert res["nr_task"][1] is None + assert len(res["ec_task"]) == 2 + assert_equal(res["ec_task"][0].numpy(), edge_prob[("n0", "r0", "n1")].numpy()) + assert res["ec_task"][1] is None + assert len(res["er_task"]) == 2 + assert_equal(res["er_task"][0].numpy(), edge_prob[("n0", "r0", "n1")].numpy()) + assert res["er_task"][1] is None + assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy()) + + res = multi_task_mini_batch_predict(mt_model, + None, + mt_dataloader, + device=th.device('cpu'), + return_proba=False, + return_label=True) + assert len(res["nc_task"]) == 2 + assert_equal(res["nc_task"][0].numpy(), node_pred["n0"].numpy()) + assert_equal(res["nc_task"][1].numpy(), node_label["n0"].numpy()) + assert len(res["nr_task"]) == 2 + assert_equal(res["nr_task"][0].numpy(), node_pred["n0"].numpy()) + assert_equal(res["nr_task"][1].numpy(), node_label["n0"].numpy()) + assert len(res["ec_task"]) == 2 + assert_equal(res["ec_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["ec_task"][0].numpy(), edge_label[("n0", "r0", "n1")].numpy()) + assert len(res["er_task"]) == 2 + assert_equal(res["er_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["ec_task"][0].numpy(), edge_label[("n0", "r0", "n1")].numpy()) + assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy()) + + + new_dataloaders = [nc_dataloader, None, ec_dataloader, None, None] + mt_dataloader = GSgnnMultiTaskDataLoader(None, task_infos, new_dataloaders) + + res = multi_task_mini_batch_predict(mt_model, + None, + mt_dataloader, + device=th.device('cpu'), + return_proba=False, + return_label=False) + assert len(res["nc_task"]) == 2 + assert_equal(res["nc_task"][0].numpy(), node_pred["n0"].numpy()) + assert res["nc_task"][1] is None + assert len(res["nr_task"]) == 2 + assert res["nr_task"][0] is None + assert res["nr_task"][1] is None + assert len(res["ec_task"]) == 2 + assert_equal(res["ec_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert res["ec_task"][1] is None + assert len(res["er_task"]) == 2 + assert res["er_task"][0] is None + assert res["er_task"][1] is None + assert res["lp_task"] is None + + + + check_forward() + + if __name__ == '__main__': + test_multi_task_forward() + test_multi_task_predict() + test_multi_task_mini_batch_predict() + test_lm_rgcn_node_prediction_with_reconstruct() test_rgcn_node_prediction_with_reconstruct(True) test_rgcn_node_prediction_with_reconstruct(False) From 601aa1b47c253cf0052ba0c1e954a72a6b31d3ff Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 23 May 2024 15:55:49 -0700 Subject: [PATCH 52/79] Add GSgnnMultiTaskSharedEncoderModel --- python/graphstorm/model/multitask_gnn.py | 335 +++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 python/graphstorm/model/multitask_gnn.py diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py new file mode 100644 index 0000000000..f762886e3a --- /dev/null +++ b/python/graphstorm/model/multitask_gnn.py @@ -0,0 +1,335 @@ +""" + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + GNN model for multi-task learning in GraphStorm +""" +import abc +import logging +import torch as th +from torch import nn + +from ..config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) +from .gnn import GSgnnModel + + +from .node_gnn import run_node_mini_batch_predict +from .edge_gnn import run_edge_mini_batch_predict +from .lp_gnn import run_lp_mini_batch_predict + + +class GSgnnMultiTaskModelInterface: + """ The interface for GraphStorm multi-task learning. + + This interface defines two main methods for training and inference. + """ + @abc.abstractmethod + def forward(self, task_id, mini_batch): + """ The forward function for multi-task learning + + This method is used for training, It runs model forword + on a mini-batch for one task at a time. + The loss of the model in the mini-batch is returned. + + Parameters + ---------- + task_id: str + ID of the task. + mini_batch: tuple + Mini-batch info + + + Return + ------ + The loss of prediction. + """ + + @abc.abstractmethod + def predict(self, task_info, mini_batch): + """ The forward function for multi-task prediction. + + This method is used for inference, It runs model forword + on a mini-batch for one task at a time. + The prediction result is returned. + + Parameters + ---------- + task_info: TaskInfo + task meta information + mini_batch: tuple + mini-batch info + + Returns + ------- + Tensor or dict of Tensor: + the prediction results. + """ + +class GSgnnMultiTaskSharedEncoderModel(GSgnnModel, GSgnnMultiTaskModelInterface): + """ GraphStorm GNN model for multi-task learning + with a shared encoder model and separate decoder models. + + Parameters + ---------- + alpha_l2norm : float + The alpha for L2 normalization. + """ + def __init__(self, alpha_l2norm): + super(GSgnnMultiTaskSharedEncoderModel, self).__init__() + self._alpha_l2norm = alpha_l2norm + self._task_pool = {} + self._decoder = nn.ModuleDict() + + def add_task(self, task_id, task_type, + decoder, loss_func): + """ Add a task into the multi-task pool + + Parameters + ---------- + task_id: str + Task ID. + task_type: str + Task type. + decoder: GSNodeDecoder or + GSEdgeDecoder or + LinkPredictNoParamDecoder or + LinkPredictLearnableDecoder + Task decoder. + loss_func: func + Loss function. + """ + assert task_id not in self._task_pool, \ + f"Task {task_id} already exists" + logging.info("Setup task %s", task_id) + self._task_pool[task_id] = (task_type, loss_func) + self._decoder[task_id] = decoder + + @property + def alpha_l2norm(self): + """Get parameter norm params + """ + return self._alpha_l2norm + + @property + def task_pool(self): + """ Get task pool + """ + return self._task_pool + + @property + def task_decoders(self): + """ Get task decoders + """ + return self._decoder + + # pylint: disable=unused-argument + def forward(self, task_id, mini_batch): + """ The forward function for multi-task learning + """ + assert task_id in self.task_pool, \ + f"Unknown task: {task_id} in multi-task learning." \ + f"Existing tasks are {self.task_pool.keys()}" + + encoder_data, decoder_data = mini_batch + # message passing graph, node features, edge features, seed nodes + blocks, node_feats, _, input_nodes = encoder_data + if blocks is None or len(blocks) == 0: + # no GNN message passing + encode_embs = self.comput_input_embed(input_nodes, node_feats) + else: + # GNN message passing + encode_embs = self.compute_embed_step(blocks, node_feats, input_nodes) + + # Call emb normalization. + encode_embs = self.normalize_node_embs(encode_embs) + + task_type, loss_func = self.task_pool[task_id] + task_decoder = self.decoder[task_id] + + if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + labels = decoder_data + assert len(labels) == 1, \ + "In multi-task learning, only support do prediction " \ + "on one node type for a single node task." + pred_loss = 0 + target_ntype = list(labels.keys())[0] + + assert target_ntype in encode_embs, f"Node type {target_ntype} not in encode_embs" + assert target_ntype in labels, f"Node type {target_ntype} not in labels" + emb = encode_embs[target_ntype] + ntype_labels = labels[target_ntype] + ntype_logits = task_decoder(emb) + pred_loss = loss_func(ntype_logits, ntype_labels) + + return pred_loss + elif task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + batch_graph, target_edge_feats, labels = decoder_data + assert len(labels) == 1, \ + "In multi-task learning, only support do prediction " \ + "on one edge type for a single edge task." + pred_loss = 0 + target_etype = list(labels.keys())[0] + logits = task_decoder(batch_graph, encode_embs, target_edge_feats) + pred_loss = loss_func(logits, labels[target_etype]) + + return pred_loss + elif task_type == BUILTIN_TASK_LINK_PREDICTION: + pos_graph, neg_graph, pos_edge_feats, neg_edge_feats = decoder_data + + pos_score = task_decoder(pos_graph, encode_embs, pos_edge_feats) + neg_score = task_decoder(neg_graph, encode_embs, neg_edge_feats) + assert pos_score.keys() == neg_score.keys(), \ + "Positive scores and Negative scores must have edges of same" \ + f"edge types, but get {pos_score.keys()} and {neg_score.keys()}" + pred_loss = loss_func(pos_score, neg_score) + return pred_loss + else: + raise TypeError("Unknow task type %s", task_type) + + + def predict(self, task_id, mini_batch, return_proba=False): + """ The forward function for multi-task inference + """ + assert task_id in self.task_pool, \ + f"Unknown task: {task_id} in multi-task learning." \ + f"Existing tasks are {self.task_pool.keys()}" + + encoder_data, decoder_data = mini_batch + # message passing graph, node features, edge features, seed nodes + blocks, node_feats, _, input_nodes = encoder_data + if blocks is None or len(blocks) == 0: + # no GNN message passing + encode_embs = self.comput_input_embed(input_nodes, node_feats) + else: + # GNN message passing + encode_embs = self.compute_embed_step(blocks, node_feats, input_nodes) + + # Call emb normalization. + encode_embs = self.normalize_node_embs(encode_embs) + + task_type, _ = self.task_pool[task_id] + task_decoder = self.decoder[task_id] + + if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + assert len(encode_embs) == 1, \ + "In multi-task learning, only support do prediction " \ + "on one node type for a single node task." + target_ntype = list(encode_embs.keys())[0] + predicts = {} + if return_proba: + predicts[target_ntype] = task_decoder.predict_proba(encode_embs[target_ntype]) + else: + predicts[target_ntype] = task_decoder.predict(encode_embs[target_ntype]) + return predicts + elif task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + batch_graph, target_edge_feats, _ = decoder_data + if return_proba: + return task_decoder.predict_proba(batch_graph, encode_embs, target_edge_feats) + return task_decoder.predict(batch_graph, encode_embs, target_edge_feats) + elif task_type == BUILTIN_TASK_LINK_PREDICTION: + logging.warning("Prediction for link prediction is not implemented") + return None + else: + raise TypeError("Unknow task type %s", task_type) + +def multi_task_mini_batch_predict( + model, emb, loader, device, return_proba=True, return_label=False): + """ conduct mini batch prediction on multiple tasks + + Parameters + ---------- + model: GSgnnMultiTaskModelInterface, GSgnnModel + Multi-task learning model + emb : dict of Tensor + The GNN embeddings + loader: GSgnnMultiTaskDataLoader + The mini-batch dataloader. + device: th.device + Device used to compute test scores. + return_proba: bool + Whether to return all the predictions or the maximum prediction. + return_label : bool + Whether or not to return labels. + + Returns + ------- + dict: prediction results of each task + """ + dataloaders = loader.dataloaders + task_infos = loader.task_infos + task_decoders = model.task_decoders + res = {} + with th.no_grad(): + for dataloader, task_info in zip(dataloaders, task_infos): + if task_info.task_type in \ + [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + if dataloader is None: + # In cases when there is no validation or test set. + # set pred and labels to None + res[task_info.task_id] = (None, None) + else: + decoder = task_decoders[task_info.task_id] + preds, labels = \ + run_node_mini_batch_predict(decoder, + emb, + dataloader, + device, + return_proba, + return_label) + assert labels is None or len(labels) == 1, \ + "In multi-task learning, for each training task, " \ + "we only support prediction on one node type." \ + "For multiple node types, please treat them as " \ + "different training tasks." + ntype = list(preds.keys())[0] + res[task_info.task_id] = (preds[ntype], labels[ntype] if labels is not None else None) + elif task_info.task_type in \ + [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + if dataloader is None: + # In cases when there is no validation or test set. + # set pred and labels to None + res[task_info.task_id] = (None, None) + else: + decoder = task_decoders[task_info.task_id] + preds, labels = \ + run_edge_mini_batch_predict(decoder, + emb, + dataloader, + device, + return_proba, + return_label) + assert labels is None or len(labels) == 1, \ + "In multi-task learning, for each training task, " \ + "we only support prediction on one edge type." \ + "For multiple edge types, please treat them as " \ + "different training tasks." + etype = list(preds.keys())[0] + res[task_info.task_id] = (preds[etype], labels[etype] if labels is not None else None) + elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: + if dataloader is None: + # In cases when there is no validation or test set. + res[task_info.task_id] = None + else: + decoder = task_decoders[task_info.task_id] + ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device) + res[task_info.task_id] = ranking + else: + raise TypeError("Unknown task %s", task_info) + + return res + From 4a3ec01a9f1bc3292512660e1b704e73b339bf77 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 23 May 2024 15:56:43 -0700 Subject: [PATCH 53/79] Add unitests --- tests/unit-tests/test_gnn.py | 574 ++++++++++++++++++++++++++++++++++- 1 file changed, 572 insertions(+), 2 deletions(-) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 4ac8a479bc..526f9c2f01 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -32,8 +32,13 @@ import dgl -from graphstorm.config import GSConfig +from graphstorm.config import GSConfig, TaskInfo from graphstorm.config import BUILTIN_LP_DOT_DECODER +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) from graphstorm.model import GSNodeEncoderInputLayer, RelationalGCNEncoder from graphstorm.model import GSgnnNodeModel, GSgnnEdgeModel from graphstorm.model import GSLMNodeEncoderInputLayer, GSPureLMNodeInputLayer @@ -53,7 +58,7 @@ LinkPredictWeightedDistMultDecoder) from graphstorm.model.node_decoder import EntityRegression, EntityClassifier from graphstorm.dataloading import GSgnnData -from graphstorm.dataloading import GSgnnNodeDataLoader, GSgnnEdgeDataLoader +from graphstorm.dataloading import GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnMultiTaskDataLoader from graphstorm.dataloading.dataset import prepare_batch_input from graphstorm import create_builtin_edge_gnn_model, create_builtin_node_gnn_model from graphstorm import create_builtin_lp_gnn_model @@ -67,8 +72,13 @@ from graphstorm.model.edge_gnn import (edge_mini_batch_predict, run_edge_mini_batch_predict, edge_mini_batch_gnn_predict) +from graphstorm.model.multitask_gnn import multi_task_mini_batch_predict from graphstorm.model.gnn_with_reconstruct import construct_node_feat, get_input_embeds_combined from graphstorm.model.utils import load_model, save_model +from graphstorm.model import GSgnnMultiTaskSharedEncoderModel +from graphstorm.dataloading import (GSgnnEdgeDataLoaderBase, + GSgnnLinkPredictionDataLoaderBase, + GSgnnNodeDataLoaderBase) from data_utils import generate_dummy_dist_graph, generate_dummy_dist_graph_multi_target_ntypes from data_utils import generate_dummy_dist_graph_reconstruct @@ -1714,7 +1724,567 @@ def check_predict(mock_get_labels, return_dict): th.distributed.destroy_process_group() +class DummyNCDecoder(nn.Module): + + def forward(self, inputs): + return inputs + + def predict(self, inputs): + return inputs + + def predict_proba(self, inputs): + return inputs * 2 + +class DummyNRDecoder(nn.Module): + + def forward(self, inputs): + return inputs + + def predict(self, inputs): + return inputs + + def predict_proba(self, inputs): + return inputs * 2 + +class DummyECDecoder(nn.Module): + + def forward(self, g, h, e_h): + return h["n0"] + + def predict(self, g, h, e_h): + return h["n0"] + + def predict_proba(self, g, h, e_h): + return h["n0"] * 2 + +class DummyERDecoder(nn.Module): + + def forward(self, g, h, e_h): + return h["n0"] * 2 + + def predict(self, g, h, e_h): + return h["n0"] + + def predict_proba(self, g, h, e_h): + return h["n0"] * 2 + +class DummyLPDecoder(nn.Module): + + def forward(self, g, h, e_h=None): + return h + + +def test_multi_task_forward(): + mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) + + def pred_los_func(logits, labels): + return logits - labels + def pred_lp_loss_func(pos_score, neg_score): + return pos_score["n0"] + neg_score["n0"] + mt_model.add_task("nc_task", + BUILTIN_TASK_NODE_CLASSIFICATION, + DummyNCDecoder(), + pred_los_func) + + mt_model.add_task("nr_task", + BUILTIN_TASK_NODE_REGRESSION, + DummyNRDecoder(), + pred_los_func) + + mt_model.add_task("ec_task", + BUILTIN_TASK_EDGE_CLASSIFICATION, + DummyECDecoder(), + pred_los_func) + + mt_model.add_task("er_task", + BUILTIN_TASK_EDGE_REGRESSION, + DummyERDecoder(), + pred_los_func) + + mt_model.add_task("lp_task", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + pred_lp_loss_func) + + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'comput_input_embed') + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'compute_embed_step') + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'normalize_node_embs') + def check_forward(mock_normalize_node_embs, + mock_compute_emb, + mock_input_embed): + + def normalize_size_effect_func(embs): + return embs + + def compute_side_effect_func(blocks, node_feats, input_nodes): + return input_nodes + + def input_embed_side_effect_func(input_nodes, node_feats): + return input_nodes + + mock_normalize_node_embs.side_effect = normalize_size_effect_func + mock_compute_emb.side_effect = compute_side_effect_func + mock_input_embed.side_effect = input_embed_side_effect_func + + ### blocks is None (no GNN setting) + # NC task + task_id = "nc_task" + blocks = None + input_nodes = {"n0": th.randint(5, (10,))} + labels = {"n0": th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) + + # NR task + task_id = "nr_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + labels = {"n0": th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) + + # EC task + task_id = "ec_task" + blocks = None + input_nodes = {"n0": th.randint(5, (10,))} + labels = {("n0", "r1", "n1"): th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels[("n0", "r1", "n1")]).numpy()) + + # ER task + task_id = "er_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + labels = {("n0", "r1", "n1"): th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]*2-labels[("n0", "r1", "n1")]).numpy()) + + # LP task + task_id = "lp_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]*2).numpy()) + + ### blocks is a list (GNN setting) + # NC task + task_id = "nc_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.randint(5, (10,))} + labels = {"n0": th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) + + # NR task + task_id = "nr_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + labels = {"n0": th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) + + # EC task + task_id = "ec_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.randint(5, (10,))} + labels = {("n0", "r1", "n1"): th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]-labels[("n0", "r1", "n1")]).numpy()) + + # ER task + task_id = "er_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + labels = {("n0", "r1", "n1"): th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]*2-labels[("n0", "r1", "n1")]).numpy()) + + # LP task + task_id = "lp_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) + loss = mt_model(task_id, mini_batch) + assert_equal(loss.numpy(), (input_nodes["n0"]*2).numpy()) + + + check_forward() + +def test_multi_task_predict(): + mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) + + def pred_los_func(logits, labels): + return logits - labels + def pred_lp_loss_func(pos_score, neg_score): + return pos_score["n0"] + neg_score["n0"] + mt_model.add_task("nc_task", + BUILTIN_TASK_NODE_CLASSIFICATION, + DummyNCDecoder(), + pred_los_func) + + mt_model.add_task("nr_task", + BUILTIN_TASK_NODE_REGRESSION, + DummyNRDecoder(), + pred_los_func) + + mt_model.add_task("ec_task", + BUILTIN_TASK_EDGE_CLASSIFICATION, + DummyECDecoder(), + pred_los_func) + + mt_model.add_task("er_task", + BUILTIN_TASK_EDGE_REGRESSION, + DummyERDecoder(), + pred_los_func) + + mt_model.add_task("lp_task", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + pred_lp_loss_func) + + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'comput_input_embed') + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'compute_embed_step') + @patch.object(GSgnnMultiTaskSharedEncoderModel, 'normalize_node_embs') + def check_forward(mock_normalize_node_embs, + mock_compute_emb, + mock_input_embed): + + def normalize_size_effect_func(embs): + return embs + + def compute_side_effect_func(blocks, node_feats, input_nodes): + return input_nodes + + def input_embed_side_effect_func(input_nodes, node_feats): + return input_nodes + + mock_normalize_node_embs.side_effect = normalize_size_effect_func + mock_compute_emb.side_effect = compute_side_effect_func + mock_input_embed.side_effect = input_embed_side_effect_func + + ### blocks is None (no GNN setting) + # NC task + task_id = "nc_task" + blocks = None + input_nodes = {"n0": th.randint(5, (10,))} + labels = {"n0": th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + pred = mt_model.predict(task_id, mini_batch) + assert_equal(pred["n0"].numpy(), (input_nodes["n0"]).numpy()) + + # NR task + task_id = "nr_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + labels = {"n0": th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + pred = mt_model.predict(task_id, mini_batch) + assert_equal(pred["n0"].numpy(), (input_nodes["n0"]).numpy()) + + # EC task + task_id = "ec_task" + blocks = None + input_nodes = {"n0": th.randint(5, (10,))} + labels = {("n0", "r1", "n1"): th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + pred = mt_model.predict(task_id, mini_batch) + assert_equal(pred.numpy(), (input_nodes["n0"]).numpy()) + + # ER task + task_id = "er_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + labels = {("n0", "r1", "n1"): th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + pred = mt_model.predict(task_id, mini_batch) + assert_equal(pred.numpy(), (input_nodes["n0"]).numpy()) + + # LP task + task_id = "lp_task" + blocks = None + input_nodes = {"n0": th.rand((10,))} + mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) + pred = mt_model.predict(task_id, mini_batch) + assert pred is None + + ### blocks is a list (GNN setting) and call return_proba=True + # NC task + task_id = "nc_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.randint(5, (10,))} + labels = {"n0": th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert_equal(pred["n0"].numpy(), (input_nodes["n0"]*2).numpy()) + + # NR task + task_id = "nr_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + labels = {"n0": th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), labels) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert_equal(pred["n0"].numpy(), (input_nodes["n0"]*2).numpy()) + + # EC task + task_id = "ec_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.randint(5, (10,))} + labels = {("n0", "r1", "n1"): th.randint(5, (10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert_equal(pred.numpy(), (input_nodes["n0"]*2).numpy()) + + # ER task + task_id = "er_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + labels = {("n0", "r1", "n1"): th.rand((10,))} + mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert_equal(pred.numpy(), (input_nodes["n0"]*2).numpy()) + + # LP task + task_id = "lp_task" + blocks = [None, None] # trick mt_model there are two gnn layers. + input_nodes = {"n0": th.rand((10,))} + mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) + pred = mt_model.predict(task_id, mini_batch, return_proba=True) + assert pred is None + + check_forward() + +class DummyGSgnnNodeDataLoader(GSgnnNodeDataLoaderBase): + def __init__(self): + pass # do nothing + + def __len__(self): + return 10 + + def __iter__(self): + return self + +class DummyGSgnnEdgeDataLoader(GSgnnEdgeDataLoaderBase): + def __init__(self): + pass # do nothing + + def __len__(self): + return 10 + + def __iter__(self): + return self + +class DummyGSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase): + def __init__(self): + pass # do nothing + + def __len__(self): + return 10 + + def __iter__(self): + return self + +def test_multi_task_mini_batch_predict(): + mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) + + def pred_los_func(logits, labels): + return logits - labels + def pred_lp_loss_func(pos_score, neg_score): + return pos_score["n0"] + neg_score["n0"] + mt_model.add_task("nc_task", + BUILTIN_TASK_NODE_CLASSIFICATION, + DummyNCDecoder(), + pred_los_func) + + mt_model.add_task("nr_task", + BUILTIN_TASK_NODE_REGRESSION, + DummyNRDecoder(), + pred_los_func) + + mt_model.add_task("ec_task", + BUILTIN_TASK_EDGE_CLASSIFICATION, + DummyECDecoder(), + pred_los_func) + + mt_model.add_task("er_task", + BUILTIN_TASK_EDGE_REGRESSION, + DummyERDecoder(), + pred_los_func) + + mt_model.add_task("lp_task", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + pred_lp_loss_func) + + tast_info_nc = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_id='nc_task', + task_config=None) + nc_dataloader = DummyGSgnnNodeDataLoader() + tast_info_nr = TaskInfo(task_type=BUILTIN_TASK_NODE_REGRESSION, + task_id='nr_task', + task_config=None) + nr_dataloader = DummyGSgnnNodeDataLoader() + tast_info_ec = TaskInfo(task_type=BUILTIN_TASK_EDGE_CLASSIFICATION, + task_id='ec_task', + task_config=None) + ec_dataloader = DummyGSgnnEdgeDataLoader() + tast_info_er = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, + task_id='er_task', + task_config=None) + er_dataloader = DummyGSgnnEdgeDataLoader() + tast_info_lp = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, + task_id='lp_task', + task_config=None) + lp_dataloader = DummyGSgnnLinkPredictionDataLoader() + task_infos = [tast_info_nc, tast_info_nr, tast_info_ec, tast_info_er, tast_info_lp] + dataloaders = [nc_dataloader, nr_dataloader, ec_dataloader, er_dataloader, lp_dataloader] + + node_pred = {"n0": th.arange(10)} + node_prob = {"n0": th.arange(10)/10} + node_label = {"n0": th.arange(10)} + edge_pred = {("n0", "r0", "n1"): th.arange(5)} + edge_prob = {("n0", "r0", "n1"): th.arange(5)/10} + edge_label = {("n0", "r0", "n1"): th.arange(5)} + lp_pred = {("n0", "r0", "n1"): th.arange(5)/10, + ("n0", "r0", "n2"): th.arange(5)/20} + + def run_node_mini_batch_predict_side_func(decoder, emb, loader, device, return_prob, return_label): + pred = node_pred + label = None + if return_prob: + pred = node_prob + if return_label: + label = node_label + + return pred, label + + def run_edge_mini_batch_predict_side_func(decoder, emb, loader, device, return_prob, return_label): + pred = edge_pred + label = None + if return_prob: + pred = edge_prob + if return_label: + label = edge_label + + return pred, label + + def run_lpmini_batch_predict_side_func(decoder, emb, loader, device): + return lp_pred + + @patch("graphstorm.model.multitask_gnn.run_node_mini_batch_predict", side_effect = run_node_mini_batch_predict_side_func) + @patch("graphstorm.model.multitask_gnn.run_edge_mini_batch_predict", side_effect = run_edge_mini_batch_predict_side_func) + @patch("graphstorm.model.multitask_gnn.run_lp_mini_batch_predict", side_effect = run_lpmini_batch_predict_side_func) + def check_forward(mock_run_lp_mini_batch_predict, + mock_run_edge_mini_batch_predict, + mock_run_node_mini_batch_predict): + + mt_dataloader = GSgnnMultiTaskDataLoader(None, task_infos, dataloaders) + res = multi_task_mini_batch_predict(mt_model, + None, + mt_dataloader, + device=th.device('cpu'), + return_proba=False, + return_label=False) + assert len(res["nc_task"]) == 2 + assert_equal(res["nc_task"][0].numpy(), node_pred["n0"].numpy()) + assert res["nc_task"][1] is None + assert len(res["nr_task"]) == 2 + assert_equal(res["nr_task"][0].numpy(), node_pred["n0"].numpy()) + assert res["nr_task"][1] is None + assert len(res["ec_task"]) == 2 + assert_equal(res["ec_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert res["ec_task"][1] is None + assert len(res["er_task"]) == 2 + assert_equal(res["er_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert res["er_task"][1] is None + assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy()) + + res = multi_task_mini_batch_predict(mt_model, + None, + mt_dataloader, + device=th.device('cpu'), + return_proba=True, + return_label=False) + assert len(res["nc_task"]) == 2 + assert_equal(res["nc_task"][0].numpy(), node_prob["n0"].numpy()) + assert res["nc_task"][1] is None + assert len(res["nr_task"]) == 2 + assert_equal(res["nr_task"][0].numpy(), node_prob["n0"].numpy()) + assert res["nr_task"][1] is None + assert len(res["ec_task"]) == 2 + assert_equal(res["ec_task"][0].numpy(), edge_prob[("n0", "r0", "n1")].numpy()) + assert res["ec_task"][1] is None + assert len(res["er_task"]) == 2 + assert_equal(res["er_task"][0].numpy(), edge_prob[("n0", "r0", "n1")].numpy()) + assert res["er_task"][1] is None + assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy()) + + res = multi_task_mini_batch_predict(mt_model, + None, + mt_dataloader, + device=th.device('cpu'), + return_proba=False, + return_label=True) + assert len(res["nc_task"]) == 2 + assert_equal(res["nc_task"][0].numpy(), node_pred["n0"].numpy()) + assert_equal(res["nc_task"][1].numpy(), node_label["n0"].numpy()) + assert len(res["nr_task"]) == 2 + assert_equal(res["nr_task"][0].numpy(), node_pred["n0"].numpy()) + assert_equal(res["nr_task"][1].numpy(), node_label["n0"].numpy()) + assert len(res["ec_task"]) == 2 + assert_equal(res["ec_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["ec_task"][0].numpy(), edge_label[("n0", "r0", "n1")].numpy()) + assert len(res["er_task"]) == 2 + assert_equal(res["er_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["ec_task"][0].numpy(), edge_label[("n0", "r0", "n1")].numpy()) + assert_equal(res["lp_task"][("n0", "r0", "n1")].numpy(), lp_pred[("n0", "r0", "n1")].numpy()) + assert_equal(res["lp_task"][("n0", "r0", "n2")].numpy(), lp_pred[("n0", "r0", "n2")].numpy()) + + + new_dataloaders = [nc_dataloader, None, ec_dataloader, None, None] + mt_dataloader = GSgnnMultiTaskDataLoader(None, task_infos, new_dataloaders) + + res = multi_task_mini_batch_predict(mt_model, + None, + mt_dataloader, + device=th.device('cpu'), + return_proba=False, + return_label=False) + assert len(res["nc_task"]) == 2 + assert_equal(res["nc_task"][0].numpy(), node_pred["n0"].numpy()) + assert res["nc_task"][1] is None + assert len(res["nr_task"]) == 2 + assert res["nr_task"][0] is None + assert res["nr_task"][1] is None + assert len(res["ec_task"]) == 2 + assert_equal(res["ec_task"][0].numpy(), edge_pred[("n0", "r0", "n1")].numpy()) + assert res["ec_task"][1] is None + assert len(res["er_task"]) == 2 + assert res["er_task"][0] is None + assert res["er_task"][1] is None + assert res["lp_task"] is None + + + + check_forward() + + if __name__ == '__main__': + test_multi_task_forward() + test_multi_task_predict() + test_multi_task_mini_batch_predict() + test_lm_rgcn_node_prediction_with_reconstruct() test_rgcn_node_prediction_with_reconstruct(True) test_rgcn_node_prediction_with_reconstruct(False) From 46125cadfb6ab1849d7cbbb870f8b028028a5669 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 23 May 2024 15:57:51 -0700 Subject: [PATCH 54/79] Update --- python/graphstorm/model/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 18a741e200..34cce011ce 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -35,6 +35,9 @@ GSgnnLinkPredictionModelBase, GSgnnLinkPredictionModelInterface, run_lp_mini_batch_predict) +from .multitask_gnn import (GSgnnMultiTaskModelInterface, + GSgnnMultiTaskSharedEncoderModel) +from .multitask_gnn import multi_task_mini_batch_predict from .rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer from .rgat_encoder import RelationalGATEncoder, RelationalAttLayer from .sage_encoder import SAGEEncoder, SAGEConv From 24039306bc07e8ac6e5f8716a1d068f1e9da61da Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 23 May 2024 15:59:52 -0700 Subject: [PATCH 55/79] update dataloader --- python/graphstorm/dataloading/dataloading.py | 36 +++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index c9d45b28da..6a123caa67 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -1706,12 +1706,15 @@ def __init__(self, dataset, task_infos, task_dataloaders): # check dataloaders lens = [] for task_info, dataloader in zip(task_infos, task_dataloaders): + # For evaluation and testing, we allow some of the val_dataloaders or test_dataloaders + # are empty (None). assert isinstance(dataloader, (GSgnnEdgeDataLoaderBase, GSgnnLinkPredictionDataLoaderBase, - GSgnnNodeDataLoaderBase)), \ + GSgnnNodeDataLoaderBase)) or dataloader is None, \ "The task data loader should be an instance of GSgnnEdgeDataLoaderBase, " \ - "GSgnnLinkPredictionDataLoaderBase or GSgnnNodeDataLoaderBase" - num_iters = len(dataloader) + "GSgnnLinkPredictionDataLoaderBase or GSgnnNodeDataLoaderBase" \ + f"But get {type(dataloader)}" + num_iters = len(dataloader) if dataloader is not None else 0 lens.append(num_iters) logging.debug("Task %s has number of iterations of %d", task_info, num_iters) @@ -1728,7 +1731,8 @@ def _reset_loader(self): """ reset the dataloaders """ for dataloader in self._dataloaders: - iter(dataloader) + if dataloader is not None: + iter(dataloader) self._num_iters = 0 def __iter__(self): @@ -1747,6 +1751,19 @@ def __next__(self): # call __next__ of each dataloader mini_batches = [] for task_info, dataloader in zip(self._task_infos, self._dataloaders): + if dataloader is None: + # The dataloader is None + logging.warning("The dataloader of %s is None. " + "Please check whether the coresponding " + "train/val/test mask(s) are missing." + "If you are calling iter(mt_dataloader) for validation " + "or testing, we suggest you to use " + "mt_dataloader.dataloaders to get task specific " + "dataloaders and call the corresponding evaluators " + "task by task", task_info.task_id) + mini_batches.append((task_info, None)) + continue + try: mini_batch = next(dataloader) except StopIteration: @@ -1789,6 +1806,17 @@ def task_infos(self): # useful for conducting validation scores and test scores. return self._task_infos + @property + def fanout(self): + """ The fanout of each GNN layers of each dataloader + + Returns + ------- + list or a dict of list : the fanouts for each GNN layer. + """ + fanouts = [dataloader.fanout if dataloader is not None else None for dataloader in self.dataloaders] + return fanouts + ####################### Distillation ############################# From caa851af33470bff6402f753a99823e0e763d848 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 23 May 2024 16:35:27 -0700 Subject: [PATCH 56/79] Fix lint --- python/graphstorm/dataloading/dataloading.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index 6a123caa67..b035e5df43 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -1814,7 +1814,8 @@ def fanout(self): ------- list or a dict of list : the fanouts for each GNN layer. """ - fanouts = [dataloader.fanout if dataloader is not None else None for dataloader in self.dataloaders] + fanouts = [dataloader.fanout if dataloader is not None \ + else None for dataloader in self.dataloaders] return fanouts From e3e33f85aa5f88e06590cf68cb718af8194681d2 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 23 May 2024 17:57:39 -0700 Subject: [PATCH 57/79] Fix lint --- python/graphstorm/model/multitask_gnn.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index f762886e3a..d84f5f7717 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -51,7 +51,7 @@ def forward(self, task_id, mini_batch): task_id: str ID of the task. mini_batch: tuple - Mini-batch info + Mini-batch info. Return @@ -60,7 +60,7 @@ def forward(self, task_id, mini_batch): """ @abc.abstractmethod - def predict(self, task_info, mini_batch): + def predict(self, task_id, mini_batch): """ The forward function for multi-task prediction. This method is used for inference, It runs model forword @@ -69,10 +69,10 @@ def predict(self, task_info, mini_batch): Parameters ---------- - task_info: TaskInfo - task meta information + task_id: str + Task ID. mini_batch: tuple - mini-batch info + Mini-batch info. Returns ------- @@ -199,8 +199,7 @@ def forward(self, task_id, mini_batch): pred_loss = loss_func(pos_score, neg_score) return pred_loss else: - raise TypeError("Unknow task type %s", task_type) - + raise TypeError(f"Unknow task type {task_type}") def predict(self, task_id, mini_batch, return_proba=False): """ The forward function for multi-task inference @@ -245,7 +244,7 @@ def predict(self, task_id, mini_batch, return_proba=False): logging.warning("Prediction for link prediction is not implemented") return None else: - raise TypeError("Unknow task type %s", task_type) + raise TypeError(f"Unknow task type {task_type}") def multi_task_mini_batch_predict( model, emb, loader, device, return_proba=True, return_label=False): @@ -297,7 +296,8 @@ def multi_task_mini_batch_predict( "For multiple node types, please treat them as " \ "different training tasks." ntype = list(preds.keys())[0] - res[task_info.task_id] = (preds[ntype], labels[ntype] if labels is not None else None) + res[task_info.task_id] = (preds[ntype], labels[ntype] \ + if labels is not None else None) elif task_info.task_type in \ [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: if dataloader is None: @@ -319,7 +319,8 @@ def multi_task_mini_batch_predict( "For multiple edge types, please treat them as " \ "different training tasks." etype = list(preds.keys())[0] - res[task_info.task_id] = (preds[etype], labels[etype] if labels is not None else None) + res[task_info.task_id] = (preds[etype], labels[etype] \ + if labels is not None else None) elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: if dataloader is None: # In cases when there is no validation or test set. @@ -329,7 +330,6 @@ def multi_task_mini_batch_predict( ranking = run_lp_mini_batch_predict(decoder, emb, dataloader, device) res[task_info.task_id] = ranking else: - raise TypeError("Unknown task %s", task_info) + raise TypeError(f"Unknown task {task_info}") return res - From 36f4c60ade900e26eb7d893262e82bbcf04a87d5 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Thu, 23 May 2024 22:57:36 -0700 Subject: [PATCH 58/79] update --- tests/end2end-tests/data_process/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/end2end-tests/data_process/test.sh b/tests/end2end-tests/data_process/test.sh index df3c010c4e..1bc9caeb7d 100644 --- a/tests/end2end-tests/data_process/test.sh +++ b/tests/end2end-tests/data_process/test.sh @@ -121,7 +121,7 @@ python3 $GS_HOME/tests/end2end-tests/data_process/test_multitask_data.py --graph error_and_exit $? -echo "********* Test the DistDGL graph format with multi mask support from saved config g********" +echo "********* Test the DistDGL graph format with multi mask support from saved config ********" python3 -m graphstorm.gconstruct.construct_graph --conf-file /tmp/multitask_test_data/test_multitask_data_transform_new.conf --num-processes 2 --output-dir /tmp/test_partition2 --graph-name test --add-reverse-edges error_and_exit $? From 4a0a7dfe67addd56c4eb5209ab393eaca2f3e742 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 24 May 2024 00:32:57 -0700 Subject: [PATCH 59/79] Fix DDP bug --- python/graphstorm/model/multitask_gnn.py | 52 +++++++++++- python/graphstorm/trainer/mt_trainer.py | 84 ++++++++----------- .../end2end-tests/graphstorm-mt/mgpu_test.sh | 7 +- 3 files changed, 90 insertions(+), 53 deletions(-) diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index d84f5f7717..df03c121ee 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -137,15 +137,63 @@ def task_decoders(self): """ return self._decoder + def _run_mini_batch(self, task_info, mini_batch): + """ Run mini_batch forward + """ + if task_info.task_type in \ + [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + # Order follow GSgnnNodeModelInterface.forward + blocks, input_feats, edge_feats, lbl,input_nodes = mini_batch + loss = self._forward(task_info.task_id, + (blocks, input_feats, edge_feats, input_nodes), + lbl) + + elif task_info.task_type in \ + [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + # Order follow GSgnnEdgeModelInterface.forward + blocks, target_edges, node_feats, edge_feats, \ + edge_decoder_feats, lbl, input_nodes = mini_batch + loss = self._forward(task_info.task_id, + (blocks, node_feats, None, input_nodes), + (target_edges, edge_decoder_feats, lbl)) + + elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: + # Order follow GSgnnLinkPredictionModelInterface.forward + blocks, pos_graph, neg_graph, node_feats, edge_feats, \ + pos_edge_feats, neg_edge_feats, input_nodes = mini_batch + + loss = self._forward(task_info.task_id, + (blocks, node_feats, edge_feats, input_nodes), + (pos_graph, neg_graph, pos_edge_feats, neg_edge_feats)) + else: + raise TypeError("Unknown task %s", task_info) + + return loss, task_info.task_config.task_weight + + def forward(self, task_mini_batches): + losses = [] + for (task_info, mini_batch) in task_mini_batches: + loss, weight = self._run_mini_batch(task_info, mini_batch) + losses.append((loss, weight)) + + reg_loss = th.tensor(0.).to(losses[0][0].device) + for d_para in self.get_dense_params(): + reg_loss += d_para.square().sum() + alpha_l2norm = self.alpha_l2norm + + mt_loss = reg_loss * alpha_l2norm + for loss, weight in losses: + mt_loss += loss * weight + return mt_loss + # pylint: disable=unused-argument - def forward(self, task_id, mini_batch): + def _forward(self, task_id, encoder_data, decoder_data): """ The forward function for multi-task learning """ assert task_id in self.task_pool, \ f"Unknown task: {task_id} in multi-task learning." \ f"Existing tasks are {self.task_pool.keys()}" - encoder_data, decoder_data = mini_batch # message passing graph, node features, edge features, seed nodes blocks, node_feats, _, input_nodes = encoder_data if blocks is None or len(blocks) == 0: diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 8df2091e15..6cdba8ae6e 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -38,7 +38,7 @@ from ..utils import sys_tracker, rt_profiler, print_mem, get_rank from ..utils import barrier, is_distributed -def run_node_mini_batch(model, data, task_info, mini_batch, device): +def prepare_node_mini_batch(data, task_info, mini_batch, device): """ Run node mini_batch forward """ g = data.g @@ -52,19 +52,21 @@ def run_node_mini_batch(model, data, task_info, mini_batch, device): input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) lbl = data.get_node_feats(seeds, label_field, device) blocks = [block.to(device) for block in blocks] - # TODO: we don't support edge features for now. - loss = model(task_info.task_id, ((blocks, input_feats, None, input_nodes), lbl)) - return loss, task_info.task_config.task_weight + # Order follow GSgnnNodeModelInterface.forward + # TODO: we don't support edge features for now. + return (blocks, input_feats, None, lbl, input_nodes) -def run_edge_mini_batch(model, data, task_info, mini_batch, device): +def prepare_edge_mini_batch(data, task_info, mini_batch, device): + """ + """ input_nodes, batch_graph, blocks = mini_batch if not isinstance(input_nodes, dict): assert len(batch_graph.ntypes) == 1 input_nodes = {batch_graph.ntypes[0]: input_nodes} nfeat_fields = task_info.dataloader.node_feat_fields - input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) + node_feats = data.get_node_feats(input_nodes, nfeat_fields, device) if task_info.dataloader.decoder_edge_feat_fields is not None: input_edges = {etype: batch_graph.edges[etype].data[dgl.EID] \ @@ -90,13 +92,12 @@ def run_edge_mini_batch(model, data, task_info, mini_batch, device): batch_graph = batch_graph.to(device) rt_profiler.record('train_graph2GPU') + # Order follow GSgnnEdgeModelInterface.forward # TODO(zhengda) we don't support edge features for now. - loss = model(task_info.task_id, - ((blocks, input_feats, None, input_nodes), - (batch_graph, edge_decoder_feats, lbl))) - return loss, task_info.task_config.task_weight + return (blocks, batch_graph, node_feats, None, + edge_decoder_feats, lbl, input_nodes) -def run_link_predict_mini_batch(model, data, task_info, mini_batch, device): +def prepare_link_predict_mini_batch(data, task_info, mini_batch, device): input_nodes, pos_graph, neg_graph, blocks = mini_batch if not isinstance(input_nodes, dict): @@ -104,7 +105,7 @@ def run_link_predict_mini_batch(model, data, task_info, mini_batch, device): input_nodes = {pos_graph.ntypes[0]: input_nodes} nfeat_fields = task_info.dataloader.node_feat_fields - input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) + node_feats = data.get_node_feats(input_nodes, nfeat_fields, device) if task_info.dataloader.pos_graph_feat_fields is not None: input_edges = {etype: pos_graph.edges[etype].data[dgl.EID] \ @@ -119,12 +120,8 @@ def run_link_predict_mini_batch(model, data, task_info, mini_batch, device): neg_graph = neg_graph.to(device) blocks = [blk.to(device) for blk in blocks] - # TODO: we don't support edge features for now. - loss = model(task_info.task_id, - ((blocks, input_feats, None, input_nodes), - (pos_graph, neg_graph,pos_graph_feats, None))) - return loss, task_info.task_config.task_weight - + return (blocks, pos_graph, neg_graph, node_feats, None, \ + pos_graph_feats, None, input_nodes) class GSgnnMultiTaskLearningTrainer(GSgnnTrainer): r""" A trainer for multi-task learning @@ -149,8 +146,8 @@ def __init__(self, model, topk_model_to_save=1): assert isinstance(model, GSgnnMultiTaskModelInterface) and isinstance(model, GSgnnModelBase), \ "The input model is not a GSgnnModel model. Please implement GSgnnModelBase." - def _run_mini_batch(self, data, model, task_info, mini_batch, device): - """ run mini batch for a single task + def _prepare_mini_batch(self, data, task_info, mini_batch, device): + """ prepare mini batch for a single task Parameters ---------- @@ -167,28 +164,25 @@ def _run_mini_batch(self, data, model, task_info, mini_batch, device): Return ------ - loss + tuple: mini-batch """ if task_info.task_type in \ [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - return run_node_mini_batch(model, - data, - task_info, - mini_batch, - device) + return prepare_node_mini_batch(data, + task_info, + mini_batch, + device) elif task_info.task_type in \ [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - return run_edge_mini_batch(model, - data, - task_info, - mini_batch, - device) + return prepare_edge_mini_batch(data, + task_info, + mini_batch, + device) elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: - return run_link_predict_mini_batch(model, - data, - task_info, - mini_batch, - device) + return prepare_link_predict_mini_batch(data, + task_info, + mini_batch, + device) else: raise TypeError("Unknown task %s", task_info) @@ -285,19 +279,13 @@ def fit(self, train_loader, rt_profiler.record('train_sample') total_steps += 1 - losses = [] + mini_batches = [] for (task_info, mini_batch) in task_mini_batches: - loss, weight = self._run_mini_batch(data, model, task_info, mini_batch, device) - losses.append((loss, weight)) + mini_batches.append((task_info, \ + self._prepare_mini_batch(data, task_info, mini_batch, device))) - reg_loss = th.tensor(0.).to(device) - for d_para in model.module.get_dense_params(): - reg_loss += d_para.square().sum() - alpha_l2norm = model.module.alpha_l2norm + loss = model(mini_batches) - mt_loss = reg_loss * alpha_l2norm - for loss, weight in losses: - mt_loss += loss * weight rt_profiler.record('train_forward') self.optimizer.zero_grad() loss.backward() @@ -312,7 +300,7 @@ def fit(self, train_loader, if i % 20 == 0 and get_rank() == 0: rt_profiler.print_stats() logging.info("Epoch %05d | Batch %03d | Train Loss: %.4f | Time: %.4f", - epoch, i, mt_loss.item(), time.time() - batch_tic) + epoch, i, loss.item(), time.time() - batch_tic) val_score = None if self.evaluator is not None and \ @@ -358,8 +346,6 @@ def fit(self, train_loader, rt_profiler.print_stats() barrier() - - rt_profiler.save_profile() print_mem(device) if get_rank() == 0 and self.evaluator is not None: diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index dfbf2b18ee..d8a3cdcdd6 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -26,5 +26,8 @@ error_and_exit () { df /dev/shm -h -echo "**************dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, multi-task, save model" -python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True \ No newline at end of file +echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" +python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True + +echo "**************[Multi-task with learnable embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" +python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True \ No newline at end of file From 424b4b2caa6e3e279581fe2fceae742c5a564942 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 24 May 2024 00:40:25 -0700 Subject: [PATCH 60/79] Update --- tests/unit-tests/test_gnn.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 526f9c2f01..20e6129390 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -1832,8 +1832,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = None input_nodes = {"n0": th.randint(5, (10,))} labels = {"n0": th.randint(5, (10,))} - mini_batch = ((blocks, None, None, input_nodes), labels) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), labels) assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) # NR task @@ -1841,8 +1840,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = None input_nodes = {"n0": th.rand((10,))} labels = {"n0": th.rand((10,))} - mini_batch = ((blocks, None, None, input_nodes), labels) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), labels) assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) # EC task @@ -1850,8 +1848,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = None input_nodes = {"n0": th.randint(5, (10,))} labels = {("n0", "r1", "n1"): th.randint(5, (10,))} - mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, labels)) assert_equal(loss.numpy(), (input_nodes["n0"]-labels[("n0", "r1", "n1")]).numpy()) # ER task @@ -1859,16 +1856,14 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = None input_nodes = {"n0": th.rand((10,))} labels = {("n0", "r1", "n1"): th.rand((10,))} - mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, labels)) assert_equal(loss.numpy(), (input_nodes["n0"]*2-labels[("n0", "r1", "n1")]).numpy()) # LP task task_id = "lp_task" blocks = None input_nodes = {"n0": th.rand((10,))} - mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, None, None)) assert_equal(loss.numpy(), (input_nodes["n0"]*2).numpy()) ### blocks is a list (GNN setting) @@ -1877,8 +1872,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.randint(5, (10,))} labels = {"n0": th.randint(5, (10,))} - mini_batch = ((blocks, None, None, input_nodes), labels) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), labels) assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) # NR task @@ -1886,8 +1880,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.rand((10,))} labels = {"n0": th.rand((10,))} - mini_batch = ((blocks, None, None, input_nodes), labels) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), labels) assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) # EC task @@ -1895,8 +1888,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.randint(5, (10,))} labels = {("n0", "r1", "n1"): th.randint(5, (10,))} - mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, labels)) assert_equal(loss.numpy(), (input_nodes["n0"]-labels[("n0", "r1", "n1")]).numpy()) # ER task @@ -1904,16 +1896,14 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.rand((10,))} labels = {("n0", "r1", "n1"): th.rand((10,))} - mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, labels)) assert_equal(loss.numpy(), (input_nodes["n0"]*2-labels[("n0", "r1", "n1")]).numpy()) # LP task task_id = "lp_task" blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.rand((10,))} - mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, None, None)) assert_equal(loss.numpy(), (input_nodes["n0"]*2).numpy()) From f62e03571f5e853a4802e2168cd640e4335a4252 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 24 May 2024 00:40:48 -0700 Subject: [PATCH 61/79] update --- python/graphstorm/model/multitask_gnn.py | 52 +++++++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index d84f5f7717..df03c121ee 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -137,15 +137,63 @@ def task_decoders(self): """ return self._decoder + def _run_mini_batch(self, task_info, mini_batch): + """ Run mini_batch forward + """ + if task_info.task_type in \ + [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + # Order follow GSgnnNodeModelInterface.forward + blocks, input_feats, edge_feats, lbl,input_nodes = mini_batch + loss = self._forward(task_info.task_id, + (blocks, input_feats, edge_feats, input_nodes), + lbl) + + elif task_info.task_type in \ + [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + # Order follow GSgnnEdgeModelInterface.forward + blocks, target_edges, node_feats, edge_feats, \ + edge_decoder_feats, lbl, input_nodes = mini_batch + loss = self._forward(task_info.task_id, + (blocks, node_feats, None, input_nodes), + (target_edges, edge_decoder_feats, lbl)) + + elif task_info.task_type == BUILTIN_TASK_LINK_PREDICTION: + # Order follow GSgnnLinkPredictionModelInterface.forward + blocks, pos_graph, neg_graph, node_feats, edge_feats, \ + pos_edge_feats, neg_edge_feats, input_nodes = mini_batch + + loss = self._forward(task_info.task_id, + (blocks, node_feats, edge_feats, input_nodes), + (pos_graph, neg_graph, pos_edge_feats, neg_edge_feats)) + else: + raise TypeError("Unknown task %s", task_info) + + return loss, task_info.task_config.task_weight + + def forward(self, task_mini_batches): + losses = [] + for (task_info, mini_batch) in task_mini_batches: + loss, weight = self._run_mini_batch(task_info, mini_batch) + losses.append((loss, weight)) + + reg_loss = th.tensor(0.).to(losses[0][0].device) + for d_para in self.get_dense_params(): + reg_loss += d_para.square().sum() + alpha_l2norm = self.alpha_l2norm + + mt_loss = reg_loss * alpha_l2norm + for loss, weight in losses: + mt_loss += loss * weight + return mt_loss + # pylint: disable=unused-argument - def forward(self, task_id, mini_batch): + def _forward(self, task_id, encoder_data, decoder_data): """ The forward function for multi-task learning """ assert task_id in self.task_pool, \ f"Unknown task: {task_id} in multi-task learning." \ f"Existing tasks are {self.task_pool.keys()}" - encoder_data, decoder_data = mini_batch # message passing graph, node features, edge features, seed nodes blocks, node_feats, _, input_nodes = encoder_data if blocks is None or len(blocks) == 0: From d8dd0466976241115215891c7bcd54835d2f2256 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 24 May 2024 00:41:14 -0700 Subject: [PATCH 62/79] update test --- tests/unit-tests/test_gnn.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 526f9c2f01..20e6129390 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -1832,8 +1832,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = None input_nodes = {"n0": th.randint(5, (10,))} labels = {"n0": th.randint(5, (10,))} - mini_batch = ((blocks, None, None, input_nodes), labels) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), labels) assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) # NR task @@ -1841,8 +1840,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = None input_nodes = {"n0": th.rand((10,))} labels = {"n0": th.rand((10,))} - mini_batch = ((blocks, None, None, input_nodes), labels) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), labels) assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) # EC task @@ -1850,8 +1848,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = None input_nodes = {"n0": th.randint(5, (10,))} labels = {("n0", "r1", "n1"): th.randint(5, (10,))} - mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, labels)) assert_equal(loss.numpy(), (input_nodes["n0"]-labels[("n0", "r1", "n1")]).numpy()) # ER task @@ -1859,16 +1856,14 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = None input_nodes = {"n0": th.rand((10,))} labels = {("n0", "r1", "n1"): th.rand((10,))} - mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, labels)) assert_equal(loss.numpy(), (input_nodes["n0"]*2-labels[("n0", "r1", "n1")]).numpy()) # LP task task_id = "lp_task" blocks = None input_nodes = {"n0": th.rand((10,))} - mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, None, None)) assert_equal(loss.numpy(), (input_nodes["n0"]*2).numpy()) ### blocks is a list (GNN setting) @@ -1877,8 +1872,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.randint(5, (10,))} labels = {"n0": th.randint(5, (10,))} - mini_batch = ((blocks, None, None, input_nodes), labels) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), labels) assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) # NR task @@ -1886,8 +1880,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.rand((10,))} labels = {"n0": th.rand((10,))} - mini_batch = ((blocks, None, None, input_nodes), labels) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), labels) assert_equal(loss.numpy(), (input_nodes["n0"]-labels["n0"]).numpy()) # EC task @@ -1895,8 +1888,7 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.randint(5, (10,))} labels = {("n0", "r1", "n1"): th.randint(5, (10,))} - mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, labels)) assert_equal(loss.numpy(), (input_nodes["n0"]-labels[("n0", "r1", "n1")]).numpy()) # ER task @@ -1904,16 +1896,14 @@ def input_embed_side_effect_func(input_nodes, node_feats): blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.rand((10,))} labels = {("n0", "r1", "n1"): th.rand((10,))} - mini_batch = ((blocks, None, None, input_nodes), (None, None, labels)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, labels)) assert_equal(loss.numpy(), (input_nodes["n0"]*2-labels[("n0", "r1", "n1")]).numpy()) # LP task task_id = "lp_task" blocks = [None, None] # trick mt_model there are two gnn layers. input_nodes = {"n0": th.rand((10,))} - mini_batch = mini_batch = ((blocks, None, None, input_nodes), (None, None, None, None)) - loss = mt_model(task_id, mini_batch) + loss = mt_model._forward(task_id, (blocks, None, None, input_nodes), (None, None, None, None)) assert_equal(loss.numpy(), (input_nodes["n0"]*2).numpy()) From 1a5a16567a519f2c01a62f7ea5e6cd8ace7aa1e3 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 24 May 2024 01:39:19 -0700 Subject: [PATCH 63/79] Update --- python/graphstorm/trainer/mt_trainer.py | 13 +- tests/unit-tests/test_trainer.py | 160 ++++++++++++++++-------- 2 files changed, 115 insertions(+), 58 deletions(-) diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 6cdba8ae6e..bd220dd369 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -51,7 +51,8 @@ def prepare_node_mini_batch(data, task_info, mini_batch, device): label_field = task_info.dataloader.label_field input_feats = data.get_node_feats(input_nodes, nfeat_fields, device) lbl = data.get_node_feats(seeds, label_field, device) - blocks = [block.to(device) for block in blocks] + blocks = [block.to(device) for block in blocks] \ + if blocks is not None else None # Order follow GSgnnNodeModelInterface.forward # TODO: we don't support edge features for now. @@ -84,18 +85,19 @@ def prepare_edge_mini_batch(data, task_info, mini_batch, device): assert len(batch_graph.etypes) == 1 target_etype = batch_graph.canonical_etypes[0] # TODO(zhengda) the data loader should return labels directly. - seeds = batch_graph.edges[target_etype[1]].data[dgl.EID] + seeds = batch_graph.edges[target_etype].data[dgl.EID] label_field = task_info.dataloader.label_field lbl = data.get_edge_feats({target_etype: seeds}, label_field, device) - blocks = [block.to(device) for block in blocks] + blocks = [block.to(device) for block in blocks] \ + if blocks is not None else None batch_graph = batch_graph.to(device) rt_profiler.record('train_graph2GPU') # Order follow GSgnnEdgeModelInterface.forward # TODO(zhengda) we don't support edge features for now. return (blocks, batch_graph, node_feats, None, - edge_decoder_feats, lbl, input_nodes) + edge_decoder_feats, lbl, input_nodes) def prepare_link_predict_mini_batch(data, task_info, mini_batch, device): input_nodes, pos_graph, neg_graph, blocks = mini_batch @@ -118,7 +120,8 @@ def prepare_link_predict_mini_batch(data, task_info, mini_batch, device): pos_graph = pos_graph.to(device) neg_graph = neg_graph.to(device) - blocks = [blk.to(device) for blk in blocks] + blocks = [blk.to(device) for blk in blocks] \ + if blocks is not None else None return (blocks, pos_graph, neg_graph, node_feats, None, \ pos_graph_feats, None, input_nodes) diff --git a/tests/unit-tests/test_trainer.py b/tests/unit-tests/test_trainer.py index a84089b7a6..1ce6d02cd9 100644 --- a/tests/unit-tests/test_trainer.py +++ b/tests/unit-tests/test_trainer.py @@ -32,9 +32,9 @@ from graphstorm.trainer import GSgnnTrainer from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.utils import setup_device, get_device -from graphstorm.trainer.mt_trainer import (run_node_mini_batch, - run_edge_mini_batch, - run_link_predict_mini_batch) +from graphstorm.trainer.mt_trainer import (prepare_node_mini_batch, + prepare_edge_mini_batch, + prepare_link_predict_mini_batch) from graphstorm.dataloading import (GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnLinkPredictionDataLoader) @@ -195,7 +195,7 @@ def forward(self, task_id, mini_batch): def predict(self, task_id, mini_batch, return_proba=False): pass -def test_mtask_run_node_mini_batch(): +def test_mtask_prepare_node_mini_batch(): with tempfile.TemporaryDirectory() as tmpdirname: # get the test dummy distributed graph _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) @@ -209,32 +209,41 @@ def test_mtask_run_node_mini_batch(): dataloader = GSgnnNodeDataLoader(np_data, target_idx, [10], 10, label_field='label', node_feats='feat', - edge_feats='feat', train_task=False) task_config = GSConfig.__new__(GSConfig) - expected_loss = th.rand(np_data.g.number_of_nodes('n1')) setattr(task_config, "task_weight", 0.75) task_info = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, task_id=task_id, task_config=task_config, dataloader=dataloader) - node_feats = np_data.get_node_feats(target_idx, 'feat', device=device) - labels = np_data.get_node_feats(target_idx, 'label', device=device) + node_feats = np_data.get_node_feats(target_idx, 'feat') + labels = np_data.get_node_feats(target_idx, 'label') mini_batch = (target_idx, target_idx, None) - model = DummyGSgnnMultiTaskSharedEncoderModel(task_id=task_id, - task_type=BUILTIN_TASK_NODE_CLASSIFICATION, - input_nodes=target_idx, - labels=labels, - node_feast=node_feats, - expected_loss=expected_loss) - loss, weight = run_node_mini_batch(model, np_data, task_info, mini_batch, device) - assert assert_equal(loss.numpy(), expected_loss.numpy()) - assert weight == 0.75 - -def test_mtask_run_edge_mini_batch(): + + blocks, input_feats, _, lbl, input_nodes = \ + prepare_node_mini_batch(np_data, task_info, mini_batch, device) + assert blocks is None + assert_equal(input_nodes["n1"].numpy(), target_idx["n1"].numpy()) + assert_equal(input_feats["n1"].cpu().numpy(), node_feats["n1"].numpy()) + assert_equal(lbl["n1"].cpu().numpy(), labels["n1"].numpy()) + + dataloader = GSgnnNodeDataLoader(np_data, target_idx, [10], 10, + label_field='label', + train_task=False) + task_info = TaskInfo(task_type=BUILTIN_TASK_NODE_CLASSIFICATION, + task_id=task_id, + task_config=task_config, + dataloader=dataloader) + _, input_feats, _, lbl, input_nodes = \ + prepare_node_mini_batch(np_data, task_info, mini_batch, device) + assert_equal(input_nodes["n1"].numpy(), target_idx["n1"].numpy()) + assert len(input_feats) == 0 + assert_equal(lbl["n1"].cpu().numpy(), labels["n1"].numpy()) + +def test_mtask_prepare_edge_mini_batch(): with tempfile.TemporaryDirectory() as tmpdirname: # get the test dummy distributed graph - _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) + g, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) ep_data = GSgnnData(part_config=part_config) setup_device(0) @@ -244,36 +253,55 @@ def test_mtask_run_edge_mini_batch(): task_id = "test_edge_prediction" dataloader = GSgnnEdgeDataLoader(ep_data, target_idx, [10], 10, node_feats='feat', - edge_feats='feat', label_field='label', train_task=True, remove_target_edge_type=False) task_config = GSConfig.__new__(GSConfig) - expected_loss = th.rand(ep_data.g.number_of_edges('r1')) setattr(task_config, "task_weight", 0.71) task_info = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, task_id=task_id, task_config=task_config, dataloader=dataloader) - input_nodes = { + input_node_idx = { "n0": th.arange(10), "n1": th.arange(20), } - node_feats = ep_data.get_node_feats(input_nodes, 'feat', device=device) - labels = ep_data.get_node_feats(target_idx, 'label', device=device) - mini_batch = (input_nodes, None, None) - model = DummyGSgnnMultiTaskSharedEncoderModel(task_id, - task_type=BUILTIN_TASK_EDGE_REGRESSION, - labels=labels, - node_feast=node_feats, - expected_loss=expected_loss) - + node_feats = ep_data.get_node_feats(input_node_idx, 'feat') + labels = ep_data.get_edge_feats(target_idx, 'label') + print(g.edges[('n0', 'r1', 'n1')]) + batch_graph = dgl.heterograph( + {('n0', 'r1', 'n1'): (th.randint(g.number_of_nodes("n0"), (g.number_of_edges('r1'),)), + th.randint(g.number_of_nodes("n1"), (g.number_of_edges('r1'),)))} + ) + batch_graph.edges[('n0', 'r1', 'n1')].data[dgl.EID] = th.arange(ep_data.g.number_of_edges('r1')) + mini_batch = (input_node_idx, batch_graph, None) + blocks, edge_graph, input_feats, _, \ + edge_decoder_feats, lbl, input_nodes = prepare_edge_mini_batch(ep_data, task_info, mini_batch, device) + + assert blocks is None + assert edge_decoder_feats is None + assert edge_graph.number_of_edges('r1') == batch_graph.number_of_edges('r1') + assert_equal(input_nodes["n0"].numpy(), input_node_idx["n0"].numpy()) + assert_equal(input_nodes["n1"].numpy(), input_node_idx["n1"].numpy()) + assert_equal(input_feats["n0"].cpu().numpy(), node_feats["n0"].numpy()) + assert_equal(input_feats["n1"].cpu().numpy(), node_feats["n1"].numpy()) + assert_equal(lbl[('n0', 'r1', 'n1')].cpu().numpy(), labels[('n0', 'r1', 'n1')].numpy()) - loss, weight = run_edge_mini_batch(model, ep_data, task_info, mini_batch, device) - assert assert_equal(loss.numpy(), expected_loss.numpy()) - assert weight == 0.71 - -def test_mtask_run_lp_mini_batch(): + dataloader = GSgnnEdgeDataLoader(ep_data, target_idx, [10], 10, + label_field='label', + train_task=True, remove_target_edge_type=False) + task_info = TaskInfo(task_type=BUILTIN_TASK_EDGE_REGRESSION, + task_id=task_id, + task_config=task_config, + dataloader=dataloader) + _, _, input_feats, _, \ + _, lbl, input_nodes = prepare_edge_mini_batch(ep_data, task_info, mini_batch, device) + assert_equal(input_nodes["n0"].numpy(), input_node_idx["n0"].numpy()) + assert_equal(input_nodes["n1"].numpy(), input_node_idx["n1"].numpy()) + assert len(input_feats) == 0 + assert_equal(lbl[('n0', 'r1', 'n1')].cpu().numpy(), labels[('n0', 'r1', 'n1')].numpy()) + +def test_mtask_prepare_lp_mini_batch(): with tempfile.TemporaryDirectory() as tmpdirname: # get the test dummy distributed graph _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) @@ -287,34 +315,60 @@ def test_mtask_run_lp_mini_batch(): dataloader = GSgnnLinkPredictionDataLoader(ep_data, target_idx, [10], 10, num_negative_edges=2, + node_feats='feat', train_task=False) task_config = GSConfig.__new__(GSConfig) - expected_loss = th.rand(ep_data.g.number_of_edges('r1')) setattr(task_config, "task_weight", 0.72) task_info = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, task_id=task_id, task_config=task_config, dataloader=dataloader) - input_nodes = { + input_node_idx = { "n0": th.arange(10), "n1": th.arange(20), } - node_feats = ep_data.get_node_feats(input_nodes, 'feat', device=device) - - mini_batch = (input_nodes, None, None, None) - model = DummyGSgnnMultiTaskSharedEncoderModel(task_id, - task_type=BUILTIN_TASK_LINK_PREDICTION, - labels=None, - node_feast=node_feats, - expected_loss=expected_loss) + node_feats = ep_data.get_node_feats(input_node_idx, 'feat') + + input_pos_graph = dgl.heterograph( + {('n0', 'r1', 'n1'): (th.tensor([0,1]), + th.tensor([1,2]))}) + input_neg_graph = dgl.heterograph( + {('n0', 'r1', 'n1'): (th.tensor([0,1]), + th.tensor([1,2]))}) + + mini_batch = (input_node_idx, input_pos_graph, input_neg_graph, None) + + blocks, pos_graph, neg_graph, input_feats, _, \ + pos_graph_feats, _, input_nodes = \ + prepare_link_predict_mini_batch(ep_data, task_info, mini_batch, device) + + assert blocks is None + assert_equal(input_nodes["n0"].numpy(), input_node_idx["n0"].numpy()) + assert_equal(input_nodes["n1"].numpy(), input_node_idx["n1"].numpy()) + assert_equal(input_feats["n0"].cpu().numpy(), node_feats["n0"].numpy()) + assert_equal(input_feats["n1"].cpu().numpy(), node_feats["n1"].numpy()) + assert pos_graph_feats is None + assert input_pos_graph.number_of_edges('r1') == pos_graph.number_of_edges('r1') + assert input_neg_graph.number_of_edges('r1') == neg_graph.number_of_edges('r1') - loss, weight = run_link_predict_mini_batch(model, ep_data, task_info, mini_batch, device) - assert assert_equal(loss.numpy(), expected_loss.numpy()) - assert weight == 0.72 + dataloader = GSgnnLinkPredictionDataLoader(ep_data, target_idx, + [10], 10, + num_negative_edges=2, + train_task=False) + task_info = TaskInfo(task_type=BUILTIN_TASK_LINK_PREDICTION, + task_id=task_id, + task_config=task_config, + dataloader=dataloader) + _, _, _, input_feats, _, \ + _, _, input_nodes = \ + prepare_link_predict_mini_batch(ep_data, task_info, mini_batch, device) + assert len(input_feats) == 0 + assert_equal(input_nodes["n0"].numpy(), input_node_idx["n0"].numpy()) + assert_equal(input_nodes["n1"].numpy(), input_node_idx["n1"].numpy()) if __name__ == '__main__': test_trainer_setup_evaluator() - test_mtask_run_node_mini_batch() - test_mtask_run_edge_mini_batch() - test_mtask_run_lp_mini_batch() + test_mtask_prepare_node_mini_batch() + test_mtask_prepare_edge_mini_batch() + test_mtask_prepare_lp_mini_batch() From 1d92008d5730b488713aa9890c0d9eeea47adc9e Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 24 May 2024 01:41:47 -0700 Subject: [PATCH 64/79] Update --- python/graphstorm/model/multitask_gnn.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index df03c121ee..5fc7a7cf4a 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -39,7 +39,7 @@ class GSgnnMultiTaskModelInterface: This interface defines two main methods for training and inference. """ @abc.abstractmethod - def forward(self, task_id, mini_batch): + def forward(self, task_mini_batches): """ The forward function for multi-task learning This method is used for training, It runs model forword @@ -48,11 +48,8 @@ def forward(self, task_id, mini_batch): Parameters ---------- - task_id: str - ID of the task. - mini_batch: tuple - Mini-batch info. - + task_mini_batches: list + Mini-batches. Return ------ @@ -166,7 +163,7 @@ def _run_mini_batch(self, task_info, mini_batch): (blocks, node_feats, edge_feats, input_nodes), (pos_graph, neg_graph, pos_edge_feats, neg_edge_feats)) else: - raise TypeError("Unknown task %s", task_info) + raise TypeError(f"Unknown task {task_info}") return loss, task_info.task_config.task_weight From d6c2cb30cfccab8e1f8e7640af24042209af9329 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Fri, 24 May 2024 16:44:38 -0700 Subject: [PATCH 65/79] Update --- python/graphstorm/__init__.py | 1 + python/graphstorm/gsf.py | 37 +++++++- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 89 ++++++++++++++----- .../end2end-tests/graphstorm-mt/mgpu_test.sh | 5 +- 4 files changed, 104 insertions(+), 28 deletions(-) diff --git a/python/graphstorm/__init__.py b/python/graphstorm/__init__.py index fd310ca5ae..29b747104a 100644 --- a/python/graphstorm/__init__.py +++ b/python/graphstorm/__init__.py @@ -28,6 +28,7 @@ from .gsf import create_builtin_lp_model from .gsf import create_builtin_edge_model from .gsf import create_builtin_node_model +from .gsf import create_task_decoder from .gsf import (create_builtin_node_decoder, create_builtin_edge_decoder, diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index dc1c6fc0ab..91924437c6 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -28,10 +28,11 @@ from .utils import sys_tracker, get_rank from .utils import setup_device -from .config import BUILTIN_TASK_NODE_CLASSIFICATION -from .config import BUILTIN_TASK_NODE_REGRESSION -from .config import BUILTIN_TASK_EDGE_CLASSIFICATION -from .config import BUILTIN_TASK_EDGE_REGRESSION +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) from .config import BUILTIN_LP_DOT_DECODER from .config import BUILTIN_LP_DISTMULT_DECODER from .config import (BUILTIN_LP_LOSS_CROSS_ENTROPY, @@ -842,3 +843,31 @@ def get_builtin_lp_train_dataloader_class(config): raise ValueError('Unknown negative sampler') return dataloader_cls + +def create_task_decoder(task_info, g, decoder_input_dim, train_task): + """ Create task decoders according to task_info. + + Parameters + ---------- + task_info: TaskInfo + Task info. + g: Dist DGLGraph + Graph + decoder_input_dim: int + The dimension of the input embedding of the decoder + train_task: bool + Whether the task is a training task + + Return + ------ + decoder: The node task decoder(s) + loss_func: The loss function(s) + """ + if task_info.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: + return create_builtin_node_decoder(g, decoder_input_dim, task_info.task_config, train_task) + elif task_info.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: + return create_builtin_edge_decoder(g, decoder_input_dim, task_info.task_config, train_task) + elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + return create_builtin_lp_decoder(g, decoder_input_dim, task_info.task_config, train_task) + + return None, None diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 74d37be20f..531e122c9e 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. - GSgnn multi-task learning + GSgnn multi-task learning training entry point. """ import os import logging @@ -44,9 +44,23 @@ from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph from graphstorm.utils import get_lm_ntypes -def create_task_train_dataloader(task, config, task_config, train_data): - """ +def create_task_train_dataloader(task, config, train_data): + """ Create task specific dataloader for training tasks + + Parameters + ---------- + task: TaskInfo + Task info. + config: GSConfig + Training config. + train_data: GSgnnData + Training data. + + Return + ------ + Task dataloader """ + task_config = task.task_config # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.fanout # All tasks share the same input encoder, so the node feats must be same. @@ -94,9 +108,23 @@ def create_task_train_dataloader(task, config, task_config, train_data): return None -def create_task_val_dataloader(task, config, task_config, train_data): - """ +def create_task_val_dataloader(task, config, train_data): + """ Create task specific validation dataloader. + + Parameters + ---------- + task: TaskInfo + Task info. + config: GSConfig + Training config. + train_data: GSgnnData + Training data. + + Return + ------ + Task dataloader """ + task_config = task.task_config if task_config.val_mask is None: # There is no validation mask return None @@ -158,9 +186,23 @@ def create_task_val_dataloader(task, config, task_config, train_data): return None -def create_task_test_dataloader(task, config, task_config, train_data): - """ +def create_task_test_dataloader(task, config, train_data): + """ Create task specific test dataloader. + + Parameters + ---------- + task: TaskInfo + Task info. + config: GSConfig + Training config. + train_data: GSgnnData + Training data. + + Return + ------ + Task dataloader """ + task_config = task.task_config if task_config.test_mask is None: # There is no validation mask return None @@ -222,17 +264,19 @@ def create_task_test_dataloader(task, config, task_config, train_data): pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) return None -def create_task_decoder(task, g, decoder_input_dim, train_task): - if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - return gs.create_builtin_node_decoder(g, decoder_input_dim, task.task_config, train_task) - elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - return gs.create_builtin_edge_decoder(g, decoder_input_dim, task.task_config, train_task) - elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - return gs.create_builtin_lp_decoder(g, decoder_input_dim, task.task_config, train_task) +def create_evaluator(task): + """ Create task specific evaluators - return None, None + Parameters + ---------- + task: TaskInfo + Task info. -def create_evaluator(task, config): + Return + ------ + Evaluators + """ + config = task.task_config if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION]: multilabel = config.multilabel[config.eval_target_ntype] \ if isinstance(config.multilabel, dict) else config.multilabel @@ -312,14 +356,13 @@ def main(config_args): if model.gnn_encoder is not None \ else model.node_input_encoder.out_dims for task in tasks: - task_config = task.task_config - train_loader = create_task_train_dataloader(task, config, task_config, train_data) - val_loader = create_task_val_dataloader(task, config, task_config, train_data) - test_loader = create_task_test_dataloader(task, config, task_config, train_data) + train_loader = create_task_train_dataloader(task, config, train_data) + val_loader = create_task_val_dataloader(task, config, train_data) + test_loader = create_task_test_dataloader(task, config, train_data) train_dataloaders.append(train_loader) val_dataloaders.append(val_loader) test_dataloaders.append(test_loader) - decoder, loss_func = create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) + decoder, loss_func = gs.create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) model.add_task(task.task_id, task.task_type, decoder, loss_func) if not config.no_validation: if val_loader is None: @@ -327,7 +370,7 @@ def main(config_args): if test_loader is None: logging.warning("The training data do not have test set.") task_evaluators[task.task_id] = \ - create_evaluator(task, task_config) + create_evaluator(task) train_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, train_dataloaders) val_dataloader = GSgnnMultiTaskDataLoader(train_data, tasks, val_dataloaders) @@ -380,7 +423,7 @@ def main(config_args): if config.save_embed_path is not None: # Save node embeddings model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm) - gs.set_encoder(model, train_data.g, config, train_task=True) + gs.gsf.set_encoder(model, train_data.g, config, train_task=True) best_model_path = trainer.get_best_model_path() # TODO(zhengda) the model path has to be in a shared filesystem. model.restore_model(best_model_path) diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index d8a3cdcdd6..ec1cd72921 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -30,4 +30,7 @@ echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fi python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True echo "**************[Multi-task with learnable embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" -python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True \ No newline at end of file +python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True + +echo "**************[Multi-task with learnable embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" +python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True --use-mini-batch-infer False --save-embed-path /data/gsgnn_mt/emb/ \ No newline at end of file From c90fb5f0eed5a718b03c6d7e86ae1ee4f2a6ea23 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Sun, 26 May 2024 12:44:30 -0700 Subject: [PATCH 66/79] Update --- python/graphstorm/inference/lp_infer.py | 2 +- python/graphstorm/model/multitask_gnn.py | 11 +- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 6 +- python/graphstorm/trainer/mt_trainer.py | 2 +- .../end2end-tests/graphstorm-mt/mgpu_test.sh | 158 +++++++++++++++++- 5 files changed, 167 insertions(+), 12 deletions(-) diff --git a/python/graphstorm/inference/lp_infer.py b/python/graphstorm/inference/lp_infer.py index c5239e6642..1a02db4f3c 100644 --- a/python/graphstorm/inference/lp_infer.py +++ b/python/graphstorm/inference/lp_infer.py @@ -72,7 +72,7 @@ def infer(self, data, loader, save_embed_path, save_embed_format : str Specify the format of saved embeddings. infer_batch_size: int - Specify the inference batch size when computing node embeddings + Specify the inference batch size when computing node embeddings with mini batch inference. """ sys_tracker.check('start inferencing') diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index df03c121ee..5fc7a7cf4a 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -39,7 +39,7 @@ class GSgnnMultiTaskModelInterface: This interface defines two main methods for training and inference. """ @abc.abstractmethod - def forward(self, task_id, mini_batch): + def forward(self, task_mini_batches): """ The forward function for multi-task learning This method is used for training, It runs model forword @@ -48,11 +48,8 @@ def forward(self, task_id, mini_batch): Parameters ---------- - task_id: str - ID of the task. - mini_batch: tuple - Mini-batch info. - + task_mini_batches: list + Mini-batches. Return ------ @@ -166,7 +163,7 @@ def _run_mini_batch(self, task_info, mini_batch): (blocks, node_feats, edge_feats, input_nodes), (pos_graph, neg_graph, pos_edge_feats, neg_edge_feats)) else: - raise TypeError("Unknown task %s", task_info) + raise TypeError(f"Unknown task {task_info}") return loss, task_info.task_config.task_weight diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 531e122c9e..df7f0532b3 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -424,6 +424,10 @@ def main(config_args): # Save node embeddings model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm) gs.gsf.set_encoder(model, train_data.g, config, train_task=True) + + for task in tasks: + decoder, loss_func = gs.create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) + model.add_task(task.task_id, task.task_type, decoder, loss_func) best_model_path = trainer.get_best_model_path() # TODO(zhengda) the model path has to be in a shared filesystem. model.restore_model(best_model_path) @@ -434,7 +438,7 @@ def main(config_args): model.prepare_input_encoder(train_data) embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, - edge_mask="train_mask", task_tracker=tracker) + task_tracker=tracker) save_full_node_embeddings( train_data.g, diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index bd220dd369..c0e404f3f2 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -140,7 +140,7 @@ class GSgnnMultiTaskLearningTrainer(GSgnnTrainer): Parameters ---------- model : GSgnnMultiTaskModel - The GNN model for node prediction. + The GNN model for prediction. topk_model_to_save : int The top K model to save. """ diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index ec1cd72921..4336706804 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -29,8 +29,162 @@ df /dev/shm -h echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True +error_and_exit $? + +# check prints + +bst_cnt=$(grep "Best Test node_classification" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test node_classification" + exit -1 +fi + +cnt=$(grep "Test node_classification" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test node_classification" + exit -1 +fi + +bst_cnt=$(grep "Best Validation node_classification" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation accuracy node_classification" + exit -1 +fi + +cnt=$(grep "Validation node_classification" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation node_classification" + exit -1 +fi + +bst_cnt=$(grep "Best Test edge_classification" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test edge_classification" + exit -1 +fi + +cnt=$(grep "Test edge_classification" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test edge_classification" + exit -1 +fi + +bst_cnt=$(grep "Best Validation edge_classification" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation edge_classification" + exit -1 +fi + +cnt=$(grep "Validation edge_classification" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation edge_classification" + exit -1 +fi + +bst_cnt=$(grep "Best Test edge_regression" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test edge_regression" + exit -1 +fi + +cnt=$(grep "Test edge_regression" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test edge_regression" + exit -1 +fi + +bst_cnt=$(grep "Best Validation edge_regression" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation edge_regression" + exit -1 +fi + +cnt=$(grep "Validation edge_regression" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation edge_regression" + exit -1 +fi + +bst_cnt=$(grep "Best Test link_prediction" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test link_prediction" + exit -1 +fi + +cnt=$(grep "Test link_prediction" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Test link_prediction" + exit -1 +fi + +bst_cnt=$(grep "Best Validation link_prediction" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation link_prediction" + exit -1 +fi + +cnt=$(grep "Validation link_prediction" /tmp/train_log.txt | wc -l) +if test $cnt -lt $((1+$bst_cnt)) +then + echo "We use SageMaker task tracker, we should have Validation link_prediction" + exit -1 +fi + +cnt=$(ls -l /data/gsgnn_mt/ | grep epoch | wc -l) +if test $cnt != 1 +then + echo "The number of save models $cnt is not equal to the specified topk 1" + exit -1 +fi + echo "**************[Multi-task with learnable embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True -echo "**************[Multi-task with learnable embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" -python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True --use-mini-batch-infer False --save-embed-path /data/gsgnn_mt/emb/ \ No newline at end of file +error_and_exit $? + +rm /tmp/train_log.txt +rm -frm /data/gsgnn_mt/ + +echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" +python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-mini-batch-infer False --save-embed-path /data/gsgnn_mt/emb/ + +error_and_exit $? + +cnt=$(grep "save_embed_path: /data/gsgnn_mt/emb/" /tmp/train_log.txt | wc -l) +if test $cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have save_embed_path" + exit -1 +fi + +echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch load from saved model" +python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --restore-model-path /data/gsgnn_mt/epoch-2/ --save-model-path /data/gsgnn_mt_2/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug + +error_and_exit $? + +cnt=$(ls -l /data/gsgnn_mt_2/ | grep epoch | wc -l) +if test $cnt != 1 +then + echo "The number of save models $cnt is not equal to the specified topk 1" + exit -1 +fi + +echo "**************[Multi-task gen embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, load from saved model" +python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_mt/ --num-trainers $NUM_TRAINERS --use-mini-batch-infer false --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-embed-path /data/gsgnn_mt/save-emb/ --restore-model-path --restore-model-path /data/gsgnn_mt/epoch-2/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True + +error_and_exit $? \ No newline at end of file From 3a98131770588aa5171d9fe95357f6b67837e099 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Sun, 26 May 2024 12:53:38 -0700 Subject: [PATCH 67/79] Update --- python/graphstorm/dataloading/dataloading.py | 4 ++-- python/graphstorm/model/multitask_gnn.py | 20 +++++++++++++++++--- tests/unit-tests/test_gnn.py | 4 ++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index b035e5df43..6b6e072078 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -1707,7 +1707,7 @@ def __init__(self, dataset, task_infos, task_dataloaders): lens = [] for task_info, dataloader in zip(task_infos, task_dataloaders): # For evaluation and testing, we allow some of the val_dataloaders or test_dataloaders - # are empty (None). + # to be empty (None). assert isinstance(dataloader, (GSgnnEdgeDataLoaderBase, GSgnnLinkPredictionDataLoaderBase, GSgnnNodeDataLoaderBase)) or dataloader is None, \ @@ -1812,7 +1812,7 @@ def fanout(self): Returns ------- - list or a dict of list : the fanouts for each GNN layer. + list of list or list of dict of list : the fanouts for each GNN layer. """ fanouts = [dataloader.fanout if dataloader is not None \ else None for dataloader in self.dataloaders] diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 5fc7a7cf4a..68cf5c8349 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -79,7 +79,7 @@ def predict(self, task_id, mini_batch): class GSgnnMultiTaskSharedEncoderModel(GSgnnModel, GSgnnMultiTaskModelInterface): """ GraphStorm GNN model for multi-task learning - with a shared encoder model and separate decoder models. + with a shared encoder model and separate decoder models for each task. Parameters ---------- @@ -135,7 +135,7 @@ def task_decoders(self): return self._decoder def _run_mini_batch(self, task_info, mini_batch): - """ Run mini_batch forward + """ Run mini_batch forward. """ if task_info.task_type in \ [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: @@ -168,6 +168,10 @@ def _run_mini_batch(self, task_info, mini_batch): return loss, task_info.task_config.task_weight def forward(self, task_mini_batches): + """ The forward function for multi-task learning + It will iterate over the mini-batches and call + forward for each task. + """ losses = [] for (task_info, mini_batch) in task_mini_batches: loss, weight = self._run_mini_batch(task_info, mini_batch) @@ -185,7 +189,17 @@ def forward(self, task_mini_batches): # pylint: disable=unused-argument def _forward(self, task_id, encoder_data, decoder_data): - """ The forward function for multi-task learning + """ The forward function to run forward for a specific + task with task_id. + + Parameters + ---------- + task_id: str + Task ID. + encoder_data: tuple + The input data for the encoder. + decoder_data: tuple + The input for the decoder. """ assert task_id in self.task_pool, \ f"Unknown task: {task_id} in multi-task learning." \ diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 20e6129390..d0225802cc 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -77,8 +77,8 @@ from graphstorm.model.utils import load_model, save_model from graphstorm.model import GSgnnMultiTaskSharedEncoderModel from graphstorm.dataloading import (GSgnnEdgeDataLoaderBase, - GSgnnLinkPredictionDataLoaderBase, - GSgnnNodeDataLoaderBase) + GSgnnLinkPredictionDataLoaderBase, + GSgnnNodeDataLoaderBase) from data_utils import generate_dummy_dist_graph, generate_dummy_dist_graph_multi_target_ntypes from data_utils import generate_dummy_dist_graph_reconstruct From 9baac44d39dd7577676ecd74f19837eff6d8fc20 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Sun, 26 May 2024 23:09:24 -0700 Subject: [PATCH 68/79] Update --- tests/end2end-tests/graphstorm-mt/mgpu_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index 4336706804..e116685430 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -185,6 +185,6 @@ then fi echo "**************[Multi-task gen embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, load from saved model" -python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_mt/ --num-trainers $NUM_TRAINERS --use-mini-batch-infer false --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-embed-path /data/gsgnn_mt/save-emb/ --restore-model-path --restore-model-path /data/gsgnn_mt/epoch-2/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True +python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_mt/ --num-trainers $NUM_TRAINERS --use-mini-batch-infer false --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-embed-path /data/gsgnn_mt/save-emb/ --restore-model-path /data/gsgnn_mt/epoch-2/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True error_and_exit $? \ No newline at end of file From 6003cdc544f47e8ad08859489d0c4a84af869e9b Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Sun, 26 May 2024 23:11:32 -0700 Subject: [PATCH 69/79] update --- tests/end2end-tests/graphstorm-mt/mgpu_test.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index e116685430..48c6b6f98c 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -183,8 +183,3 @@ then echo "The number of save models $cnt is not equal to the specified topk 1" exit -1 fi - -echo "**************[Multi-task gen embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, load from saved model" -python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_mt/ --num-trainers $NUM_TRAINERS --use-mini-batch-infer false --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-embed-path /data/gsgnn_mt/save-emb/ --restore-model-path /data/gsgnn_mt/epoch-2/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True - -error_and_exit $? \ No newline at end of file From a7c14e3f7b42a19c25112004a93d178a59a1f2a8 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Sun, 26 May 2024 23:24:09 -0700 Subject: [PATCH 70/79] Update --- .github/workflow_scripts/e2e_mgpu_check.sh | 1 + tests/end2end-tests/graphstorm-mt/mgpu_test.sh | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflow_scripts/e2e_mgpu_check.sh b/.github/workflow_scripts/e2e_mgpu_check.sh index 66417f78a7..35ea1b4654 100644 --- a/.github/workflow_scripts/e2e_mgpu_check.sh +++ b/.github/workflow_scripts/e2e_mgpu_check.sh @@ -10,4 +10,5 @@ sh ./tests/end2end-tests/create_data.sh bash ./tests/end2end-tests/graphstorm-lp/mgpu_test.sh bash ./tests/end2end-tests/graphstorm-nc/mgpu_test.sh bash ./tests/end2end-tests/graphstorm-ec/mgpu_test.sh +bash ./tests/end2end-tests/graphstorm-mt/mgpu_test.sh diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index 48c6b6f98c..96cf29569f 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -146,7 +146,7 @@ then fi cnt=$(ls -l /data/gsgnn_mt/ | grep epoch | wc -l) -if test $cnt != 1 +if test $cnt != 3 then echo "The number of save models $cnt is not equal to the specified topk 1" exit -1 @@ -172,13 +172,20 @@ then exit -1 fi +cnt=$(ls -l /data/gsgnn_mt/emb/ | wc -l) +cnt=$[cnt - 1] +if test $cnt != 2 +then + echo "The number of saved embs $cnt is not equal to 2 (for movie and user)." +fi + echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch load from saved model" python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --restore-model-path /data/gsgnn_mt/epoch-2/ --save-model-path /data/gsgnn_mt_2/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug error_and_exit $? cnt=$(ls -l /data/gsgnn_mt_2/ | grep epoch | wc -l) -if test $cnt != 1 +if test $cnt != 3 then echo "The number of save models $cnt is not equal to the specified topk 1" exit -1 From 2a9aa880c90dd977eb9832d856eb8717f0cbfb18 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Sun, 26 May 2024 23:51:22 -0700 Subject: [PATCH 71/79] Fix lint --- python/graphstorm/trainer/mt_trainer.py | 76 +++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 6 deletions(-) diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index c0e404f3f2..dd7282176f 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -39,7 +39,26 @@ from ..utils import barrier, is_distributed def prepare_node_mini_batch(data, task_info, mini_batch, device): - """ Run node mini_batch forward + """ Prepare mini-batch for node classification and regression tasks. + + The input is a mini-batch sampled by a node sampler. + The output ia a prepared input following the + input arguments of GSgnnNodeModelInterface.forward. + + Parameters + ---------- + data: GSgnnData + Graph data + task_info: TaskInfo + Task meta information + mini_batch: tuple + Mini-batch info + device: torch.device + Device + + Return + ------ + tuple: mini-batch """ g = data.g input_nodes, seeds, blocks = mini_batch @@ -59,7 +78,26 @@ def prepare_node_mini_batch(data, task_info, mini_batch, device): return (blocks, input_feats, None, lbl, input_nodes) def prepare_edge_mini_batch(data, task_info, mini_batch, device): - """ + """ Prepare mini-batch for edge classification and regression tasks. + + The input is a mini-batch sampled by an edge sampler. + The output ia a prepared input following the + input arguments of GSgnnEdgeModelInterface.forward. + + Parameters + ---------- + data: GSgnnData + Graph data + task_info: TaskInfo + Task meta information + mini_batch: tuple + Mini-batch info + device: torch.device + Device + + Return + ------ + tuple: mini-batch """ input_nodes, batch_graph, blocks = mini_batch if not isinstance(input_nodes, dict): @@ -70,6 +108,7 @@ def prepare_edge_mini_batch(data, task_info, mini_batch, device): node_feats = data.get_node_feats(input_nodes, nfeat_fields, device) if task_info.dataloader.decoder_edge_feat_fields is not None: + # There are edge features used in decoder. input_edges = {etype: batch_graph.edges[etype].data[dgl.EID] \ for etype in batch_graph.canonical_etypes} edge_decoder_feats = \ @@ -84,11 +123,10 @@ def prepare_edge_mini_batch(data, task_info, mini_batch, device): # retrieving seed edge id from the graph to find labels assert len(batch_graph.etypes) == 1 target_etype = batch_graph.canonical_etypes[0] - # TODO(zhengda) the data loader should return labels directly. seeds = batch_graph.edges[target_etype].data[dgl.EID] - label_field = task_info.dataloader.label_field lbl = data.get_edge_feats({target_etype: seeds}, label_field, device) + blocks = [block.to(device) for block in blocks] \ if blocks is not None else None batch_graph = batch_graph.to(device) @@ -100,6 +138,27 @@ def prepare_edge_mini_batch(data, task_info, mini_batch, device): edge_decoder_feats, lbl, input_nodes) def prepare_link_predict_mini_batch(data, task_info, mini_batch, device): + """ Prepare mini-batch for link prediction tasks. + + The input is a mini-batch sampled by an edge sampler. + The output ia a prepared input following the + input arguments of GSgnnLinkPredictionModelInterface.forward. + + Parameters + ---------- + data: GSgnnData + Graph data + task_info: TaskInfo + Task meta information + mini_batch: tuple + Mini-batch info + device: torch.device + Device + + Return + ------ + tuple: mini-batch + """ input_nodes, pos_graph, neg_graph, blocks = mini_batch if not isinstance(input_nodes, dict): @@ -123,6 +182,7 @@ def prepare_link_predict_mini_batch(data, task_info, mini_batch, device): blocks = [blk.to(device) for blk in blocks] \ if blocks is not None else None + # follow the interface of GSgnnLinkPredictionModelInterface.forward return (blocks, pos_graph, neg_graph, node_feats, None, \ pos_graph_feats, None, input_nodes) @@ -146,7 +206,8 @@ class GSgnnMultiTaskLearningTrainer(GSgnnTrainer): """ def __init__(self, model, topk_model_to_save=1): super(GSgnnMultiTaskLearningTrainer, self).__init__(model, topk_model_to_save) - assert isinstance(model, GSgnnMultiTaskModelInterface) and isinstance(model, GSgnnModelBase), \ + assert isinstance(model, GSgnnMultiTaskModelInterface) \ + and isinstance(model, GSgnnModelBase), \ "The input model is not a GSgnnModel model. Please implement GSgnnModelBase." def _prepare_mini_batch(self, data, task_info, mini_batch, device): @@ -187,8 +248,9 @@ def _prepare_mini_batch(self, data, task_info, mini_batch, device): mini_batch, device) else: - raise TypeError("Unknown task %s", task_info) + raise TypeError(f"Unknown task {task_info}", ) + # pylint: disable=unused-argument def fit(self, train_loader, num_epochs, val_loader=None, @@ -227,6 +289,8 @@ def fit(self, train_loader, The number of iteration to train the model before saving the model. save_perf_results_path : str The path of the file where the performance results are saved. + TODO(xiangsx): Add support for saving performance results on disk. + Reserved for future. freeze_input_layer_epochs: int Freeze the input layer for N epochs. This is commonly used when the input layer contains language models. From fc45ea215961d2030436381e5d3e07cc996240f8 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 27 May 2024 00:14:57 -0700 Subject: [PATCH 72/79] Fix lint --- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 52 ++++++++++++++-------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index df7f0532b3..d264297474 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -68,7 +68,9 @@ def create_task_train_dataloader(task, config, train_data): logging.info("Create dataloader for %s", task.task_id) if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: - train_idxs = train_data.get_node_train_set(task_config.target_ntype, mask=task_config.train_mask) + train_idxs = train_data.get_node_train_set( + task_config.target_ntype, + mask=task_config.train_mask) # TODO(xiangsx): Support construct feat return GSgnnNodeDataLoader(train_data, train_idxs, @@ -78,7 +80,9 @@ def create_task_train_dataloader(task, config, train_data): node_feats=node_feats, label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - train_idxs = train_data.get_edge_train_set(task_config.target_etype, mask=task_config.train_mask) + train_idxs = train_data.get_edge_train_set( + task_config.target_etype, + mask=task_config.train_mask) # TODO(xiangsx): Support construct feat return GSgnnEdgeDataLoader(train_data, train_idxs, @@ -91,7 +95,9 @@ def create_task_train_dataloader(task, config, train_data): reverse_edge_types_map=task_config.reverse_edge_types_map, remove_target_edge_type=task_config.remove_target_edge_type) elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - train_idxs = train_data.get_edge_train_set(task_config.train_etype, mask=task_config.train_mask) + train_idxs = train_data.get_edge_train_set( + task_config.train_etype, + mask=task_config.train_mask) dataloader_cls = gs.get_builtin_lp_train_dataloader_class(task_config) return dataloader_cls(train_data, train_idxs, @@ -226,7 +232,9 @@ def create_task_test_dataloader(task, config, train_data): label_field=task_config.label_field) elif task.task_type in [BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION]: - test_idxs = train_data.get_edge_test_set(task_config.target_etype, mask=task_config.test_mask) + test_idxs = train_data.get_edge_test_set( + task_config.target_etype, + mask=task_config.test_mask) # All tasks share the same GNN model, so the fanout should be the global fanout fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: @@ -314,18 +322,20 @@ def create_evaluator(task): assert len(config.eval_metric) == 1, \ "GraphStorm doees not support computing multiple metrics at the same time." if config.report_eval_per_type: - return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency, - major_etype=config.model_select_etype, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) + return GSgnnPerEtypeMrrLPEvaluator( + eval_frequency=config.eval_frequency, + major_etype=config.model_select_etype, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) else: - return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) + return GSgnnMrrLPEvaluator( + eval_frequency=config.eval_frequency, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) return None def main(config_args): @@ -362,7 +372,10 @@ def main(config_args): train_dataloaders.append(train_loader) val_dataloaders.append(val_loader) test_dataloaders.append(test_loader) - decoder, loss_func = gs.create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) + decoder, loss_func = gs.create_task_decoder(task, + train_data.g, + encoder_out_dims, + train_task=True) model.add_task(task.task_id, task.task_type, decoder, loss_func) if not config.no_validation: if val_loader is None: @@ -426,7 +439,10 @@ def main(config_args): gs.gsf.set_encoder(model, train_data.g, config, train_task=True) for task in tasks: - decoder, loss_func = gs.create_task_decoder(task, train_data.g, encoder_out_dims, train_task=True) + decoder, loss_func = gs.create_task_decoder(task, + train_data.g, + encoder_out_dims, + train_task=True) model.add_task(task.task_id, task.task_type, decoder, loss_func) best_model_path = trainer.get_best_model_path() # TODO(zhengda) the model path has to be in a shared filesystem. @@ -459,5 +475,3 @@ def generate_parser(): # Ignore unknown args to make script more robust to input arguments gs_args, _ = arg_parser.parse_known_args() main(gs_args) - - From 7f9947c02a98bff42190f17783ecf3f0fe146175 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 27 May 2024 13:24:45 -0700 Subject: [PATCH 73/79] Update --- tests/unit-tests/test_config.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index 0b354d2440..cdbbbfb2da 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -1718,7 +1718,6 @@ def test_multi_task_config(): assert nc_config.task_id == f"{BUILTIN_TASK_NODE_CLASSIFICATION}-a-label_c" nc_config = nc_config.task_config assert nc_config.task_weight == 1 - assert len(nc_config.mask_fields) == 3 assert nc_config.train_mask == "class_train_mask" assert nc_config.val_mask == "class_eval_mask" assert nc_config.test_mask == "class_test_mask" @@ -1739,7 +1738,6 @@ def test_multi_task_config(): assert nr_config.task_id == f"{BUILTIN_TASK_NODE_REGRESSION}-a-label_r" nr_config = nr_config.task_config assert nr_config.task_weight == 0.5 - assert len(nr_config.mask_fields) == 3 assert nr_config.train_mask == "reg_train_mask" assert nr_config.val_mask == "reg_eval_mask" assert nr_config.test_mask == "reg_test_mask" @@ -1754,7 +1752,6 @@ def test_multi_task_config(): assert ec_config.task_id == f"{BUILTIN_TASK_EDGE_CLASSIFICATION}-query_match_asin-label_ec" ec_config = ec_config.task_config assert ec_config.task_weight == 1 - assert len(ec_config.mask_fields) == 3 assert ec_config.train_mask == "ec_train_mask" assert ec_config.val_mask == "ec_eval_mask" assert ec_config.test_mask == "ec_test_mask" @@ -1777,7 +1774,6 @@ def test_multi_task_config(): assert er_config.task_id == f"{BUILTIN_TASK_EDGE_REGRESSION}-query_match-2_asin-label_er" er_config = er_config.task_config assert er_config.task_weight == 1 - assert len(er_config.mask_fields) == 3 assert er_config.train_mask == "er_train_mask" assert er_config.val_mask == "er_eval_mask" assert er_config.test_mask == "er_test_mask" @@ -1797,7 +1793,6 @@ def test_multi_task_config(): assert lp_config.task_id == f"{BUILTIN_TASK_LINK_PREDICTION}-query_exactmatch_asin" lp_config = lp_config.task_config assert lp_config.task_weight == 1 - assert len(lp_config.mask_fields) == 3 assert lp_config.train_mask == "lp_train_mask" assert lp_config.val_mask == "lp_eval_mask" assert lp_config.test_mask == "lp_test_mask" @@ -1826,7 +1821,6 @@ def test_multi_task_config(): assert lp_config.task_id == f"{BUILTIN_TASK_LINK_PREDICTION}-ALL_ETYPE" lp_config = lp_config.task_config assert lp_config.task_weight == 2 - assert len(lp_config.mask_fields) == 3 assert lp_config.train_mask == "lp2_train_mask" assert lp_config.val_mask == "lp2_eval_mask" assert lp_config.test_mask == "lp2_test_mask" From 2da6621515622d57c14b3d3c4237e61effecc636 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Mon, 27 May 2024 22:47:01 -0700 Subject: [PATCH 74/79] update --- tests/end2end-tests/graphstorm-mt/mgpu_test.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index 96cf29569f..0298465592 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -9,7 +9,7 @@ NUM_INFO_TRAINERS=2 export PYTHONPATH=$GS_HOME/python/ cd $GS_HOME/training_scripts/gsgnn_mt echo "127.0.0.1" > ip_list.txt -cd $GS_HOME/inference_scripts/lp_infer +cd $GS_HOME/inference_scripts/gsgnn_mt echo "127.0.0.1" > ip_list.txt error_and_exit () { @@ -27,7 +27,7 @@ df /dev/shm -h echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" -python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True +python3 -m graphstorm.run.gs_multi_task_learning --workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True error_and_exit $? @@ -153,7 +153,7 @@ then fi echo "**************[Multi-task with learnable embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" -python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True +python3 -m graphstorm.run.gs_multi_task_learning --workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-node-embeddings True error_and_exit $? @@ -161,7 +161,7 @@ rm /tmp/train_log.txt rm -frm /data/gsgnn_mt/ echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" -python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-mini-batch-infer False --save-embed-path /data/gsgnn_mt/emb/ +python3 -m graphstorm.run.gs_multi_task_learning --workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-mini-batch-infer False --save-embed-path /data/gsgnn_mt/emb/ error_and_exit $? @@ -180,7 +180,7 @@ then fi echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch load from saved model" -python3 -m graphstorm.run.gs_multi_task_learning -workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --restore-model-path /data/gsgnn_mt/epoch-2/ --save-model-path /data/gsgnn_mt_2/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug +python3 -m graphstorm.run.gs_multi_task_learning --workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --restore-model-path /data/gsgnn_mt/epoch-2/ --save-model-path /data/gsgnn_mt_2/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug error_and_exit $? From 28181765c5baabd534de638150b65264cf4dc22a Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Tue, 28 May 2024 15:15:47 -0700 Subject: [PATCH 75/79] Update --- python/graphstorm/run/gs_multi_task_learning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/graphstorm/run/gs_multi_task_learning.py b/python/graphstorm/run/gs_multi_task_learning.py index 93ed3c6edd..dd58526a79 100644 --- a/python/graphstorm/run/gs_multi_task_learning.py +++ b/python/graphstorm/run/gs_multi_task_learning.py @@ -34,7 +34,7 @@ def main(): lib_dir = os.path.abspath(os.path.dirname(__file__)) if args.inference: - cmd = "gsgnn_mt/gsgnn_infer_mt.py" + assert False, "Not implemented" else: cmd = "gsgnn_mt/gsgnn_mt.py" cmd_path = os.path.join(lib_dir, cmd) From 0340223d105c8be706887ec7cc7f7d0abb9f5166 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 29 May 2024 00:34:46 -0700 Subject: [PATCH 76/79] Update --- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index d264297474..d8e820b2ba 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -66,6 +66,9 @@ def create_task_train_dataloader(task, config, train_data): # All tasks share the same input encoder, so the node feats must be same. node_feats = config.node_feat_name + assert task_config.train_mask is not None, \ + "For multi-task learning, train_mask field name " \ + "must be provided through mask_fields, but get None" logging.info("Create dataloader for %s", task.task_id) if task.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: train_idxs = train_data.get_node_train_set( @@ -252,13 +255,15 @@ def create_task_test_dataloader(task, config, train_data): elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: test_idxs = train_data.get_edge_test_set(task_config.eval_etype, mask=task_config.val_mask) dataloader_cls = gs.get_builtin_lp_eval_dataloader_class(task_config) + # All tasks share the same GNN model, so the fanout should be the global fanout + fanout = config.eval_fanout if task_config.use_mini_batch_infer else [] if len(test_idxs) > 0: # TODO(xiangsx): Support construct feat if task_config.eval_etypes_negative_dstnode is not None: return dataloader_cls(train_data, test_idxs, task_config.eval_batch_size, fixed_edge_dst_negative_field=task_config.eval_etypes_negative_dstnode, - fanout=task_config.eval_fanout, + fanout=fanout, fixed_test_size=task_config.fixed_test_size, node_feats=node_feats, pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) @@ -266,7 +271,7 @@ def create_task_test_dataloader(task, config, train_data): return dataloader_cls(train_data, test_idxs, task_config.eval_batch_size, task_config.num_negative_edges_eval, - task_config.eval_fanout, + fanout=fanout, fixed_test_size=task_config.fixed_test_size, node_feats=node_feats, pos_graph_edge_feats=task_config.lp_edge_weight_for_loss) From 6442b60de4398a5cc2dc450c59dbc6d6925f6d71 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 29 May 2024 00:35:30 -0700 Subject: [PATCH 77/79] Update --- python/graphstorm/config/argument.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index f5cc2c569f..3203220d84 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -263,18 +263,16 @@ def _parse_general_task_config(self, task_config): task_config: dict Task config """ - assert "mask_fields" in task_config, \ - "mask_fields should be provided for each node classification task " \ - "in multi task learning" - assert "task_weight" in task_config, \ - "task_weight should be provided for each node classification task " \ - "in multi task learning" - - mask_fields = task_config["mask_fields"] - assert len(mask_fields) == 3, \ - "The mask_fileds should be a list as [train-mask, validation-mask, test-mask], " \ - f"but get {mask_fields}" - task_weight = task_config["task_weight"] + if "mask_fields" in task_config: + mask_fields = task_config["mask_fields"] + assert len(mask_fields) == 3, \ + "The mask_fileds should be a list as [train-mask, validation-mask, test-mask], " \ + f"but get {mask_fields}" + else: + mask_fields = (None, None, None) + + task_weight = task_config["task_weight"] \ + if "task_weight" in task_config else 1.0 assert task_weight > 0, f"task_weight should be larger than 0, but get {task_weight}" batch_size = self.batch_size \ From 0e7c6602675494303ecaa7b749709520474a051e Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 29 May 2024 19:10:21 -0700 Subject: [PATCH 78/79] Update --- python/graphstorm/gsf.py | 20 ++++++++++---------- python/graphstorm/trainer/mt_trainer.py | 19 ++++++++++++++----- python/graphstorm/trainer/np_trainer.py | 5 ++++- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index 91924437c6..9b9d0136e6 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -28,11 +28,11 @@ from .utils import sys_tracker, get_rank from .utils import setup_device -from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, - BUILTIN_TASK_NODE_REGRESSION, - BUILTIN_TASK_EDGE_CLASSIFICATION, - BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) +from .config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) from .config import BUILTIN_LP_DOT_DECODER from .config import BUILTIN_LP_DISTMULT_DECODER from .config import (BUILTIN_LP_LOSS_CROSS_ENTROPY, @@ -845,7 +845,7 @@ def get_builtin_lp_train_dataloader_class(config): return dataloader_cls def create_task_decoder(task_info, g, decoder_input_dim, train_task): - """ Create task decoders according to task_info. + """ Create a task decoder according to task_info. Parameters ---------- @@ -860,8 +860,8 @@ def create_task_decoder(task_info, g, decoder_input_dim, train_task): Return ------ - decoder: The node task decoder(s) - loss_func: The loss function(s) + decoder: The task decoder + loss_func: The loss function """ if task_info.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: return create_builtin_node_decoder(g, decoder_input_dim, task_info.task_config, train_task) @@ -869,5 +869,5 @@ def create_task_decoder(task_info, g, decoder_input_dim, train_task): return create_builtin_edge_decoder(g, decoder_input_dim, task_info.task_config, train_task) elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: return create_builtin_lp_decoder(g, decoder_input_dim, task_info.task_config, train_task) - - return None, None + else: + raise TypeError(f"Unknown task type {task_info.task_type}") diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index dd7282176f..0553f0fa01 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -63,7 +63,10 @@ def prepare_node_mini_batch(data, task_info, mini_batch, device): g = data.g input_nodes, seeds, blocks = mini_batch if not isinstance(input_nodes, dict): - assert len(g.ntypes) == 1 + # This happens on a homogeneous graph. + assert len(g.ntypes) == 1, \ + "The graph should be a homogeneous graph, " \ + f"but it has multiple node types {g.ntypes}" input_nodes = {g.ntypes[0]: input_nodes} nfeat_fields = task_info.dataloader.node_feat_fields @@ -101,7 +104,9 @@ def prepare_edge_mini_batch(data, task_info, mini_batch, device): """ input_nodes, batch_graph, blocks = mini_batch if not isinstance(input_nodes, dict): - assert len(batch_graph.ntypes) == 1 + assert len(batch_graph.ntypes) == 1, \ + "The graph should be a homogeneous graph, " \ + f"but it has multiple node types {batch_graph.ntypes}" input_nodes = {batch_graph.ntypes[0]: input_nodes} nfeat_fields = task_info.dataloader.node_feat_fields @@ -121,7 +126,9 @@ def prepare_edge_mini_batch(data, task_info, mini_batch, device): edge_decoder_feats = None # retrieving seed edge id from the graph to find labels - assert len(batch_graph.etypes) == 1 + assert len(batch_graph.etypes) == 1, \ + "Edge classification/regression tasks only support " \ + "conducting prediction on one edge type." target_etype = batch_graph.canonical_etypes[0] seeds = batch_graph.edges[target_etype].data[dgl.EID] label_field = task_info.dataloader.label_field @@ -162,7 +169,9 @@ def prepare_link_predict_mini_batch(data, task_info, mini_batch, device): input_nodes, pos_graph, neg_graph, blocks = mini_batch if not isinstance(input_nodes, dict): - assert len(pos_graph.ntypes) == 1 + assert len(pos_graph.ntypes) == 1, \ + "The graph should be a homogeneous graph, " \ + f"but it has multiple node types {pos_graph.ntypes}" input_nodes = {pos_graph.ntypes[0]: input_nodes} nfeat_fields = task_info.dataloader.node_feat_fields @@ -189,7 +198,7 @@ def prepare_link_predict_mini_batch(data, task_info, mini_batch, device): class GSgnnMultiTaskLearningTrainer(GSgnnTrainer): r""" A trainer for multi-task learning - This class is used to train models for multi task learning. + This class is used to train models for multi-task learning. It makes use of the functions provided by `GSgnnTrainer` to define two main functions: `fit` that performs the training diff --git a/python/graphstorm/trainer/np_trainer.py b/python/graphstorm/trainer/np_trainer.py index d875f8f7da..d4a264fc10 100644 --- a/python/graphstorm/trainer/np_trainer.py +++ b/python/graphstorm/trainer/np_trainer.py @@ -170,7 +170,10 @@ def fit(self, train_loader, num_epochs, total_steps += 1 if not isinstance(input_nodes, dict): - assert len(g.ntypes) == 1 + # This happens on a homogeneous graph. + assert len(g.ntypes) == 1, \ + "The graph should be a homogeneous graph, " \ + f"but it has multiple node types {g.ntypes}" input_nodes = {g.ntypes[0]: input_nodes} nfeat_fields = train_loader.node_feat_fields label_field = train_loader.label_field From 8ffeea6c8b0f2a97239531d05e7ca722d61379b1 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 29 May 2024 23:02:53 -0700 Subject: [PATCH 79/79] update --- python/graphstorm/trainer/mt_trainer.py | 8 +++++--- tests/end2end-tests/graphstorm-mt/mgpu_test.sh | 4 ++-- tests/unit-tests/test_trainer.py | 1 - 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 0553f0fa01..b6063e3b6b 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -217,7 +217,9 @@ def __init__(self, model, topk_model_to_save=1): super(GSgnnMultiTaskLearningTrainer, self).__init__(model, topk_model_to_save) assert isinstance(model, GSgnnMultiTaskModelInterface) \ and isinstance(model, GSgnnModelBase), \ - "The input model is not a GSgnnModel model. Please implement GSgnnModelBase." + "The input model is not a GSgnnModel model "\ + "or not implement the GSgnnMultiTaskModelInterface." \ + "Please implement GSgnnModelBase." def _prepare_mini_batch(self, data, task_info, mini_batch, device): """ prepare mini batch for a single task @@ -386,8 +388,8 @@ def fit(self, train_loader, data, val_loader, test_loader, total_steps) # TODO(xiangsx): Add early stop support - # Every n iterations, check to save the top k models. Will save - # the last k model or all models depends on the setting of top k + # Every n iterations, save the model and keep + # the lask k models. # TODO(xiangsx): support saving the best top k model. if save_model_frequency > 0 and \ total_steps % save_model_frequency == 0 and \ diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index 0298465592..f809a63df1 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -148,7 +148,7 @@ fi cnt=$(ls -l /data/gsgnn_mt/ | grep epoch | wc -l) if test $cnt != 3 then - echo "The number of save models $cnt is not equal to the specified topk 1" + echo "The number of save models $cnt is not equal to the specified topk 3" exit -1 fi @@ -187,6 +187,6 @@ error_and_exit $? cnt=$(ls -l /data/gsgnn_mt_2/ | grep epoch | wc -l) if test $cnt != 3 then - echo "The number of save models $cnt is not equal to the specified topk 1" + echo "The number of save models $cnt is not equal to the specified topk 3" exit -1 fi diff --git a/tests/unit-tests/test_trainer.py b/tests/unit-tests/test_trainer.py index 1ce6d02cd9..d6191c5fb0 100644 --- a/tests/unit-tests/test_trainer.py +++ b/tests/unit-tests/test_trainer.py @@ -268,7 +268,6 @@ def test_mtask_prepare_edge_mini_batch(): } node_feats = ep_data.get_node_feats(input_node_idx, 'feat') labels = ep_data.get_edge_feats(target_idx, 'label') - print(g.edges[('n0', 'r1', 'n1')]) batch_graph = dgl.heterograph( {('n0', 'r1', 'n1'): (th.randint(g.number_of_nodes("n0"), (g.number_of_edges('r1'),)), th.randint(g.number_of_nodes("n1"), (g.number_of_edges('r1'),)))}