Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor Evaluator #822

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/source/advanced/own-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,13 @@ The GraphStorm trainers can have evaluators and task trackers associated. The fo
.. code-block:: python

# Optional: set up a evaluator
evaluator = GSgnnAccEvaluator(config.eval_frequency,
config.eval_metric,
config.multilabel,
config.use_early_stop,
config.early_stop_burnin_rounds,
config.early_stop_rounds,
config.early_stop_strategy)
evaluator = GSgnnClassificationEvaluator(config.eval_frequency,
config.eval_metric,
config.multilabel,
config.use_early_stop,
config.early_stop_burnin_rounds,
config.early_stop_rounds,
config.early_stop_strategy)
trainer.setup_evaluator(evaluator)
# Optional: set up a task tracker to show the progress of training.
tracker = GSSageMakerTaskTracker(config.eval_frequency)
Expand Down
16 changes: 9 additions & 7 deletions docs/source/api/graphstorm.eval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ graphstorm.eval
Learning (GML) tasks.

If users want to implement customized evaluators or evaluation methods, a best practice is to
extend base evaluators, i.e., the ``GSgnnInstanceEvaluator`` class for node or edge prediction
tasks, and ``GSgnnLPEvaluator`` for link prediction tasks, and then implement the abstract methods.
extend the base evaluator, i.e., the ``GSgnnBaseEvaluator``, and the corresponding evaluation
interfaces, e.g., ``GSgnnPredictionEvalInterface``` for prediction evaluation, and
``GSgnnLPRankingEvalInterface`` for ranking based link prediction evaluation, and then
implement the abstract methods defined in those interface classes.

.. currentmodule:: graphstorm.eval

Expand All @@ -20,8 +22,9 @@ Base Evaluators
:nosignatures:
:template: evaltemplate.rst

GSgnnInstanceEvaluator
GSgnnLPEvaluator
GSgnnBaseEvaluator
GSgnnPredictionEvalInterface
GSgnnLPRankingEvalInterface

Evaluators
-----------
Expand All @@ -31,8 +34,7 @@ Evaluators
:nosignatures:
:template: evaltemplate.rst

GSgnnLPEvaluator
GSgnnClassificationEvaluator
GSgnnRegressionEvaluator
GSgnnMrrLPEvaluator
GSgnnPerEtypeMrrLPEvaluator
GSgnnAccEvaluator
GSgnnRegressionEvaluator
16 changes: 8 additions & 8 deletions examples/customized_models/HGT/hgt_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from graphstorm.inference import GSgnnNodePredictionInferrer
from graphstorm.dataloading import GSgnnNodeTrainData, GSgnnNodeInferData
from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.eval import GSgnnAccEvaluator
from graphstorm.eval import GSgnnClassificationEvaluator
from graphstorm.tracker import GSSageMakerTaskTracker
from graphstorm.utils import get_device

Expand Down Expand Up @@ -326,13 +326,13 @@ def main(args):
train_task=False)

# Optional: set up a evaluator
evaluator = GSgnnAccEvaluator(config.eval_frequency,
config.eval_metric,
config.multilabel,
config.use_early_stop,
config.early_stop_burnin_rounds,
config.early_stop_rounds,
config.early_stop_strategy)
evaluator = GSgnnClassificationEvaluator(config.eval_frequency,
config.eval_metric,
config.multilabel,
config.use_early_stop,
config.early_stop_burnin_rounds,
config.early_stop_rounds,
config.early_stop_strategy)
trainer.setup_evaluator(evaluator)
# Optional: set up a task tracker to show the progress of training.
tracker = GSSageMakerTaskTracker(config.eval_frequency)
Expand Down
14 changes: 6 additions & 8 deletions examples/peft_llm_gnn/main_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@ def main(config_args):
trainer.setup_device(device=get_device())

# set evaluator
evaluator = GSgnnMrrLPEvaluator(config.eval_frequency,
train_data,
config.num_negative_edges_eval,
config.lp_decoder_type,
config.use_early_stop,
config.early_stop_burnin_rounds,
config.early_stop_rounds,
config.early_stop_strategy
evaluator = GSgnnMrrLPEvaluator(
eval_frequency=config.eval_frequency,
use_early_stop=config.use_early_stop,
early_stop_burnin_rounds=config.early_stop_burnin_rounds,
early_stop_rounds=config.early_stop_rounds,
early_stop_strategy=config.early_stop_strategy
)
# disbale validation for efficiency
# trainer.setup_evaluator(evaluator)
Expand Down
4 changes: 2 additions & 2 deletions examples/peft_llm_gnn/main_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from graphstorm.config import get_argument_parser
from graphstorm.config import GSConfig
from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.eval import GSgnnAccEvaluator
from graphstorm.eval import GSgnnClassificationEvaluator
from graphstorm.dataloading import GSgnnNodeTrainData
from graphstorm.utils import get_device
from graphstorm.inference import GSgnnNodePredictionInferrer
Expand Down Expand Up @@ -52,7 +52,7 @@ def main(config_args):
trainer.setup_device(device=get_device())

# set evaluator
evaluator = GSgnnAccEvaluator(
evaluator = GSgnnClassificationEvaluator(
config.eval_frequency,
config.eval_metric,
config.multilabel,
Expand Down
5 changes: 2 additions & 3 deletions examples/standalone_mode_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"from graphstorm.dataloading import GSgnnNodeTrainData, GSgnnNodeDataLoader, GSgnnNodeInferData\n",
"from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, EntityClassifier, ClassifyLossFunc, RelationalGCNEncoder\n",
"from graphstorm.inference import GSgnnNodePredictionInferrer\n",
"from graphstorm.eval import GSgnnAccEvaluator"
"from graphstorm.eval import GSgnnClassificationEvaluator"
]
},
{
Expand Down Expand Up @@ -315,9 +315,8 @@
"trainer.setup_device(device=device)\n",
"\n",
"# set up evaluator for the trainer:\n",
"evaluator = GSgnnAccEvaluator(\n",
"evaluator = GSgnnClassificationEvaluator(\n",
" eval_frequency=10000,\n",
" eval_metric=['accuracy'],\n",
" multilabel=multilabel)\n",
"\n",
"trainer.setup_evaluator(evaluator)"
Expand Down
4 changes: 2 additions & 2 deletions examples/temporal_graph_learning/main_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from graphstorm.config import get_argument_parser
from graphstorm.config import GSConfig
from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.eval import GSgnnAccEvaluator
from graphstorm.eval import GSgnnClassificationEvaluator
from graphstorm.dataloading import GSgnnNodeTrainData
from graphstorm.utils import get_device
from graphstorm.trainer import GSgnnNodePredictionTrainer
Expand Down Expand Up @@ -45,7 +45,7 @@ def main(config_args):
trainer.setup_device(device=get_device())

# set evaluator
evaluator = GSgnnAccEvaluator(
evaluator = GSgnnClassificationEvaluator(
config.eval_frequency,
config.eval_metric,
config.multilabel,
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def verify_arguments(self, is_train):
_ = self.log_report_frequency

_ = self.task_type
# For classification tasks.
# For classification/regression tasks.
if self.task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_EDGE_CLASSIFICATION]:
_ = self.label_field
_ = self.num_classes
Expand All @@ -368,6 +368,7 @@ def verify_arguments(self, is_train):
BUILTIN_TASK_LINK_PREDICTION] and is_train:
_ = self.exclude_training_targets
_ = self.reverse_edge_types_map
# For link prediction tasks.
if self.task_type == BUILTIN_TASK_LINK_PREDICTION:
_ = self.gamma
_ = self.lp_decoder_type
Expand Down
7 changes: 3 additions & 4 deletions python/graphstorm/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from .eval_func import SUPPORTED_REGRESSION_METRICS
from .eval_func import SUPPORTED_LINK_PREDICTION_METRICS

from .evaluator import GSgnnInstanceEvaluator
from .evaluator import GSgnnLPEvaluator
from .evaluator import GSgnnMrrLPEvaluator, GSgnnPerEtypeMrrLPEvaluator
from .evaluator import GSgnnAccEvaluator
from .evaluator import GSgnnMrrLPEvaluator
from .evaluator import GSgnnPerEtypeMrrLPEvaluator
from .evaluator import GSgnnClassificationEvaluator
from .evaluator import GSgnnRegressionEvaluator
30 changes: 30 additions & 0 deletions python/graphstorm/eval/eval_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ def __init__(self):
self.metric_function["mse"] = compute_mse
self.metric_function["mae"] = compute_mae

# This is the operator used to measure each metric performance in evaluation
self.metric_eval_function = {}
self.metric_eval_function["rmse"] = compute_rmse
self.metric_eval_function["mse"] = compute_mse
self.metric_eval_function["mae"] = compute_mae

def assert_supported_metric(self, metric):
""" check if the given metric is supported.
"""
Expand Down Expand Up @@ -135,6 +141,14 @@ def __init__(self):
self.metric_comparator = {}
self.metric_comparator["mrr"] = operator.le

# This is the operator used to measure each metric performance
self.metric_function = {}
self.metric_function["mrr"] = compute_mrr

# This is the operator used to measure each metric performance in evaluation
self.metric_eval_function = {}
self.metric_eval_function["mrr"] = compute_mrr

def assert_supported_metric(self, metric):
""" check if the given metric is supported.
"""
Expand Down Expand Up @@ -583,3 +597,19 @@ def compute_mae(pred, labels):

diff = th.abs(pred.cpu() - labels.cpu())
return th.mean(diff).cpu().item()

def compute_mrr(ranking):
""" Get link prediction mrr metrics

Parameters
----------
ranking:
ranking of each positive edge

Returns
-------
link prediction mrr metrics: tensor
"""
logs = th.div(1.0, ranking)
metrics = th.tensor(th.div(th.sum(logs),len(logs)))
return metrics
Loading
Loading