diff --git a/python/graphstorm/__init__.py b/python/graphstorm/__init__.py index f81c863127..92116ba166 100644 --- a/python/graphstorm/__init__.py +++ b/python/graphstorm/__init__.py @@ -29,7 +29,8 @@ from .gsf import create_builtin_edge_model from .gsf import create_builtin_node_model from .gsf import (create_task_decoder, - create_evaluator) + create_evaluator, + create_lp_evaluator) from .gsf import (create_builtin_node_decoder, create_builtin_edge_decoder, diff --git a/python/graphstorm/eval/__init__.py b/python/graphstorm/eval/__init__.py index 1c84b58f2f..9ac636b653 100644 --- a/python/graphstorm/eval/__init__.py +++ b/python/graphstorm/eval/__init__.py @@ -27,6 +27,8 @@ from .evaluator import (GSgnnBaseEvaluator, GSgnnPredictionEvalInterface, GSgnnLPRankingEvalInterface, + GSgnnLPEvaluator, + GSgnnPerEtypeLPEvaluator, GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator, GSgnnHitsLPEvaluator, diff --git a/python/graphstorm/eval/eval_func.py b/python/graphstorm/eval/eval_func.py index 39bca4a6d4..c940d61017 100644 --- a/python/graphstorm/eval/eval_func.py +++ b/python/graphstorm/eval/eval_func.py @@ -152,7 +152,7 @@ class LinkPredictionMetrics: Parameters ---------- eval_metric_list: list of string - Evaluation metric(s) used during evaluation, for example, ["hit_at_10", "hit_at_100"]. + Evaluation metric(s) used during evaluation, for example, ["mrr", "hit_at_1", "hit_at_100"]. """ def __init__(self, eval_metric_list=None): self.supported_metrics = SUPPORTED_LINK_PREDICTION_METRICS diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index 8898e1157c..184e50b1b9 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -16,6 +16,7 @@ Evaluator for different tasks. """ +import warnings import abc from statistics import mean import torch as th @@ -144,14 +145,14 @@ def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): Returns ----------- eval_score: dict - Validation scores of differnet metrics in the format of {metric: val_score}. + Validation scores of different metrics in the format of {metric: val_score}. test_score: dict Test scores of different metrics in the format of {metric: test_score}. """ @abc.abstractmethod def compute_score(self, pred, labels, train=True): - """ Compute evaluation score of Prediciton results. + """ Compute evaluation score of Prediction results. **Classification** and **regression** evaluators should provide both predictions and labels to this method. @@ -177,7 +178,7 @@ class GSgnnLPRankingEvalInterface(): The interface sets two abstract methods for Link Prediction evaluator classes that use ranking method to compute evaluation metrics, such as ``mrr`` (Mean Reciprocal Rank). - There are two methdos to be implemented if inherite this interface. + There are two methods to be implemented if inherit this interface. 1. ``evaluate()`` method, which will be called by different **Trainer** in their ``eval()`` function to provide ranking-based evaluation results of validation and test sets during @@ -188,7 +189,7 @@ class GSgnnLPRankingEvalInterface(): @abc.abstractmethod def evaluate(self, val_rankings, test_rankings, total_iters): - """Evaluate Link Prediciton results on validation and test sets. + """Evaluate Link Prediction results on validation and test sets. **Link Prediction** evaluators should provide the ranking of validation and test sets as input to this method. @@ -212,7 +213,7 @@ def evaluate(self, val_rankings, test_rankings, total_iters): @abc.abstractmethod def compute_score(self, rankings, train=True): - """ Compute Link Prediciton evaluation score. + """ Compute Link Prediction evaluation score. Ranking-based Link Prediction evaluators should provide ranking values as input to this method. @@ -237,8 +238,8 @@ class GSgnnBaseEvaluator(): ``GSgnnClassificationEvaluator``, ``GSgnnRegressionEvaluator``, ``GSgnnMrrLPEvaluator``, ``GSgnnPerEtypeMrrLPEvaluator``, and ``GSgnnRconstructFeatRegScoreEvaluator``. - In order to create customized Evaluators, users can inherite this class and the corresponding - EvalInteface class, and then implement their two abstract methods, i.e., ``evaluate()`` + In order to create customized Evaluators, users can inherit this class and the corresponding + EvalInterface class, and then implement their two abstract methods, i.e., ``evaluate()`` and ``compute_score()`` accordingly. Parameters @@ -300,7 +301,7 @@ def setup_task_tracker(self, task_tracker): def do_eval(self, total_iters, epoch_end=False): """ Decide whether to do the evaluation in current iteration or epoch. - Return `True`, if the current iteration is larger than 0 and is a mutiple of the given + Return `True`, if the current iteration is larger than 0 and is a multiple of the given `eval_frequency`, or is the end of an epoch. Otherwise return `False`. Parameters @@ -336,7 +337,7 @@ def do_early_stop(self, val_score): return False assert len(val_score) == 1, \ - f"valudation score should be a signle key value pair but got {val_score}" + f"validation score should be a single key value pair but got {val_score}" self._num_early_stop_calls += 1 # Not enough existing validation scores if self._num_early_stop_calls <= self._early_stop_burnin_rounds: @@ -517,7 +518,7 @@ def __init__(self, eval_frequency, self._best_iter[metric] = 0 def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): - """ Compute classificaton metric scores on validation and test sets. + """ Compute classification metric scores on validation and test sets. Parameters ---------- @@ -673,7 +674,7 @@ def evaluate(self, val_pred, test_pred, val_labels, test_labels, total_iters): Returns ----------- eval_score: dict - Validation scores of differnet regression metrics in the format of + Validation scores of different regression metrics in the format of {metric: val_score}. test_score: dict Test scores of different regression metrics in the format of {metric: test_score}. @@ -731,7 +732,7 @@ def compute_score(self, pred, labels, train=True): if train: # training expects always a single number to be - # returned and has a different (potentially) evluation function + # returned and has a different (potentially) evaluation function scores[metric] = self.metrics_obj.metric_function[metric](pred, labels) else: # validation or testing may have a different @@ -747,7 +748,7 @@ def compute_score(self, pred, labels, train=True): class GSgnnRconstructFeatRegScoreEvaluator(GSgnnRegressionEvaluator): """ Evaluator for feature reconstruction tasks using regression scores. - A built-in evalutor for feature reconstruction tasks. It uses ``mse`` or ``rmse`` as + A built-in evaluator for feature reconstruction tasks. It uses ``mse`` or ``rmse`` as evaluation metrics. This evaluator requires the prediction results to be a 2D float tensor and @@ -817,7 +818,7 @@ def compute_score(self, pred, labels, train=True): if train: # training expects always a single number to be - # returned and has a different (potentially) evluation function + # returned and has a different (potentially) evaluation function scores[metric] = self.metrics_obj.metric_function[metric](pred, labels) else: # validation or testing may have a different @@ -830,6 +831,333 @@ def compute_score(self, pred, labels, train=True): return scores +class GSgnnLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): + """ Link Prediction Evaluator using “mrr” and/or "hit@k" as metric. + + GS built-in evaluator for Link Prediction tasks. It uses "mrr" as the default eval metric, + which implements the `GSgnnLPRankingEvalInterface`. + + Parameters + ---------- + eval_frequency: int + The frequency (number of iterations) of doing evaluation. + eval_metric_list: list of string + Evaluation metric used during evaluation. Default: ['mrr'] + use_early_stop: bool + Set true to use early stop. + early_stop_burnin_rounds: int + Burn-in rounds before start checking for the early stop condition. + early_stop_rounds: int + The number of rounds for validation scores used to decide early stop. + early_stop_strategy: str + The early stop strategy. GraphStorm supports two strategies: + 1) consecutive_increase and 2) average_increase. + + .. versionadded:: 0.4.0 + The :py:class:`GSgnnLPEvaluator`. + """ + def __init__(self, eval_frequency, + eval_metric_list=None, + use_early_stop=False, + early_stop_burnin_rounds=0, + early_stop_rounds=3, + early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): + # set default metric list + if eval_metric_list is None: + eval_metric_list = ["mrr"] + super(GSgnnLPEvaluator, self).__init__(eval_frequency, + eval_metric_list, use_early_stop, early_stop_burnin_rounds, + early_stop_rounds, early_stop_strategy) + self.metrics_obj = LinkPredictionMetrics(eval_metric_list) + + self._best_val_score = {} + self._best_test_score = {} + self._best_iter = {} + for metric in self.metric_list: + self._best_val_score[metric] = self.metrics_obj.init_best_metric(metric=metric) + self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) + self._best_iter[metric] = 0 + + def evaluate(self, val_rankings, test_rankings, total_iters): + """ `GSgnnLinkPredictionTrainer` and `GSgnnLinkPredictionInferrer` will call this function + to compute validation and test scores. + + Parameters + ---------- + val_rankings: dict of tensors + Rankings of positive scores of validation edges for each edge type. + test_rankings: dict of tensors + Rankings of positive scores of test edges for each edge type. + total_iters: int + The current interation number. + + Returns + ----------- + val_score: float + Validation score + test_score: float + Test score + """ + with th.no_grad(): + if test_rankings is not None: + test_score = self.compute_score(test_rankings) + else: + test_score = {} + for metric in self.metric_list: + test_score[metric] = "N/A" # Dummy + + if val_rankings is not None: + val_score = self.compute_score(val_rankings) + + if get_rank() == 0: + for metric in self.metric_list: + # be careful whether > or < it might change per metric. + if self.metrics_obj.metric_comparator[metric]( + self._best_val_score[metric], val_score[metric]): + self._best_val_score[metric] = val_score[metric] + self._best_test_score[metric] = test_score[metric] + self._best_iter[metric] = total_iters + else: + val_score = {} + for metric in self.metric_list: + val_score[metric] = "N/A" # Dummy + + self._history.append((val_score, test_score)) + + return val_score, test_score + + def compute_score(self, rankings, train=True): + """ Compute evaluation score + + Parameters + ---------- + rankings: dict of tensors + Rankings of positive scores in format of {etype: ranking} + train: boolean + If in model training. + + Returns + ------- + Evaluation metric values: dict + """ + # We calculate global score, etype is ignored. + ranking = [] + for _, rank in rankings.items(): + ranking.append(rank) + ranking = th.cat(ranking, dim=0) + + # compute ranking value for each metric + metrics = {} + for metric in self.metric_list: + if train: + # training expects always a single number to be + # returned and has a different (potentially) evaluation function + metrics[metric] = self.metrics_obj.metric_function[metric](ranking) + else: + # validation or testing may have a different + # evaluation function, in our case the evaluation code + # may return a dictionary with the metric values for each metric + metrics[metric] = self.metrics_obj.metric_eval_function[metric](ranking) + + # When world size == 1, we do not need the barrier + if get_world_size() > 1: + barrier() + for _, metric_val in metrics.items(): + th.distributed.all_reduce(metric_val) + + return_metrics = {} + for metric, metric_val in metrics.items(): + return_metric = metric_val / get_world_size() + return_metrics[metric] = return_metric.item() + + return return_metrics + +class GSgnnPerEtypeLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): + """ + The class for link prediction evaluation using "mrr" and/or "hit@k" metrics and + return a per etype score. + + Parameters + ---------- + eval_frequency: int + The frequency (number of iterations) of doing evaluation. + eval_metric_list: list of string + Evaluation metric used during evaluation. Default: ['mrr'] + major_etype: tuple + Canonical etype used for selecting the best model. If None, use the general hit@k. + use_early_stop: bool + Set true to use early stop. + early_stop_burnin_rounds: int + Burn-in rounds before start checking for the early stop condition. + early_stop_rounds: int + The number of rounds for validation scores used to decide early stop. + early_stop_strategy: str + The early stop strategy. GraphStorm supports two strategies: + 1) consecutive_increase and 2) average_increase. + + .. versionadded:: 0.4.0 + The :py:class:`GSgnnPerEtypeLPEvaluator`. + """ + def __init__(self, eval_frequency, + eval_metric_list=None, + major_etype=LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL, + use_early_stop=False, + early_stop_burnin_rounds=0, + early_stop_rounds=3, + early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): + # set default metric list + if eval_metric_list is None: + eval_metric_list = ["mrr"] + super(GSgnnPerEtypeLPEvaluator, self).__init__(eval_frequency, + eval_metric_list, use_early_stop, early_stop_burnin_rounds, + early_stop_rounds, early_stop_strategy) + + self.major_etype = major_etype + self.metrics_obj = LinkPredictionMetrics(eval_metric_list) + + self._best_val_score = {} + self._best_test_score = {} + self._best_iter = {} + for metric in self.metric_list: + self._best_val_score[metric] = self.metrics_obj.init_best_metric(metric=metric) + self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) + self._best_iter[metric] = 0 + + def evaluate(self, val_rankings, test_rankings, total_iters): + """ `GSgnnLinkPredictionTrainer` and `GSgnnLinkPredictionInferrer` will call this function + to compute validation and test scores. + + Parameters + ---------- + val_rankings: dict of tensors + Rankings of positive scores of validation edges for each edge type. + test_rankings: dict of tensors + Rankings of positive scores of test edges for each edge type. + total_iters: int + The current interation number. + + Returns + ----------- + val_score: float + Validation score + test_score: float + Test score + """ + with th.no_grad(): + if test_rankings is not None: + test_score = self.compute_score(test_rankings) + else: + test_score = {} + for metric in self.metric_list: + test_score[metric] = "N/A" # Dummy + + if val_rankings is not None: + val_score = self.compute_score(val_rankings) + + if get_rank() == 0: + for metric in self.metric_list: + # be careful whether > or < it might change per metric. + major_val_score = self._get_major_score(val_score[metric]) + major_test_score = self._get_major_score(test_score[metric]) + if self.metrics_obj.metric_comparator[metric]( + self._best_val_score[metric], major_val_score): + self._best_val_score[metric] = major_val_score + self._best_test_score[metric] = major_test_score + self._best_iter[metric] = total_iters + else: + val_score = {} + for metric in self.metric_list: + val_score[metric] = "N/A" # Dummy + + self._history.append((val_score, test_score)) + + return val_score, test_score + + def _get_major_score(self, score): + """ Get the score for save best model(s) and early stop + """ + if isinstance(self.major_etype, str) and \ + self.major_etype == LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL: + major_score = sum(score.values()) / len(score) + else: + major_score = score[self.major_etype] + return major_score + + def compute_score(self, rankings, train=True): + """ Compute evaluation score + + Parameters + ---------- + rankings: dict of tensors + Rankings of positive scores in format of {etype: ranking} + train: boolean + If in model training. + + Returns + ------- + Evaluation metric values: dict + """ + # We calculate per etype score + per_etype_metrics = {} + for etype, ranking in rankings.items(): + # compute ranking value for each metric + metrics = {} + for metric in self.metric_list: + if train: + # training expects always a single number to be + # returned and has a different (potentially) evaluation function + metrics[metric] = self.metrics_obj.metric_function[metric](ranking) + else: + # validation or testing may have a different + # evaluation function, in our case the evaluation code + # may return a dictionary with the metric values for each metric + metrics[metric] = self.metrics_obj.metric_eval_function[metric](ranking) + per_etype_metrics[etype] = metrics + + # When world size == 1, we do not need the barrier + if get_world_size() > 1: + barrier() + for _, metric in per_etype_metrics.items(): + for _, metric_val in metric.items(): + th.distributed.all_reduce(metric_val) + + return_metrics = {} + for etype, metric in per_etype_metrics.items(): + for metric_key, metric_val in metric.items(): + return_metric = metric_val / get_world_size() + if metric_key not in return_metrics: + return_metrics[metric_key] = {} + return_metrics[metric_key][etype] = return_metric.item() + return return_metrics + + def get_val_score_rank(self, val_score): + """ Get the rank of the validation score of the ``major_etype`` initialized in class + initialization by comparing its value to the existing historical values. If using + the default ``major_etype``, it will compute the rank as the summation of validation + values of all edge types. + + Parameters + ---------- + val_score: dict of dict + A dict in the format of {metric: {etype: score}}. + + Returns + -------- + rank: int + The rank of the validation score of the given ``major_etype`` initialized in + class initialization. If using the default ``major_etype``, the rank will be + computed based on the summation of validation scores for all edge types. + """ + val_score = list(val_score.values())[0] + val_score = self._get_major_score(val_score) + + rank = get_val_score_rank(val_score, + self._val_perf_rank_list, + self.get_metric_comparator()) + # after compare, append the score into existing list + self._val_perf_rank_list.append(val_score) + return rank + class GSgnnMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): """ Evaluator for Link Prediction tasks using ``mrr`` as metric. @@ -838,7 +1166,7 @@ class GSgnnMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): To create a customized Link Prediction evaluator that use an evaluation metric other than ``mrr``, users might need to 1) define a new evaluation interface if the evaluation method - requires different input arguments; 2) inherite the new evaluation interface in a + requires different input arguments; 2) inherit the new evaluation interface in a customized Link Prediction evaluator; 3) define a customized Link Prediction Trainer/Inferrer to call the customized Link Prediction evaluator. @@ -860,6 +1188,9 @@ class GSgnnMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): The early stop strategy. GraphStorm supports two strategies: 1) ``consecutive_increase``, and 2) ``average_increase``. Default: ``average_increase``. + + .. deprecated:: 0.4.0 + Use :py:class:`GSgnnLPEvaluator` instead. """ def __init__(self, eval_frequency, eval_metric_list=None, @@ -883,6 +1214,14 @@ def __init__(self, eval_frequency, self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) self._best_iter[metric] = 0 + + warnings.warn( + "The GSgnnMrrLPEvaluator has been deprecated from version 0.4.0. " + "Please use GSgnnLPEvaluator instead.", + DeprecationWarning, + stacklevel=1 + ) + def evaluate(self, val_rankings, test_rankings, total_iters): """ ``GSgnnLinkPredictionTrainer`` and ``GSgnnLinkPredictionInferrer`` will call this function to compute validation and test ``mrr`` scores. @@ -959,7 +1298,7 @@ def compute_score(self, rankings, train=True): for metric in self.metric_list: if train: # training expects always a single number to be - # returned and has a different (potentially) evluation function + # returned and has a different (potentially) evaluation function metrics[metric] = self.metrics_obj.metric_function[metric](ranking) else: # validation or testing may have a different @@ -981,7 +1320,8 @@ def compute_score(self, rankings, train=True): return return_metrics class GSgnnPerEtypeMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): - """ Evaluator for Link Prediction tasks using ``mrr`` as metric, and + """ + Evaluator for Link Prediction tasks using ``mrr`` as metric, and return per edge type ``mrr`` scores. Parameters @@ -1005,6 +1345,9 @@ class GSgnnPerEtypeMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterfac The early stop strategy. GraphStorm supports two strategies: 1) ``consecutive_increase``, and 2) ``average_increase``. Default: ``average_increase``. + + .. deprecated:: 0.4.0 + Use :py:class:`GSgnnPerEtypeLPEvaluator` instead. """ def __init__(self, eval_frequency, eval_metric_list=None, @@ -1033,6 +1376,13 @@ def __init__(self, eval_frequency, self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) self._best_iter[metric] = 0 + warnings.warn( + "The GSgnnPerEtypeMrrLPEvaluator has been deprecated from version 0.4.0. " + "Please use GSgnnPerEtypeLPEvaluator instead.", + DeprecationWarning, + stacklevel=1 + ) + def evaluate(self, val_rankings, test_rankings, total_iters): """ ``GSgnnLinkPredictionTrainer`` and ``GSgnnLinkPredictionInferrer`` will call this function to compute validation and test ``mrr`` scores. @@ -1109,7 +1459,7 @@ def compute_score(self, rankings, train=True): for metric in self.metric_list: if train: # training expects always a single number to be - # returned and has a different (potentially) evluation function + # returned and has a different (potentially) evaluation function metrics[metric] = self.metrics_obj.metric_function[metric](ranking) else: # validation or testing may have a different @@ -1197,6 +1547,9 @@ class GSgnnHitsLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): The early stop strategy. GraphStorm supports two strategies: 1) ``consecutive_increase``, and 2) ``average_increase``. Default: ``average_increase``. + + .. deprecated:: 0.4.0 + Use :py:class:`GSgnnLPEvaluator` instead. """ def __init__(self, eval_frequency, eval_metric_list=None, @@ -1220,6 +1573,13 @@ def __init__(self, eval_frequency, self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) self._best_iter[metric] = 0 + warnings.warn( + "The GSgnnHitsLPEvaluator has been deprecated from version 0.4.0. " + "Please use GSgnnLPEvaluator instead.", + DeprecationWarning, + stacklevel=1 + ) + def evaluate(self, val_rankings, test_rankings, total_iters): """ ``GSgnnLinkPredictionTrainer`` and ``GSgnnLinkPredictionInferrer`` will call this function to compute validation and test ``hit@k`` scores. @@ -1320,8 +1680,9 @@ def compute_score(self, rankings, train=True): return return_metrics class GSgnnPerEtypeHitsLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface): - """ Evaluator for Link Prediction tasks using ``hit@k`` as metric, and - return per edge type ``hit@k`` scores. + """ + Evaluator for Link Prediction tasks using ``hit@k`` as metric, and + return per edge type ``hit@k`` scores. Parameters ---------- @@ -1344,6 +1705,9 @@ class GSgnnPerEtypeHitsLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterfa early_stop_strategy: str 1) ``consecutive_increase``, and 2) ``average_increase``. Default: ``average_increase``. + + .. deprecated:: 0.4.0 + Use :py:class:`GSgnnPerEtypeLPEvaluator` instead. """ def __init__(self, eval_frequency, eval_metric_list=None, @@ -1370,6 +1734,13 @@ def __init__(self, eval_frequency, self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) self._best_iter[metric] = 0 + warnings.warn( + "The GSgnnPerEtypeHitsLPEvaluator has been deprecated from version 0.4.0. " + "Please use GSgnnPerEtypeLPEvaluator instead.", + DeprecationWarning, + stacklevel=1 + ) + def evaluate(self, val_rankings, test_rankings, total_iters): """ ``GSgnnLinkPredictionTrainer`` and ``GSgnnLinkPredictionInferrer`` will call this function to compute validation and test ``hit@k`` scores. @@ -1520,7 +1891,7 @@ class GSgnnMultiTaskEvalInterface(): def evaluate(self, val_results, test_results, total_iters): """Evaluate validation and test sets for Prediciton tasks - GSgnnTrainers will call this function to do evalution in their eval() fuction. + GSgnnTrainers will call this function to do evaluation in their eval() fuction. Parameters ---------- diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index 723959847d..18d9bf69b6 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -46,6 +46,7 @@ BUILTIN_LP_LOSS_CONTRASTIVELOSS, BUILTIN_CLASS_LOSS_CROSS_ENTROPY, BUILTIN_CLASS_LOSS_FOCAL) +from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS from .model.embed import GSNodeEncoderInputLayer from .model.lm_embed import GSLMNodeEncoderInputLayer, GSPureLMNodeInputLayer from .model.rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer @@ -112,6 +113,8 @@ from .eval import (GSgnnClassificationEvaluator, GSgnnRegressionEvaluator, GSgnnRconstructFeatRegScoreEvaluator, + GSgnnPerEtypeLPEvaluator, + GSgnnLPEvaluator, GSgnnPerEtypeMrrLPEvaluator, GSgnnMrrLPEvaluator) from .trainer import (GSgnnLinkPredictionTrainer, @@ -1118,23 +1121,21 @@ def create_evaluator(task_info): config.early_stop_rounds, config.early_stop_strategy) elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]: - assert len(config.eval_metric) == 1, \ - "GraphStorm doees not support computing multiple metrics at the same time for link prediction tasks." if config.report_eval_per_type: - return GSgnnPerEtypeMrrLPEvaluator( - eval_frequency=config.eval_frequency, - major_etype=config.model_select_etype, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) + return GSgnnPerEtypeLPEvaluator(eval_frequency=config.eval_frequency, + eval_metric_list=config.eval_metric, + major_etype=config.model_select_etype, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) else: - return GSgnnMrrLPEvaluator( - eval_frequency=config.eval_frequency, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) + return GSgnnLPEvaluator(eval_frequency=config.eval_frequency, + eval_metric_list=config.eval_metric, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) elif task_info.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]: return GSgnnRconstructFeatRegScoreEvaluator( config.eval_frequency, @@ -1144,3 +1145,36 @@ def create_evaluator(task_info): config.early_stop_rounds, config.early_stop_strategy) return None + +def create_lp_evaluator(config): + """ Create LP specific evaluator. + + Parameters + ---------- + config: GSConfig + Configuration. + + Return + ------ + Evaluator: A link prediction evaluator + """ + assert all((x.startswith(SUPPORTED_HIT_AT_METRICS) or x == 'mrr') for x in + config.eval_metric), ( + "Invalid LP evaluation metrics. " + "GraphStorm only supports MRR and Hit@K metrics for link prediction.") + + if config.report_eval_per_type: + return GSgnnPerEtypeLPEvaluator(eval_frequency=config.eval_frequency, + eval_metric_list=config.eval_metric, + major_etype=config.model_select_etype, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) + else: + return GSgnnLPEvaluator(eval_frequency=config.eval_frequency, + eval_metric_list=config.eval_metric, + use_early_stop=config.use_early_stop, + early_stop_burnin_rounds=config.early_stop_burnin_rounds, + early_stop_rounds=config.early_stop_rounds, + early_stop_strategy=config.early_stop_strategy) \ No newline at end of file diff --git a/python/graphstorm/inference/lp_infer.py b/python/graphstorm/inference/lp_infer.py index cb58b42f89..44bd986bd8 100644 --- a/python/graphstorm/inference/lp_infer.py +++ b/python/graphstorm/inference/lp_infer.py @@ -104,12 +104,11 @@ def infer(self, data, loader, save_embed_path, if self.evaluator is not None: test_start = time.time() test_rankings = lp_mini_batch_predict(self._model, embs, loader, device) - # TODO: to refactor the names - val_mrr, test_mrr = self.evaluator.evaluate(None, test_rankings, 0) + val_score, test_score = self.evaluator.evaluate(None, test_rankings, 0) sys_tracker.check('run evaluation') if get_rank() == 0: - self.log_print_metrics(val_score=val_mrr, - test_score=test_mrr, + self.log_print_metrics(val_score=val_score, + test_score=test_score, dur_eval=time.time() - test_start, total_steps=0) diff --git a/python/graphstorm/model/edge_decoder.py b/python/graphstorm/model/edge_decoder.py index 0bf46e0687..50b12b37b5 100644 --- a/python/graphstorm/model/edge_decoder.py +++ b/python/graphstorm/model/edge_decoder.py @@ -1116,29 +1116,33 @@ def get_relembs(self): return self._w_relation.weight, self.etype2rid class LinkPredictRotatEDecoder(LinkPredictMultiRelationLearnableDecoder): - r""" Decoder for link prediction using the RotatE as the score function. + r""" + .. versionadded:: 0.4 + The :py:class:`LinkPredictRotatEDecoder`. - Score function of RotateE measures the angular distance between - head and tail elements. The angular distance is defined as: + Decoder for link prediction using the RotatE as the score function. - .. math:: + Score function of RotateE measures the angular distance between + head and tail elements. The angular distance is defined as: - d_r(h, t)=\|h\circ r-t\| + .. math:: - The RotatE score function is defined as: + d_r(h, t)=\|h\circ r-t\| - .. math:: + The RotatE score function is defined as: - gamma - \|h\circ r-t\|^2 + .. math:: - where gamma is a margin. + gamma - \|h\circ r-t\|^2 - For more details, please refer to https://arxiv.org/abs/1902.10197 - or https://dglke.dgl.ai/doc/kg.html#rotatee. + where gamma is a margin. - Note: The relation embedding of RotatE has two parts, - one for real numbers and one for complex numbers. - Each has the dimension size as half of the input dimension size. + For more details, please refer to https://arxiv.org/abs/1902.10197 + or https://dglke.dgl.ai/doc/kg.html#rotatee. + + Note: The relation embedding of RotatE has two parts, + one for real numbers and one for complex numbers. + Each has the dimension size as half of the input dimension size. Parameters ---------- @@ -1376,14 +1380,18 @@ def out_dims(self): return 1 class LinkPredictContrastiveRotatEDecoder(LinkPredictRotatEDecoder): - """ Decoder for link prediction designed for contrastive loss - using the RotatE as the score function. + """ + .. versionadded:: 0.4 + The :py:class:`LinkPredictContrastiveRotatEDecoder`. - Note: - ------ - This class is specifically implemented for contrastive loss. But - it could also be used by other pair-wise loss functions for link - prediction tasks. + Decoder for link prediction designed for contrastive loss + using the RotatE as the score function. + + Note: + ------ + This class is specifically implemented for contrastive loss. But + it could also be used by other pair-wise loss functions for link + prediction tasks. Parameters ---------- @@ -1442,10 +1450,13 @@ def forward(self, g, h, e_h=None): return scores class LinkPredictWeightedRotatEDecoder(LinkPredictRotatEDecoder): - """Link prediction decoder with the score function of RotatE - with edge weight. + """ + .. versionadded:: 0.4 + The :py:class:`LinkPredictWeightedRotatEDecoder`. + + Link prediction decoder with the score function of RotatE with edge weight. - When computing loss, edge weights are used to adjust the loss. + When computing loss, edge weights are used to adjust the loss. Parameters ---------- @@ -1510,26 +1521,30 @@ def forward(self, g, h, e_h): return scores class LinkPredictTransEDecoder(LinkPredictMultiRelationLearnableDecoder): - r""" Decoder for link prediction using the TransE as the score function. + r""" + .. versionadded:: 0.4 + The :py:class:`LinkPredictTransEDecoder`. - Score function of TransE measures the angular distance between - head and tail elements. The angular distance is defined as: + Decoder for link prediction using the TransE as the score function. - .. math:: + Score function of TransE measures the angular distance between + head and tail elements. The angular distance is defined as: - d_r(h, t)= -\|h+r-t\| + .. math:: - The TransE score function is defined as: + d_r(h, t)= -\|h+r-t\| - .. math:: + The TransE score function is defined as: - gamma - \|h+r-t\|^{frac{1}{2}} \text{or} gamma - \|h+r-t\| + .. math:: - where gamma is a margin. + gamma - \|h+r-t\|^{frac{1}{2}} \text{or} gamma - \|h+r-t\| - For more details, please refer to - https://papers.nips.cc/paper_files/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html - or https://dglke.dgl.ai/doc/kg.html#transe. + where gamma is a margin. + + For more details, please refer to + https://papers.nips.cc/paper_files/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html + or https://dglke.dgl.ai/doc/kg.html#transe. Parameters ---------- @@ -1769,14 +1784,18 @@ def out_dims(self): return 1 class LinkPredictContrastiveTransEDecoder(LinkPredictTransEDecoder): - """ Decoder for link prediction designed for contrastive loss - using the TransE as the score function. + """ + .. versionadded:: 0.4 + The :py:class:`LinkPredictContrastiveTransEDecoder`. - Note: - ------ - This class is specifically implemented for contrastive loss. But - it could also be used by other pair-wise loss functions for link - prediction tasks. + Decoder for link prediction designed for contrastive loss + using the TransE as the score function. + + Note: + ------ + This class is specifically implemented for contrastive loss. But + it could also be used by other pair-wise loss functions for link + prediction tasks. Parameters ---------- @@ -1834,10 +1853,13 @@ def forward(self, g, h, e_h=None): return scores class LinkPredictWeightedTransEDecoder(LinkPredictTransEDecoder): - """Link prediction decoder with the score function of TransE - with edge weight. + """ + .. versionadded:: 0.4 + The :py:class:`LinkPredictWeightedTransEDecoder`. + + Link prediction decoder with the score function of TransE with edge weight. - When computing loss, edge weights are used to adjust the loss. + When computing loss, edge weights are used to adjust the loss. Parameters ---------- diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py index 5cef15e91e..d5918ea55a 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py @@ -23,57 +23,9 @@ from graphstorm.config import GSConfig from graphstorm.trainer import GSgnnLinkPredictionTrainer from graphstorm.dataloading import GSgnnData -from graphstorm.eval import (GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator, - GSgnnHitsLPEvaluator, GSgnnPerEtypeHitsLPEvaluator) from graphstorm.model.utils import save_full_node_embeddings from graphstorm.model import do_full_graph_inference from graphstorm.utils import rt_profiler, sys_tracker, get_device -from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS - -def get_evaluator(config): - """ Get evaluator according to config - - Parameters - ---------- - config: GSConfig - Configuration - """ - # TODO: to create a generic evaluator for LP tasks - assert (len(config.eval_metric) == 1 and config.eval_metric[0] == 'mrr') \ - or (len(config.eval_metric) >= 1 - and all((x.startswith(SUPPORTED_HIT_AT_METRICS) for x in config.eval_metric))), \ - "GraphStorm does not support computing MRR and Hit@K metrics at the same time." - - if config.report_eval_per_type: - if 'mrr' in config.eval_metric: - return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency, - major_etype=config.model_select_etype, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) - else: - return GSgnnPerEtypeHitsLPEvaluator(eval_frequency=config.eval_frequency, - eval_metric_list=config.eval_metric, - major_etype=config.model_select_etype, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) - else: - if 'mrr' in config.eval_metric: - return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) - else: - return GSgnnHitsLPEvaluator(eval_frequency=config.eval_frequency, - eval_metric_list=config.eval_metric, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) def main(config_args): """ main function @@ -98,7 +50,7 @@ def main(config_args): if not config.no_validation: # TODO(zhengda) we need to refactor the evaluator. # Currently, we only support mrr - evaluator = get_evaluator(config) + evaluator = gs.create_lp_evaluator(config) trainer.setup_evaluator(evaluator) val_idxs = train_data.get_edge_val_set(config.eval_etype) assert len(val_idxs) > 0, "The training data do not have validation set." diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index 44b6aac5cd..696ef93283 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -24,9 +24,6 @@ from graphstorm.config import GSConfig from graphstorm.trainer import GSgnnLinkPredictionTrainer from graphstorm.dataloading import GSgnnData -from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS -from graphstorm.eval import (GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator, - GSgnnHitsLPEvaluator, GSgnnPerEtypeHitsLPEvaluator) from graphstorm.model.utils import save_full_node_embeddings from graphstorm.model import do_full_graph_inference from graphstorm.utils import ( @@ -37,51 +34,6 @@ ) from graphstorm.utils import get_lm_ntypes -def get_evaluator(config): - """ Get evaluator according to config - - Parameters - ---------- - config: GSConfig - Configuration - """ - # TODO: to create a generic evaluator for LP tasks - assert (len(config.eval_metric) == 1 and config.eval_metric[0] == 'mrr') \ - or (len(config.eval_metric) >= 1 - and all((x.startswith(SUPPORTED_HIT_AT_METRICS) for x in config.eval_metric))), \ - "GraphStorm does not support computing MRR and Hit@K metrics at the same time." - - if config.report_eval_per_type: - if 'mrr' in config.eval_metric: - return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency, - major_etype=config.model_select_etype, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) - else: - return GSgnnPerEtypeHitsLPEvaluator(eval_frequency=config.eval_frequency, - eval_metric_list=config.eval_metric, - major_etype=config.model_select_etype, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) - else: - if 'mrr' in config.eval_metric: - return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) - else: - return GSgnnHitsLPEvaluator(eval_frequency=config.eval_frequency, - eval_metric_list=config.eval_metric, - use_early_stop=config.use_early_stop, - early_stop_burnin_rounds=config.early_stop_burnin_rounds, - early_stop_rounds=config.early_stop_rounds, - early_stop_strategy=config.early_stop_strategy) - def main(config_args): """ main function """ @@ -108,7 +60,7 @@ def main(config_args): if not config.no_validation: # TODO(zhengda) we need to refactor the evaluator. # Currently, we only support mrr - evaluator = get_evaluator(config) + evaluator = gs.create_lp_evaluator(config) trainer.setup_evaluator(evaluator) val_idxs = train_data.get_edge_val_set(config.eval_etype) assert len(val_idxs) > 0, "The training data do not have validation set." diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index 782fc4e112..36e10b0aa6 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -21,7 +21,7 @@ from graphstorm.config import get_argument_parser from graphstorm.config import GSConfig from graphstorm.inference import GSgnnLinkPredictionInferrer -from graphstorm.eval import GSgnnMrrLPEvaluator, GSgnnHitsLPEvaluator +from graphstorm.eval import GSgnnLPEvaluator from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, GSgnnLinkPredictionJointTestDataLoader, @@ -56,19 +56,14 @@ def main(config_args): model_layer_to_load=config.restore_model_layers) infer = GSgnnLinkPredictionInferrer(model) infer.setup_device(device=get_device()) - # TODO: to create a generic evaluator for LP tasks - if len(config.eval_metric) > 1 and ("mrr" in config.eval_metric) \ - and any((x.startswith(SUPPORTED_HIT_AT_METRICS) for x in config.eval_metric)): - logging.warning("GraphStorm does not support computing MRR and Hit@K metrics at the " - "same time. If both metrics are given, only 'mrr' is returned.") + assert all((x.startswith(SUPPORTED_HIT_AT_METRICS) or x == 'mrr') for x in + config.eval_metric), ( + "Invalid LP evaluation metrics. " + "GraphStorm only supports MRR and Hit@K metrics for link prediction.") if not config.no_validation: infer_idxs = infer_data.get_edge_test_set(config.eval_etype) - if len(config.eval_metric) == 0 or 'mrr' in config.eval_metric: - infer.setup_evaluator( - GSgnnMrrLPEvaluator(config.eval_frequency)) - else: - infer.setup_evaluator(GSgnnHitsLPEvaluator( - config.eval_frequency, eval_metric_list=config.eval_metric)) + infer.setup_evaluator(GSgnnLPEvaluator( + config.eval_frequency, eval_metric_list=config.eval_metric)) assert len(infer_idxs) > 0, "There is not test data for evaluation." else: infer_idxs = infer_data.get_edge_infer_set(config.eval_etype) diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py index 64bdb778c0..4004556cd3 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py @@ -23,7 +23,7 @@ from graphstorm.config import get_argument_parser from graphstorm.config import GSConfig from graphstorm.inference import GSgnnLinkPredictionInferrer -from graphstorm.eval import GSgnnMrrLPEvaluator, GSgnnHitsLPEvaluator +from graphstorm.eval import GSgnnLPEvaluator from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, GSgnnLinkPredictionJointTestDataLoader, @@ -50,19 +50,14 @@ def main(config_args): model_layer_to_load=config.restore_model_layers) infer = GSgnnLinkPredictionInferrer(model) infer.setup_device(device=get_device()) - # TODO: to create a generic evaluator for LP tasks - if len(config.eval_metric) > 1 and ("mrr" in config.eval_metric) \ - and any((x.startswith(SUPPORTED_HIT_AT_METRICS) for x in config.eval_metric)): - logging.warning("GraphStorm does not support computing MRR and Hit@K metrics at the " - "same time. If both metrics are given, only 'mrr' is returned.") + assert all((x.startswith(SUPPORTED_HIT_AT_METRICS) or x == 'mrr') for x in + config.eval_metric), ( + "Invalid LP evaluation metrics. " + "GraphStorm only supports MRR and Hit@K metrics for link prediction.") if not config.no_validation: infer_idxs = infer_data.get_edge_test_set(config.eval_etype) - if len(config.eval_metric) == 0 or 'mrr' in config.eval_metric: - infer.setup_evaluator( - GSgnnMrrLPEvaluator(config.eval_frequency)) - else: - infer.setup_evaluator(GSgnnHitsLPEvaluator( - config.eval_frequency, eval_metric_list=config.eval_metric)) + infer.setup_evaluator(GSgnnLPEvaluator( + config.eval_frequency, eval_metric_list=config.eval_metric)) assert len(infer_idxs) > 0, "There is not test data for evaluation." else: infer_idxs = infer_data.get_edge_infer_set(config.eval_etype) diff --git a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh index 2e4a9cf780..6de93d6c3c 100644 --- a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh @@ -26,7 +26,7 @@ error_and_exit () { df /dev/shm -h -echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, exclude_training_targets: true, save model" +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, exclude_training_targets: true, save model, eval_metric [hit_at_1 hit_at_3 hit_at_10]" python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_dot/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_dot/emb/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --eval-metric hit_at_1 hit_at_3 hit_at_10 error_and_exit $? @@ -132,6 +132,112 @@ fi rm /tmp/train_log.txt +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, exclude_training_targets: true, save model, eval_metric [hit_at_1 mrr hit_at_10]" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_dot/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_dot/emb/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --eval-metric hit_at_1 mrr hit_at_10 + +error_and_exit $? + +# check prints +cnt=$(grep "save_embed_path: /data/gsgnn_lp_ml_dot/emb/" /tmp/train_log.txt | wc -l) +if test $cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have save_embed_path" + exit -1 +fi + +cnt=$(grep "save_model_path: /data/gsgnn_lp_ml_dot/" /tmp/train_log.txt | wc -l) +if test $cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have save_model_path" + exit -1 +fi + +bst_cnt=$(grep "Best Test hit_at_1" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test hit@1" + exit -1 +fi + +cnt=$(grep "| Test hit_at_1" /tmp/train_log.txt | wc -l) +if test $cnt -lt $bst_cnt +then + echo "We use SageMaker task tracker, we should have Test hit@1" + exit -1 +fi + +bst_cnt=$(grep "Best Validation hit_at_1" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation hit@1" + exit -1 +fi + +cnt=$(grep "Validation hit_at_1" /tmp/train_log.txt | wc -l) +if test $cnt -lt $bst_cnt +then + echo "We use SageMaker task tracker, we should have Validation hit@1" + exit -1 +fi + +bst_cnt=$(grep "Best Test mrr" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test mrr" + exit -1 +fi + +cnt=$(grep "| Test mrr" /tmp/train_log.txt | wc -l) +if test $cnt -lt $bst_cnt +then + echo "We use SageMaker task tracker, we should have Test mrr" + exit -1 +fi + +bst_cnt=$(grep "Best Validation mrr" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation mrr" + exit -1 +fi + +cnt=$(grep "Validation mrr" /tmp/train_log.txt | wc -l) +if test $cnt -lt $bst_cnt +then + echo "We use SageMaker task tracker, we should have Validation mrr" + exit -1 +fi + +bst_cnt=$(grep "Best Test hit_at_10" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Test hit@10" + exit -1 +fi + +cnt=$(grep "| Test hit_at_10" /tmp/train_log.txt | wc -l) +if test $cnt -lt $bst_cnt +then + echo "We use SageMaker task tracker, we should have Test hit@10" + exit -1 +fi + +bst_cnt=$(grep "Best Validation hit_at_10" /tmp/train_log.txt | wc -l) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation hit@10" + exit -1 +fi + +cnt=$(grep "Validation hit_at_10" /tmp/train_log.txt | wc -l) +if test $cnt -lt $bst_cnt +then + echo "We use SageMaker task tracker, we should have Validation hit@10" + exit -1 +fi + +rm /tmp/train_log.txt + echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, exclude_training_targets: true, save model" python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_dot/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_dot/emb/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True diff --git a/tests/unit-tests/test_evaluator.py b/tests/unit-tests/test_evaluator.py index cee2be38ec..44306e694a 100644 --- a/tests/unit-tests/test_evaluator.py +++ b/tests/unit-tests/test_evaluator.py @@ -21,7 +21,9 @@ from numpy.testing import assert_equal, assert_almost_equal import dgl -from graphstorm.eval import (GSgnnMrrLPEvaluator, +from graphstorm.eval import (GSgnnLPEvaluator, + GSgnnPerEtypeLPEvaluator, + GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator, GSgnnHitsLPEvaluator, GSgnnPerEtypeHitsLPEvaluator, @@ -84,6 +86,22 @@ def gen_hits_lp_eval_data(): return config, etypes, (val_pos_scores, val_neg_scores), (test_pos_scores, test_neg_scores) +def gen_lp_eval_data(): + config = Dummy({ + "eval_frequency": 100, + "eval_metric_list": ["hit_at_1", "mrr", "hit_at_10", "hit_at_100", "hit_at_200"], + "use_early_stop": False, + }) + + etypes = [("n0", "r0", "n1"), ("n0", "r1", "n1")] + + val_pos_scores = th.rand((10, 1)) + val_neg_scores = th.rand((10, 10)) + test_pos_scores = th.rand((10, 1)) + test_neg_scores = th.rand((10, 10)) + + return config, etypes, (val_pos_scores, val_neg_scores), (test_pos_scores, test_neg_scores) + def test_mrr_per_etype_lp_evaluation(): # system heavily depends on th distributed @@ -645,6 +663,333 @@ def check_evaluate_infer(mock_compute_score): th.distributed.destroy_process_group() +def test_per_etype_lp_evaluation(): + # system heavily depends on th distributed + dist_init_method = 'tcp://{master_ip}:{master_port}'.format( + master_ip='127.0.0.1', master_port='12346') + th.distributed.init_process_group(backend="gloo", + init_method=dist_init_method, + world_size=1, + rank=0) + config, etypes, val_scores, test_scores = gen_lp_eval_data() + + score = { + ("a", "r1", "b"): 0.9, + ("a", "r2", "b"): 0.8, + } + + # Test get_major_score + lp = GSgnnPerEtypeLPEvaluator(config.eval_frequency, + eval_metric_list=config.eval_metric_list, + use_early_stop=False) + assert lp.major_etype == LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL + + m_score = lp._get_major_score(score) + assert m_score == sum(score.values()) / 2 + + # Test get_major_score + lp = GSgnnPerEtypeLPEvaluator(config.eval_frequency, + eval_metric_list=config.eval_metric_list, + major_etype=("a", "r2", "b"), + use_early_stop=config.use_early_stop) + assert lp.major_etype == ("a", "r2", "b") + + m_score = lp._get_major_score(score) + assert m_score == score[("a", "r2", "b")] + + # Test score computation + val_pos_scores, val_neg_scores = val_scores + test_pos_scores, test_neg_scores = test_scores + + lp = GSgnnPerEtypeLPEvaluator(config.eval_frequency, + eval_metric_list=config.eval_metric_list, + use_early_stop=config.use_early_stop) + + # test for val scores + rank0 = [] + rank1 = [] + for i in range(len(val_pos_scores)): + val_pos = val_pos_scores[i] + val_neg0 = val_neg_scores[i] / 2 + val_neg1 = val_neg_scores[i] / 4 + scores = th.cat([val_pos, val_neg0]) + _, indices = th.sort(scores, descending=True) + ranking = th.nonzero(indices == 0) + 1 + rank0.append(ranking.cpu().detach()) + scores = th.cat([val_pos, val_neg1]) + _, indices = th.sort(scores, descending=True) + ranking = th.nonzero(indices == 0) + 1 + rank1.append(ranking.cpu().detach()) + val_ranks = {etypes[0]: th.cat(rank0, dim=0), etypes[1]: th.cat(rank1, dim=0)} + val_s = lp.compute_score(val_ranks) + + for metric in config.eval_metric_list: + if metric == 'mrr': + mrr = 1.0 / val_ranks[etypes[0]] + mrr = th.sum(mrr) / len(mrr) + assert_almost_equal(val_s['mrr'][etypes[0]], mrr.numpy(), decimal=7) + mrr = 1.0 / val_ranks[etypes[1]] + mrr = th.sum(mrr) / len(mrr) + assert_almost_equal(val_s['mrr'][etypes[1]], mrr.numpy(), decimal=7) + else: + k = int(metric[len(SUPPORTED_HIT_AT_METRICS) + 1:]) + hits_0 = th.div(th.sum(th.squeeze(val_ranks[etypes[0]]) <= k), len(th.squeeze(val_ranks[etypes[0]]))) + assert_almost_equal(val_s[metric][etypes[0]], hits_0.numpy(), decimal=7) + hits_1 = th.div(th.sum(th.squeeze(val_ranks[etypes[1]]) <= k), len(th.squeeze(val_ranks[etypes[1]]))) + assert_almost_equal(val_s[metric][etypes[1]], hits_1.numpy(), decimal=7) + + # test for test scores + rank0 = [] + rank1 = [] + for i in range(len(test_pos_scores)): + test_pos = test_pos_scores[i] + test_neg0 = test_neg_scores[i] / 2 + test_neg1 = test_neg_scores[i] / 4 + scores = th.cat([test_pos, test_neg0]) + _, indices = th.sort(scores, descending=True) + ranking = th.nonzero(indices == 0) + 1 + rank0.append(ranking.cpu().detach()) + scores = th.cat([test_pos, test_neg1]) + _, indices = th.sort(scores, descending=True) + ranking = th.nonzero(indices == 0) + 1 + rank1.append(ranking.cpu().detach()) + test_ranks = {etypes[0]: th.cat(rank0, dim=0), etypes[1]: th.cat(rank1, dim=0)} + test_s = lp.compute_score(test_ranks) + + for metric in config.eval_metric_list: + if metric == 'mrr': + mrr = 1.0 / test_ranks[etypes[0]] + mrr = th.sum(mrr) / len(mrr) + assert_almost_equal(np.array([test_s['mrr'][etypes[0]]]), mrr.numpy(), decimal=7) + mrr = 1.0 / test_ranks[etypes[1]] + mrr = th.sum(mrr) / len(mrr) + assert_almost_equal(np.array([test_s['mrr'][etypes[1]]]), mrr.numpy(), decimal=7) + else: + k = int(metric[len(SUPPORTED_HIT_AT_METRICS) + 1:]) + hits_0 = th.div(th.sum(th.squeeze(test_ranks[etypes[0]]) <= k), len(th.squeeze(test_ranks[etypes[0]]))) + assert_almost_equal(test_s[metric][etypes[0]], hits_0.numpy(), decimal=7) + hits_1 = th.div(th.sum(th.squeeze(test_ranks[etypes[1]]) <= k), len(th.squeeze(test_ranks[etypes[1]]))) + assert_almost_equal(test_s[metric][etypes[1]], hits_1.numpy(), decimal=7) + + # Check evaluate() + val_sc, test_sc = lp.evaluate(val_ranks, test_ranks, 0) + for metric in config.eval_metric_list: + val_s_score = (val_s[metric][etypes[0]] + val_s[metric][etypes[1]]) / 2 + test_s_score = (test_s[metric][etypes[0]] + test_s[metric][etypes[1]]) / 2 + assert_equal(val_s[metric][etypes[0]], val_sc[metric][etypes[0]]) + assert_equal(val_s[metric][etypes[1]], val_sc[metric][etypes[1]]) + assert_equal(test_s[metric][etypes[0]], test_sc[metric][etypes[0]]) + assert_equal(test_s[metric][etypes[1]], test_sc[metric][etypes[1]]) + + assert_almost_equal(np.array([val_s_score]), lp.best_val_score[metric]) + assert_almost_equal(np.array([test_s_score]), lp.best_test_score[metric]) + + lp = GSgnnPerEtypeLPEvaluator(config.eval_frequency, + eval_metric_list=config.eval_metric_list, + major_etype=etypes[1], + use_early_stop=config.use_early_stop) + + for metric in config.eval_metric_list: + val_sc, test_sc = lp.evaluate(val_ranks, test_ranks, 0) + assert_equal(val_s[metric][etypes[0]], val_sc[metric][etypes[0]]) + assert_equal(val_s[metric][etypes[1]], val_sc[metric][etypes[1]]) + assert_equal(test_s[metric][etypes[0]], test_sc[metric][etypes[0]]) + assert_equal(test_s[metric][etypes[1]], test_sc[metric][etypes[1]]) + + assert_almost_equal(val_s[metric][etypes[1]], lp.best_val_score[metric]) + assert_almost_equal(test_s[metric][etypes[1]], lp.best_test_score[metric]) + + th.distributed.destroy_process_group() + +def test_lp_evaluator(): + # system heavily depends on th distributed + dist_init_method = 'tcp://{master_ip}:{master_port}'.format( + master_ip='127.0.0.1', master_port='12346') + th.distributed.init_process_group(backend="gloo", + init_method=dist_init_method, + world_size=1, + rank=0) + config, etypes, val_scores, test_scores = gen_lp_eval_data() + val_pos_scores, val_neg_scores = val_scores + test_pos_scores, test_neg_scores = test_scores + + # test default settings + lp = GSgnnLPEvaluator(config.eval_frequency, + use_early_stop=config.use_early_stop) + assert lp.metric_list == ["mrr"] + + # test given settings + lp = GSgnnLPEvaluator(config.eval_frequency, + eval_metric_list=config.eval_metric_list, + use_early_stop=config.use_early_stop) + + # test computation for val scores + rank = [] + for i in range(len(val_pos_scores)): + val_pos = val_pos_scores[i] + val_neg0 = val_neg_scores[i] / 2 + val_neg1 = val_neg_scores[i] / 4 + scores = th.cat([val_pos, val_neg0]) + _, indices = th.sort(scores, descending=True) + ranking = th.nonzero(indices == 0) + 1 + rank.append(ranking.cpu().detach()) + scores = th.cat([val_pos, val_neg1]) + _, indices = th.sort(scores, descending=True) + ranking = th.nonzero(indices == 0) + 1 + rank.append(ranking.cpu().detach()) + val_ranks = {etypes[0]: th.cat(rank, dim=0)} + val_s = lp.compute_score(val_ranks) + for metric in config.eval_metric_list: + if metric == 'mrr': + mrr = 1.0 / val_ranks[etypes[0]] + mrr = th.sum(mrr) / len(mrr) + assert_almost_equal(val_s['mrr'], mrr.numpy(), decimal=7) + else: + k = int(metric[len(SUPPORTED_HIT_AT_METRICS) + 1:]) + hits_0 = th.div(th.sum(th.squeeze(val_ranks[etypes[0]]) <= k), len(th.squeeze(val_ranks[etypes[0]]))) + assert_almost_equal(val_s[metric], hits_0.numpy(), decimal=7) + + # test computation for test scores + rank = [] + for i in range(len(test_pos_scores)): + test_pos = test_pos_scores[i] + test_neg0 = test_neg_scores[i] / 2 + test_neg1 = test_neg_scores[i] / 4 + scores = th.cat([test_pos, test_neg0]) + _, indices = th.sort(scores, descending=True) + ranking = th.nonzero(indices == 0) + 1 + rank.append(ranking.cpu().detach()) + scores = th.cat([test_pos, test_neg1]) + _, indices = th.sort(scores, descending=True) + ranking = th.nonzero(indices == 0) + 1 + rank.append(ranking.cpu().detach()) + test_ranks = {etypes[0]: th.cat(rank, dim=0)} + test_s = lp.compute_score(test_ranks) + for metric in config.eval_metric_list: + if metric == 'mrr': + mrr = 1.0 / test_ranks[etypes[0]] + mrr = th.sum(mrr) / len(mrr) + assert_almost_equal(test_s['mrr'], mrr.numpy(), decimal=7) + else: + k = int(metric[len(SUPPORTED_HIT_AT_METRICS) + 1:]) + hits_0 = th.div(th.sum(th.squeeze(test_ranks[etypes[0]]) <= k), len(th.squeeze(test_ranks[etypes[0]]))) + assert_almost_equal(test_s[metric], hits_0.numpy(), decimal=7) + + # check evaluate() + val_sc, test_sc = lp.evaluate(val_ranks, test_ranks, 0) + for metric in config.eval_metric_list: + assert_equal(val_s[metric], val_sc[metric]) + assert_equal(test_s[metric], test_sc[metric]) + + # val_ranks is None + val_sc, test_sc = lp.evaluate(None, test_ranks, 0) + for metric in config.eval_metric_list: + assert_equal(val_sc[metric], "N/A") + assert_equal(test_s[metric], test_sc[metric]) + + # test_ranks is None + val_sc, test_sc = lp.evaluate(val_ranks, None, 0) + for metric in config.eval_metric_list: + assert_equal(val_s[metric], val_sc[metric]) + assert_equal(test_sc[metric], "N/A") + + # test evaluate + @patch.object(GSgnnLPEvaluator, 'compute_score') + def check_evaluate(mock_compute_score): + lp = GSgnnLPEvaluator(config.eval_frequency, + eval_metric_list=["hit_at_1", "mrr", "hit_at_10"], + use_early_stop=config.use_early_stop) + + mock_compute_score.side_effect = [ + {"hit_at_1": 0.6, "mrr": 0.66, "hit_at_10": 0.9}, + {"hit_at_1": 0.7, "mrr": 0.73, "hit_at_10": 0.75}, + {"hit_at_1": 0.65, "mrr": 0.68, "hit_at_10": 0.8}, + {"hit_at_1": 0.76, "mrr": 0.82, "hit_at_10": 0.76}, + {"hit_at_1": 0.76, "mrr": 0.81, "hit_at_10": 0.78}, + {"hit_at_1": 0.8, "mrr": 0.86, "hit_at_10": 0.85} + ] + + val_score, test_score = lp.evaluate( + {("u", "b", "v") : ()}, {("u", "b", "v") : ()}, 100) + mock_compute_score.assert_called() + assert val_score["hit_at_1"] == 0.7 and val_score["hit_at_10"] == 0.75 and val_score["mrr"] == 0.73 + assert test_score["hit_at_1"] == 0.6 and test_score["hit_at_10"] == 0.9 and test_score["mrr"] == 0.66 + + val_score, test_score = lp.evaluate( + {("u", "b", "v") : ()}, {("u", "b", "v") : ()}, 200) + mock_compute_score.assert_called() + assert val_score["hit_at_1"] == 0.76 and val_score["hit_at_10"] == 0.76 and val_score["mrr"] == 0.82 + assert test_score["hit_at_1"] == 0.65 and test_score["hit_at_10"] == 0.8 and test_score["mrr"] == 0.68 + + val_score, test_score = lp.evaluate( + {("u", "b", "v") : ()}, {("u", "b", "v") : ()}, 300) + mock_compute_score.assert_called() + assert val_score["hit_at_1"] == 0.8 and val_score["hit_at_10"] == 0.85 and val_score["mrr"] == 0.86 + assert test_score["hit_at_1"] == 0.76 and test_score["hit_at_10"] == 0.78 and test_score["mrr"] == 0.81 + + assert lp.best_val_score["hit_at_1"] == 0.8 and lp.best_val_score["hit_at_10"] == 0.85 and lp.best_val_score["mrr"] == 0.86 + assert lp.best_test_score["hit_at_1"] == 0.76 and lp.best_test_score["hit_at_10"] == 0.78 and lp.best_test_score["mrr"] == 0.81 + assert lp.best_iter_num["hit_at_1"] == 300 and lp.best_iter_num["hit_at_10"] == 300 and lp.best_iter_num["mrr"] == 300 + + # check GSgnnLPEvaluator.evaluate() + check_evaluate() + + # test evaluate + @patch.object(GSgnnLPEvaluator, 'compute_score') + def check_evaluate_infer(mock_compute_score): + lp = GSgnnLPEvaluator(config.eval_frequency, + eval_metric_list=["hit_at_1", "mrr", "hit_at_10"], + use_early_stop=config.use_early_stop) + + mock_compute_score.side_effect = [ + {"hit_at_1": 0.7, "mrr": 0.66, "hit_at_10": 0.9}, + {"hit_at_1": 0.78, "mrr": 0.73, "hit_at_10": 0.88}, + ] + + val_score, test_score = lp.evaluate(None, [], 100) + mock_compute_score.assert_called() + assert val_score["hit_at_1"] == "N/A" and val_score["hit_at_10"] == "N/A" and val_score["mrr"] == "N/A" + assert test_score["hit_at_1"] == 0.7 and test_score["hit_at_10"] == 0.9 and test_score["mrr"] == 0.66 + + val_score, test_score = lp.evaluate(None, [], 200) + mock_compute_score.assert_called() + assert val_score["hit_at_1"] == "N/A" and val_score["hit_at_10"] == "N/A" and val_score["mrr"] == "N/A" + assert test_score["hit_at_1"] == 0.78 and test_score["hit_at_10"] == 0.88 and test_score["mrr"] == 0.73 + + assert lp.best_val_score["hit_at_1"] == 0 and lp.best_val_score["hit_at_10"] == 0 and lp.best_val_score["mrr"] == 0 + assert lp.best_test_score["hit_at_1"] == 0 and lp.best_test_score["hit_at_10"] == 0 and lp.best_test_score["mrr"] == 0 + assert lp.best_iter_num["hit_at_1"] == 0 and lp.best_iter_num["hit_at_10"] == 0 and lp.best_iter_num["mrr"] == 0 + + check_evaluate_infer() + + # check GSgnnLPEvaluator.do_eval() + # train_data.do_validation True + # config.no_validation False + lp = GSgnnLPEvaluator(config.eval_frequency, + eval_metric_list=["hit_at_1", "mrr", "hit_at_10"], + use_early_stop=config.use_early_stop) + assert lp.do_eval(120, epoch_end=True) is True + assert lp.do_eval(200) is True + assert lp.do_eval(0) is True + assert lp.do_eval(1) is False + + config3 = Dummy({ + "eval_frequency": 0, + "eval_metric_list": ["hit_at_1", "mrr", "hit_at_10"], + "use_early_stop": False, + }) + + # train_data.do_validation True + # config.no_validation False + # eval_frequency is 0 + lp = GSgnnLPEvaluator(config3.eval_frequency, + eval_metric_list=config3.eval_metric_list, + use_early_stop=config3.use_early_stop) + assert lp.do_eval(120, epoch_end=True) is True + assert lp.do_eval(200) is False + + th.distributed.destroy_process_group() + def test_classification_evaluator(): # system heavily depends on th distributed dist_init_method = 'tcp://{master_ip}:{master_port}'.format( @@ -1327,4 +1672,6 @@ def check_multi_task_eval(mock_reg_compute_score, mock_class_compute_score, mock test_classification_evaluator() test_hits_per_etype_lp_evaluation() - test_hits_lp_evaluator() \ No newline at end of file + test_hits_lp_evaluator() + test_per_etype_lp_evaluation() + test_lp_evaluator() \ No newline at end of file