From 977741807bf1e87d69b2300ce4cd641ecabc1a39 Mon Sep 17 00:00:00 2001 From: rvandewater Date: Tue, 15 Oct 2024 09:52:56 +0200 Subject: [PATCH] revert to log loss --- icu_benchmarks/models/wrappers.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/icu_benchmarks/models/wrappers.py b/icu_benchmarks/models/wrappers.py index cdfa7f7c..a9dc9caa 100644 --- a/icu_benchmarks/models/wrappers.py +++ b/icu_benchmarks/models/wrappers.py @@ -366,16 +366,17 @@ class MLWrapper(BaseModule, ABC): requires_backprop = False _supported_run_modes = [RunMode.classification, RunMode.regression] - def __init__(self, *args, run_mode=RunMode.classification, loss=average_precision_score, patience=10, mps=False, **kwargs): + def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patience=10, mps=False, **kwargs): super().__init__() self.save_hyperparameters() self.scaler = None self.check_supported_runmode(run_mode) self.run_mode = run_mode - if loss.__name__ in ["average_precision_score", "roc_auc_score"]: - self.loss = scorer_wrapper(average_precision_score) - else: - self.loss = self.loss + # if loss.__name__ in ["average_precision_score", "roc_auc_score"]: + # self.loss = scorer_wrapper(average_precision_score) + # else: + # self.loss = self.loss + self.loss = loss self.patience = patience self.mps = mps self.loss_weight = None