Skip to content

Commit

Permalink
revert to log loss
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater committed Oct 15, 2024
1 parent b307474 commit 9777418
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions icu_benchmarks/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,17 @@ class MLWrapper(BaseModule, ABC):
requires_backprop = False
_supported_run_modes = [RunMode.classification, RunMode.regression]

def __init__(self, *args, run_mode=RunMode.classification, loss=average_precision_score, patience=10, mps=False, **kwargs):
def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patience=10, mps=False, **kwargs):
super().__init__()
self.save_hyperparameters()
self.scaler = None
self.check_supported_runmode(run_mode)
self.run_mode = run_mode
if loss.__name__ in ["average_precision_score", "roc_auc_score"]:
self.loss = scorer_wrapper(average_precision_score)
else:
self.loss = self.loss
# if loss.__name__ in ["average_precision_score", "roc_auc_score"]:
# self.loss = scorer_wrapper(average_precision_score)
# else:
# self.loss = self.loss
self.loss = loss
self.patience = patience
self.mps = mps
self.loss_weight = None
Expand Down

0 comments on commit 9777418

Please sign in to comment.