From e2678c4b3a3e57514f7acd8d618797d3cb260876 Mon Sep 17 00:00:00 2001 From: GitHub Action <52708150+marcpinet@users.noreply.github.com> Date: Tue, 3 Dec 2024 18:52:53 +0100 Subject: [PATCH] fix(optimizer-losses): generic from_config --- neuralnetlib/losses.py | 92 ++++++++++++++++++-------------------- neuralnetlib/optimizers.py | 18 ++++---- 2 files changed, 52 insertions(+), 58 deletions(-) diff --git a/neuralnetlib/losses.py b/neuralnetlib/losses.py index 24f7e50..dbb7dcb 100644 --- a/neuralnetlib/losses.py +++ b/neuralnetlib/losses.py @@ -16,57 +16,53 @@ def get_config(self) -> dict: @staticmethod def from_config(config: dict) -> 'LossFunction': - if config['name'] == 'MeanSquaredError': - return MeanSquaredError() - elif config['name'] == 'BinaryCrossentropy': - return BinaryCrossentropy() - elif config['name'] == 'CategoricalCrossentropy': - return CategoricalCrossentropy() - elif config['name'] == 'MeanAbsoluteError': - return MeanAbsoluteError() - elif config['name'] == 'HuberLoss': - return HuberLoss(config['delta']) - elif config['name'] == 'KullbackLeiblerDivergence': - return KullbackLeiblerDivergence() - elif config['name'] == 'CrossEntropyWithLabelSmoothing': - return CrossEntropyWithLabelSmoothing(config['label_smoothing']) - elif config['name'] == 'Wasserstein': - return Wasserstein() - elif config['name'] == 'FocalLoss': - return FocalLoss(config['gamma'], config['alpha']) - else: - raise ValueError(f'Unknown loss function: {config["name"]}') + loss_name = config['name'] + + for loss_class in LossFunction.__subclasses__(): + if loss_class.__name__ == loss_name: + constructor_params = {k: v for k, v in config.items() if k != 'name'} + return loss_class(**constructor_params) @staticmethod def from_name(name: str) -> "LossFunction": + aliases = { + "mse": "MeanSquaredError", + "bce": "BinaryCrossentropy", + "cce": "CategoricalCrossentropy", + "scce": "SparseCategoricalCrossentropy", + "mae": "MeanAbsoluteError", + "kld": "KullbackLeiblerDivergence", + "cels": "CrossEntropyWithLabelSmoothing", + "wass": "Wasserstein", + "focal": "FocalLoss", + "fl": "FocalLoss" + } + + original_name = name name = name.lower().replace("_", "") - if name == "mse" or name == "meansquarederror": - return MeanSquaredError() - elif name == "bce" or name == "binarycrossentropy": - return BinaryCrossentropy() - elif name == "cce" or name == "categorycrossentropy": - return CategoricalCrossentropy() - elif name == "scce" or name == "sparsecategoricalcrossentropy": - return SparseCategoricalCrossentropy() - elif name == "mae" or name == "meanabsoluteerror": - return MeanAbsoluteError() - elif name == "kld" or name == "kullbackleiblerdivergence": - return KullbackLeiblerDivergence() - elif name == "crossentropywithlabelsmoothing" or name == "cels": - return CrossEntropyWithLabelSmoothing() - elif name == "Wasserstein" or name == "wasserstein" or name == "wass": - return Wasserstein() - elif name == "focalloss" or name == "focal" or name == "fl": - return FocalLoss() - elif name.startswith("huber") and len(name.split("_")) == 2: - delta = float(name.split("_")[-1]) - return HuberLoss(delta) - else: - for subclass in LossFunction.__subclasses__(): - if subclass.__name__.lower() == name: - return subclass() - - raise ValueError(f"No loss function found for the name: {name}") + + if name.startswith("huber") and len(original_name.split("_")) == 2: + try: + delta = float(original_name.split("_")[-1]) + return Huber(delta=delta) + except ValueError: + pass + + if name in aliases: + name = aliases[name] + + for loss_class in LossFunction.__subclasses__(): + if loss_class.__name__.lower() == name or loss_class.__name__ == name: + if loss_class.__name__ == "Huber": + return loss_class(delta=1.0) + elif loss_class.__name__ == "CrossEntropyWithLabelSmoothing": + return loss_class(label_smoothing=0.1) + elif loss_class.__name__ == "FocalLoss": + return loss_class(gamma=2.0, alpha=0.25) + else: + return loss_class() + + raise ValueError(f"No loss function found for the name: {original_name}") class MeanSquaredError(LossFunction): @@ -142,7 +138,7 @@ def __str__(self): return "MeanAbsoluteError" -class HuberLoss(LossFunction): +class Huber(LossFunction): def __init__(self, delta: float = 1.0): super().__init__() self.delta = delta diff --git a/neuralnetlib/optimizers.py b/neuralnetlib/optimizers.py index a941c83..07ca802 100644 --- a/neuralnetlib/optimizers.py +++ b/neuralnetlib/optimizers.py @@ -16,16 +16,14 @@ def get_config(self) -> dict: @staticmethod def from_config(config: dict): - if config['name'] == 'SGD': - return SGD.from_config(config) - elif config['name'] == 'Momentum': - return Momentum.from_config(config) - elif config['name'] == 'RMSprop': - return RMSprop.from_config(config) - elif config['name'] == 'Adam': - return Adam.from_config(config) - else: - raise ValueError(f"Unknown optimizer name: {config['name']}") + optimizer_name = config['name'] + + for optimizer_class in Optimizer.__subclasses__(): + if optimizer_class.__name__ == optimizer_name: + constructor_params = {k: v for k, v in config.items() if k != 'name'} + return optimizer_class(**constructor_params) + + raise ValueError(f"No optimizer found for the name: {optimizer_name}") @staticmethod def from_name(name: str) -> "Optimizer":