diff --git a/docs/source/advanced/own-models.rst b/docs/source/advanced/own-models.rst index ebafdfafa2..26c6b45cc8 100644 --- a/docs/source/advanced/own-models.rst +++ b/docs/source/advanced/own-models.rst @@ -263,13 +263,13 @@ The GraphStorm trainers can have evaluators and task trackers associated. The fo .. code-block:: python # Optional: set up a evaluator - evaluator = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + evaluator = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) trainer.setup_evaluator(evaluator) # Optional: set up a task tracker to show the progress of training. tracker = GSSageMakerTaskTracker(config.eval_frequency) diff --git a/docs/source/api/graphstorm.eval.rst b/docs/source/api/graphstorm.eval.rst index a8193f0f2c..92f5b73162 100644 --- a/docs/source/api/graphstorm.eval.rst +++ b/docs/source/api/graphstorm.eval.rst @@ -7,8 +7,10 @@ graphstorm.eval Learning (GML) tasks. If users want to implement customized evaluators or evaluation methods, a best practice is to - extend base evaluators, i.e., the ``GSgnnInstanceEvaluator`` class for node or edge prediction - tasks, and ``GSgnnLPEvaluator`` for link prediction tasks, and then implement the abstract methods. + extend the base evaluator, i.e., the ``GSgnnBaseEvaluator``, and the corresponding evaluation + interfaces, e.g., ``GSgnnPredictionEvalInterface``` for prediction evaluation, and + ``GSgnnLPRankingEvalInterface`` for ranking based link prediction evaluation, and then + implement the abstract methods defined in those interface classes. .. currentmodule:: graphstorm.eval @@ -20,8 +22,9 @@ Base Evaluators :nosignatures: :template: evaltemplate.rst - GSgnnInstanceEvaluator - GSgnnLPEvaluator + GSgnnBaseEvaluator + GSgnnPredictionEvalInterface + GSgnnLPRankingEvalInterface Evaluators ----------- @@ -31,8 +34,7 @@ Evaluators :nosignatures: :template: evaltemplate.rst - GSgnnLPEvaluator + GSgnnClassificationEvaluator + GSgnnRegressionEvaluator GSgnnMrrLPEvaluator GSgnnPerEtypeMrrLPEvaluator - GSgnnAccEvaluator - GSgnnRegressionEvaluator diff --git a/examples/customized_models/HGT/hgt_nc.py b/examples/customized_models/HGT/hgt_nc.py index d00f469285..c90a43b07a 100644 --- a/examples/customized_models/HGT/hgt_nc.py +++ b/examples/customized_models/HGT/hgt_nc.py @@ -14,7 +14,7 @@ from graphstorm.inference import GSgnnNodePredictionInferrer from graphstorm.dataloading import GSgnnNodeTrainData, GSgnnNodeInferData from graphstorm.dataloading import GSgnnNodeDataLoader -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.tracker import GSSageMakerTaskTracker from graphstorm.utils import get_device @@ -326,13 +326,13 @@ def main(args): train_task=False) # Optional: set up a evaluator - evaluator = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + evaluator = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) trainer.setup_evaluator(evaluator) # Optional: set up a task tracker to show the progress of training. tracker = GSSageMakerTaskTracker(config.eval_frequency) diff --git a/examples/peft_llm_gnn/main_lp.py b/examples/peft_llm_gnn/main_lp.py index afafa5378b..14d89a3def 100644 --- a/examples/peft_llm_gnn/main_lp.py +++ b/examples/peft_llm_gnn/main_lp.py @@ -54,14 +54,12 @@ def main(config_args): trainer.setup_device(device=get_device()) # set evaluator - evaluator = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - config.num_negative_edges_eval, - config.lp_decoder_type, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy + evaluator = GSgnnMrrLPEvaluator( + eval_frequency=config.eval_frequency, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy ) # disbale validation for efficiency # trainer.setup_evaluator(evaluator) diff --git a/examples/peft_llm_gnn/main_nc.py b/examples/peft_llm_gnn/main_nc.py index 5325160556..b5ebc67349 100644 --- a/examples/peft_llm_gnn/main_nc.py +++ b/examples/peft_llm_gnn/main_nc.py @@ -3,7 +3,7 @@ from graphstorm.config import get_argument_parser from graphstorm.config import GSConfig from graphstorm.dataloading import GSgnnNodeDataLoader -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.dataloading import GSgnnNodeTrainData from graphstorm.utils import get_device from graphstorm.inference import GSgnnNodePredictionInferrer @@ -52,7 +52,7 @@ def main(config_args): trainer.setup_device(device=get_device()) # set evaluator - evaluator = GSgnnAccEvaluator( + evaluator = GSgnnClassificationEvaluator( config.eval_frequency, config.eval_metric, config.multilabel, diff --git a/examples/standalone_mode_demo.ipynb b/examples/standalone_mode_demo.ipynb index e5a8f09c6e..a34213418e 100644 --- a/examples/standalone_mode_demo.ipynb +++ b/examples/standalone_mode_demo.ipynb @@ -39,7 +39,7 @@ "from graphstorm.dataloading import GSgnnNodeTrainData, GSgnnNodeDataLoader, GSgnnNodeInferData\n", "from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, EntityClassifier, ClassifyLossFunc, RelationalGCNEncoder\n", "from graphstorm.inference import GSgnnNodePredictionInferrer\n", - "from graphstorm.eval import GSgnnAccEvaluator" + "from graphstorm.eval import GSgnnClassificationEvaluator" ] }, { @@ -315,9 +315,8 @@ "trainer.setup_device(device=device)\n", "\n", "# set up evaluator for the trainer:\n", - "evaluator = GSgnnAccEvaluator(\n", + "evaluator = GSgnnClassificationEvaluator(\n", " eval_frequency=10000,\n", - " eval_metric=['accuracy'],\n", " multilabel=multilabel)\n", "\n", "trainer.setup_evaluator(evaluator)" diff --git a/examples/temporal_graph_learning/main_nc.py b/examples/temporal_graph_learning/main_nc.py index e6eb4405d3..b7ebf9934f 100644 --- a/examples/temporal_graph_learning/main_nc.py +++ b/examples/temporal_graph_learning/main_nc.py @@ -3,7 +3,7 @@ from graphstorm.config import get_argument_parser from graphstorm.config import GSConfig from graphstorm.dataloading import GSgnnNodeDataLoader -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.dataloading import GSgnnNodeTrainData from graphstorm.utils import get_device from graphstorm.trainer import GSgnnNodePredictionTrainer @@ -45,7 +45,7 @@ def main(config_args): trainer.setup_device(device=get_device()) # set evaluator - evaluator = GSgnnAccEvaluator( + evaluator = GSgnnClassificationEvaluator( config.eval_frequency, config.eval_metric, config.multilabel, diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 1b6fe906ae..9e2cba7420 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -352,7 +352,7 @@ def verify_arguments(self, is_train): _ = self.log_report_frequency _ = self.task_type - # For classification tasks. + # For classification/regression tasks. if self.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_EDGE_CLASSIFICATION]: _ = self.label_field _ = self.num_classes @@ -368,6 +368,7 @@ def verify_arguments(self, is_train): BUILTIN_TASK_LINK_PREDICTION] and is_train: _ = self.exclude_training_targets _ = self.reverse_edge_types_map + # For link prediction tasks. if self.task_type == BUILTIN_TASK_LINK_PREDICTION: _ = self.gamma _ = self.lp_decoder_type diff --git a/python/graphstorm/eval/__init__.py b/python/graphstorm/eval/__init__.py index f07ea9339e..6860effa89 100644 --- a/python/graphstorm/eval/__init__.py +++ b/python/graphstorm/eval/__init__.py @@ -23,8 +23,7 @@ from .eval_func import SUPPORTED_REGRESSION_METRICS from .eval_func import SUPPORTED_LINK_PREDICTION_METRICS -from .evaluator import GSgnnInstanceEvaluator -from .evaluator import GSgnnLPEvaluator -from .evaluator import GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator -from .evaluator import GSgnnAccEvaluator +from .evaluator import GSgnnMrrLPEvaluator +from .evaluator import GSgnnPerEtypeMrrLPEvaluator +from .evaluator import GSgnnClassificationEvaluator from .evaluator import GSgnnRegressionEvaluator diff --git a/python/graphstorm/eval/eval_func.py b/python/graphstorm/eval/eval_func.py index 2439e9e09e..1ca530cb9c 100644 --- a/python/graphstorm/eval/eval_func.py +++ b/python/graphstorm/eval/eval_func.py @@ -103,6 +103,12 @@ def __init__(self): self.metric_function["mse"] = compute_mse self.metric_function["mae"] = compute_mae + # This is the operator used to measure each metric performance in evaluation + self.metric_eval_function = {} + self.metric_eval_function["rmse"] = compute_rmse + self.metric_eval_function["mse"] = compute_mse + self.metric_eval_function["mae"] = compute_mae + def assert_supported_metric(self, metric): """ check if the given metric is supported. """ @@ -135,6 +141,14 @@ def __init__(self): self.metric_comparator = {} self.metric_comparator["mrr"] = operator.le + # This is the operator used to measure each metric performance + self.metric_function = {} + self.metric_function["mrr"] = compute_mrr + + # This is the operator used to measure each metric performance in evaluation + self.metric_eval_function = {} + self.metric_eval_function["mrr"] = compute_mrr + def assert_supported_metric(self, metric): """ check if the given metric is supported. """ @@ -583,3 +597,19 @@ def compute_mae(pred, labels): diff = th.abs(pred.cpu() - labels.cpu()) return th.mean(diff).cpu().item() + +def compute_mrr(ranking): + """ Get link prediction mrr metrics + + Parameters + ---------- + ranking: + ranking of each positive edge + + Returns + ------- + link prediction mrr metrics: tensor + """ + logs = th.div(1.0, ranking) + metrics = th.tensor(th.div(th.sum(logs),len(logs))) + return metrics diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index 4ec2fbdd0d..07509b239b 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -15,7 +15,7 @@ Evaluator for different tasks. """ -import logging + import abc from statistics import mean import torch as th @@ -26,7 +26,7 @@ EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY, LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL) from ..utils import get_rank, get_world_size, barrier -from .utils import gen_mrr_score + def early_stop_avg_increase_judge(val_score, val_perf_list, comparator): """ @@ -82,7 +82,7 @@ def get_val_score_rank(val_score, val_perf_rank_list, comparator): Here use the most naive method, i.e., scan the entire list once to get the rank. For the same value, will treat the given validation score as the next rank. For example, in a - list [1., 1., 2., 2., 3., 4.], the given value 2. will be ranked to the 5th highest score. + list [1., 1., 2., 2., 3., 4.], the given value 2 will be ranked to the 5th highest score. Later on if need to increase the speed, could use more complex data structure, e.g. LinkedList @@ -108,15 +108,138 @@ def get_val_score_rank(val_score, val_perf_rank_list, comparator): return rank -# TODO(xiangsx): combine GSgnnInstanceEvaluator and GSgnnLPEvaluator -class GSgnnInstanceEvaluator(): - """ Template class for user defined evaluator. +class GSgnnPredictionEvalInterface(): + """ Interface for prediction evaluation functions + + The interface set the two abstract methods for prediction classes, i.e., Classification + and Regression, which share the same input arguments. + """ + + @abc.abstractmethod + def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): + """Evaluate validation and test sets for Prediciton tasks + + GSgnnTrainers will call this function to do evalution in their eval() fuction. + + Classification and regression evaluators should provide both predictions and labels in + validation and test sets. + + Parameters + ---------- + val_pred : tensor + The tensor stores the prediction results on the validation nodes. + test_pred : tensor + The tensor stores the prediction results on the test nodes. + val_labels : tensor + The tensor stores the labels of the validation nodes. + test_labels : tensor + The tensor stores the labels of the test nodes. + total_iters: int + The current interation number. + + Returns + ----------- + eval_score: float + Validation score + test_score: float + Test score + """ + + @abc.abstractmethod + def compute_score(self, pred, labels, train=True): + """ Compute evaluation score for Prediciton tasks + + Classification and regression evaluators should provide both predictions and labels. + + Parameters + ---------- + pred: + Rediction result + labels: + Label + train: boolean + If in model training. + + Returns + ------- + Evaluation metric values: dict + """ + + +class GSgnnLPRankingEvalInterface(): + """ Interface for Link Prediction evaluation function using ranking methods + + The interface set the two abstract methods for Link Prediction classes that use ranking + method to compute evaluation metrics, such as "mrr" (Mean Reciprocal Rank). + + There are two methdos to be implemented if inherite this interface. + 1. ``evaluate()`` method, which will be called by Trainers to provide ranking-based evaluation + results of validation and test sets during training process. + 2. ``compute_score()`` method, which compute the scores for given rankings. + """ + + @abc.abstractmethod + def evaluate(self, val_rankings, test_rankings, total_iters): + """Evaluate validation and test sets for Link Prediciton tasks + + GSgnnTrainers will call this function to do evalution in their eval() fuction. + + Link Prediction evaluators should provide the ranking of validation and test sets as + input. + + Parameters + ---------- + val_rankings: dict of tensors + The rankings of validation edges for each edge type in format of {etype: ranking}. + test_rankings: dict of tensors + The rankings of testing edges for each edge type in format of {etype: ranking}. + total_iters: int + The current interation number. + + Returns + ----------- + eval_score: float + Validation score for each edge type in format of {etype: score}. + test_score: float + Test score for each edge type in format of {etype: score}. + """ + + @abc.abstractmethod + def compute_score(self, rankings, train=True): + """ Compute evaluation score for Prediciton tasks + + Ranking-based link prediction evaluators should provide ranking values as input. + + Parameters + ---------- + rankings: dict of tensors + Rankings of positive scores in format of {etype: ranking} + train: boolean + If in model training. + + Returns + ------- + Evaluation metric values: dict + scores for each edge type. + """ + + +class GSgnnBaseEvaluator(): + """ Base class for Evaluators. + + New base class in V0.3 to replace ``GSgnnInstanceEvaluator`` and ``GSgnnLPEvaluator``. This + class serves as the base for the built-in ``GSgnnClassificationEvaluator``, + ``GSgnnRegressionEvaluator``, ``GSgnnMrrLPEvaluator``, and ``GSgnnPerEtypeMrrLPEvaluator``. + + In order to create customized Evaluators, users can inherite this class and the corresponding + EvalInteface class, and then implement the two abstract methods, i.e., ``evaluate()`` + and ``compute_score()`` accordingly. Parameters ---------- eval_frequency: int The frequency (# of iterations) of doing evaluation. - eval_metric: list of string + eval_metric_list: list of string Evaluation metric used during evaluation. use_early_stop: bool Set true to use early stop. @@ -128,7 +251,7 @@ class GSgnnInstanceEvaluator(): The early stop strategy. GraphStorm supports two strategies: 1) consecutive_increase and 2) average_increase. """ - def __init__(self, eval_frequency, eval_metric, + def __init__(self, eval_frequency, eval_metric_list, use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, @@ -142,9 +265,8 @@ def __init__(self, eval_frequency, eval_metric, self._best_iter = None self.metrics_obj = None # Evaluation metrics obj - self._metric = eval_metric - assert len(self.metric) > 0, \ - "At least one metric must be defined" + self._metric_list = eval_metric_list + assert len(self.metric_list) > 0, "At least one metric must be defined, but got 0." self._eval_frequency = eval_frequency self._do_early_stop = use_early_stop if self._do_early_stop: @@ -153,7 +275,7 @@ def __init__(self, eval_frequency, eval_metric, self._early_stop_rounds = early_stop_rounds self._early_stop_strategy = early_stop_strategy self._val_perf_list = [] - # add this list to store + # add this list to store all of the performance rank of validation scores for pick top k self._val_perf_rank_list = [] def setup_task_tracker(self, task_tracker): @@ -166,32 +288,6 @@ def setup_task_tracker(self, task_tracker): """ self.tracker = task_tracker - @abc.abstractmethod - def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): - """ - GSgnnLinkPredictionModel.fit() will call this function to do user defined evalution. - - Parameters - ---------- - val_pred : tensor - The tensor stores the prediction results on the validation nodes. - test_pred : tensor - The tensor stores the prediction results on the test nodes. - val_labels : tensor - The tensor stores the labels of the validation nodes. - test_labels : tensor - The tensor stores the labels of the test nodes. - total_iters: int - The current interation number. - - Returns - ----------- - eval_score: float - Validation score - test_score: float - Test score - """ - def do_eval(self, total_iters, epoch_end=False): """ Decide whether to do the evaluation in current iteration or epoch @@ -212,26 +308,6 @@ def do_eval(self, total_iters, epoch_end=False): return True return False - - @abc.abstractmethod - def compute_score(self, pred, labels): - """ Compute evaluation score - - Parameters - ---------- - pred: - Rediction result - labels: - Label - """ - - def print_history(self): - """ Print history eval info - """ - for val_score, test_score in self._history: - logging.info("val %s: %.3f, test %s: %.3f", - self.metric, val_score, self.metric, test_score) - def do_early_stop(self, val_score): """ Decide whether to stop the training @@ -260,10 +336,12 @@ def do_early_stop(self, val_score): # does not improve in the last N evaluation iterations if self._early_stop_strategy == EARLY_STOP_AVERAGE_INCREASE_STRATEGY: early_stop = early_stop_avg_increase_judge(val_score, - self._val_perf_list, self.get_metric_comparator()) + self._val_perf_list, + self.get_metric_comparator()) elif self._early_stop_strategy == EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY: early_stop = early_stop_cons_increase_judge(val_score, - self._val_perf_list, self.get_metric_comparator()) + self._val_perf_list, + self.get_metric_comparator()) else: return False @@ -277,9 +355,8 @@ def get_metric_comparator(self): We treat the first metric in all evaluation metrics as the major metric. """ - assert self.metrics_obj is not None, \ - "Evaluation metrics object should not be None" - metric = self.metric[0] + assert self.metrics_obj is not None, "Evaluation metrics object should not be None." + metric = self.metric_list[0] return self.metrics_obj.metric_comparator[metric] def get_val_score_rank(self, val_score): @@ -303,10 +380,10 @@ def get_val_score_rank(self, val_score): return rank @property - def metric(self): + def metric_list(self): """ evaluation metrics """ - return self._metric + return self._metric_list @property def best_val_score(self): @@ -329,12 +406,13 @@ def best_iter_num(self): @property def history(self): """ Evaluation history - + Returns ------- A list of evaluation history in training. The detailed contents of the list rely on specific Evaluators. For example, ``GSgnnRegressionEvaluator`` and - ``GSgnnAccEvaluator`` add a tuple of validation and testing score as one list element. + ``GSgnnClassificationEvaluator`` add a tuple of validation and testing score as one + list element. """ return self._history @@ -350,15 +428,29 @@ def task_tracker(self): """ return self.tracker -class GSgnnRegressionEvaluator(GSgnnInstanceEvaluator): - """ The class for user defined evaluator. + @property + def val_perf_rank_list(self): + """ validation performance rank list + """ + return self._val_perf_rank_list + + +class GSgnnClassificationEvaluator(GSgnnBaseEvaluator, GSgnnPredictionEvalInterface): + """Classification evaluator + + GS built-in evaluator for classification task. It uses "accuracy" as the default eval metric, + and sets the multilabel to be False. + + It replacees the ``GSgnnAccEvaluator`` since v0.3. Parameters ---------- eval_frequency: int The frequency (number of iterations) of doing evaluation. - eval_metric: list of string - Evaluation metric used during evaluation. + eval_metric_list: list of string + Evaluation metrics used during evaluation. Default: ["accuracy"]. + multilabel: bool + If set to true, the task is a multi-label classification task. Default: False. use_early_stop: bool Set true to use early stop. early_stop_burnin_rounds: int @@ -370,20 +462,28 @@ class GSgnnRegressionEvaluator(GSgnnInstanceEvaluator): 1) consecutive_increase and 2) average_increase. """ def __init__(self, eval_frequency, - eval_metric, + eval_metric_list=None, + multilabel=False, use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): - super(GSgnnRegressionEvaluator, self).__init__(eval_frequency, - eval_metric, use_early_stop, early_stop_burnin_rounds, - early_stop_rounds, early_stop_strategy) + # set default metric list + if eval_metric_list is None: + eval_metric_list = ["accuracy"] + super(GSgnnClassificationEvaluator, self).__init__(eval_frequency, + eval_metric_list, + use_early_stop, + early_stop_burnin_rounds, + early_stop_rounds, + early_stop_strategy) + self._multilabel = multilabel self._best_val_score = {} self._best_test_score = {} self._best_iter = {} - self.metrics_obj = RegressionMetrics() + self.metrics_obj = ClassificationMetrics(multilabel=self._multilabel) - for metric in self.metric: + for metric in self.metric_list: self.metrics_obj.assert_supported_metric(metric=metric) self._best_val_score[metric] = self.metrics_obj.init_best_metric(metric=metric) self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) @@ -407,9 +507,9 @@ def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): Returns ----------- float - Validation MSE + Validation Score float - Test MSE + Test Score """ # exchange preds and labels between runners local_rank = get_rank() @@ -422,13 +522,13 @@ def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): if test_labels is not None else None with th.no_grad(): - val_score = self.compute_score(val_pred, val_labels) - test_score = self.compute_score(test_pred, test_labels) + val_score = self.compute_score(val_pred, val_labels, train=False) + test_score = self.compute_score(test_pred, test_labels, train=False) - for metric in self.metric: + for metric in self.metric_list: # be careful whether > or < it might change per metric. - if self.metrics_obj.metric_comparator[metric](self._best_val_score[metric], - val_score[metric]): + if self.metrics_obj.metric_comparator[metric]( + self._best_val_score[metric], val_score[metric]): self._best_val_score[metric] = val_score[metric] self._best_test_score[metric] = test_score[metric] self._best_iter[metric] = total_iters @@ -436,7 +536,7 @@ def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): return val_score, test_score - def compute_score(self, pred, labels): + def compute_score(self, pred, labels, train=True): """ Compute evaluation score Parameters @@ -445,35 +545,48 @@ def compute_score(self, pred, labels): Rediction result labels: Label + train: boolean + If in model training. Returns ------- Evaluation metric values: dict """ - scores = {} - if pred is None or labels is None: - for metric in self.metric: - scores[metric] = "N/A" - else: # pred is not None and labels is not None - pred = th.squeeze(pred) - labels = th.squeeze(labels) - pred = pred.to(th.float32) - labels = labels.to(th.float32) - for metric in self.metric: - scores[metric] = self.metrics_obj.metric_function[metric](pred, labels) - return scores + results = {} + for metric in self.metric_list: + if pred is not None and labels is not None: + if train: + # training expects always a single number to be + # returned and has a different (potentially) evalution function + results[metric] = self.metrics_obj.metric_function[metric](pred, labels) + else: + # validation or testing may have a different + # evaluation function, in our case the evaluation code + # may return a dictionary with the metric values for each metric + results[metric] = self.metrics_obj.metric_eval_function[metric](pred, labels) + else: + # if the pred is None or the labels is None the metric can not me computed + results[metric] = "N/A" + return results + + @property + def multilabel(self): + """ Indicator of if using mutliple labels + """ + return self._multilabel + + +class GSgnnRegressionEvaluator(GSgnnBaseEvaluator, GSgnnPredictionEvalInterface): + """ Regression Evaluator. -class GSgnnAccEvaluator(GSgnnInstanceEvaluator): - """ The class for user defined evaluator. + GS built-in evaluator for regression task. It uses "rmse" as the default eval metric. Parameters ---------- eval_frequency: int The frequency (number of iterations) of doing evaluation. - eval_metric: list of string - Evaluation metric used during evaluation. - multilabel: bool - If set to true, the task is a multi-label classification task. + eval_metric_list: list of string + Evaluation metric used during evaluation. Default: ["rmse"]. use_early_stop: bool Set true to use early stop. early_stop_burnin_rounds: int @@ -485,21 +598,23 @@ class GSgnnAccEvaluator(GSgnnInstanceEvaluator): 1) consecutive_increase and 2) average_increase. """ def __init__(self, eval_frequency, - eval_metric, multilabel, + eval_metric_list=None, use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, - early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): # pylint: disable=unused-argument - super(GSgnnAccEvaluator, self).__init__(eval_frequency, - eval_metric, use_early_stop, early_stop_burnin_rounds, + early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): + # set default metric list + if eval_metric_list is None: + eval_metric_list = ["rmse"] + super(GSgnnRegressionEvaluator, self).__init__(eval_frequency, + eval_metric_list, use_early_stop, early_stop_burnin_rounds, early_stop_rounds, early_stop_strategy) - self.multilabel = multilabel self._best_val_score = {} self._best_test_score = {} self._best_iter = {} - self.metrics_obj = ClassificationMetrics(multilabel=self.multilabel) + self.metrics_obj = RegressionMetrics() - for metric in self.metric: + for metric in self.metric_list: self.metrics_obj.assert_supported_metric(metric=metric) self._best_val_score[metric] = self.metrics_obj.init_best_metric(metric=metric) self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) @@ -523,9 +638,9 @@ def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): Returns ----------- float - Validation Score + Validation MSE float - Test Score + Test MSE """ # exchange preds and labels between runners local_rank = get_rank() @@ -538,13 +653,13 @@ def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): if test_labels is not None else None with th.no_grad(): - val_score = self.compute_score(val_pred, val_labels, train=False) - test_score = self.compute_score(test_pred, test_labels, train=False) + val_score = self.compute_score(val_pred, val_labels) + test_score = self.compute_score(test_pred, test_labels) - for metric in self.metric: + for metric in self.metric_list: # be careful whether > or < it might change per metric. - if self.metrics_obj.metric_comparator[metric]( - self._best_val_score[metric],val_score[metric]): + if self.metrics_obj.metric_comparator[metric](self._best_val_score[metric], + val_score[metric]): self._best_val_score[metric] = val_score[metric] self._best_test_score[metric] = test_score[metric] self._best_iter[metric] = total_iters @@ -561,37 +676,54 @@ def compute_score(self, pred, labels, train=True): Rediction result labels: Label + train: boolean + If in model training. Returns ------- Evaluation metric values: dict """ - results = {} - for metric in self.metric: + scores = {} + for metric in self.metric_list: if pred is not None and labels is not None: + pred = th.squeeze(pred) + labels = th.squeeze(labels) + pred = pred.to(th.float32) + labels = labels.to(th.float32) + if train: # training expects always a single number to be - # returned and has a different (potentially) function - results[metric] = self.metrics_obj.metric_function[metric](pred, labels) + # returned and has a different (potentially) evluation function + scores[metric] = self.metrics_obj.metric_function[metric](pred, labels) else: # validation or testing may have a different # evaluation function, in our case the evaluation code - # may return a dictionary with the metric values for each label - results[metric] = self.metrics_obj.metric_eval_function[metric](pred, labels) + # may return a dictionary with the metric values for each metric + scores[metric] = self.metrics_obj.metric_eval_function[metric](pred, labels) else: # if the pred is None or the labels is None the metric can not me computed - results[metric] = "N/A" - return results + scores[metric] = "N/A" -class GSgnnLPEvaluator(): - """ Template class for user defined evaluator. + return scores + + +class GSgnnMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): + """ Link Prediction Evaluator using "mrr" as metric. + + GS built-in evaluator for Link Prediction tasks. It uses "mrr" as the default eval metric, + which implements the `GSgnnLPRankingEvalInterface`. + + To create a customized LP evaluator that use evaluation metric other than "mrr", users might + need to 1) define a new evaluation interface if the evaluation method requires different input + arguments; 2) inherite the new evaluation interface in a customized LP evaluator; 3) define + a customized LP trainer/inferrer to call the customized LP evaluator. Parameters ---------- eval_frequency: int The frequency (number of iterations) of doing evaluation. - eval_metric: list of string - Evaluation metric used during evaluation. + eval_metric_list: list of string + Evaluation metric used during evaluation. Default: ['mrr'] use_early_stop: bool Set true to use early stop. early_stop_burnin_rounds: int @@ -602,335 +734,120 @@ class GSgnnLPEvaluator(): The early stop strategy. GraphStorm supports two strategies: 1) consecutive_increase and 2) average_increase. """ - def __init__(self, eval_frequency, eval_metric, + def __init__(self, eval_frequency, + eval_metric_list=None, use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): - # nodes whose embeddings are used during evaluation - # if None all nodes are used. - self._target_nidx = None - self.tracker = None - self._best_val_score = None - self._best_test_score = None - self._best_iter = None - self.metrics_obj = None # Evaluation metrics obj - - self._metric = eval_metric - assert len(self.metric) > 0, "At least one metric must be defined" - self._eval_frequency = eval_frequency - self._do_early_stop = use_early_stop - if self._do_early_stop: - self._early_stop_burnin_rounds = early_stop_burnin_rounds - self._num_early_stop_calls = 0 - self._early_stop_rounds = early_stop_rounds - self._early_stop_strategy = early_stop_strategy - self._val_perf_list = [] - # add this list to store all of the performance rank of validation scores for pick top k - self._val_perf_rank_list = [] - - def setup_task_tracker(self, task_tracker): - """ Setup evaluation tracker - - Parameters - ---------- - task_tracker: - a task tracker - """ - self.tracker = task_tracker + # set default metric list + if eval_metric_list is None: + eval_metric_list = ["mrr"] + super(GSgnnMrrLPEvaluator, self).__init__(eval_frequency, + eval_metric_list, use_early_stop, early_stop_burnin_rounds, + early_stop_rounds, early_stop_strategy) + self.metrics_obj = LinkPredictionMetrics() - @abc.abstractmethod - def evaluate(self, val_scores, test_scores, total_iters): - """ - GSgnnLinkPredictionModel.fit() will call this function to do user defined evalution. + self._best_val_score = {} + self._best_test_score = {} + self._best_iter = {} + for metric in self.metric_list: + self._best_val_score[metric] = self.metrics_obj.init_best_metric(metric=metric) + self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) + self._best_iter[metric] = 0 - Note: Make sure each trainer will get the same validation scores. - The early stop and model saving progress rely on certain scores. + def evaluate(self, val_rankings, test_rankings, total_iters): + """ `GSgnnLinkPredictionTrainer` and `GSgnnLinkPredictionInferrer` will call this function + to compute validation and test scores. Parameters ---------- - val_scores: dict of tensors - The rankings of validation edges for each edge type. - test_scores: dict of tensors - The rankings of testing edges for each edge type. + val_rankings: dict of tensors + Rankings of positive scores of validation edges for each edge type. + test_rankings: dict of tensors + Rankings of positive scores of test edges for each edge type.. total_iters: int The current interation number. Returns ----------- - eval_score: float - Validation score - test_score: float - Test score - """ - - def do_eval(self, total_iters, epoch_end=False): - """ Decide whether to do the evaluation in current iteration or epoch - - Parameters - ---------- - epoch: int - The epoch number - total_iters: int - The total number of iterations has been taken. - epoch_end: bool - Whether it is the end of an epoch - - Returns - ------- - Whether do evaluation: bool - """ - if epoch_end: - return True - elif self._eval_frequency != 0 and \ - total_iters % self._eval_frequency == 0: - return True - return False - - def do_early_stop(self, val_score): - """ Decide whether to stop the training - - Parameters - ---------- - val_score: float - Evaluation score - """ - if self._do_early_stop is False: - return False - - assert len(val_score) == 1, \ - f"valudation score should be a signle key value pair but got {val_score}" - self._num_early_stop_calls += 1 - # Not enough existing validation scores - if self._num_early_stop_calls <= self._early_stop_burnin_rounds: - return False - - val_score = list(val_score.values())[0] - # Not enough validation scores to make early stop decision - if len(self._val_perf_list) < self._early_stop_rounds: - self._val_perf_list.append(val_score) - return False - - # early stop criteria: if the average evaluation value - # does not improve in the last N evaluation iterations - if self._early_stop_strategy == EARLY_STOP_AVERAGE_INCREASE_STRATEGY: - early_stop = early_stop_avg_increase_judge(val_score, - self._val_perf_list, self.get_metric_comparator()) - elif self._early_stop_strategy == EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY: - early_stop = early_stop_cons_increase_judge(val_score, - self._val_perf_list, self.get_metric_comparator()) - - self._val_perf_list.pop(0) - self._val_perf_list.append(val_score) - - return early_stop - - def get_metric_comparator(self): - """ Return the comparator of the major eval metric. - - We treat the first metric in all evaluation metrics as the major metric. - """ - - assert self.metrics_obj is not None, \ - "Evaluation metrics object should not be None" - metric = self.metric[0] - return self.metrics_obj.metric_comparator[metric] - - def get_val_score_rank(self, val_score): - """ - Get the rank of the given val score by comparing its values to the existing value list. - - Parameters - ---------- - val_score: dict - A dictionary whose key is the metric and the value is a score from evaluator's - validation computation. - """ - val_score = list(val_score.values())[0] - - rank = get_val_score_rank(val_score, - self._val_perf_rank_list, - self.get_metric_comparator()) - # after compare, append the score into existing list - self._val_perf_rank_list.append(val_score) - return rank - - @property - def target_nidx(self): - """ target_nidx - """ - return self._target_nidx - - @property - def metric(self): - """ evaluation metrics - """ - return self._metric - - @property - def best_val_score(self): - """ Best validation score - """ - return self._best_val_score - - @property - def best_test_score(self): - """ Best test score - """ - return self._best_test_score - - @property - def best_iter_num(self): - """ Best iteration number - """ - return self._best_iter - - @property - def val_perf_rank_list(self): - """ validation performance rank list - """ - return self._val_perf_rank_list - - @property - def eval_frequency(self): - """ Evaluation frequency. - """ - return self._eval_frequency - - @property - def task_tracker(self): - """ Task tracker of this evaluator + val_mrr: float + Validation mrr score + test_mrr: float + Test mrr score """ - return self.tracker - -class GSgnnMrrLPEvaluator(GSgnnLPEvaluator): - """ The class for link prediction evaluation using Mrr metric. + with th.no_grad(): + if test_rankings is not None: + test_score = self.compute_score(test_rankings) + else: + for metric in self.metric_list: + test_score = {metric: "N/A"} # Dummy - Parameters - ---------- - eval_frequency: int - The frequency (# of iterations) of doing evaluation. - data: GSgnnEdgeData - The processed dataset - num_negative_edges_eval: int - Number of negative edges sampled for each positive edge in evalation. - lp_decoder_type: str - Link prediction decoder type. - use_early_stop: bool - Set true to use early stop. - early_stop_burnin_rounds: int - Burn-in rounds before start checking for the early stop condition. - early_stop_rounds: int - The number of rounds for validation scores used to decide early stop. - early_stop_strategy: str - The early stop strategy. GraphStorm supports two strategies: - 1) consecutive_increase and 2) average_increase. - """ - def __init__(self, eval_frequency, data, - num_negative_edges_eval, lp_decoder_type, - use_early_stop=False, - early_stop_burnin_rounds=0, - early_stop_rounds=3, - early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): - eval_metric = ["mrr"] - super(GSgnnMrrLPEvaluator, self).__init__(eval_frequency, - eval_metric, use_early_stop, early_stop_burnin_rounds, - early_stop_rounds, early_stop_strategy) - self.train_idxs = data.train_idxs - self.val_idxs = data.val_idxs - self.test_idxs = data.test_idxs - self.num_negative_edges_eval = num_negative_edges_eval - self.lp_decoder_type = lp_decoder_type + if val_rankings is not None: + val_score = self.compute_score(val_rankings) - self.metrics_obj = LinkPredictionMetrics() + if get_rank() == 0: + for metric in self.metric_list: + # be careful whether > or < it might change per metric. + if self.metrics_obj.metric_comparator[metric]( + self._best_val_score[metric], val_score[metric]): + self._best_val_score[metric] = val_score[metric] + self._best_test_score[metric] = test_score[metric] + self._best_iter[metric] = total_iters + else: + for metric in self.metric_list: + val_score = {metric: "N/A"} # Dummy - self._best_val_score = {} - self._best_test_score = {} - self._best_iter = {} - for metric in self.metric: - self._best_val_score[metric] = self.metrics_obj.init_best_metric(metric=metric) - self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) - self._best_iter[metric] = 0 + return val_score, test_score - def compute_score(self, rankings, train=False): # pylint:disable=unused-argument + def compute_score(self, rankings, train=True): """ Compute evaluation score Parameters ---------- rankings: dict of tensors Rankings of positive scores in format of {etype: ranking} - train: bool - TODO: Reversed for future use cases when we want to use different - way to generate scores for train (more efficient but less accurate) - and test. + train: boolean + If in model training. Returns ------- Evaluation metric values: dict """ # We calculate global mrr, etype is ignored. - # User can develop its own per etype MRR evaluator ranking = [] for _, rank in rankings.items(): ranking.append(rank) ranking = th.cat(ranking, dim=0) - metrics = gen_mrr_score(ranking) + # compute ranking value for each metric + metrics = {} + for metric in self.metric_list: + if train: + # training expects always a single number to be + # returned and has a different (potentially) evluation function + metrics[metric] = self.metrics_obj.metric_function[metric](ranking) + else: + # validation or testing may have a different + # evaluation function, in our case the evaluation code + # may return a dictionary with the metric values for each metric + metrics[metric] = self.metrics_obj.metric_eval_function[metric](ranking) # When world size == 1, we do not need the barrier if get_world_size() > 1: barrier() for _, metric_val in metrics.items(): th.distributed.all_reduce(metric_val) + return_metrics = {} for metric, metric_val in metrics.items(): return_metric = metric_val / get_world_size() return_metrics[metric] = return_metric.item() - return return_metrics - - def evaluate(self, val_scores, test_scores, total_iters): - """ GSgnnLinkPredictionModel.fit() will call this function to do user defined evalution. - - Parameters - ---------- - val_scores: dict of tensors - Rankings of positive scores of validation edges for each edge type. - test_scores: dict of tensors - Rankings of positive scores of test edges for each edge type.. - total_iters: int - The current interation number. - - Returns - ----------- - val_mrr: float - Validation mrr - test_mrr: float - Test mrr - """ - with th.no_grad(): - if test_scores is not None: - test_score = self.compute_score(test_scores) - else: - test_score = {"mrr": "N/A"} # Dummy - - if val_scores is not None: - val_score = self.compute_score(val_scores) - if get_rank() == 0: - for metric in self.metric: - # be careful whether > or < it might change per metric. - if self.metrics_obj.metric_comparator[metric]( - self._best_val_score[metric], val_score[metric]): - self._best_val_score[metric] = val_score[metric] - self._best_test_score[metric] = test_score[metric] - self._best_iter[metric] = total_iters - else: - val_score = {"mrr": "N/A"} # Dummy - - return val_score, test_score + return return_metrics -class GSgnnPerEtypeMrrLPEvaluator(GSgnnMrrLPEvaluator): +class GSgnnPerEtypeMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): """ The class for link prediction evaluation using Mrr metric and return a Per etype mrr score. @@ -938,12 +855,8 @@ class GSgnnPerEtypeMrrLPEvaluator(GSgnnMrrLPEvaluator): ---------- eval_frequency: int The frequency (# of iterations) of doing evaluation. - data: GSgnnEdgeData - The processed dataset - num_negative_edges_eval: int - Number of negative edges sampled for each positive edge in evalation. - lp_decoder_type: str - Link prediction decoder type. + eval_metric_list: list of string + Evaluation metric used during evaluation. Default: ['mrr'] major_etype: tuple Canonical etype used for selecting the best model. If None, use the general mrr. use_early_stop: bool @@ -956,31 +869,42 @@ class GSgnnPerEtypeMrrLPEvaluator(GSgnnMrrLPEvaluator): The early stop strategy. GraphStorm supports two strategies: 1) consecutive_increase and 2) average_increase. """ - def __init__(self, eval_frequency, data, - num_negative_edges_eval, lp_decoder_type, + def __init__(self, eval_frequency, + eval_metric_list=None, major_etype = LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL, use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): - self.major_etype = major_etype + # set default metric list + if eval_metric_list is None: + eval_metric_list = ["mrr"] super(GSgnnPerEtypeMrrLPEvaluator, self).__init__(eval_frequency, - data, num_negative_edges_eval, lp_decoder_type, - use_early_stop, early_stop_burnin_rounds, - early_stop_rounds, early_stop_strategy) + eval_metric_list, + use_early_stop, + early_stop_burnin_rounds, + early_stop_rounds, + early_stop_strategy) + self.major_etype = major_etype + self.metrics_obj = LinkPredictionMetrics() + self._best_val_score = {} + self._best_test_score = {} + self._best_iter = {} + for metric in self.metric_list: + self._best_val_score[metric] = self.metrics_obj.init_best_metric(metric=metric) + self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) + self._best_iter[metric] = 0 - def compute_score(self, rankings, train=False): # pylint:disable=unused-argument + def compute_score(self, rankings, train=True): """ Compute evaluation score Parameters ---------- rankings: dict of tensors Rankings of positive scores in format of {etype: ranking} - train: bool - TODO: Reversed for future use cases when we want to use different - way to generate scores for train (more efficient but less accurate) - and test. + train: boolean + If in model training. Returns ------- @@ -988,19 +912,31 @@ def compute_score(self, rankings, train=False): # pylint:disable=unused-argument """ # We calculate global mrr, etype is ignored. # User can develop its own per etype MRR evaluator - metrics = {} - for etype, rank in rankings.items(): - metrics[etype] = gen_mrr_score(rank) + per_etype_metrics = {} + for etype, ranking in rankings.items(): + # compute ranking value for each metric + metrics = {} + for metric in self.metric_list: + if train: + # training expects always a single number to be + # returned and has a different (potentially) evluation function + metrics[metric] = self.metrics_obj.metric_function[metric](ranking) + else: + # validation or testing may have a different + # evaluation function, in our case the evaluation code + # may return a dictionary with the metric values for each metric + metrics[metric] = self.metrics_obj.metric_eval_function[metric](ranking) + per_etype_metrics[etype] = metrics # When world size == 1, we do not need the barrier if get_world_size() > 1: barrier() - for _, metric in metrics.items(): + for _, metric in per_etype_metrics.items(): for _, metric_val in metric.items(): th.distributed.all_reduce(metric_val) return_metrics = {} - for etype, metric in metrics.items(): + for etype, metric in per_etype_metrics.items(): for metric_key, metric_val in metric.items(): return_metric = metric_val / get_world_size() if metric_key not in return_metrics: @@ -1018,14 +954,15 @@ def _get_major_score(self, score): major_score = score[self.major_etype] return major_score - def evaluate(self, val_scores, test_scores, total_iters): - """ GSgnnLinkPredictionModel.fit() will call this function to do user defined evalution. + def evaluate(self, val_rankings, test_rankings, total_iters): + """ `GSgnnLinkPredictionTrainer` and `GSgnnLinkPredictionInferrer` will call this function + to compute validation and test mrr scores. Parameters ---------- - val_scores: dict of tensors + val_rankings: dict of tensors Rankings of positive scores of validation edges for each edge type. - test_scores: dict of tensors + test_rankings: dict of tensors Rankings of positive scores of test edges for each edge type.. total_iters: int The current interation number. @@ -1038,16 +975,17 @@ def evaluate(self, val_scores, test_scores, total_iters): Test mrr """ with th.no_grad(): - if test_scores is not None: - test_score = self.compute_score(test_scores) + if test_rankings is not None: + test_score = self.compute_score(test_rankings) else: - test_score = {"mrr": "N/A"} # Dummy + for metric in self.metric_list: + test_score = {metric: "N/A"} # Dummy - if val_scores is not None: - val_score = self.compute_score(val_scores) + if val_rankings is not None: + val_score = self.compute_score(val_rankings) if get_rank() == 0: - for metric in self.metric: + for metric in self.metric_list: # be careful whether > or < it might change per metric. major_val_score = self._get_major_score(val_score[metric]) major_test_score = self._get_major_score(test_score[metric]) @@ -1057,7 +995,8 @@ def evaluate(self, val_scores, test_scores, total_iters): self._best_test_score[metric] = major_test_score self._best_iter[metric] = total_iters else: - val_score = {"mrr": "N/A"} # Dummy + for metric in self.metric_list: + val_score = {metric: "N/A"} # Dummy return val_score, test_score diff --git a/python/graphstorm/inference/graphstorm_infer.py b/python/graphstorm/inference/graphstorm_infer.py index ec4d579c3c..3ac5565c37 100644 --- a/python/graphstorm/inference/graphstorm_infer.py +++ b/python/graphstorm/inference/graphstorm_infer.py @@ -93,7 +93,7 @@ def log_print_metrics(self, val_score, test_score, dur_eval, total_steps, train_ best_val_score = self.evaluator.best_val_score best_test_score = self.evaluator.best_test_score best_iter_num = self.evaluator.best_iter_num - self.task_tracker.log_iter_metrics(self.evaluator.metric, + self.task_tracker.log_iter_metrics(self.evaluator.metric_list, train_score=train_score, val_score=val_score, test_score=test_score, diff --git a/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py b/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py index f8dcf99d9c..065cd3502e 100644 --- a/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py @@ -20,7 +20,7 @@ from graphstorm.config import get_argument_parser from graphstorm.config import GSConfig from graphstorm.inference import GSgnnEdgePredictionInferrer -from graphstorm.eval import GSgnnAccEvaluator, GSgnnRegressionEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator, GSgnnRegressionEvaluator from graphstorm.dataloading import GSgnnEdgeInferData, GSgnnEdgeDataLoader from graphstorm.utils import get_device, get_lm_ntypes, use_wholegraph @@ -31,9 +31,9 @@ def get_evaluator(config): # pylint: disable=unused-argument return GSgnnRegressionEvaluator(config.eval_frequency, config.eval_metric) elif config.task_type == 'edge_classification': - return GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel) + return GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel) else: raise AttributeError(config.task_type + ' is not supported.') diff --git a/python/graphstorm/run/gsgnn_ep/ep_infer_lm.py b/python/graphstorm/run/gsgnn_ep/ep_infer_lm.py index 73626f90d3..ac1be177be 100644 --- a/python/graphstorm/run/gsgnn_ep/ep_infer_lm.py +++ b/python/graphstorm/run/gsgnn_ep/ep_infer_lm.py @@ -21,7 +21,7 @@ from graphstorm.config import get_argument_parser from graphstorm.config import GSConfig from graphstorm.inference import GSgnnEdgePredictionInferrer -from graphstorm.eval import GSgnnAccEvaluator, GSgnnRegressionEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator, GSgnnRegressionEvaluator from graphstorm.dataloading import GSgnnEdgeInferData, GSgnnEdgeDataLoader from graphstorm.utils import get_device @@ -32,9 +32,9 @@ def get_evaluator(config): # pylint: disable=unused-argument return GSgnnRegressionEvaluator(config.eval_frequency, config.eval_metric) elif config.task_type == 'edge_classification': - return GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel) + return GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel) else: raise AttributeError(config.task_type + ' is not supported.') diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py index c943c747ac..b0bb9fbda5 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py @@ -23,7 +23,7 @@ from graphstorm.config import GSConfig from graphstorm.trainer import GSgnnEdgePredictionTrainer from graphstorm.dataloading import GSgnnEdgeTrainData, GSgnnEdgeDataLoader -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.eval import GSgnnRegressionEvaluator from graphstorm.model.utils import save_full_node_embeddings from graphstorm.model import do_full_graph_inference @@ -34,13 +34,13 @@ def get_evaluator(config): """ Get evaluator class """ if config.task_type == "edge_classification": - return GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + return GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) elif config.task_type == "edge_regression": return GSgnnRegressionEvaluator(config.eval_frequency, config.eval_metric, diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py index 18d45b1e9e..ea1cddae21 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py @@ -23,7 +23,7 @@ from graphstorm.config import GSConfig from graphstorm.trainer import GSgnnEdgePredictionTrainer from graphstorm.dataloading import GSgnnEdgeTrainData, GSgnnEdgeDataLoader -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.eval import GSgnnRegressionEvaluator from graphstorm.model.utils import save_full_node_embeddings from graphstorm.model import do_full_graph_inference @@ -33,13 +33,13 @@ def get_evaluator(config): """ Get evaluator class """ if config.task_type == "edge_classification": - return GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + return GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) elif config.task_type == "edge_regression": return GSgnnRegressionEvaluator(config.eval_frequency, config.eval_metric, diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py index 64f2a9bf27..f4360ab7a3 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py @@ -45,37 +45,29 @@ from graphstorm.model import do_full_graph_inference from graphstorm.utils import rt_profiler, sys_tracker, get_device -def get_evaluator(config, train_data): +def get_evaluator(config): """ Get evaluator according to config Parameters ---------- config: GSConfig Configuration - train_data: GSgnnEdgeData - Training data """ assert len(config.eval_metric) == 1, \ "GraphStorm doees not support computing multiple metrics at the same time." if config.report_eval_per_type: - return GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, - train_data, - config.num_negative_edges_eval, - config.lp_decoder_type, - config.model_select_etype, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency, + major_etype=config.model_select_etype, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) else: - return GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - config.num_negative_edges_eval, - config.lp_decoder_type, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) def main(config_args): """ main function @@ -102,7 +94,7 @@ def main(config_args): if not config.no_validation: # TODO(zhengda) we need to refactor the evaluator. # Currently, we only support mrr - evaluator = get_evaluator(config, train_data) + evaluator = get_evaluator(config) trainer.setup_evaluator(evaluator) assert len(train_data.val_idxs) > 0, "The training data do not have validation set." # TODO(zhengda) we need to compute the size of the entire validation set to make sure diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index 408cae8418..e06bece663 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -59,37 +59,29 @@ ) from graphstorm.utils import get_lm_ntypes -def get_evaluator(config, train_data): +def get_evaluator(config): """ Get evaluator according to config Parameters ---------- config: GSConfig Configuration - train_data: GSgnnEdgeData - Training data """ assert len(config.eval_metric) == 1, \ "GraphStorm doees not support computing multiple metrics at the same time." if config.report_eval_per_type: - return GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, - train_data, - config.num_negative_edges_eval, - config.lp_decoder_type, - config.model_select_etype, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency, + major_etype=config.model_select_etype, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) else: - return GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - config.num_negative_edges_eval, - config.lp_decoder_type, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) def main(config_args): """ main function @@ -119,7 +111,7 @@ def main(config_args): if not config.no_validation: # TODO(zhengda) we need to refactor the evaluator. # Currently, we only support mrr - evaluator = get_evaluator(config, train_data) + evaluator = get_evaluator(config) trainer.setup_evaluator(evaluator) assert len(train_data.val_idxs) > 0, "The training data do not have validation set." # TODO(zhengda) we need to compute the size of the entire validation set to make sure diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index a58d0c708a..a12c8e79e8 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -57,10 +57,7 @@ def main(config_args): infer.setup_device(device=get_device()) if not config.no_validation: infer.setup_evaluator( - GSgnnMrrLPEvaluator(config.eval_frequency, - infer_data, - config.num_negative_edges_eval, - config.lp_decoder_type)) + GSgnnMrrLPEvaluator(config.eval_frequency)) assert len(infer_data.test_idxs) > 0, "There is not test data for evaluation." tracker = gs.create_builtin_task_tracker(config) infer.setup_task_tracker(tracker) diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py index 7773df6477..916ba9abf9 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py @@ -51,10 +51,7 @@ def main(config_args): infer.setup_device(device=get_device()) if not config.no_validation: infer.setup_evaluator( - GSgnnMrrLPEvaluator(config.eval_frequency, - infer_data, - config.num_negative_edges_eval, - config.lp_decoder_type)) + GSgnnMrrLPEvaluator(config.eval_frequency)) assert len(infer_data.test_idxs) > 0, "There is not test data for evaluation." tracker = gs.create_builtin_task_tracker(config) infer.setup_task_tracker(tracker) diff --git a/python/graphstorm/run/gsgnn_np/gsgnn_np.py b/python/graphstorm/run/gsgnn_np/gsgnn_np.py index 98482c032a..818d4039e7 100644 --- a/python/graphstorm/run/gsgnn_np/gsgnn_np.py +++ b/python/graphstorm/run/gsgnn_np/gsgnn_np.py @@ -25,7 +25,7 @@ from graphstorm.trainer import GLEMNodePredictionTrainer from graphstorm.dataloading import GSgnnNodeTrainData, GSgnnNodeDataLoader,\ GSgnnNodeSemiSupDataLoader -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.eval import GSgnnRegressionEvaluator from graphstorm.model.utils import save_full_node_embeddings from graphstorm.model import do_full_graph_inference @@ -38,13 +38,13 @@ def get_evaluator(config): if config.task_type == "node_classification": multilabel = config.multilabel[config.eval_target_ntype] \ if isinstance(config.multilabel, dict) else config.multilabel - return GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - multilabel, - config.use_early_stop, - config.early_stop_burnin_rounds, - config.early_stop_rounds, - config.early_stop_strategy) + return GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + multilabel, + config.use_early_stop, + config.early_stop_burnin_rounds, + config.early_stop_rounds, + config.early_stop_strategy) elif config.task_type == "node_regression": return GSgnnRegressionEvaluator(config.eval_frequency, config.eval_metric, diff --git a/python/graphstorm/run/gsgnn_np/np_infer_gnn.py b/python/graphstorm/run/gsgnn_np/np_infer_gnn.py index de5858974c..a9aeaa5176 100644 --- a/python/graphstorm/run/gsgnn_np/np_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_np/np_infer_gnn.py @@ -19,7 +19,7 @@ from graphstorm.config import get_argument_parser from graphstorm.config import GSConfig from graphstorm.inference import GSgnnNodePredictionInferrer -from graphstorm.eval import GSgnnAccEvaluator, GSgnnRegressionEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator, GSgnnRegressionEvaluator from graphstorm.dataloading import GSgnnNodeInferData, GSgnnNodeDataLoader from graphstorm.utils import get_device, get_lm_ntypes, use_wholegraph @@ -30,9 +30,9 @@ def get_evaluator(config): # pylint: disable=unused-argument return GSgnnRegressionEvaluator(config.eval_frequency, config.eval_metric) elif config.task_type == 'node_classification': - return GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel) + return GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel) else: raise AttributeError(config.task_type + ' is not supported.') diff --git a/python/graphstorm/trainer/ep_trainer.py b/python/graphstorm/trainer/ep_trainer.py index 4c572120dc..49052f1c40 100644 --- a/python/graphstorm/trainer/ep_trainer.py +++ b/python/graphstorm/trainer/ep_trainer.py @@ -283,7 +283,7 @@ def fit(self, train_loader, num_epochs, 'peak_RAM_mem_alloc_MB': \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, 'best validation iteration': \ - self.evaluator.best_iter_num[self.evaluator.metric[0]], + self.evaluator.best_iter_num[self.evaluator.metric_list[0]], 'best model path': \ self.get_best_model_path() if save_model_path is not None else \ "No model is saved, please set save_model_path"} @@ -319,7 +319,7 @@ def eval(self, model, val_loader, test_loader, use_mini_batch_infer, total_steps test_start = time.time() sys_tracker.check('start prediction') - metric = set(self.evaluator.metric) + metric = set(self.evaluator.metric_list) need_proba = metric.intersection({'roc_auc', 'per_class_roc_auc', 'precision_recall'}) need_label_pred = metric.intersection({'accuracy', 'f1_score', 'per_class_f1_score'}) assert len(need_proba) == 0 or len(need_label_pred) == 0, \ diff --git a/python/graphstorm/trainer/gsgnn_trainer.py b/python/graphstorm/trainer/gsgnn_trainer.py index b5e70508ba..d62bb91639 100644 --- a/python/graphstorm/trainer/gsgnn_trainer.py +++ b/python/graphstorm/trainer/gsgnn_trainer.py @@ -186,7 +186,7 @@ def log_print_metrics(self, val_score, test_score, dur_eval, total_steps, train_ best_val_score = self.evaluator.best_val_score best_test_score = self.evaluator.best_test_score best_iter_num = self.evaluator.best_iter_num - self.task_tracker.log_iter_metrics(self.evaluator.metric, + self.task_tracker.log_iter_metrics(self.evaluator.metric_list, train_score=train_score, val_score=val_score, test_score=test_score, best_val_score=best_val_score, best_test_score=best_test_score, best_iter_num=best_iter_num, diff --git a/python/graphstorm/trainer/lp_trainer.py b/python/graphstorm/trainer/lp_trainer.py index 985acd2774..d434c88f49 100644 --- a/python/graphstorm/trainer/lp_trainer.py +++ b/python/graphstorm/trainer/lp_trainer.py @@ -275,7 +275,7 @@ def fit(self, train_loader, num_epochs, 'peak_RAM_mem_alloc_MB': \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, 'best validation iteration': \ - self.evaluator.best_iter_num[self.evaluator.metric[0]], + self.evaluator.best_iter_num[self.evaluator.metric_list[0]], 'best model path': \ self.get_best_model_path() if save_model_path is not None else None} self.log_params(output) diff --git a/python/graphstorm/trainer/np_trainer.py b/python/graphstorm/trainer/np_trainer.py index f9d0d04aea..e58f84d3b1 100644 --- a/python/graphstorm/trainer/np_trainer.py +++ b/python/graphstorm/trainer/np_trainer.py @@ -262,7 +262,7 @@ def fit(self, train_loader, num_epochs, 'peak_RAM_mem_alloc_MB': \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024, 'best validation iteration': \ - self.evaluator.best_iter_num[self.evaluator.metric[0]], + self.evaluator.best_iter_num[self.evaluator.metric_list[0]], 'best model path': \ self.get_best_model_path() if save_model_path is not None else None} self.log_params(output) @@ -297,7 +297,7 @@ def eval(self, model, val_loader, test_loader, use_mini_batch_infer, total_steps teval = time.time() sys_tracker.check('before prediction') - metric = set(self.evaluator.metric) + metric = set(self.evaluator.metric_list) need_proba = metric.intersection({'roc_auc', 'per_class_roc_auc', 'precision_recall'}) need_label_pred = metric.intersection({'accuracy', 'f1_score', 'per_class_f1_score'}) assert len(need_proba) == 0 or len(need_label_pred) == 0, \ diff --git a/tests/unit-tests/test_dist_eval.py b/tests/unit-tests/test_dist_eval.py index 3af5abfa85..4ce8a53124 100644 --- a/tests/unit-tests/test_dist_eval.py +++ b/tests/unit-tests/test_dist_eval.py @@ -18,7 +18,7 @@ and when called by two workers (distributed evaluation). For classification tasks. it compares the output of - `GSgnnAccEvaluator.evaluate' when called by a single worker + `GSgnnClassificationEvaluator.evaluate' when called by a single worker (single process evaluation) and when called by two workers (distributed evaluation). """ @@ -30,7 +30,7 @@ from numpy.testing import assert_almost_equal import numpy as np -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.eval import GSgnnRegressionEvaluator from graphstorm.eval import GSgnnMrrLPEvaluator from graphstorm.utils import setup_device @@ -41,7 +41,7 @@ from test_evaluator import gen_hg -def run_dist_lp_eval_worker(worker_rank, train_data, config, val_scores, test_scores, conn): +def run_dist_lp_eval_worker(worker_rank, config, val_scores, test_scores, conn): dist_init_method = 'tcp://{master_ip}:{master_port}'.format( master_ip='127.0.0.1', master_port='12345') th.distributed.init_process_group(backend="gloo", @@ -50,9 +50,6 @@ def run_dist_lp_eval_worker(worker_rank, train_data, config, val_scores, test_sc rank=worker_rank) lp_eval = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, use_early_stop=config.use_early_stop) val_sc, test_sc = lp_eval.evaluate(val_scores, test_scores, 0) @@ -60,15 +57,15 @@ def run_dist_lp_eval_worker(worker_rank, train_data, config, val_scores, test_sc conn.send((val_sc, test_sc)) th.distributed.destroy_process_group() -def run_dist_lp_eval(train_data, config, +def run_dist_lp_eval(config, val_scores_0, val_scores_1, test_scores_0, test_scores_1): ctx = mp.get_context('spawn') conn1, conn2 = mp.Pipe() p0 = ctx.Process(target=run_dist_lp_eval_worker, - args=(0, train_data, config, val_scores_0, test_scores_0, conn2)) + args=(0, config, val_scores_0, test_scores_0, conn2)) p1 = ctx.Process(target=run_dist_lp_eval_worker, - args=(1, train_data, config, val_scores_1, test_scores_1, None)) + args=(1, config, val_scores_1, test_scores_1, None)) p0.start() p1.start() p0.join() @@ -82,11 +79,11 @@ def run_dist_lp_eval(train_data, config, conn2.close() return val_scores, test_scores -def run_local_lp_eval(train_data, config, val_scores, test_scores): +def run_local_lp_eval(config, val_scores, test_scores): ctx = mp.get_context('spawn') conn1, conn2 = mp.Pipe() p = ctx.Process(target=run_local_lp_eval_worker, - args=(train_data, config, val_scores, test_scores, conn2)) + args=(config, val_scores, test_scores, conn2)) p.start() p.join() assert p.exitcode == 0 @@ -96,7 +93,7 @@ def run_local_lp_eval(train_data, config, val_scores, test_scores): conn2.close() return val_scores, test_scores -def run_local_lp_eval_worker(train_data, config, val_scores, test_scores, conn): +def run_local_lp_eval_worker(config, val_scores, test_scores, conn): dist_init_method = 'tcp://{master_ip}:{master_port}'.format( master_ip='127.0.0.1', master_port='12346') th.distributed.init_process_group(backend="gloo", @@ -105,9 +102,6 @@ def run_local_lp_eval_worker(train_data, config, val_scores, test_scores, conn): rank=0) lp_eval = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, use_early_stop=config.use_early_stop) val_sc, test_sc = lp_eval.evaluate(val_scores, test_scores, 0) conn.send((val_sc, test_sc)) @@ -149,25 +143,18 @@ def test_lp_dist_eval(seed): } # Dummy objects - train_data = Dummy({ - "train_idxs": th.randint(10, (10,)), - "val_idxs": th.randint(10, (10,)), - "test_idxs": th.randint(10, (10,)), - }) config = Dummy({ - "num_negative_edges_eval": 10, - "lp_decoder_type": BUILTIN_LP_DOT_DECODER, "eval_frequency": 100, "use_early_stop": False, - "eval_metric": ["mrr"] + "eval_metric_list": ["mrr"] }) # do evaluation with two workers - val_dist, test_dist = run_dist_lp_eval(train_data, config, + val_dist, test_dist = run_dist_lp_eval(config, val_scores_0, val_scores_1, test_scores_0, test_scores_1) # do evaluation with single worker - val_local, test_local = run_local_lp_eval(train_data, config, + val_local, test_local = run_local_lp_eval(config, {etypes[0]: th.cat((val_scores_0[etypes[0]], val_scores_1[etypes[0]]), dim = 0), etypes[1]: th.cat((val_scores_0[etypes[1]], val_scores_1[etypes[1]]), dim = 0)}, {etypes[0]: th.cat((test_scores_0[etypes[0]], test_scores_1[etypes[0]]), dim = 0), @@ -191,17 +178,17 @@ def run_dist_nc_eval_worker(eval_config, worker_rank, metric, val_pred, test_pre th.cuda.set_device(worker_rank) device = setup_device(worker_rank) - config, train_data = eval_config + config = eval_config - if config.eval_metric[0] in ["rmse", "mse"]: + if config.eval_metric_list[0] in ["rmse", "mse"]: evaluator = GSgnnRegressionEvaluator(config.eval_frequency, - config.eval_metric, + config.eval_metric_list, config.use_early_stop) else: - evaluator = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) + evaluator = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric_list, + config.multilabel, + config.use_early_stop) val_score0, test_score0 = evaluator.evaluate( val_pred.to(device), @@ -220,7 +207,7 @@ def run_dist_nc_eval_worker(eval_config, worker_rank, metric, val_pred, test_pre test_labels.to(device), 300) if worker_rank == 0: - assert evaluator.metric == metric + assert evaluator.metric_list == metric assert evaluator.best_iter_num[metric[0]] == 200 assert evaluator.best_val_score == val_score1 assert evaluator.best_test_score == test_score1 @@ -284,17 +271,18 @@ def run_local_nc_eval_worker(eval_config, metric, val_pred, test_pred, init_method=dist_init_method, world_size=1, rank=0) - config, train_data = eval_config + config = eval_config - if config.eval_metric[0] in ["rmse", "mse"]: + if config.eval_metric_list[0] in ["rmse", "mse"]: evaluator = GSgnnRegressionEvaluator(config.eval_frequency, - config.eval_metric, + config.eval_metric_list, config.use_early_stop) else: - evaluator = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) + evaluator = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric_list, + config.multilabel, + config.use_early_stop) + val_score0, test_score0 = evaluator.evaluate(val_pred, test_pred, val_labels0, test_labels, 100) val_score1, test_score1 = evaluator.evaluate(val_pred, test_pred, val_labels1, test_labels, 200) val_score2, _ = evaluator.evaluate(val_pred, test_pred, val_labels2, test_labels, 300) @@ -302,7 +290,7 @@ def run_local_nc_eval_worker(eval_config, metric, val_pred, test_pred, assert val_score0 != val_score2 assert test_score0 == test_score1 - assert evaluator.metric == metric + assert evaluator.metric_list == metric assert evaluator.best_iter_num[metric[0]] == 200 assert evaluator.best_val_score == val_score1 assert evaluator.best_test_score == test_score1 @@ -355,21 +343,18 @@ def test_nc_dist_eval(metric, seed, backend): val_labels2[:160] = val_pred[:160] config = Dummy({ - "eval_metric": metric, + "eval_metric_list": metric, "no_validation": False, "multilabel": False, "eval_frequency": 100, "use_early_stop": False, }) - train_data = Dummy({ - "do_validation": True - }) # do evaluation with single worker - metrics_local = run_local_nc_eval((config, train_data), metric, val_pred, test_pred, + metrics_local = run_local_nc_eval(config, metric, val_pred, test_pred, val_labels0, val_labels1, val_labels2, test_labels) # do evaluation with two workers - metrics_dist = run_dist_nc_eval((config, train_data), metric, val_pred, test_pred, + metrics_dist = run_dist_nc_eval(config, metric, val_pred, test_pred, val_labels0, val_labels1, val_labels2, test_labels, backend) @@ -435,21 +420,18 @@ def test_nc_dist_eval_multilabel(seed, backend): test_logits = softmax(test_logits) config = Dummy({ - "eval_metric": ["accuracy"], + "eval_metric_list": ["accuracy"], "no_validation": False, "multilabel": True, "eval_frequency": 100, "use_early_stop": False, }) - train_data = Dummy({ - "do_validation": True - }) # do evaluation with single worker - metrics_local = run_local_nc_eval((config, train_data), ["accuracy"], val_logits, test_logits, + metrics_local = run_local_nc_eval(config, ["accuracy"], val_logits, test_logits, val_labels0, val_labels1, val_labels2, test_labels0) # do evaluation with two workers - metrics_dist = run_dist_nc_eval((config, train_data), ["accuracy"], val_logits, test_logits, + metrics_dist = run_dist_nc_eval(config, ["accuracy"], val_logits, test_logits, val_labels0, val_labels1, val_labels2, test_labels0, backend) metrics_keys = list(metrics_local.keys()) @@ -478,18 +460,15 @@ def test_nc_dist_regression_eval(metric, seed, backend): test_labels[100:170] = test_pred[100:170] config = Dummy({ - "eval_metric": metric, + "eval_metric_list": metric, "no_validation": False, "eval_frequency": 100, "use_early_stop": False, }) - train_data = Dummy({ - "do_validation": True - }) - metrics_local = run_local_nc_eval((config, train_data), metric, val_pred, test_pred, + metrics_local = run_local_nc_eval(config, metric, val_pred, test_pred, val_labels0, val_labels1, val_labels2, test_labels) - metrics_dist = run_dist_nc_eval((config, train_data), metric, val_pred, test_pred, + metrics_dist = run_dist_nc_eval(config, metric, val_pred, test_pred, val_labels0, val_labels1, val_labels2, test_labels, backend) metrics_keys = list(metrics_local.keys()) diff --git a/tests/unit-tests/test_evaluator.py b/tests/unit-tests/test_evaluator.py index 2de010e749..50dae032e9 100644 --- a/tests/unit-tests/test_evaluator.py +++ b/tests/unit-tests/test_evaluator.py @@ -22,7 +22,7 @@ import dgl from graphstorm.eval import GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from graphstorm.eval import GSgnnRegressionEvaluator from graphstorm.eval.evaluator import early_stop_avg_increase_judge from graphstorm.eval.evaluator import early_stop_cons_increase_judge @@ -49,17 +49,7 @@ def gen_hg(): return hg def gen_mrr_lp_eval_data(): - # common Dummy objects - train_data = Dummy({ - "train_idxs": th.randint(10, (10,)), - "val_idxs": th.randint(10, (10,)), - "test_idxs": th.randint(10, (10,)), - "do_validation": True - }) - config = Dummy({ - "num_negative_edges_eval": 10, - "lp_decoder_type": BUILTIN_LP_DOT_DECODER, "eval_frequency": 100, "use_early_stop": False, }) @@ -71,7 +61,7 @@ def gen_mrr_lp_eval_data(): test_pos_scores = th.rand((10,1)) test_neg_scores = th.rand((10,10)) - return train_data, config, etypes, (val_pos_scores, val_neg_scores), (test_pos_scores, test_neg_scores) + return config, etypes, (val_pos_scores, val_neg_scores), (test_pos_scores, test_neg_scores) def test_mrr_per_etype_lp_evaluation(): # system heavily depends on th distributed @@ -81,7 +71,7 @@ def test_mrr_per_etype_lp_evaluation(): init_method=dist_init_method, world_size=1, rank=0) - train_data, config, etypes, val_scores, test_scores = gen_mrr_lp_eval_data() + config, etypes, val_scores, test_scores = gen_mrr_lp_eval_data() score = { ("a", "r1", "b"): 0.9, @@ -89,22 +79,14 @@ def test_mrr_per_etype_lp_evaluation(): } # Test get_major_score - lp = GSgnnPerEtypeMrrLPEvaluator(10, - train_data, - num_negative_edges_eval=4, - lp_decoder_type=BUILTIN_LP_DOT_DECODER, - use_early_stop=False) + lp = GSgnnPerEtypeMrrLPEvaluator(10, use_early_stop=False) assert lp.major_etype == LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL m_score = lp._get_major_score(score) assert m_score == sum(score.values()) / 2 # Test get_major_score - lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, - train_data, - major_etype=("a", "r2", "b"), - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, + lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, major_etype=("a", "r2", "b"), use_early_stop=config.use_early_stop) assert lp.major_etype == ("a", "r2", "b") @@ -114,11 +96,7 @@ def test_mrr_per_etype_lp_evaluation(): val_pos_scores, val_neg_scores = val_scores test_pos_scores, test_neg_scores = test_scores - lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, - use_early_stop=config.use_early_stop) + lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, use_early_stop=config.use_early_stop) rank0 = [] rank1 = [] @@ -180,11 +158,8 @@ def test_mrr_per_etype_lp_evaluation(): assert_almost_equal(np.array([test_s_mrr]), lp.best_test_score['mrr']) lp = GSgnnPerEtypeMrrLPEvaluator(config.eval_frequency, - train_data, - major_etype=etypes[1], - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, - use_early_stop=config.use_early_stop) + major_etype=etypes[1], + use_early_stop=config.use_early_stop) val_sc, test_sc = lp.evaluate(val_ranks, test_ranks, 0) assert_equal(val_s['mrr'][etypes[0]], val_sc['mrr'][etypes[0]]) @@ -205,15 +180,14 @@ def test_mrr_lp_evaluator(): init_method=dist_init_method, world_size=1, rank=0) - train_data, config, etypes, val_scores, test_scores = gen_mrr_lp_eval_data() + config, etypes, val_scores, test_scores = gen_mrr_lp_eval_data() val_pos_scores, val_neg_scores = val_scores test_pos_scores, test_neg_scores = test_scores - lp = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, - use_early_stop=config.use_early_stop) + lp = GSgnnMrrLPEvaluator(config.eval_frequency, use_early_stop=config.use_early_stop) + + # checke default metric list + assert lp.metric_list == ['mrr'] rank = [] for i in range(len(val_pos_scores)): @@ -273,9 +247,6 @@ def test_mrr_lp_evaluator(): @patch.object(GSgnnMrrLPEvaluator, 'compute_score') def check_evaluate(mock_compute_score): lp = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config. num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, use_early_stop=config.use_early_stop) mock_compute_score.side_effect = [ @@ -307,21 +278,11 @@ def check_evaluate(mock_compute_score): # check GSgnnMrrLPEvaluator.evaluate() check_evaluate() - # common Dummy objects - train_data = Dummy({ - "train_idxs": None, - "val_idxs": None, - "test_idxs": th.randint(10, (10,)), - "do_validation": True - }) # test evaluate @patch.object(GSgnnMrrLPEvaluator, 'compute_score') def check_evaluate_infer(mock_compute_score): lp = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, - use_early_stop=config.use_early_stop) + use_early_stop=config.use_early_stop) mock_compute_score.side_effect = [ {"mrr": 0.6}, @@ -346,9 +307,6 @@ def check_evaluate_infer(mock_compute_score): # train_data.do_validation True # config.no_validation False lp = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, use_early_stop=config.use_early_stop) assert lp.do_eval(120, epoch_end=True) is True assert lp.do_eval(200) is True @@ -356,8 +314,6 @@ def check_evaluate_infer(mock_compute_score): assert lp.do_eval(1) is False config3 = Dummy({ - "num_negative_edges_eval": 10, - "lp_decoder_type": BUILTIN_LP_DOT_DECODER, "eval_frequency": 0, "use_early_stop": False, }) @@ -366,16 +322,13 @@ def check_evaluate_infer(mock_compute_score): # config.no_validation False # eval_frequency is 0 lp = GSgnnMrrLPEvaluator(config3.eval_frequency, - train_data, - num_negative_edges_eval=config3.num_negative_edges_eval, - lp_decoder_type=config3.lp_decoder_type, - use_early_stop=config3.use_early_stop) + use_early_stop=config3.use_early_stop) assert lp.do_eval(120, epoch_end=True) is True assert lp.do_eval(200) is False th.distributed.destroy_process_group() -def test_acc_evaluator(): +def test_classification_evaluator(): # system heavily depends on th distributed dist_init_method = 'tcp://{master_ip}:{master_port}'.format( master_ip='127.0.0.1', master_port='12346') @@ -384,34 +337,40 @@ def test_acc_evaluator(): world_size=1, rank=0) + # test default settings + cl_eval = GSgnnClassificationEvaluator(eval_frequency=100) + assert cl_eval.metric_list == ["accuracy"] + assert cl_eval.multilabel is False + + # test given settings config = Dummy({ "multilabel": False, "eval_frequency": 100, - "eval_metric": ["accuracy"], + "eval_metric_list": ["accuracy"], "use_early_stop": False, }) # Test compute_score - nc = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) + cl_eval = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric_list, + config.multilabel, + config.use_early_stop) pred = th.randint(10, (100,)) labels = th.randint(10, (100,)) - result = nc.compute_score(pred, labels, True) + result = cl_eval.compute_score(pred, labels, True) assert_equal(result["accuracy"], th.sum(pred == labels).item() / len(labels)) - result = nc.compute_score(None, None, True) + result = cl_eval.compute_score(None, None, True) assert result["accuracy"] == "N/A" - # Test evaluate - @patch.object(GSgnnAccEvaluator, 'compute_score') + # Test the evaluate method + @patch.object(GSgnnClassificationEvaluator, 'compute_score') def check_evaluate(mock_compute_score): - nc = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) + cl_eval = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric_list, + config.multilabel, + config.use_early_stop) mock_compute_score.side_effect = [ {"accuracy": 0.7}, {"accuracy": 0.65}, @@ -420,34 +379,34 @@ def check_evaluate(mock_compute_score): {"accuracy": 0.76}, {"accuracy": 0.8}, ] - val_score, test_score = nc.evaluate(th.rand((10,)), th.rand((10,)), th.rand((10,)), th.rand((10,)), 100) + val_score, test_score = cl_eval.evaluate(th.rand((10,)), th.rand((10,)), th.rand((10,)), th.rand((10,)), 100) mock_compute_score.assert_called() assert val_score["accuracy"] == 0.7 assert test_score["accuracy"] == 0.65 - val_score, test_score = nc.evaluate(th.rand((10,)), th.rand((10,)), th.rand((10,)), th.rand((10,)), 200) + val_score, test_score = cl_eval.evaluate(th.rand((10,)), th.rand((10,)), th.rand((10,)), th.rand((10,)), 200) mock_compute_score.assert_called() assert val_score["accuracy"] == 0.8 assert test_score["accuracy"] == 0.7 - val_score, test_score = nc.evaluate(th.rand((10,)), th.rand((10,)), th.rand((10,)), th.rand((10,)), 300) + val_score, test_score = cl_eval.evaluate(th.rand((10,)), th.rand((10,)), th.rand((10,)), th.rand((10,)), 300) mock_compute_score.assert_called() assert val_score["accuracy"] == 0.76 assert test_score["accuracy"] == 0.8 - assert nc.best_val_score["accuracy"] == 0.8 - assert nc.best_test_score["accuracy"] == 0.7 - assert nc.best_iter_num["accuracy"] == 200 + assert cl_eval.best_val_score["accuracy"] == 0.8 + assert cl_eval.best_test_score["accuracy"] == 0.7 + assert cl_eval.best_iter_num["accuracy"] == 200 check_evaluate() - # Test evaluate with out test score - @patch.object(GSgnnAccEvaluator, 'compute_score') + # Test evaluate without test score + @patch.object(GSgnnClassificationEvaluator, 'compute_score') def check_evaluate_no_test(mock_compute_score): - nc = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) + cl_eval = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric_list, + config.multilabel, + config.use_early_stop) mock_compute_score.side_effect = [ {"accuracy": 0.7}, {"accuracy": "N/A"}, @@ -456,55 +415,55 @@ def check_evaluate_no_test(mock_compute_score): {"accuracy": 0.76}, {"accuracy": "N/A"}, ] - val_score, test_score = nc.evaluate(th.rand((10,)), None, th.rand((10,)), None, 100) + val_score, test_score = cl_eval.evaluate(th.rand((10,)), None, th.rand((10,)), None, 100) mock_compute_score.assert_called() assert val_score["accuracy"] == 0.7 assert test_score["accuracy"] == "N/A" - val_score, test_score = nc.evaluate(th.rand((10,)), None, th.rand((10,)), None, 200) + val_score, test_score = cl_eval.evaluate(th.rand((10,)), None, th.rand((10,)), None, 200) mock_compute_score.assert_called() assert val_score["accuracy"] == 0.8 assert test_score["accuracy"] == "N/A" - val_score, test_score = nc.evaluate(th.rand((10,)), None, th.rand((10,)), None, 300) + val_score, test_score = cl_eval.evaluate(th.rand((10,)), None, th.rand((10,)), None, 300) mock_compute_score.assert_called() assert val_score["accuracy"] == 0.76 assert test_score["accuracy"] == "N/A" - assert nc.best_val_score["accuracy"] == 0.8 - assert nc.best_test_score["accuracy"] == "N/A" - assert nc.best_iter_num["accuracy"] == 200 + assert cl_eval.best_val_score["accuracy"] == 0.8 + assert cl_eval.best_test_score["accuracy"] == "N/A" + assert cl_eval.best_iter_num["accuracy"] == 200 check_evaluate_no_test() - # check GSgnnAccEvaluator.do_eval() + # check GSgnnClassificationEvaluator.do_eval() # train_data.do_validation True # config.no_validation False - nc = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) - assert nc.do_eval(120, epoch_end=True) is True - assert nc.do_eval(200) is True - assert nc.do_eval(0) is True - assert nc.do_eval(1) is False + cl_eval = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric_list, + config.multilabel, + config.use_early_stop) + assert cl_eval.do_eval(120, epoch_end=True) is True + assert cl_eval.do_eval(200) is True + assert cl_eval.do_eval(0) is True + assert cl_eval.do_eval(1) is False config3 = Dummy({ "multilabel": False, "eval_frequency": 0, - "eval_metric": ["accuracy"], + "eval_metric_list": ["accuracy"], "use_early_stop": False, }) # train_data.do_validation True # config.no_validation False # eval_frequency is 0 - nc = GSgnnAccEvaluator(config3.eval_frequency, - config3.eval_metric, - config3.multilabel, - config3.use_early_stop) - assert nc.do_eval(120, epoch_end=True) is True - assert nc.do_eval(200) is False + cl_eval = GSgnnClassificationEvaluator(config3.eval_frequency, + config3.eval_metric_list, + config3.multilabel, + config3.use_early_stop) + assert cl_eval.do_eval(120, epoch_end=True) is True + assert cl_eval.do_eval(200) is False th.distributed.destroy_process_group() def test_regression_evaluator(): @@ -516,16 +475,18 @@ def test_regression_evaluator(): world_size=1, rank=0) + # test default settings + nr_eval = GSgnnRegressionEvaluator(eval_frequency=100) + assert nr_eval.metric_list == ["rmse"] + config = Dummy({ "eval_frequency": 100, - "eval_metric": ["rmse"], "use_early_stop": False, }) # Test compute_score nr = GSgnnRegressionEvaluator(config.eval_frequency, - config.eval_metric, - config.use_early_stop) + use_early_stop=config.use_early_stop) pred = th.rand(100) labels = th.rand(100) result = nr.compute_score(pred, labels) @@ -540,8 +501,7 @@ def test_regression_evaluator(): @patch.object(GSgnnRegressionEvaluator, 'compute_score') def check_evaluate(mock_compute_score): nr = GSgnnRegressionEvaluator(config.eval_frequency, - config.eval_metric, - config.use_early_stop) + use_early_stop=config.use_early_stop) mock_compute_score.side_effect = [ {"rmse": 0.7}, {"rmse": 0.8}, @@ -576,8 +536,7 @@ def check_evaluate(mock_compute_score): @patch.object(GSgnnRegressionEvaluator, 'compute_score') def check_evaluate_no_test(mock_compute_score): nr = GSgnnRegressionEvaluator(config.eval_frequency, - config.eval_metric, - config.use_early_stop) + use_early_stop=config.use_early_stop) mock_compute_score.side_effect = [ {"rmse": 0.7}, {"rmse": "N/A"}, @@ -611,8 +570,7 @@ def check_evaluate_no_test(mock_compute_score): # check GSgnnRegressionEvaluator.do_eval() # train_data.do_validation True nr = GSgnnRegressionEvaluator(config.eval_frequency, - config.eval_metric, - config.use_early_stop) + use_early_stop=config.use_early_stop) assert nr.do_eval(120, epoch_end=True) is True assert nr.do_eval(200) is True assert nr.do_eval(0) is True @@ -621,15 +579,13 @@ def check_evaluate_no_test(mock_compute_score): config3 = Dummy({ "eval_frequency": 0, "no_validation": False, - "eval_metric": ["rmse"], "use_early_stop": False, }) # train_data.do_validation True # eval_frequency is 0 nr = GSgnnRegressionEvaluator(config3.eval_frequency, - config3.eval_metric, - config3.use_early_stop) + use_early_stop=config3.use_early_stop) assert nr.do_eval(120, epoch_end=True) is True assert nr.do_eval(200) is False th.distributed.destroy_process_group() @@ -723,13 +679,14 @@ def test_early_stop_evaluator(): "early_stop_strategy": EARLY_STOP_AVERAGE_INCREASE_STRATEGY, }) - evaluator = GSgnnAccEvaluator(config2.eval_frequency, - config2.eval_metric, - config2.multilabel, - config2.use_early_stop, - config2.early_stop_burnin_rounds, - config2.early_stop_rounds, - config2.early_stop_strategy) + evaluator = GSgnnClassificationEvaluator(config2.eval_frequency, + config2.eval_metric, + config2.multilabel, + config2.use_early_stop, + config2.early_stop_burnin_rounds, + config2.early_stop_rounds, + config2.early_stop_strategy) + for _ in range(5): # always return false assert evaluator.do_early_stop({"accuracy": 0.5}) is False @@ -744,31 +701,17 @@ def test_early_stop_evaluator(): def test_early_stop_lp_evaluator(): # common Dummy objects - train_data = Dummy({ - "train_idxs": th.randint(10, (10,)), - "val_idxs": th.randint(10, (10,)), - "test_idxs": th.randint(10, (10,)), - "do_validation": True - }) - config = Dummy({ - "num_negative_edges_eval": 10, - "lp_decoder_type": BUILTIN_LP_DOT_DECODER, "eval_frequency": 100, "use_early_stop": False, }) evaluator = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, use_early_stop=config.use_early_stop) for _ in range(10): # always return false assert evaluator.do_early_stop({"mrr": 0.5}) is False config = Dummy({ - "num_negative_edges_eval": 10, - "lp_decoder_type": BUILTIN_LP_DOT_DECODER, "eval_frequency": 100, "use_early_stop": True, "early_stop_burnin_rounds": 5, @@ -776,9 +719,6 @@ def test_early_stop_lp_evaluator(): "early_stop_strategy": EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY, }) evaluator = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, use_early_stop=config.use_early_stop, early_stop_burnin_rounds=config.early_stop_burnin_rounds, early_stop_rounds=config.early_stop_rounds, @@ -797,8 +737,6 @@ def test_early_stop_lp_evaluator(): assert evaluator.do_early_stop({"mrr": 0.45}) # early stop config = Dummy({ - "num_negative_edges_eval": 10, - "lp_decoder_type": BUILTIN_LP_DOT_DECODER, "eval_frequency": 100, "use_early_stop": True, "early_stop_burnin_rounds": 5, @@ -806,9 +744,6 @@ def test_early_stop_lp_evaluator(): "early_stop_strategy": EARLY_STOP_AVERAGE_INCREASE_STRATEGY, }) evaluator = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, use_early_stop=config.use_early_stop, early_stop_burnin_rounds=config.early_stop_burnin_rounds, early_stop_rounds=config.early_stop_rounds, @@ -835,10 +770,11 @@ def test_get_val_score_rank(): "use_early_stop": False, }) - evaluator = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) + evaluator = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel, + config.use_early_stop) + # For accuracy, the bigger the better. val_score = {"accuracy": 0.47} assert evaluator.get_val_score_rank(val_score) == 1 @@ -891,25 +827,13 @@ def test_get_val_score_rank(): # ------------------- test LPEvaluator ------------------- # common Dummy objects - train_data = Dummy({ - "train_idxs": th.randint(10, (10,)), - "val_idxs": th.randint(10, (10,)), - "test_idxs": th.randint(10, (10,)), - }) - config = Dummy({ - "num_negative_edges_eval": 10, - "lp_decoder_type": BUILTIN_LP_DOT_DECODER, "eval_frequency": 100, "use_early_stop": False, - "eval_metric": ["mrr"] }) evaluator = GSgnnMrrLPEvaluator(config.eval_frequency, - train_data, - num_negative_edges_eval=config.num_negative_edges_eval, - lp_decoder_type=config.lp_decoder_type, - use_early_stop=config.use_early_stop) + use_early_stop=config.use_early_stop) # For MRR, the bigger the better val_score = {"mrr": 0.47} @@ -929,10 +853,11 @@ def test_get_val_score_rank(): # test evaluators test_mrr_per_etype_lp_evaluation() test_mrr_lp_evaluator() - test_acc_evaluator() test_regression_evaluator() test_early_stop_avg_increase_judge() test_early_stop_cons_increase_judge() test_early_stop_evaluator() test_early_stop_lp_evaluator() test_get_val_score_rank() + + test_classification_evaluator() \ No newline at end of file diff --git a/tests/unit-tests/test_inferrer.py b/tests/unit-tests/test_inferrer.py index 9e2c070a3c..48ec584fe6 100644 --- a/tests/unit-tests/test_inferrer.py +++ b/tests/unit-tests/test_inferrer.py @@ -26,7 +26,7 @@ from graphstorm.tracker import GSSageMakerTaskTracker from graphstorm import create_builtin_node_gnn_model from graphstorm.inference.graphstorm_infer import GSInferrer -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from data_utils import generate_dummy_dist_graph @@ -79,10 +79,10 @@ def test_inferrer_setup_evaluator(): # case 1: by default trainer has no task_tracker assert inferrer.task_tracker is None - evaluator = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) + evaluator = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel, + config.use_early_stop) # case 2: evaluator has no task_tracker by default assert evaluator.task_tracker is None diff --git a/tests/unit-tests/test_trainer.py b/tests/unit-tests/test_trainer.py index fa064605f8..c6ca1748cf 100644 --- a/tests/unit-tests/test_trainer.py +++ b/tests/unit-tests/test_trainer.py @@ -26,7 +26,7 @@ from graphstorm.tracker import GSSageMakerTaskTracker from graphstorm import create_builtin_node_gnn_model from graphstorm.trainer import GSgnnTrainer -from graphstorm.eval import GSgnnAccEvaluator +from graphstorm.eval import GSgnnClassificationEvaluator from data_utils import generate_dummy_dist_graph @@ -79,10 +79,10 @@ def test_trainer_setup_evaluator(): # case 1: by default trainer has no task_tracker assert trainer.task_tracker is None - evaluator = GSgnnAccEvaluator(config.eval_frequency, - config.eval_metric, - config.multilabel, - config.use_early_stop) + evaluator = GSgnnClassificationEvaluator(config.eval_frequency, + config.eval_metric, + config.multilabel, + config.use_early_stop) # case 2: evaluator has no task_tracker by default assert evaluator.task_tracker is None