From 993a71f55ab0c89a18994d717c43fd3ae0f8374c Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Thu, 24 Oct 2024 10:08:49 -0700 Subject: [PATCH] Add Adjusted Mean Ranking Index metric for Link Prediction (#1061) *Issue #, if available:* *Description of changes:* * Add Adjust Mean Rank Index LP metric. This metric is normalized by the candidate list size, allowing for easier comparison between datasets/models and negative edge sample counts. * To get the list sizes we modify `run_lp_mini_batch_predict` and `lp_mini_batch_predict` to _conditionally_ return a tuple of `rankings, lengths` that allows us to calculate AMRI in the cases where it's needed. The return type is determined by a new argument added, with a default value, so the changes are backwards compatible, existing calls to the two functions will function as before. * Add `LinkPredictionTestScoreInterface` as a common ancestor to `LinkPredictNoParamDecoder` and `LinkPredictLearnableDecoder`, this way we can ensure at runtime that the decoder should implement the `calc_test_scores` function that is used by the LP evaluator. * Modify the `GSgnnLPRankingEvalInterface` `evaluate` function to add optional `**kwargs`. In cases where we want to calculate a metric that needs the candidate list lengths we use these kwargs to pass the information. * Modify the `def compute_score(self, rankings, train=True, **kwargs):` to add optional kwargs, currently only used during AMRI calculation. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: xiang song(charlie.song) --- .github/workflow_scripts/pytest_check.sh | 1 - .../mt_infer/ml_nc_ec_er_lp_only_infer.yaml | 3 + .../ml_nc_ec_er_lp_with_mask_infer.yaml | 3 + python/graphstorm/eval/eval_func.py | 112 ++++++++-- python/graphstorm/eval/evaluator.py | 206 +++++++++++++++--- python/graphstorm/gsf.py | 9 +- .../graphstorm/inference/graphstorm_infer.py | 3 +- python/graphstorm/inference/lp_infer.py | 13 +- python/graphstorm/inference/mt_infer.py | 28 ++- python/graphstorm/model/edge_decoder.py | 57 ++++- python/graphstorm/model/lp_gnn.py | 72 ++++-- python/graphstorm/run/gsgnn_lp/gsgnn_lp.py | 2 - python/graphstorm/trainer/gsgnn_trainer.py | 4 +- python/graphstorm/trainer/lp_trainer.py | 30 ++- python/graphstorm/trainer/mt_trainer.py | 40 ++-- .../end2end-tests/graphstorm-lp/mgpu_test.sh | 18 +- .../end2end-tests/graphstorm-mt/mgpu_test.sh | 6 +- tests/unit-tests/test_eval_func.py | 78 +++++-- tests/unit-tests/test_inferrer.py | 5 +- tests/unit-tests/test_trainer.py | 27 ++- 20 files changed, 571 insertions(+), 146 deletions(-) diff --git a/.github/workflow_scripts/pytest_check.sh b/.github/workflow_scripts/pytest_check.sh index 3b8b87a576..1937eaf098 100644 --- a/.github/workflow_scripts/pytest_check.sh +++ b/.github/workflow_scripts/pytest_check.sh @@ -7,4 +7,3 @@ FORCE_CUDA=1 python3 -m pip install -e '.[test]' --no-build-isolation python3 -m pip install pytest sh ./tests/unit-tests/prepare_test_data.sh export NCCL_IB_DISABLE=1; export NCCL_SHM_DISABLE=1; NCCL_NET=Socket NCCL_DEBUG=INFO python3 -m pytest -x ./tests/unit-tests -s - diff --git a/inference_scripts/mt_infer/ml_nc_ec_er_lp_only_infer.yaml b/inference_scripts/mt_infer/ml_nc_ec_er_lp_only_infer.yaml index ed78af480c..810fbe868e 100644 --- a/inference_scripts/mt_infer/ml_nc_ec_er_lp_only_infer.yaml +++ b/inference_scripts/mt_infer/ml_nc_ec_er_lp_only_infer.yaml @@ -71,6 +71,9 @@ gsf: reverse_edge_types_map: - user,rating,rating-rev,movie batch_size: 128 # will overwrite the global batch_size + eval_metric: + - "mrr" + - "amri" - reconstruct_node_feat: reconstruct_nfeat_name: "title" target_ntype: "movie" diff --git a/inference_scripts/mt_infer/ml_nc_ec_er_lp_with_mask_infer.yaml b/inference_scripts/mt_infer/ml_nc_ec_er_lp_with_mask_infer.yaml index 00ae2bbfe5..70e308432c 100644 --- a/inference_scripts/mt_infer/ml_nc_ec_er_lp_with_mask_infer.yaml +++ b/inference_scripts/mt_infer/ml_nc_ec_er_lp_with_mask_infer.yaml @@ -91,6 +91,9 @@ gsf: - "train_mask_field_lp" - null # empty means there is no validation mask - "test_mask_field_lp" + eval_metric: + - "mrr" + - "amri" - reconstruct_node_feat: reconstruct_nfeat_name: "title" target_ntype: "movie" diff --git a/python/graphstorm/eval/eval_func.py b/python/graphstorm/eval/eval_func.py index c940d61017..07991357ee 100644 --- a/python/graphstorm/eval/eval_func.py +++ b/python/graphstorm/eval/eval_func.py @@ -16,9 +16,11 @@ Evaluation functions """ import logging +import operator +from collections.abc import Callable from enum import Enum from functools import partial -import operator + import numpy as np import torch as th from sklearn.metrics import roc_auc_score @@ -28,7 +30,7 @@ SUPPORTED_CLASSIFICATION_METRICS = {'accuracy', 'precision_recall', \ 'roc_auc', 'f1_score', 'per_class_f1_score', 'per_class_roc_auc', SUPPORTED_HIT_AT_METRICS} SUPPORTED_REGRESSION_METRICS = {'rmse', 'mse', 'mae'} -SUPPORTED_LINK_PREDICTION_METRICS = {"mrr", SUPPORTED_HIT_AT_METRICS} +SUPPORTED_LINK_PREDICTION_METRICS = {"mrr", SUPPORTED_HIT_AT_METRICS, "amri"} class ClassificationMetrics: """ object that compute metrics for classification tasks. @@ -158,16 +160,19 @@ def __init__(self, eval_metric_list=None): self.supported_metrics = SUPPORTED_LINK_PREDICTION_METRICS # This is the operator used to compare whether current value is better than the current best - self.metric_comparator = {} + self.metric_comparator: dict[str, Callable] = {} self.metric_comparator["mrr"] = operator.le + self.metric_comparator["amri"] = operator.le # This is the operator used to measure each metric performance - self.metric_function = {} + self.metric_function: dict[str, Callable[..., th.Tensor]] = {} self.metric_function["mrr"] = compute_mrr + self.metric_function["amri"] = compute_amri # This is the operator used to measure each metric performance in evaluation - self.metric_eval_function = {} + self.metric_eval_function: dict[str, Callable[..., th.Tensor]] = {} self.metric_eval_function["mrr"] = compute_mrr + self.metric_eval_function["amri"] = compute_amri if eval_metric_list: for eval_metric in eval_metric_list: @@ -190,20 +195,23 @@ def assert_supported_metric(self, metric): assert metric in self.supported_metrics, \ f"Metric {metric} not supported for link prediction" - def init_best_metric(self, metric): + def init_best_metric(self, metric: str): """ Return the initial value for the metric to keep track of the best metric. Parameters ---------- - metric: the metric to initialize + metric: str + the name of the metric to initialize Returns ------- - + float + An initial value for the metric. """ # Need to check if the given metric is supported first self.assert_supported_metric(metric) - return 0 + # The minimum value for AMRI is -1.0 so we init with that + return -1.0 if metric == "amri" else 0.0 def labels_to_one_hot(labels, total_labels): @@ -694,18 +702,82 @@ def compute_mae(pred, labels): diff = th.abs(pred.cpu() - labels.cpu()) return th.mean(diff).cpu().item() -def compute_mrr(ranking): - """ Get link prediction mrr metrics +def compute_mrr(ranking: th.Tensor) -> th.Tensor: + """ Get link prediction Mean Reciprocal Rank (MRR) metrics - Parameters - ---------- - ranking: - ranking of each positive edge + Parameters + ---------- + ranking: torch.Tensor + ranking of each positive edge - Returns - ------- - link prediction mrr metrics: tensor + Returns + ------- + th.Tensor + link prediction mrr metrics """ - logs = th.div(1.0, ranking) - metrics = th.tensor(th.div(th.sum(logs),len(logs))) + reciprocal_ranks = th.div(1.0, ranking) + metrics = th.tensor(th.div(th.sum(reciprocal_ranks), len(reciprocal_ranks))) return metrics + +def compute_amri(ranking: th.Tensor, candidate_sizes: th.Tensor) -> th.Tensor: + """Computes the Adjusted Mean Rank Index (AMRI) for the given ranking and candidate sizes. + + AMRI is a metric that evaluates the performance of link prediction models by considering both + the rank of the correct candidate and the number of candidates. It is calculated as: + + .. math:: + AMRI = 1 - \\frac{\\text{MR}-1}{\\mathbb{E}[\\text{MR}-1]} + + where MR is the mean rank, and `E[MR]` is the expected mean rank, which is used + to adjust for chance. E[MR] is defined as: + + .. math:: + \\mathbb{E}[\\text{MR}] = \\mathbb{E} \\left[ \\frac{1}{n} \\sum^n_{i=1}{r_i} \\right] + + Where :math:`r_i` is the rank the model assigns to the positive edge, + compared to the negative edges in the candidate list, and :math:`n` is the number of + candidate lists, one per positive edge. + + AMRI values will be in the :math:`[-1, 1]` range, where 1 corresponds + to optimal performance where each individual rank is 1. A value of 0 indicates + model performance similar to a model assigning random scores, or equal score + to every candidate. The value is negative if the model performs worse than the + constant-score model." + + For more details see https://arxiv.org/abs/2002.06914 + + Parameters + ---------- + ranking : torch.Tensor + ranking of each positive edge + candidate_sizes : th.Tensor + The size of each candidate list. If all candidate lists have + the same size this will be a single-value tensor. + + Returns + ------- + th.Tensor + A single-value Tensor with the AMRI metric. + + .. versionadded: 0.4.0 + """ + if candidate_sizes.shape[0] > 1: + assert ranking.shape[0] == candidate_sizes.shape[0], \ + ("ranking and candidate_sizes must have the same length, " + f"got {ranking.shape=} {candidate_sizes.shape=}" ) + assert th.all(ranking <= candidate_sizes).item(), \ + "all ranks must be <= candidate_sizes" + + # We use the simplified form of AMRI calculation + # 1 - \frac{MR-1}{E[MR-1]} = 1 - \frac{2*\sum_n{r-1}}{\sum_n{|S|}} + # where n is the number of evaluations (number of positive edges), + # r is the ranking of the positive edge in each ranked score list, + # and |S| is the edge candidate set size. + # See equation (8) in https://arxiv.org/abs/2002.06914 + nominator = 2 * th.sum(ranking - 1) + if candidate_sizes.shape[0] == 1: + denominator = candidate_sizes.item() * ranking.shape[0] + else: + denominator = th.sum(candidate_sizes) + + return 1 - th.div(nominator, denominator) diff --git a/python/graphstorm/eval/evaluator.py b/python/graphstorm/eval/evaluator.py index 8aa81cfccf..96cf1113cb 100644 --- a/python/graphstorm/eval/evaluator.py +++ b/python/graphstorm/eval/evaluator.py @@ -16,9 +16,11 @@ Evaluator for different tasks. """ -import warnings import abc +import warnings from statistics import mean +from typing import Any, Dict, Optional, Tuple + import torch as th from .eval_func import SUPPORTED_HIT_AT_METRICS @@ -188,7 +190,13 @@ class GSgnnLPRankingEvalInterface(): """ @abc.abstractmethod - def evaluate(self, val_rankings, test_rankings, total_iters): + def evaluate( + self, + val_rankings, + test_rankings, + total_iters, + **kwargs, + ) -> Tuple[Dict[str, th.Tensor], Dict[str, th.Tensor]]: """Evaluate Link Prediction results on validation and test sets. **Link Prediction** evaluators should provide the ranking of validation and test sets as @@ -202,6 +210,8 @@ def evaluate(self, val_rankings, test_rankings, total_iters): The rankings of testing edges for each edge type in the format of {etype: ranking}. total_iters: int The current iteration number. + kwargs: dict + Keyword arguments to pass downstream to metric calculation functions. Returns ----------- @@ -212,7 +222,7 @@ def evaluate(self, val_rankings, test_rankings, total_iters): """ @abc.abstractmethod - def compute_score(self, rankings, train=True): + def compute_score(self, rankings, train=True, **kwargs): """ Compute Link Prediction evaluation score. Ranking-based Link Prediction evaluators should provide ranking values as input @@ -224,6 +234,15 @@ def compute_score(self, rankings, train=True): Rankings of positive scores in the format of {etype: ranking} train: boolean If in model training. + kwargs: dict + Keyword arguments to pass downstream to the metric computation. + + Currently we support: + + candidate_sizes : dict of tensors + A mapping from edge type to the size of each candidate list + (positive + negative pairs). + If the tensor has a single element we use that as the size of all lists. Returns ------- @@ -882,7 +901,13 @@ def __init__(self, eval_frequency, self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) self._best_iter[metric] = 0 - def evaluate(self, val_rankings, test_rankings, total_iters): + def evaluate( + self, + val_rankings, + test_rankings, + total_iters, + **kwargs, + ): """ `GSgnnLinkPredictionTrainer` and `GSgnnLinkPredictionInferrer` will call this function to compute validation and test scores. @@ -896,6 +921,21 @@ def evaluate(self, val_rankings, test_rankings, total_iters): {etype: ranking}. total_iters: int The current iteration number. + kwargs: + Keyword arguments to pass downstream to the metric computation. + + Currently we support: + + val_candidate_sizes : torch.Tensor + A tensor containing the size of each candidate list (positive + negative pairs) + for each testing edge in the validation set. + If the tensor has a single element we use that as the size of all lists. + test_candidate_sizes : torch.Tensor + A tensor containing the size of each candidate list (positive + negative pairs) + for every edge in the test set. + If the tensor has a single element we use that as the size of all lists. + + ..versionadded:: 0.4.0 Returns ----------- @@ -908,14 +948,20 @@ def evaluate(self, val_rankings, test_rankings, total_iters): """ with th.no_grad(): if test_rankings is not None: - test_score = self.compute_score(test_rankings) + test_score = self.compute_score( + test_rankings, + candidate_sizes=kwargs.pop("test_candidate_sizes", None), + ) else: test_score = {} for metric in self.metric_list: test_score[metric] = "N/A" # Dummy if val_rankings is not None: - val_score = self.compute_score(val_rankings) + val_score = self.compute_score( + val_rankings, + candidate_sizes=kwargs.pop("val_candidate_sizes", None), + ) if get_rank() == 0: for metric in self.metric_list: @@ -934,7 +980,12 @@ def evaluate(self, val_rankings, test_rankings, total_iters): return val_score, test_score - def compute_score(self, rankings, train=True): + def compute_score( + self, + rankings: Dict[str, th.Tensor], + train=True, + **kwargs + ): """ Compute evaluation score. Parameters @@ -943,6 +994,15 @@ def compute_score(self, rankings, train=True): Rankings of positive scores in the format of {etype: ranking}. train: boolean If in model training. + kwargs: dict + Keyword arguments to pass downstream to the metric computation. + + Currently we support: + + candidate_sizes : dict of tensors + A mapping from edge type to the the size of each candidate list + (positive + negative pairs). + If the tensor has a single element we use that as the size of all lists. Returns ------- @@ -954,19 +1014,30 @@ def compute_score(self, rankings, train=True): for _, rank in rankings.items(): ranking.append(rank) ranking = th.cat(ranking, dim=0) + sizes_list = [] + candidate_sizes: Optional[Dict[str, th.Tensor]] = kwargs.get("candidate_sizes", None) # compute ranking value for each metric - metrics = {} + metrics: Dict[str, th.Tensor] = {} for metric in self.metric_list: + if metric == "amri": + assert candidate_sizes, \ + f"candidate_sizes needs to have a value for AMRI, got {candidate_sizes=}." + for etype, _ in rankings.items(): + sizes_list.append(candidate_sizes[etype]) + candidate_sizes_tensor = th.cat(sizes_list, dim=0) + arg_tuple = (ranking, candidate_sizes_tensor) + else: + arg_tuple = (ranking,) if train: # training expects always a single number to be # returned and has a different (potentially) evaluation function - metrics[metric] = self.metrics_obj.metric_function[metric](ranking) + metrics[metric] = self.metrics_obj.metric_function[metric](*arg_tuple) else: # validation or testing may have a different # evaluation function, in our case the evaluation code # may return a dictionary with the metric values for each metric - metrics[metric] = self.metrics_obj.metric_eval_function[metric](ranking) + metrics[metric] = self.metrics_obj.metric_eval_function[metric](*arg_tuple) # When world size == 1, we do not need the barrier if get_world_size() > 1: @@ -974,7 +1045,7 @@ def compute_score(self, rankings, train=True): for _, metric_val in metrics.items(): th.distributed.all_reduce(metric_val) - return_metrics = {} + return_metrics: Dict[str, float] = {} for metric, metric_val in metrics.items(): return_metric = metric_val / get_world_size() return_metrics[metric] = return_metric.item() @@ -1037,7 +1108,13 @@ def __init__(self, eval_frequency, self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric) self._best_iter[metric] = 0 - def evaluate(self, val_rankings, test_rankings, total_iters): + def evaluate( + self, + val_rankings, + test_rankings, + total_iters, + **kwargs, + ): """ `GSgnnLinkPredictionTrainer` and `GSgnnLinkPredictionInferrer` will call this function to compute validation and test scores. @@ -1051,6 +1128,25 @@ def evaluate(self, val_rankings, test_rankings, total_iters): {etype: ranking}. total_iters: int The current iteration number. + kwargs: dict + Keyword arguments to pass downstream to metric calculation functions. + + Currently we support: + + val_candidate_sizes : dict of tensors + The size of each candidate list (positive + negative pairs) + in the validation set, in the format of {etype: size_tensor}. + If all candidate lists have the same size this + will be a single-value tensor per etype. + + test_candidate_sizes : dict of tensors + The size of each candidate list (positive + negative pairs) + in the test set, in the format of {etype: size_tensor}. + If all candidate lists have the same size this + will be a single-value tensor per etype. + + + ..versionadded:: 0.4.0 Returns ----------- @@ -1063,14 +1159,20 @@ def evaluate(self, val_rankings, test_rankings, total_iters): """ with th.no_grad(): if test_rankings is not None: - test_score = self.compute_score(test_rankings) + test_score = self.compute_score( + test_rankings, + candidate_sizes=kwargs.pop("test_candidate_sizes", None), + ) else: test_score = {} for metric in self.metric_list: test_score[metric] = "N/A" # Dummy if val_rankings is not None: - val_score = self.compute_score(val_rankings) + val_score = self.compute_score( + val_rankings, + candidate_sizes=kwargs.pop("val_candidate_sizes", None), + ) if get_rank() == 0: for metric in self.metric_list: @@ -1101,7 +1203,7 @@ def _get_major_score(self, score): major_score = score[self.major_etype] return major_score - def compute_score(self, rankings, train=True): + def compute_score(self, rankings, train=True, **kwargs): """ Compute per edge type evaluation score. Parameters @@ -1110,6 +1212,18 @@ def compute_score(self, rankings, train=True): Rankings of positive scores in the format of {etype: ranking}. train: boolean If in model training. + kwargs: dict + Keyword arguments to pass downstream to the metric computation. + + Currently we support: + + candidate_sizes: dict of tensors, optional + The size of each candidate list corresponding to each value in the + ``rankings``, in the format of {etype: sizes}. If a tensor for + an edge type has a single element we use that as the size of all + lists. + + ..versionadded:: 0.4.0 Returns ------- @@ -1121,16 +1235,19 @@ def compute_score(self, rankings, train=True): for etype, ranking in rankings.items(): # compute ranking value for each metric metrics = {} + candidate_sizes: Dict[str, th.Tensor] = kwargs.get("candidate_sizes", None) + etype_candidate_sizes = candidate_sizes[etype] if candidate_sizes is not None else None for metric in self.metric_list: + arg_tuple = (ranking, etype_candidate_sizes) if metric == "amri" else (ranking, ) if train: # training expects always a single number to be # returned and has a different (potentially) evaluation function - metrics[metric] = self.metrics_obj.metric_function[metric](ranking) + metrics[metric] = self.metrics_obj.metric_function[metric](*arg_tuple) else: # validation or testing may have a different # evaluation function, in our case the evaluation code # may return a dictionary with the metric values for each metric - metrics[metric] = self.metrics_obj.metric_eval_function[metric](ranking) + metrics[metric] = self.metrics_obj.metric_eval_function[metric](*arg_tuple) per_etype_metrics[etype] = metrics # When world size == 1, we do not need the barrier @@ -1241,7 +1358,7 @@ def __init__(self, eval_frequency, stacklevel=1 ) - def evaluate(self, val_rankings, test_rankings, total_iters): + def evaluate(self, val_rankings, test_rankings, total_iters, **kwargs): """ ``GSgnnLinkPredictionTrainer`` and ``GSgnnLinkPredictionInferrer`` will call this function to compute validation and test ``mrr`` scores. @@ -1291,7 +1408,7 @@ def evaluate(self, val_rankings, test_rankings, total_iters): return val_score, test_score - def compute_score(self, rankings, train=True): + def compute_score(self, rankings, train=True, **kwargs): """ Compute ``mrr`` evaluation score. Parameters @@ -1402,7 +1519,7 @@ def __init__(self, eval_frequency, stacklevel=1 ) - def evaluate(self, val_rankings, test_rankings, total_iters): + def evaluate(self, val_rankings, test_rankings, total_iters, **kwargs): """ ``GSgnnLinkPredictionTrainer`` and ``GSgnnLinkPredictionInferrer`` will call this function to compute validation and test ``mrr`` scores. @@ -1455,7 +1572,7 @@ def evaluate(self, val_rankings, test_rankings, total_iters): return val_score, test_score - def compute_score(self, rankings, train=True): + def compute_score(self, rankings, train=True, **kwargs): """ Compute per edge type ``mrr`` evaluation score. Parameters @@ -1599,7 +1716,7 @@ def __init__(self, eval_frequency, stacklevel=1 ) - def evaluate(self, val_rankings, test_rankings, total_iters): + def evaluate(self, val_rankings, test_rankings, total_iters, **kwargs): """ ``GSgnnLinkPredictionTrainer`` and ``GSgnnLinkPredictionInferrer`` will call this function to compute validation and test ``hit@k`` scores. @@ -1651,7 +1768,7 @@ def evaluate(self, val_rankings, test_rankings, total_iters): return val_score, test_score - def compute_score(self, rankings, train=True): + def compute_score(self, rankings, train=True, **kwargs): """ Compute ``hit@k`` evaluation score. Parameters @@ -1760,7 +1877,7 @@ def __init__(self, eval_frequency, stacklevel=1 ) - def evaluate(self, val_rankings, test_rankings, total_iters): + def evaluate(self, val_rankings, test_rankings, total_iters, **kwargs): """ ``GSgnnLinkPredictionTrainer`` and ``GSgnnLinkPredictionInferrer`` will call this function to compute validation and test ``hit@k`` scores. @@ -1852,7 +1969,7 @@ class initialization. If using the default ``major_etype``, the rank will be self._val_perf_rank_list.append(val_score) return rank - def compute_score(self, rankings, train=True): + def compute_score(self, rankings, train=True, **kwargs): """ Compute per edge type ``hit@k`` evaluation score. Parameters @@ -1901,16 +2018,19 @@ def compute_score(self, rankings, train=True): return return_metrics -class GSgnnMultiTaskEvalInterface(): +class GSgnnMultiTaskEvalInterface(abc.ABC): """ Interface for multi-task evaluation The interface set one abstract method """ @abc.abstractmethod - def evaluate(self, val_results, test_results, total_iters): - """Evaluate validation and test sets for Prediciton tasks - - GSgnnTrainers will call this function to do evaluation in their eval() fuction. + def evaluate( + self, + val_results: Dict[str, Any], + test_results: Dict[str, Any], + total_iters: int + ): + """Evaluate multi-task training results, using task-specific evaluators for each task. Parameters ---------- @@ -1919,7 +2039,7 @@ def evaluate(self, val_results, test_results, total_iters): test_results: dict Testing results in a format of {task_id: test results} total_iters: int - The current interation number. + The current iteration number. Returns ----------- @@ -2064,7 +2184,12 @@ def best_iter_num(self): def val_perf_rank_list(self): raise RuntimeError("GSgnnMultiTaskEvaluator.val_perf_rank_list not supported") - def evaluate(self, val_results, test_results, total_iters): + def evaluate( + self, + val_results: Dict[str, Any], + test_results: Dict[str, Any], + total_iters: int, + ): eval_tasks = {} val_scores = {} test_scores = {} @@ -2099,12 +2224,21 @@ def evaluate(self, val_results, test_results, total_iters): val_score, test_score = task_evaluator.evaluate( val_preds, test_preds, val_labels, test_labels, total_iters) elif isinstance(task_evaluator, GSgnnLPRankingEvalInterface): - val_rankings = eval_task[0] - test_rankings = eval_task[1] + val_rankings_and_lengths = eval_task[0] + test_rankings_and_lengths = eval_task[1] val_score, test_score = task_evaluator.evaluate( - val_rankings, test_rankings, total_iters) + val_rankings_and_lengths[0], + test_rankings_and_lengths[0], + total_iters, + val_candidate_sizes=val_rankings_and_lengths[1], + test_candidate_sizes=test_rankings_and_lengths[1], + ) else: - raise TypeError("Unknown evaluator") + raise RuntimeError( + f"Unknown evaluator type: {type(task_evaluator)}. " + "Evaluators need to implement either GSgnnPredictionEvalInterface " + "or GSgnnLPRankingEvalInterface" + ) val_scores[task_id] = val_score test_scores[task_id] = test_score diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index 722ba6de42..9121a6e8ec 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -49,8 +49,7 @@ BUILTIN_CLASS_LOSS_FOCAL) from .eval.eval_func import ( SUPPORTED_HIT_AT_METRICS, - SUPPORTED_LINK_PREDICTION_METRICS, -) + SUPPORTED_LINK_PREDICTION_METRICS) from .model.embed import GSNodeEncoderInputLayer from .model.lm_embed import GSLMNodeEncoderInputLayer, GSPureLMNodeInputLayer from .model.rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer @@ -1210,9 +1209,9 @@ def create_lp_evaluator(config): assert all( (x.startswith(SUPPORTED_HIT_AT_METRICS) or x in SUPPORTED_LINK_PREDICTION_METRICS) for x in config.eval_metric), ( - "Invalid LP evaluation metrics. " - "GraphStorm only supports MRR and Hit@K metrics for link prediction." - ) + "Invalid LP evaluation metrics. " + f"GraphStorm only supports {SUPPORTED_LINK_PREDICTION_METRICS} as metrics " + f"for link prediction, got {config.eval_metric}") if config.report_eval_per_type: return GSgnnPerEtypeLPEvaluator(eval_frequency=config.eval_frequency, diff --git a/python/graphstorm/inference/graphstorm_infer.py b/python/graphstorm/inference/graphstorm_infer.py index eb4623efc9..7b8647282d 100644 --- a/python/graphstorm/inference/graphstorm_infer.py +++ b/python/graphstorm/inference/graphstorm_infer.py @@ -17,6 +17,7 @@ """ from ..tracker import GSSageMakerTaskTracker + class GSInferrer(): """ Generic GSgnn Inferrer. @@ -25,7 +26,7 @@ class GSInferrer(): model : GSgnnModel This model could be one of the internal GraphStorm GNN models, i.e., ``GSgnnNodeModel``, ``GSgnnEdgeModel``, ``GSgnnLinkPredictionModel``, or a model - class that inherit them. + class that inherit them. For customized GNN models, they should be the concrete implementation that inherits one of the ``GSgnnNodeModelBase``, ``GSgnnEdgeModelBase``, and ``GSgnnLinkPredictionModelBase`` classes. diff --git a/python/graphstorm/inference/lp_infer.py b/python/graphstorm/inference/lp_infer.py index 44bd986bd8..f6225f350c 100644 --- a/python/graphstorm/inference/lp_infer.py +++ b/python/graphstorm/inference/lp_infer.py @@ -18,6 +18,7 @@ import time from .graphstorm_infer import GSInferrer +from ..eval.evaluator import GSgnnLPRankingEvalInterface from ..model.utils import save_full_node_embeddings as save_gsgnn_embeddings from ..model.utils import save_relation_embeddings from ..model.edge_decoder import LinkPredictMultiRelationLearnableDecoder @@ -103,8 +104,16 @@ def infer(self, data, loader, save_embed_path, if self.evaluator is not None: test_start = time.time() - test_rankings = lp_mini_batch_predict(self._model, embs, loader, device) - val_score, test_score = self.evaluator.evaluate(None, test_rankings, 0) + test_rankings, test_lengths = lp_mini_batch_predict( + self._model, embs, loader, device, return_batch_lengths=True) + assert isinstance(self.evaluator, GSgnnLPRankingEvalInterface) + val_score, test_score = self.evaluator.evaluate( + None, + test_rankings, + 0, + val_candidate_sizes=None, + test_candidate_sizes=test_lengths, + ) sys_tracker.check('run evaluation') if get_rank() == 0: self.log_print_metrics(val_score=val_score, diff --git a/python/graphstorm/inference/mt_infer.py b/python/graphstorm/inference/mt_infer.py index df1cc01e3f..69394f5199 100644 --- a/python/graphstorm/inference/mt_infer.py +++ b/python/graphstorm/inference/mt_infer.py @@ -18,12 +18,16 @@ import os import time import logging +from typing import Any, Dict, Optional + import torch as th from ..config import (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION) +from ..dataloading import GSgnnMultiTaskDataLoader +from ..eval.evaluator import GSgnnMultiTaskEvaluator from .graphstorm_infer import GSInferrer from ..model.utils import save_full_node_embeddings as save_gsgnn_embeddings from ..model.utils import (save_node_prediction_results, @@ -54,10 +58,10 @@ class GSgnnMultiTaskLearningInferrer(GSInferrer): # pylint: disable=unused-argument def infer(self, data, - predict_test_loader=None, - lp_test_loader=None, - recon_nfeat_test_loader=None, - recon_efeat_test_loader=None, + predict_test_loader: Optional[GSgnnMultiTaskDataLoader] = None, + lp_test_loader: Optional[GSgnnMultiTaskDataLoader] = None, + recon_nfeat_test_loader: Optional[GSgnnMultiTaskDataLoader] = None, + recon_efeat_test_loader: Optional[GSgnnMultiTaskDataLoader] = None, save_embed_path=None, save_prediction_path=None, use_mini_batch_infer=False, @@ -209,7 +213,8 @@ def gen_embs(edge_mask=None): # 2. node feature reconstruction (as it has the chance # to reuse the node embeddings generated at the beginning) # 3. link prediction. - pre_results = {} + pre_results: Dict[str, Any] = {} + test_lengths = None if predict_test_loader is not None: # compute prediction results for node classification, # node regressoin, edge classification @@ -309,13 +314,20 @@ def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs): inplace=True) decoder = model.task_decoders[task_info.task_id] - ranking = run_lp_mini_batch_predict(decoder, lp_test_embs, dataloader, device) - pre_results[task_info.task_id] = ranking + ranking, test_lengths = run_lp_mini_batch_predict( + decoder, lp_test_embs, dataloader, device, return_batch_lengths=True) + pre_results[task_info.task_id] = (ranking, test_lengths) if do_eval: test_start = time.time() + assert isinstance(self.evaluator, GSgnnMultiTaskEvaluator) + val_score, test_score = self.evaluator.evaluate( - pre_results, pre_results, 0) + pre_results, + pre_results, + 0, + ) + sys_tracker.check('run evaluation') if get_rank() == 0: self.log_print_metrics(val_score=val_score, diff --git a/python/graphstorm/model/edge_decoder.py b/python/graphstorm/model/edge_decoder.py index a276d540e0..5065283991 100644 --- a/python/graphstorm/model/edge_decoder.py +++ b/python/graphstorm/model/edge_decoder.py @@ -17,6 +17,8 @@ """ import abc import logging +from typing import Dict, Tuple, Union + import numpy as np import torch as th from torch import nn @@ -909,7 +911,58 @@ def predict_proba(self, g, h, e_h): return out ##################### Link Prediction Decoders ####################### -class LinkPredictNoParamDecoder(GSLayerNoParam): +class LinkPredictionTestScoreInterface(abc.ABC): + """ Mixin class for link prediction test score computation + """ + + @abc.abstractmethod + def calc_test_scores( + self, + emb: Dict[str, th.Tensor], + pos_neg_tuple: Dict[Tuple[str, str, str], th.Tensor], + neg_sample_type: str, + device: Union[int, th.device], + ) -> Dict[Tuple[str, str, str], Tuple[th.Tensor, th.Tensor]]: + """ Compute scores for positive edges and negative edges. + + Parameters + ---------- + emb: dict of Tensor + Node embeddings in the format of {ntype: emb}. + pos_neg_tuple: dict of tuple + Positive and negative edges stored in a dict of tuple in the format of + {("src_ntype1", "etype1", "dst_ntype1" ): (pos_src_idx, neg_src_idx, + pos_dst_idx, neg_dst_idx)}. + + The `pos_src_idx` represents the postive source node indexes in the format + of Torch.Tensor. The `neg_src_idx` represents the negative source node indexes + in the format of Torch.Tensor. The `pos_dst_idx` represents the postive destination + node indexes in the format of Torch.Tensor. The `neg_dst_idx` represents the + negative destination node indexes in the format of Torch.Tensor. + + We define positive and negative edges as: + + * The positive edges: (pos_src_idx, pos_dst_idx) + * The negative edges: (pos_src_idx, neg_dst_idx) and + (neg_src_idx, pos_dst_idx) + + neg_sample_type: str + Describe how negative samples are sampled. There are two options: + + * ``Uniform``: For each positive edge, we sample K negative edges. + * ``Joint``: For one batch of positive edges, we sample K negative edges. + + device: th.device + Device used to compute scores. + + Returns + -------- + scores: dict of tuple + Return a dictionary of edge type's positive scores and negative scores in the format + of {(src_ntype, etype, dst_ntype): (pos_scores, neg_scores)} + """ + +class LinkPredictNoParamDecoder(GSLayerNoParam, LinkPredictionTestScoreInterface): """ Abstract class for Link prediction decoder without trainable parameters """ @@ -936,7 +989,7 @@ def forward(self, g, h, e_h=None): in the input graph. """ -class LinkPredictLearnableDecoder(GSLayer): +class LinkPredictLearnableDecoder(GSLayer, LinkPredictionTestScoreInterface): """ Abstract class for Link prediction decoder with trainable parameters """ diff --git a/python/graphstorm/model/lp_gnn.py b/python/graphstorm/model/lp_gnn.py index c443280438..5d27463903 100644 --- a/python/graphstorm/model/lp_gnn.py +++ b/python/graphstorm/model/lp_gnn.py @@ -16,8 +16,14 @@ GNN model for link prediction in GraphStorm. """ import abc +from collections import defaultdict +from typing import Dict, List, Tuple, Union + import torch as th + +from ..dataloading.dataloading import GSgnnEdgeDataLoader from .gnn import GSgnnModel, GSgnnModelBase +from ..model.edge_decoder import LinkPredictionTestScoreInterface from .utils import normalize_node_embs from ..eval.utils import calc_ranking @@ -137,8 +143,9 @@ def forward(self, blocks, pos_graph, # weighted addition to the total loss return pred_loss + alpha_l2norm * reg_loss -def lp_mini_batch_predict(model, emb, loader, device): - """ Perform mini-batch prediction. +def lp_mini_batch_predict(model, emb, loader, device, return_batch_lengths=False): + """ Perform mini-batch prediction for link prediction and return rankings + of true edges in the predicted scores. This function follows full-graph GNN embedding inference. After having the GNN embeddings, we need to perform mini-batch @@ -157,19 +164,31 @@ def lp_mini_batch_predict(model, emb, loader, device): The GraphStorm dataloader device: th.device Device used to compute test scores + return_batch_lengths: bool, default False + Whether to return the lengths of each batch of edges for each ranking value. Returns ------- - rankings: dict of tensors + rankings: dict[str, torch.Tensor], if `return_batch_lengths` was False Rankings of positive scores in format of {etype: ranking} + rankings, batch_lengths: tuple[dict, dict], if `return_batch_lengths` was True + A tuple of rankings of positive scores in format of {etype: ranking}, + and the corresponding batch lengths for each ranking value. """ decoder = model.decoder return run_lp_mini_batch_predict(decoder, emb, loader, - device) - -def run_lp_mini_batch_predict(decoder, emb, loader, device): + device, + return_batch_lengths) + +def run_lp_mini_batch_predict( + decoder, + emb: Dict[str, th.Tensor], + loader: GSgnnEdgeDataLoader, + device: Union[th.device, int], + return_batch_lengths=False, + ): """ Perform mini-batch link prediction with the given decoder. This function follows full-graph GNN embedding inference. @@ -187,16 +206,24 @@ def run_lp_mini_batch_predict(decoder, emb, loader, device): The GNN embeddings loader : GSgnnEdgeDataLoader The GraphStorm dataloader - device: th.device + device: th.device or int Device used to compute test scores + return_batch_lengths: bool, default False + Whether to return the candidate list sizes of each ranking value. Returns ------- - rankings: dict of tensors + rankings: dict[tuple, torch.Tensor], if `return_batch_lengths` was False Rankings of positive scores in format of {etype: ranking} + rankings, batch_lengths: tuple[dict, dict], if `return_batch_lengths` was True + A tuple of rankings of positive scores in format of {etype: ranking}, + and the corresponding batch lengths for each ranking value. """ with th.no_grad(): - ranking = {} + ranking: Dict[Tuple, List[th.Tensor]] = defaultdict(list) + batch_lengths: Dict[Tuple, List[th.Tensor]] = defaultdict(list) + assert isinstance(decoder, LinkPredictionTestScoreInterface), \ + f"The decoder must implement LinkPredictionTestScoreInterface, got {decoder=}" for pos_neg_tuple, neg_sample_type in loader: score = \ decoder.calc_test_scores( @@ -205,12 +232,27 @@ def run_lp_mini_batch_predict(decoder, emb, loader, device): # We do not concatenate rankings into a single # ranking tensor to avoid unnecessary data copy. pos_score, neg_score = s - if canonical_etype in ranking: - ranking[canonical_etype].append(calc_ranking(pos_score, neg_score)) - else: - ranking[canonical_etype] = [calc_ranking(pos_score, neg_score)] - - rankings = {} + assert pos_score.shape[0] == neg_score.shape[0], \ + "There should be as many negative lists as there are positive examples" + score_ranking = calc_ranking(pos_score, neg_score) + ranking[canonical_etype].append(score_ranking) + # Set the number of candidates for each positive example + # (equal to neg_score.shape[0]) + # which will be the number of negatives (equal to neg_score.shape[1]) + # plus one for the positive example + lengths_tensor = th.tensor(neg_score.shape[0] * [neg_score.shape[1] + 1]) + # Ensure rankings and lengths are on the same device + lengths_tensor = lengths_tensor.to(score_ranking.device) + + batch_lengths[canonical_etype].append(lengths_tensor) + + rankings: Dict[Tuple, th.Tensor] = {} + batch_length_tensors: Dict[Tuple, th.Tensor] = {} for canonical_etype, rank in ranking.items(): rankings[canonical_etype] = th.cat(rank, dim=0) + etype_lengths = batch_lengths[canonical_etype] + batch_length_tensors[canonical_etype] = th.cat(etype_lengths, dim=0) + + if return_batch_lengths: + return rankings, batch_length_tensors return rankings diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index 696ef93283..95d83ee41a 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -58,8 +58,6 @@ def main(config_args): model_layer_to_load=config.restore_model_layers) trainer.setup_device(device=get_device()) if not config.no_validation: - # TODO(zhengda) we need to refactor the evaluator. - # Currently, we only support mrr evaluator = gs.create_lp_evaluator(config) trainer.setup_evaluator(evaluator) val_idxs = train_data.get_edge_val_set(config.eval_etype) diff --git a/python/graphstorm/trainer/gsgnn_trainer.py b/python/graphstorm/trainer/gsgnn_trainer.py index e9673c7904..3c5b89beae 100644 --- a/python/graphstorm/trainer/gsgnn_trainer.py +++ b/python/graphstorm/trainer/gsgnn_trainer.py @@ -17,7 +17,9 @@ """ import os import logging +from typing import Optional +from ..eval.evaluator import GSgnnBaseEvaluator from ..model import GSOptimizer from ..model import GSgnnModel, GSgnnModelBase from ..model.utils import TopKList @@ -344,7 +346,7 @@ def can_do_validation(self, val_dataloader): return True @property - def evaluator(self): + def evaluator(self) -> Optional[GSgnnBaseEvaluator]: """ The evaluator associated with the trainer. """ return self._evaluator diff --git a/python/graphstorm/trainer/lp_trainer.py b/python/graphstorm/trainer/lp_trainer.py index a6bd8b27fa..05aaf0a38a 100644 --- a/python/graphstorm/trainer/lp_trainer.py +++ b/python/graphstorm/trainer/lp_trainer.py @@ -22,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel import dgl +from ..eval.evaluator import GSgnnLPRankingEvalInterface from ..model.lp_gnn import GSgnnLinkPredictionModelInterface from ..model.lp_gnn import lp_mini_batch_predict from ..model.gnn_with_reconstruct import GNNEncoderWithReconstructedEmbed @@ -96,7 +97,7 @@ def fit(self, train_loader, num_epochs, This function performs the training for the given link prediction model. It iterates over the training batches provided by the ``train_loader`` to compute the loss, and then performs the backward steps using trainer's - own optimizer. + own optimizer. If an evaluator and a validation dataloader are added to this trainer, during training, the trainer will perform model evaluation in three cases: @@ -122,7 +123,7 @@ def fit(self, train_loader, num_epochs, save model checkpoints. Default: None. save_model_frequency: int - The number of iterations to train the model before saving a model checkpoint. + The number of iterations to train the model before saving a model checkpoint. Default: -1, meaning only save model after each epoch. save_perf_results_path: str The path of the file where the performance results are saved. Default: None. @@ -346,7 +347,7 @@ def eval(self, model, data, val_loader, test_loader, Returns ------- val_score: dict - Validation scores of differnet metrics in the format of {metric: val_score}. + Validation scores of different metrics in the format of {metric: val_score}. """ test_start = time.time() sys_tracker.check('before prediction') @@ -361,16 +362,29 @@ def eval(self, model, data, val_loader, test_loader, edge_mask=edge_mask_for_gnn_embeddings, task_tracker=self.task_tracker) sys_tracker.check('compute embeddings') - val_scores = lp_mini_batch_predict(model, emb, val_loader, self.device) \ - if val_loader is not None else None + if val_loader is not None: + val_rankings, val_lengths = lp_mini_batch_predict( + model, emb, val_loader, self.device, return_batch_lengths=True) + else: + val_rankings, val_lengths = None, None sys_tracker.check('after_val_score') if test_loader is not None: - test_scores = lp_mini_batch_predict(model, emb, test_loader, self.device) + test_rankings, test_lengths = lp_mini_batch_predict( + model, emb, test_loader, self.device, return_batch_lengths=True) else: - test_scores = None + test_rankings, test_lengths = None, None sys_tracker.check('after_test_score') + assert self.evaluator is not None, \ + "Evaluator needs to be setup, use trainer.setup_evaluator(evaluator)" + assert isinstance(self.evaluator, GSgnnLPRankingEvalInterface), \ + f"Evaluator needs to implement GSgnnLPRankingEvalInterface, got {type(self.evaluator)}" val_score, test_score = self.evaluator.evaluate( - val_scores, test_scores, total_steps) + val_rankings, + test_rankings, + total_steps, + val_candidate_sizes=val_lengths, + test_candidate_sizes=test_lengths, + ) sys_tracker.check('evaluate validation/test') model.train() diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index 90242ae012..1217ea2e5b 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -30,6 +30,7 @@ BUILTIN_TASK_LINK_PREDICTION, BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, BUILTIN_TASK_RECONSTRUCT_EDGE_FEAT) +from ..eval.evaluator import GSgnnMultiTaskEvaluator from ..model import (do_full_graph_inference, do_mini_batch_inference, GSgnnModelBase, GSgnnModel, @@ -732,25 +733,34 @@ def gen_embs(edge_mask=None): inplace=True) decoder = model.task_decoders[task_info.task_id] - val_scores = run_lp_mini_batch_predict(decoder, - lp_test_embs, - lp_val_loader, - self.device) \ - if lp_val_loader is not None else None - test_scores = run_lp_mini_batch_predict(decoder, - lp_test_embs, - lp_test_loader, - self.device) \ - if lp_test_loader is not None else None + if lp_val_loader is not None: + val_rankings_and_lengths = run_lp_mini_batch_predict( + decoder, + lp_test_embs, + lp_val_loader, + self.device, + return_batch_lengths=True) + else: + val_rankings_and_lengths = (None, None) + + if lp_test_loader is not None: + test_rankings_and_lengths = run_lp_mini_batch_predict( + decoder, + lp_test_embs, + lp_test_loader, + self.device, + return_batch_lengths=True) + else: + test_rankings_and_lengths = (None, None) if val_results is not None: - val_results[task_info.task_id] = val_scores + val_results[task_info.task_id] = val_rankings_and_lengths else: - val_results = {task_info.task_id: val_scores} + val_results = {task_info.task_id: val_rankings_and_lengths} if test_results is not None: - test_results[task_info.task_id] = test_scores + test_results[task_info.task_id] = test_rankings_and_lengths else: - test_results = {task_info.task_id: test_scores} + test_results = {task_info.task_id: test_rankings_and_lengths} if len(nfeat_recon_tasks) > 0: def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs): @@ -811,6 +821,8 @@ def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs): sys_tracker.check('after_test_score') + assert isinstance(self.evaluator, GSgnnMultiTaskEvaluator), \ + "Evaluator must be a GSgnnMultiTaskEvaluator" val_score, test_score = self.evaluator.evaluate( val_results, test_results, total_steps) sys_tracker.check('evaluate validation/test') diff --git a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh index 6de93d6c3c..439b7bf9ff 100644 --- a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh @@ -132,8 +132,8 @@ fi rm /tmp/train_log.txt -echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, exclude_training_targets: true, save model, eval_metric [hit_at_1 mrr hit_at_10]" -python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_dot/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_dot/emb/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --eval-metric hit_at_1 mrr hit_at_10 +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, exclude_training_targets: true, save model, eval_metric [hit_at_1 mrr amri hit_at_10]" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_dot/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_dot/emb/ --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --eval-metric hit_at_1 mrr amri hit_at_10 error_and_exit $? @@ -201,6 +201,20 @@ then exit -1 fi +cnt=$(grep -c "| Test amri" /tmp/train_log.txt) +if test $cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Test amri" + exit 1 +fi + +bst_cnt=$(grep -c "Best Validation amri" /tmp/train_log.txt) +if test $bst_cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have Best Validation amri" + exit 1 +fi + cnt=$(grep "Validation mrr" /tmp/train_log.txt | wc -l) if test $cnt -lt $bst_cnt then diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index 00e9c27338..77ef1f749d 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -5,7 +5,6 @@ service ssh restart DGL_HOME=/root/dgl GS_HOME=$(pwd) NUM_TRAINERS=4 -NUM_INFO_TRAINERS=2 NUM_INFERs=2 export PYTHONPATH=$GS_HOME/python/ cd $GS_HOME/training_scripts/gsgnn_mt @@ -417,11 +416,12 @@ then exit -1 fi -cnt=$(ls -l /data/gsgnn_mt/emb/ | wc -l) -cnt=$[cnt - 1] +cnt=$(find /data/gsgnn_mt/emb/ -maxdepth 1 -type d | wc -l) +cnt=$(($cnt - 1)) if test $cnt != 2 then echo "The number of saved embs $cnt is not equal to 2 (for movie and user)." + exit 1 fi echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch load from saved model and train" diff --git a/tests/unit-tests/test_eval_func.py b/tests/unit-tests/test_eval_func.py index 007432ffb0..5996f8a889 100644 --- a/tests/unit-tests/test_eval_func.py +++ b/tests/unit-tests/test_eval_func.py @@ -13,20 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. """ -import torch as th import inspect + +import pytest +import torch as th from numpy.testing import assert_almost_equal + from graphstorm.eval.eval_func import (eval_roc_auc, eval_acc) -from graphstorm.eval.eval_func import (compute_mse, - compute_rmse, - compute_roc_auc, - compute_f1_score, - compute_precision_recall_auc, - compute_per_class_roc_auc, - compute_hit_at_classification, - compute_hit_at_link_prediction) +from graphstorm.eval.eval_func import ( + compute_amri, + compute_mse, + compute_rmse, + compute_roc_auc, + compute_f1_score, + compute_precision_recall_auc, + compute_per_class_roc_auc, + compute_hit_at_classification, + compute_hit_at_link_prediction) from graphstorm.eval.eval_func import ClassificationMetrics, LinkPredictionMetrics def test_compute_mse(): @@ -512,12 +517,14 @@ def test_compute_hit_at_classification(): assert hit_at == 4 def test_LinkPredictionMetrics(): - eval_metric_list = ["mrr", "hit_at_5", "hit_at_10"] + eval_metric_list = ["mrr", "hit_at_5", "hit_at_10", "amri"] metric = LinkPredictionMetrics(eval_metric_list) assert "mrr" in metric.metric_comparator assert "mrr" in metric.metric_function assert "mrr" in metric.metric_eval_function + assert "amri" in metric.metric_eval_function + assert "hit_at_5" in metric.metric_comparator assert "hit_at_5" in metric.metric_function @@ -532,16 +539,12 @@ def test_LinkPredictionMetrics(): assert signature.parameters["k"].default == 10 metric.assert_supported_metric("mrr") + metric.assert_supported_metric("amri") metric.assert_supported_metric("hit_at_5") metric.assert_supported_metric("hit_at_10") - pass_assert = False - try: + with pytest.raises(AssertionError): metric.assert_supported_metric("hit_at_ten") - pass_assert = True - except: - pass_assert = False - assert not pass_assert def test_compute_hit_at_link_prediction(): preds = 1 - th.arange(100) / 120 # preds for all positive and negative samples @@ -567,6 +570,49 @@ def test_compute_hit_at_link_prediction(): hit_at = compute_hit_at_link_prediction(ranking, 200) assert hit_at == 7 / 7 +def test_compute_amri(): + # Compute amri when candidate lists vary in size + ranks = th.tensor([4, 1, 4, 5, 1, 5, 5, 1, 2, 3]) + candidate_lists = th.tensor([10, 12, 7, 5, 4, 10, 12, 10, 20, 10]) + + # Use the definition from the paper to verify the values + def amri_definition(rankings, candidate_tensor) -> float: + mr = th.sum(rankings) / rankings.shape[0] + emr = (1 / (2 * candidate_tensor.shape[0])) * th.sum(candidate_tensor) + expected_amri = 1 - ((mr - 1) / emr) + return expected_amri + + actual_amri = compute_amri(ranks, candidate_lists) + expected_amri = amri_definition(ranks, candidate_lists) + assert_almost_equal( + actual_amri, + expected_amri, + decimal=6 + ) + + # Compute amri when all lists have the same size + candidate_size = th.tensor([10]) + actual_amri = compute_amri(ranks, candidate_size) + expected_amri = amri_definition(ranks, candidate_size) + assert_almost_equal( + actual_amri, + expected_amri, + decimal=4 + ) + + # amri should be 0 when all ranks are the expected rank plus one + ranks = th.tensor([10]*5) # MR is 10, MR-1 is 9 + candidate_size = th.tensor([18]*5) # E [MR] is (1/(2*5))*(18*5) = 9 + actual_amri = compute_amri(ranks, candidate_size) + expected_amri = 0 + assert_almost_equal( + actual_amri, + expected_amri, + decimal=2 + ) + + + if __name__ == '__main__': test_LinkPredictionMetrics() test_compute_hit_at_link_prediction() diff --git a/tests/unit-tests/test_inferrer.py b/tests/unit-tests/test_inferrer.py index 827692134b..e3f4c0505c 100644 --- a/tests/unit-tests/test_inferrer.py +++ b/tests/unit-tests/test_inferrer.py @@ -224,12 +224,13 @@ def mock_func_do_full_graph_inference(*args, **kwargs): "n2": None, } + lp_res = np.arange(5) + lp_length = np.array([5]) def mock_func_run_lp_mini_batch_predict(*args, **kwargs): - return lp_res + return lp_res, lp_length ntask_res = (np.arange(10), np.arange(10)) etask_res = (np.arange(20), np.arange(20)) - lp_res = np.arange(5) def mock_func_multi_task_mini_batch_predict(model, emb, dataloaders, task_infos, device, return_proba, return_label): assert len(emb) == 2 res = {} diff --git a/tests/unit-tests/test_trainer.py b/tests/unit-tests/test_trainer.py index 0166448b07..f54b9abebd 100644 --- a/tests/unit-tests/test_trainer.py +++ b/tests/unit-tests/test_trainer.py @@ -33,6 +33,7 @@ BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, BUILTIN_TASK_RECONSTRUCT_EDGE_FEAT) from graphstorm.dataloading import GSgnnData, GSgnnMultiTaskDataLoader +from graphstorm.eval.evaluator import GSgnnLPRankingEvalInterface, GSgnnMultiTaskEvaluator from graphstorm.tracker import GSSageMakerTaskTracker from graphstorm import create_builtin_node_gnn_model from graphstorm.trainer import GSgnnTrainer @@ -47,7 +48,7 @@ from graphstorm.dataloading import (GSgnnNodeDataLoader, GSgnnEdgeDataLoader, GSgnnLinkPredictionDataLoader) -from graphstorm.model import GSgnnMultiTaskModelInterface, GSgnnModel, GSgnnModelBase +from graphstorm.model import GSgnnMultiTaskModelInterface, GSgnnModel from numpy.testing import assert_equal from util import (DummyGSgnnEncoderModel, @@ -498,14 +499,14 @@ def test_mtask_prepare_lp_mini_batch(): assert_equal(input_nodes["n0"].numpy(), input_node_idx["n0"].numpy()) assert_equal(input_nodes["n1"].numpy(), input_node_idx["n1"].numpy()) -class MTaskCheckerEvaluator(): - def __init__(self, val_resluts, test_results, steps): - self._val_results = val_resluts - self._test_results = test_results - self._steps = steps +class MTaskCheckerEvaluator(GSgnnMultiTaskEvaluator): + def __init__(self, val_rankings, test_rankings, total_iters): + self._val_results = val_rankings + self._test_results = test_rankings + self._steps = total_iters - def evaluate(self, val_results, test_results, total_steps): - assert self._steps == total_steps + def evaluate(self, val_results, test_results, total_iters, **kwargs): + assert self._steps == total_iters def compare_results(target_res, check_res): assert len(target_res) == len(check_res) for task_id, target_r in target_res.items(): @@ -518,6 +519,16 @@ def compare_results(target_res, check_res): assert_equal(tr_1, cr_1) assert_equal(tr_2, cr_2) else: + # In case LP results also returned candidate list + # lengths, we check the lengths and values + if isinstance(check_r, tuple): + check_r, candidate_sizes = check_r + if candidate_sizes.shape[0] > 1: + assert check_r.shape[0] == candidate_sizes.shape[0], \ + ("ranking and candidate_sizes must have the same length, " + f"got {check_r.shape=} {candidate_sizes.shape=}" ) + assert th.all(check_r <= candidate_sizes).item(), \ + "all ranks must be <= candidate_sizes" assert_equal(target_r, check_r) if self._val_results is not None: