Skip to content

Commit

Permalink
change to responding comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Apr 24, 2024
1 parent 2b262fe commit f55b406
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 88 deletions.
4 changes: 2 additions & 2 deletions docs/source/api/graphstorm.eval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Base Evaluators

GSgnnBaseEvaluator
GSgnnPredictionEvalInterface
GSgnnLPMrrEvalInterface
GSgnnLPRankingEvalInterface

Evaluators
-----------
Expand All @@ -35,5 +35,5 @@ Evaluators

GSgnnClassificationEvaluator
GSgnnRegressionEvaluator
GSgnnLPEvaluator
GSgnnMrrLPEvaluator
GSgnnPerEtypeMrrLPEvaluator
4 changes: 2 additions & 2 deletions examples/peft_llm_gnn/main_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from graphstorm.config import get_argument_parser
from graphstorm.config import GSConfig
from graphstorm.dataloading import GSgnnLinkPredictionDataLoader, GSgnnLinkPredictionTestDataLoader
from graphstorm.eval import GSgnnLPEvaluator
from graphstorm.eval import GSgnnMrrLPEvaluator
from graphstorm.dataloading import GSgnnLPTrainData
from graphstorm.utils import get_device
from graphstorm.inference import GSgnnLinkPredictionInferrer
Expand Down Expand Up @@ -54,7 +54,7 @@ def main(config_args):
trainer.setup_device(device=get_device())

# set evaluator
evaluator = GSgnnLPEvaluator(
evaluator = GSgnnMrrLPEvaluator(
eval_frequency=config.eval_frequency,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
Expand Down
2 changes: 1 addition & 1 deletion python/graphstorm/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .eval_func import SUPPORTED_REGRESSION_METRICS
from .eval_func import SUPPORTED_LINK_PREDICTION_METRICS

from .evaluator import GSgnnLPEvaluator
from .evaluator import GSgnnMrrLPEvaluator
from .evaluator import GSgnnPerEtypeMrrLPEvaluator
from .evaluator import GSgnnClassificationEvaluator
from .evaluator import GSgnnRegressionEvaluator
79 changes: 39 additions & 40 deletions python/graphstorm/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,43 +164,49 @@ def compute_score(self, pred, labels, train=True):
"""


class GSgnnLPMrrEvalInterface():
""" Interface for Link Prediction evaluation function using "mrr"
class GSgnnLPRankingEvalInterface():
""" Interface for Link Prediction evaluation function using ranking method
The interface set the two abstract methods for Link Prediction classes that use "mrr"
as the evaluation metric.
The interface set the two abstract methods for Link Prediction classes that use ranking
method to compute evaluation metrics, such as "mrr" (Mean Reciprocal Rank).
There are two methdos to be implemented if inherite this interface.
1. ``evaluate()`` method, which will be called by Trainers to provide ranking-based evaluation
results of validation and test sets during training process.
2. ``compute_score()`` method, which compute the scores for given rankings.
"""

@abc.abstractmethod
def evaluate(self, val_scores, test_scores, total_iters):
def evaluate(self, val_rankings, test_rankings, total_iters):
"""Evaluate validation and test sets for Link Prediciton tasks
GSgnnTrainers will call this function to do evalution in their eval() fuction.
Link Prediction evaluators should provide the LP scores for validation and test sets.
Link Prediction evaluators should provide the ranking of validation and test sets as
input.
Parameters
----------
val_scores: dict of tensors
The rankings of validation edges for each edge type.
The rankings of validation edges for each edge type in format of {etype: ranking}.
test_scores: dict of tensors
The rankings of testing edges for each edge type.
The rankings of testing edges for each edge type in format of {etype: ranking}.
total_iters: int
The current interation number.
Returns
-----------
eval_score: float
Validation score
Validation score for each edge type in format of {etype: score}.
test_score: float
Test score
Test score for each edge type in format of {etype: score}.
"""

@abc.abstractmethod
def compute_score(self, rankings, train=False):
def compute_score(self, rankings):
""" Compute evaluation score for Prediciton tasks
Classification and regression evaluators should provide both predictions and labels.
Ranking-based link prediction evaluators should provide ranking values as input.
Parameters
----------
Expand All @@ -210,6 +216,7 @@ def compute_score(self, rankings, train=False):
Returns
-------
Evaluation metric values: dict
scores for each edge type.
"""


Expand All @@ -218,7 +225,7 @@ class GSgnnBaseEvaluator():
New base class in V0.3 to replace ``GSgnnInstanceEvaluator`` and ``GSgnnLPEvaluator``. This
class serves as the base for the built-in ``GSgnnClassificationEvaluator``,
``GSgnnRegressionEvaluator``, and ``GSgnnLinkPredictionEvaluator``.
``GSgnnRegressionEvaluator``, ``GSgnnMrrLPEvaluator``, and ``GSgnnPerEtypeMrrLPEvaluator``.
In order to create customized Evaluators, users can inherite this class and the corresponding
EvalInteface class, and then implement the two abstract methods, i.e., ``evaluate()``
Expand Down Expand Up @@ -590,7 +597,7 @@ def __init__(self, eval_frequency,
early_stop_burnin_rounds=0,
early_stop_rounds=3,
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY):
# set up default metric to be mse
# set up default metric to be "rmse"
if eval_metric_list is None:
eval_metric_list = ["rmse"]
super(GSgnnRegressionEvaluator, self).__init__(eval_frequency,
Expand Down Expand Up @@ -692,11 +699,11 @@ def compute_score(self, pred, labels, train=True):
return scores


class GSgnnLPEvaluator(GSgnnBaseEvaluator, GSgnnLPMrrEvalInterface):
""" Link Prediction Evaluator.
class GSgnnMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface):
""" Link Prediction Evaluator using "mrr" as metric.
GS built-in evaluator for Link Prediction task. It uses "mrr" as the default eval metric,
which therefore implements the `GSgnnLPMrrEvalInterface`.
which implements the `GSgnnLPRankingEvalInterface`.
To create a customized LP evaluator that use evaluation metric other than "mrr", users might
need to 1) define a new evaluation interface if the evaluation method requires different input
Expand Down Expand Up @@ -727,7 +734,7 @@ def __init__(self, eval_frequency,
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY):
if eval_metric_list is None:
eval_metric_list = ["mrr"]
super(GSgnnLPEvaluator, self).__init__(eval_frequency,
super(GSgnnMrrLPEvaluator, self).__init__(eval_frequency,
eval_metric_list, use_early_stop, early_stop_burnin_rounds,
early_stop_rounds, early_stop_strategy)
self.metrics_obj = LinkPredictionMetrics()
Expand All @@ -740,9 +747,9 @@ def __init__(self, eval_frequency,
self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric)
self._best_iter[metric] = 0

def evaluate(self, val_scores, test_scores, total_iters):
def evaluate(self, val_rankings, test_rankings, total_iters):
""" `GSgnnLinkPredictionTrainer` and `GSgnnLinkPredictionInferrer` will call this function
to compute validation and test mrr scores.
to compute validation and test scores.
Parameters
----------
Expand All @@ -761,14 +768,14 @@ def evaluate(self, val_scores, test_scores, total_iters):
Test mrr score
"""
with th.no_grad():
if test_scores is not None:
test_score = self.compute_score(test_scores)
if test_rankings is not None:
test_score = self.compute_score(test_rankings)
else:
for metric in self.metric_list:
test_score = {metric: "N/A"} # Dummy

if val_scores is not None:
val_score = self.compute_score(val_scores)
if val_rankings is not None:
val_score = self.compute_score(val_rankings)

if get_rank() == 0:
for metric in self.metric_list:
Expand All @@ -784,17 +791,13 @@ def evaluate(self, val_scores, test_scores, total_iters):

return val_score, test_score

def compute_score(self, rankings, train=False): # pylint:disable=unused-argument
def compute_score(self, rankings):
""" Compute evaluation score
Parameters
----------
rankings: dict of tensors
Rankings of positive scores in format of {etype: ranking}
train: bool
TODO: Reversed for future use cases when we want to use different
way to generate scores for train (more efficient but less accurate)
and test.
Returns
-------
Expand All @@ -821,7 +824,7 @@ def compute_score(self, rankings, train=False): # pylint:disable=unused-argument
return return_metrics


class GSgnnPerEtypeMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPMrrEvalInterface):
class GSgnnPerEtypeMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface):
""" The class for link prediction evaluation using Mrr metric and
return a Per etype mrr score.
Expand Down Expand Up @@ -869,17 +872,13 @@ def __init__(self, eval_frequency,
self._best_test_score[metric] = self.metrics_obj.init_best_metric(metric=metric)
self._best_iter[metric] = 0

def compute_score(self, rankings, train=False): # pylint:disable=unused-argument
def compute_score(self, rankings):
""" Compute evaluation score
Parameters
----------
rankings: dict of tensors
Rankings of positive scores in format of {etype: ranking}
train: bool
TODO: Reversed for future use cases when we want to use different
way to generate scores for train (more efficient but less accurate)
and test.
Returns
-------
Expand Down Expand Up @@ -917,7 +916,7 @@ def _get_major_score(self, score):
major_score = score[self.major_etype]
return major_score

def evaluate(self, val_scores, test_scores, total_iters):
def evaluate(self, val_rankings, test_rankings, total_iters):
""" `GSgnnLinkPredictionTrainer` and `GSgnnLinkPredictionInferrer` will call this function
to compute validation and test mrr scores.
Expand All @@ -938,13 +937,13 @@ def evaluate(self, val_scores, test_scores, total_iters):
Test mrr
"""
with th.no_grad():
if test_scores is not None:
test_score = self.compute_score(test_scores)
if test_rankings is not None:
test_score = self.compute_score(test_rankings)
else:
test_score = {"mrr": "N/A"} # Dummy

if val_scores is not None:
val_score = self.compute_score(val_scores)
if val_rankings is not None:
val_score = self.compute_score(val_rankings)

if get_rank() == 0:
for metric in self.metric_list:
Expand Down
4 changes: 2 additions & 2 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
BUILTIN_LP_LOCALJOINT_NEG_SAMPLER)
from graphstorm.dataloading import BUILTIN_LP_ALL_ETYPE_UNIFORM_NEG_SAMPLER
from graphstorm.dataloading import BUILTIN_LP_ALL_ETYPE_JOINT_NEG_SAMPLER
from graphstorm.eval import GSgnnLPEvaluator, GSgnnPerEtypeMrrLPEvaluator
from graphstorm.eval import GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator
from graphstorm.model.utils import save_full_node_embeddings
from graphstorm.model import do_full_graph_inference
from graphstorm.utils import rt_profiler, sys_tracker, get_device
Expand All @@ -63,7 +63,7 @@ def get_evaluator(config):
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
else:
return GSgnnLPEvaluator(eval_frequency=config.eval_frequency,
return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
Expand Down
4 changes: 2 additions & 2 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
FastGSgnnLPJointNegDataLoader,
FastGSgnnLPLocalUniformNegDataLoader,
FastGSgnnLPLocalJointNegDataLoader)
from graphstorm.eval import GSgnnLPEvaluator, GSgnnPerEtypeMrrLPEvaluator
from graphstorm.eval import GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator
from graphstorm.model.utils import save_full_node_embeddings
from graphstorm.model import do_full_graph_inference
from graphstorm.utils import (
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_evaluator(config):
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy)
else:
return GSgnnLPEvaluator(eval_frequency=config.eval_frequency,
return GSgnnMrrLPEvaluator(eval_frequency=config.eval_frequency,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
Expand Down
4 changes: 2 additions & 2 deletions python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from graphstorm.config import get_argument_parser
from graphstorm.config import GSConfig
from graphstorm.inference import GSgnnLinkPredictionInferrer
from graphstorm.eval import GSgnnLPEvaluator
from graphstorm.eval import GSgnnMrrLPEvaluator
from graphstorm.dataloading import GSgnnEdgeInferData
from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader,
GSgnnLinkPredictionJointTestDataLoader,
Expand Down Expand Up @@ -57,7 +57,7 @@ def main(config_args):
infer.setup_device(device=get_device())
if not config.no_validation:
infer.setup_evaluator(
GSgnnLPEvaluator(config.eval_frequency))
GSgnnMrrLPEvaluator(config.eval_frequency))
assert len(infer_data.test_idxs) > 0, "There is not test data for evaluation."
tracker = gs.create_builtin_task_tracker(config)
infer.setup_task_tracker(tracker)
Expand Down
4 changes: 2 additions & 2 deletions python/graphstorm/run/gsgnn_lp/lp_infer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from graphstorm.config import get_argument_parser
from graphstorm.config import GSConfig
from graphstorm.inference import GSgnnLinkPredictionInferrer
from graphstorm.eval import GSgnnLPEvaluator
from graphstorm.eval import GSgnnMrrLPEvaluator
from graphstorm.dataloading import GSgnnEdgeInferData
from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader,
GSgnnLinkPredictionJointTestDataLoader,
Expand Down Expand Up @@ -51,7 +51,7 @@ def main(config_args):
infer.setup_device(device=get_device())
if not config.no_validation:
infer.setup_evaluator(
GSgnnLPEvaluator(config.eval_frequency))
GSgnnMrrLPEvaluator(config.eval_frequency))
assert len(infer_data.test_idxs) > 0, "There is not test data for evaluation."
tracker = gs.create_builtin_task_tracker(config)
infer.setup_task_tracker(tracker)
Expand Down
8 changes: 4 additions & 4 deletions python/graphstorm/trainer/lp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,16 +323,16 @@ def eval(self, model, data, val_loader, test_loader,
edge_mask=edge_mask_for_gnn_embeddings,
task_tracker=self.task_tracker)
sys_tracker.check('compute embeddings')
val_scores = lp_mini_batch_predict(model, emb, val_loader, self.device) \
val_rankings = lp_mini_batch_predict(model, emb, val_loader, self.device) \
if val_loader is not None else None
sys_tracker.check('after_val_score')
if test_loader is not None:
test_scores = lp_mini_batch_predict(model, emb, test_loader, self.device)
test_rankings = lp_mini_batch_predict(model, emb, test_loader, self.device)
else:
test_scores = None
test_rankings = None
sys_tracker.check('after_test_score')
val_score, test_score = self.evaluator.evaluate(
val_scores, test_scores, total_steps)
val_rankings, test_rankings, total_steps)
sys_tracker.check('evaluate validation/test')
model.train()

Expand Down
10 changes: 5 additions & 5 deletions tests/unit-tests/test_dist_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from graphstorm.eval import GSgnnClassificationEvaluator
from graphstorm.eval import GSgnnRegressionEvaluator
from graphstorm.eval import GSgnnLPEvaluator
from graphstorm.eval import GSgnnMrrLPEvaluator
from graphstorm.utils import setup_device

from graphstorm.config import BUILTIN_LP_DOT_DECODER
Expand All @@ -49,8 +49,8 @@ def run_dist_lp_eval_worker(worker_rank, config, val_scores, test_scores, conn):
world_size=2,
rank=worker_rank)

lp_eval = GSgnnLPEvaluator(config.eval_frequency,
use_early_stop=config.use_early_stop)
lp_eval = GSgnnMrrLPEvaluator(config.eval_frequency,
use_early_stop=config.use_early_stop)
val_sc, test_sc = lp_eval.evaluate(val_scores, test_scores, 0)

if worker_rank == 0:
Expand Down Expand Up @@ -101,8 +101,8 @@ def run_local_lp_eval_worker(config, val_scores, test_scores, conn):
world_size=1,
rank=0)

lp_eval = GSgnnLPEvaluator(config.eval_frequency,
use_early_stop=config.use_early_stop)
lp_eval = GSgnnMrrLPEvaluator(config.eval_frequency,
use_early_stop=config.use_early_stop)
val_sc, test_sc = lp_eval.evaluate(val_scores, test_scores, 0)
conn.send((val_sc, test_sc))
th.distributed.destroy_process_group()
Expand Down
Loading

0 comments on commit f55b406

Please sign in to comment.