Skip to content

Commit

Permalink
[New Feature] New evaluator to support MRR and Hits metrics simultane…
Browse files Browse the repository at this point in the history
…ously for link prediction. (#1043)

*Issue #, if available:*

*Description of changes:*
New LP evaluator to support both MRR and Hits metrics.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
Oxfordblue7 authored Sep 26, 2024
1 parent aae9724 commit 88ab782
Show file tree
Hide file tree
Showing 13 changed files with 989 additions and 213 deletions.
3 changes: 2 additions & 1 deletion python/graphstorm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from .gsf import create_builtin_edge_model
from .gsf import create_builtin_node_model
from .gsf import (create_task_decoder,
create_evaluator)
create_evaluator,
create_lp_evaluator)

from .gsf import (create_builtin_node_decoder,
create_builtin_edge_decoder,
Expand Down
2 changes: 2 additions & 0 deletions python/graphstorm/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from .evaluator import (GSgnnBaseEvaluator,
GSgnnPredictionEvalInterface,
GSgnnLPRankingEvalInterface,
GSgnnLPEvaluator,
GSgnnPerEtypeLPEvaluator,
GSgnnMrrLPEvaluator,
GSgnnPerEtypeMrrLPEvaluator,
GSgnnHitsLPEvaluator,
Expand Down
2 changes: 1 addition & 1 deletion python/graphstorm/eval/eval_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class LinkPredictionMetrics:
Parameters
----------
eval_metric_list: list of string
Evaluation metric(s) used during evaluation, for example, ["hit_at_10", "hit_at_100"].
Evaluation metric(s) used during evaluation, for example, ["mrr", "hit_at_1", "hit_at_100"].
"""
def __init__(self, eval_metric_list=None):
self.supported_metrics = SUPPORTED_LINK_PREDICTION_METRICS
Expand Down
413 changes: 392 additions & 21 deletions python/graphstorm/eval/evaluator.py

Large diffs are not rendered by default.

64 changes: 49 additions & 15 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
BUILTIN_LP_LOSS_CONTRASTIVELOSS,
BUILTIN_CLASS_LOSS_CROSS_ENTROPY,
BUILTIN_CLASS_LOSS_FOCAL)
from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS
from .model.embed import GSNodeEncoderInputLayer
from .model.lm_embed import GSLMNodeEncoderInputLayer, GSPureLMNodeInputLayer
from .model.rgcn_encoder import RelationalGCNEncoder, RelGraphConvLayer
Expand Down Expand Up @@ -112,6 +113,8 @@
from .eval import (GSgnnClassificationEvaluator,
GSgnnRegressionEvaluator,
GSgnnRconstructFeatRegScoreEvaluator,
GSgnnPerEtypeLPEvaluator,
GSgnnLPEvaluator,
GSgnnPerEtypeMrrLPEvaluator,
GSgnnMrrLPEvaluator)
from .trainer import (GSgnnLinkPredictionTrainer,
Expand Down Expand Up @@ -1118,23 +1121,21 @@ def create_evaluator(task_info):
config.early_stop_rounds,
config.early_stop_strategy)
elif task_info.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
assert len(config.eval_metric) == 1, \
"GraphStorm doees not support computing multiple metrics at the same time for link prediction tasks."
if config.report_eval_per_type:
return GSgnnPerEtypeMrrLPEvaluator(
eval_frequency=config.eval_frequency,
major_etype=config.model_select_etype,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
return GSgnnPerEtypeLPEvaluator(eval_frequency=config.eval_frequency,
eval_metric_list=config.eval_metric,
major_etype=config.model_select_etype,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
else:
return GSgnnMrrLPEvaluator(
eval_frequency=config.eval_frequency,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
return GSgnnLPEvaluator(eval_frequency=config.eval_frequency,
eval_metric_list=config.eval_metric,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
elif task_info.task_type in [BUILTIN_TASK_RECONSTRUCT_NODE_FEAT]:
return GSgnnRconstructFeatRegScoreEvaluator(
config.eval_frequency,
Expand All @@ -1144,3 +1145,36 @@ def create_evaluator(task_info):
config.early_stop_rounds,
config.early_stop_strategy)
return None

def create_lp_evaluator(config):
""" Create LP specific evaluator.
Parameters
----------
config: GSConfig
Configuration.
Return
------
Evaluator: A link prediction evaluator
"""
assert all((x.startswith(SUPPORTED_HIT_AT_METRICS) or x == 'mrr') for x in
config.eval_metric), (
"Invalid LP evaluation metrics. "
"GraphStorm only supports MRR and Hit@K metrics for link prediction.")

if config.report_eval_per_type:
return GSgnnPerEtypeLPEvaluator(eval_frequency=config.eval_frequency,
eval_metric_list=config.eval_metric,
major_etype=config.model_select_etype,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
else:
return GSgnnLPEvaluator(eval_frequency=config.eval_frequency,
eval_metric_list=config.eval_metric,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
7 changes: 3 additions & 4 deletions python/graphstorm/inference/lp_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,11 @@ def infer(self, data, loader, save_embed_path,
if self.evaluator is not None:
test_start = time.time()
test_rankings = lp_mini_batch_predict(self._model, embs, loader, device)
# TODO: to refactor the names
val_mrr, test_mrr = self.evaluator.evaluate(None, test_rankings, 0)
val_score, test_score = self.evaluator.evaluate(None, test_rankings, 0)
sys_tracker.check('run evaluation')
if get_rank() == 0:
self.log_print_metrics(val_score=val_mrr,
test_score=test_mrr,
self.log_print_metrics(val_score=val_score,
test_score=test_score,
dur_eval=time.time() - test_start,
total_steps=0)

Expand Down
114 changes: 68 additions & 46 deletions python/graphstorm/model/edge_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,29 +1116,33 @@ def get_relembs(self):
return self._w_relation.weight, self.etype2rid

class LinkPredictRotatEDecoder(LinkPredictMultiRelationLearnableDecoder):
r""" Decoder for link prediction using the RotatE as the score function.
r"""
.. versionadded:: 0.4
The :py:class:`LinkPredictRotatEDecoder`.
Score function of RotateE measures the angular distance between
head and tail elements. The angular distance is defined as:
Decoder for link prediction using the RotatE as the score function.
.. math::
Score function of RotateE measures the angular distance between
head and tail elements. The angular distance is defined as:
d_r(h, t)=\|h\circ r-t\|
.. math::
The RotatE score function is defined as:
d_r(h, t)=\|h\circ r-t\|
.. math::
The RotatE score function is defined as:
gamma - \|h\circ r-t\|^2
.. math::
where gamma is a margin.
gamma - \|h\circ r-t\|^2
For more details, please refer to https://arxiv.org/abs/1902.10197
or https://dglke.dgl.ai/doc/kg.html#rotatee.
where gamma is a margin.
Note: The relation embedding of RotatE has two parts,
one for real numbers and one for complex numbers.
Each has the dimension size as half of the input dimension size.
For more details, please refer to https://arxiv.org/abs/1902.10197
or https://dglke.dgl.ai/doc/kg.html#rotatee.
Note: The relation embedding of RotatE has two parts,
one for real numbers and one for complex numbers.
Each has the dimension size as half of the input dimension size.
Parameters
----------
Expand Down Expand Up @@ -1376,14 +1380,18 @@ def out_dims(self):
return 1

class LinkPredictContrastiveRotatEDecoder(LinkPredictRotatEDecoder):
""" Decoder for link prediction designed for contrastive loss
using the RotatE as the score function.
"""
.. versionadded:: 0.4
The :py:class:`LinkPredictContrastiveRotatEDecoder`.
Note:
------
This class is specifically implemented for contrastive loss. But
it could also be used by other pair-wise loss functions for link
prediction tasks.
Decoder for link prediction designed for contrastive loss
using the RotatE as the score function.
Note:
------
This class is specifically implemented for contrastive loss. But
it could also be used by other pair-wise loss functions for link
prediction tasks.
Parameters
----------
Expand Down Expand Up @@ -1442,10 +1450,13 @@ def forward(self, g, h, e_h=None):
return scores

class LinkPredictWeightedRotatEDecoder(LinkPredictRotatEDecoder):
"""Link prediction decoder with the score function of RotatE
with edge weight.
"""
.. versionadded:: 0.4
The :py:class:`LinkPredictWeightedRotatEDecoder`.
Link prediction decoder with the score function of RotatE with edge weight.
When computing loss, edge weights are used to adjust the loss.
When computing loss, edge weights are used to adjust the loss.
Parameters
----------
Expand Down Expand Up @@ -1510,26 +1521,30 @@ def forward(self, g, h, e_h):
return scores

class LinkPredictTransEDecoder(LinkPredictMultiRelationLearnableDecoder):
r""" Decoder for link prediction using the TransE as the score function.
r"""
.. versionadded:: 0.4
The :py:class:`LinkPredictTransEDecoder`.
Score function of TransE measures the angular distance between
head and tail elements. The angular distance is defined as:
Decoder for link prediction using the TransE as the score function.
.. math::
Score function of TransE measures the angular distance between
head and tail elements. The angular distance is defined as:
d_r(h, t)= -\|h+r-t\|
.. math::
The TransE score function is defined as:
d_r(h, t)= -\|h+r-t\|
.. math::
The TransE score function is defined as:
gamma - \|h+r-t\|^{frac{1}{2}} \text{or} gamma - \|h+r-t\|
.. math::
where gamma is a margin.
gamma - \|h+r-t\|^{frac{1}{2}} \text{or} gamma - \|h+r-t\|
For more details, please refer to
https://papers.nips.cc/paper_files/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html
or https://dglke.dgl.ai/doc/kg.html#transe.
where gamma is a margin.
For more details, please refer to
https://papers.nips.cc/paper_files/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html
or https://dglke.dgl.ai/doc/kg.html#transe.
Parameters
----------
Expand Down Expand Up @@ -1769,14 +1784,18 @@ def out_dims(self):
return 1

class LinkPredictContrastiveTransEDecoder(LinkPredictTransEDecoder):
""" Decoder for link prediction designed for contrastive loss
using the TransE as the score function.
"""
.. versionadded:: 0.4
The :py:class:`LinkPredictContrastiveTransEDecoder`.
Note:
------
This class is specifically implemented for contrastive loss. But
it could also be used by other pair-wise loss functions for link
prediction tasks.
Decoder for link prediction designed for contrastive loss
using the TransE as the score function.
Note:
------
This class is specifically implemented for contrastive loss. But
it could also be used by other pair-wise loss functions for link
prediction tasks.
Parameters
----------
Expand Down Expand Up @@ -1834,10 +1853,13 @@ def forward(self, g, h, e_h=None):
return scores

class LinkPredictWeightedTransEDecoder(LinkPredictTransEDecoder):
"""Link prediction decoder with the score function of TransE
with edge weight.
"""
.. versionadded:: 0.4
The :py:class:`LinkPredictWeightedTransEDecoder`.
Link prediction decoder with the score function of TransE with edge weight.
When computing loss, edge weights are used to adjust the loss.
When computing loss, edge weights are used to adjust the loss.
Parameters
----------
Expand Down
50 changes: 1 addition & 49 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,57 +23,9 @@
from graphstorm.config import GSConfig
from graphstorm.trainer import GSgnnLinkPredictionTrainer
from graphstorm.dataloading import GSgnnData
from graphstorm.eval import (GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator,
GSgnnHitsLPEvaluator, GSgnnPerEtypeHitsLPEvaluator)
from graphstorm.model.utils import save_full_node_embeddings
from graphstorm.model import do_full_graph_inference
from graphstorm.utils import rt_profiler, sys_tracker, get_device
from graphstorm.eval.eval_func import SUPPORTED_HIT_AT_METRICS

def get_evaluator(config):
""" Get evaluator according to config
Parameters
----------
config: GSConfig
Configuration
"""
# TODO: to create a generic evaluator for LP tasks
assert (len(config.eval_metric) == 1 and config.eval_metric[0] == 'mrr') \
or (len(config.eval_metric) >= 1
and all((x.startswith(SUPPORTED_HIT_AT_METRICS) for x in config.eval_metric))), \
"GraphStorm does not support computing MRR and Hit@K metrics at the same time."

if config.report_eval_per_type:
if 'mrr' in config.eval_metric:
return GSgnnPerEtypeMrrLPEvaluator(eval_frequency=config.eval_frequency,
major_etype=config.model_select_etype,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
else:
return GSgnnPerEtypeHitsLPEvaluator(eval_frequency=config.eval_frequency,
eval_metric_list=config.eval_metric,
major_etype=config.model_select_etype,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
else:
if 'mrr' in config.eval_metric:
return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
else:
return GSgnnHitsLPEvaluator(eval_frequency=config.eval_frequency,
eval_metric_list=config.eval_metric,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)

def main(config_args):
""" main function
Expand All @@ -98,7 +50,7 @@ def main(config_args):
if not config.no_validation:
# TODO(zhengda) we need to refactor the evaluator.
# Currently, we only support mrr
evaluator = get_evaluator(config)
evaluator = gs.create_lp_evaluator(config)
trainer.setup_evaluator(evaluator)
val_idxs = train_data.get_edge_val_set(config.eval_etype)
assert len(val_idxs) > 0, "The training data do not have validation set."
Expand Down
Loading

0 comments on commit 88ab782

Please sign in to comment.