Skip to content

Commit

Permalink
Add Adjusted Mean Ranking Index metric for Link Prediction (#1061)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*

* Add Adjust Mean Rank Index LP metric. This metric is normalized by the
candidate list size, allowing for easier comparison between
datasets/models and negative edge sample counts.
* To get the list sizes we modify `run_lp_mini_batch_predict` and
`lp_mini_batch_predict` to _conditionally_ return a tuple of `rankings,
lengths` that allows us to calculate AMRI in the cases where it's
needed. The return type is determined by a new argument added, with a
default value, so the changes are backwards compatible, existing calls
to the two functions will function as before.
* Add `LinkPredictionTestScoreInterface` as a common ancestor to
`LinkPredictNoParamDecoder` and `LinkPredictLearnableDecoder`, this way
we can ensure at runtime that the decoder should implement the
`calc_test_scores` function that is used by the LP evaluator.
* Modify the `GSgnnLPRankingEvalInterface` `evaluate` function to add
optional `**kwargs`. In cases where we want to calculate a metric that
needs the candidate list lengths we use these kwargs to pass the
information.
* Modify the `def compute_score(self, rankings, train=True,
**kwargs):` to add optional kwargs, currently only used during AMRI calculation.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: xiang song(charlie.song) <[email protected]>
  • Loading branch information
thvasilo and classicsong authored Oct 24, 2024
1 parent ed0f698 commit 993a71f
Show file tree
Hide file tree
Showing 20 changed files with 571 additions and 146 deletions.
1 change: 0 additions & 1 deletion .github/workflow_scripts/pytest_check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ FORCE_CUDA=1 python3 -m pip install -e '.[test]' --no-build-isolation
python3 -m pip install pytest
sh ./tests/unit-tests/prepare_test_data.sh
export NCCL_IB_DISABLE=1; export NCCL_SHM_DISABLE=1; NCCL_NET=Socket NCCL_DEBUG=INFO python3 -m pytest -x ./tests/unit-tests -s

3 changes: 3 additions & 0 deletions inference_scripts/mt_infer/ml_nc_ec_er_lp_only_infer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ gsf:
reverse_edge_types_map:
- user,rating,rating-rev,movie
batch_size: 128 # will overwrite the global batch_size
eval_metric:
- "mrr"
- "amri"
- reconstruct_node_feat:
reconstruct_nfeat_name: "title"
target_ntype: "movie"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ gsf:
- "train_mask_field_lp"
- null # empty means there is no validation mask
- "test_mask_field_lp"
eval_metric:
- "mrr"
- "amri"
- reconstruct_node_feat:
reconstruct_nfeat_name: "title"
target_ntype: "movie"
Expand Down
112 changes: 92 additions & 20 deletions python/graphstorm/eval/eval_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
Evaluation functions
"""
import logging
import operator
from collections.abc import Callable
from enum import Enum
from functools import partial
import operator

import numpy as np
import torch as th
from sklearn.metrics import roc_auc_score
Expand All @@ -28,7 +30,7 @@
SUPPORTED_CLASSIFICATION_METRICS = {'accuracy', 'precision_recall', \
'roc_auc', 'f1_score', 'per_class_f1_score', 'per_class_roc_auc', SUPPORTED_HIT_AT_METRICS}
SUPPORTED_REGRESSION_METRICS = {'rmse', 'mse', 'mae'}
SUPPORTED_LINK_PREDICTION_METRICS = {"mrr", SUPPORTED_HIT_AT_METRICS}
SUPPORTED_LINK_PREDICTION_METRICS = {"mrr", SUPPORTED_HIT_AT_METRICS, "amri"}

class ClassificationMetrics:
""" object that compute metrics for classification tasks.
Expand Down Expand Up @@ -158,16 +160,19 @@ def __init__(self, eval_metric_list=None):
self.supported_metrics = SUPPORTED_LINK_PREDICTION_METRICS

# This is the operator used to compare whether current value is better than the current best
self.metric_comparator = {}
self.metric_comparator: dict[str, Callable] = {}
self.metric_comparator["mrr"] = operator.le
self.metric_comparator["amri"] = operator.le

# This is the operator used to measure each metric performance
self.metric_function = {}
self.metric_function: dict[str, Callable[..., th.Tensor]] = {}
self.metric_function["mrr"] = compute_mrr
self.metric_function["amri"] = compute_amri

# This is the operator used to measure each metric performance in evaluation
self.metric_eval_function = {}
self.metric_eval_function: dict[str, Callable[..., th.Tensor]] = {}
self.metric_eval_function["mrr"] = compute_mrr
self.metric_eval_function["amri"] = compute_amri

if eval_metric_list:
for eval_metric in eval_metric_list:
Expand All @@ -190,20 +195,23 @@ def assert_supported_metric(self, metric):
assert metric in self.supported_metrics, \
f"Metric {metric} not supported for link prediction"

def init_best_metric(self, metric):
def init_best_metric(self, metric: str):
"""
Return the initial value for the metric to keep track of the best metric.
Parameters
----------
metric: the metric to initialize
metric: str
the name of the metric to initialize
Returns
-------
float
An initial value for the metric.
"""
# Need to check if the given metric is supported first
self.assert_supported_metric(metric)
return 0
# The minimum value for AMRI is -1.0 so we init with that
return -1.0 if metric == "amri" else 0.0


def labels_to_one_hot(labels, total_labels):
Expand Down Expand Up @@ -694,18 +702,82 @@ def compute_mae(pred, labels):
diff = th.abs(pred.cpu() - labels.cpu())
return th.mean(diff).cpu().item()

def compute_mrr(ranking):
""" Get link prediction mrr metrics
def compute_mrr(ranking: th.Tensor) -> th.Tensor:
""" Get link prediction Mean Reciprocal Rank (MRR) metrics
Parameters
----------
ranking:
ranking of each positive edge
Parameters
----------
ranking: torch.Tensor
ranking of each positive edge
Returns
-------
link prediction mrr metrics: tensor
Returns
-------
th.Tensor
link prediction mrr metrics
"""
logs = th.div(1.0, ranking)
metrics = th.tensor(th.div(th.sum(logs),len(logs)))
reciprocal_ranks = th.div(1.0, ranking)
metrics = th.tensor(th.div(th.sum(reciprocal_ranks), len(reciprocal_ranks)))
return metrics

def compute_amri(ranking: th.Tensor, candidate_sizes: th.Tensor) -> th.Tensor:
"""Computes the Adjusted Mean Rank Index (AMRI) for the given ranking and candidate sizes.
AMRI is a metric that evaluates the performance of link prediction models by considering both
the rank of the correct candidate and the number of candidates. It is calculated as:
.. math::
AMRI = 1 - \\frac{\\text{MR}-1}{\\mathbb{E}[\\text{MR}-1]}
where MR is the mean rank, and `E[MR]` is the expected mean rank, which is used
to adjust for chance. E[MR] is defined as:
.. math::
\\mathbb{E}[\\text{MR}] = \\mathbb{E} \\left[ \\frac{1}{n} \\sum^n_{i=1}{r_i} \\right]
Where :math:`r_i` is the rank the model assigns to the positive edge,
compared to the negative edges in the candidate list, and :math:`n` is the number of
candidate lists, one per positive edge.
AMRI values will be in the :math:`[-1, 1]` range, where 1 corresponds
to optimal performance where each individual rank is 1. A value of 0 indicates
model performance similar to a model assigning random scores, or equal score
to every candidate. The value is negative if the model performs worse than the
constant-score model."
For more details see https://arxiv.org/abs/2002.06914
Parameters
----------
ranking : torch.Tensor
ranking of each positive edge
candidate_sizes : th.Tensor
The size of each candidate list. If all candidate lists have
the same size this will be a single-value tensor.
Returns
-------
th.Tensor
A single-value Tensor with the AMRI metric.
.. versionadded: 0.4.0
"""
if candidate_sizes.shape[0] > 1:
assert ranking.shape[0] == candidate_sizes.shape[0], \
("ranking and candidate_sizes must have the same length, "
f"got {ranking.shape=} {candidate_sizes.shape=}" )
assert th.all(ranking <= candidate_sizes).item(), \
"all ranks must be <= candidate_sizes"

# We use the simplified form of AMRI calculation
# 1 - \frac{MR-1}{E[MR-1]} = 1 - \frac{2*\sum_n{r-1}}{\sum_n{|S|}}
# where n is the number of evaluations (number of positive edges),
# r is the ranking of the positive edge in each ranked score list,
# and |S| is the edge candidate set size.
# See equation (8) in https://arxiv.org/abs/2002.06914
nominator = 2 * th.sum(ranking - 1)
if candidate_sizes.shape[0] == 1:
denominator = candidate_sizes.item() * ranking.shape[0]
else:
denominator = th.sum(candidate_sizes)

return 1 - th.div(nominator, denominator)
Loading

0 comments on commit 993a71f

Please sign in to comment.