Skip to content

Commit

Permalink
revert list as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Apr 27, 2024
1 parent 39d7d80 commit 515c8eb
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions python/graphstorm/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,15 @@ class GSgnnClassificationEvaluator(GSgnnBaseEvaluator, GSgnnPredictionEvalInterf
1) consecutive_increase and 2) average_increase.
"""
def __init__(self, eval_frequency,
eval_metric_list=["accuracy"],
eval_metric_list=None,
multilabel=False,
use_early_stop=False,
early_stop_burnin_rounds=0,
early_stop_rounds=3,
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): # pylint: disable=unused-argument
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY):
# set default metric list
if eval_metric_list is None:
eval_metric_list = ["accuracy"]
super(GSgnnClassificationEvaluator, self).__init__(eval_frequency,
eval_metric_list,
use_early_stop,
Expand Down Expand Up @@ -595,11 +598,14 @@ class GSgnnRegressionEvaluator(GSgnnBaseEvaluator, GSgnnPredictionEvalInterface)
1) consecutive_increase and 2) average_increase.
"""
def __init__(self, eval_frequency,
eval_metric_list=["rmse"],
eval_metric_list=None,
use_early_stop=False,
early_stop_burnin_rounds=0,
early_stop_rounds=3,
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY):
# set default metric list
if eval_metric_list is None:
eval_metric_list = ["rmse"]
super(GSgnnRegressionEvaluator, self).__init__(eval_frequency,
eval_metric_list, use_early_stop, early_stop_burnin_rounds,
early_stop_rounds, early_stop_strategy)
Expand Down Expand Up @@ -729,11 +735,14 @@ class GSgnnMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface):
1) consecutive_increase and 2) average_increase.
"""
def __init__(self, eval_frequency,
eval_metric_list=["mrr"],
eval_metric_list=None,
use_early_stop=False,
early_stop_burnin_rounds=0,
early_stop_rounds=3,
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY):
# set default metric list
if eval_metric_list is None:
eval_metric_list = ["mrr"]
super(GSgnnMrrLPEvaluator, self).__init__(eval_frequency,
eval_metric_list, use_early_stop, early_stop_burnin_rounds,
early_stop_rounds, early_stop_strategy)
Expand Down Expand Up @@ -861,12 +870,15 @@ class GSgnnPerEtypeMrrLPEvaluator(GSgnnBaseEvaluator, GSgnnLPRankingEvalInterfac
1) consecutive_increase and 2) average_increase.
"""
def __init__(self, eval_frequency,
eval_metric_list=["mrr"],
eval_metric_list=None,
major_etype = LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL,
use_early_stop=False,
early_stop_burnin_rounds=0,
early_stop_rounds=3,
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY):
# set default metric list
if eval_metric_list is None:
eval_metric_list = ["mrr"]
super(GSgnnPerEtypeMrrLPEvaluator, self).__init__(eval_frequency,
eval_metric_list,
use_early_stop,
Expand Down

0 comments on commit 515c8eb

Please sign in to comment.