Skip to content

Commit

Permalink
fix(optimizer-losses): generic from_config
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 3, 2024
1 parent d3dfa1d commit e2678c4
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 58 deletions.
92 changes: 44 additions & 48 deletions neuralnetlib/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,53 @@ def get_config(self) -> dict:

@staticmethod
def from_config(config: dict) -> 'LossFunction':
if config['name'] == 'MeanSquaredError':
return MeanSquaredError()
elif config['name'] == 'BinaryCrossentropy':
return BinaryCrossentropy()
elif config['name'] == 'CategoricalCrossentropy':
return CategoricalCrossentropy()
elif config['name'] == 'MeanAbsoluteError':
return MeanAbsoluteError()
elif config['name'] == 'HuberLoss':
return HuberLoss(config['delta'])
elif config['name'] == 'KullbackLeiblerDivergence':
return KullbackLeiblerDivergence()
elif config['name'] == 'CrossEntropyWithLabelSmoothing':
return CrossEntropyWithLabelSmoothing(config['label_smoothing'])
elif config['name'] == 'Wasserstein':
return Wasserstein()
elif config['name'] == 'FocalLoss':
return FocalLoss(config['gamma'], config['alpha'])
else:
raise ValueError(f'Unknown loss function: {config["name"]}')
loss_name = config['name']

for loss_class in LossFunction.__subclasses__():
if loss_class.__name__ == loss_name:
constructor_params = {k: v for k, v in config.items() if k != 'name'}
return loss_class(**constructor_params)

@staticmethod
def from_name(name: str) -> "LossFunction":
aliases = {
"mse": "MeanSquaredError",
"bce": "BinaryCrossentropy",
"cce": "CategoricalCrossentropy",
"scce": "SparseCategoricalCrossentropy",
"mae": "MeanAbsoluteError",
"kld": "KullbackLeiblerDivergence",
"cels": "CrossEntropyWithLabelSmoothing",
"wass": "Wasserstein",
"focal": "FocalLoss",
"fl": "FocalLoss"
}

original_name = name
name = name.lower().replace("_", "")
if name == "mse" or name == "meansquarederror":
return MeanSquaredError()
elif name == "bce" or name == "binarycrossentropy":
return BinaryCrossentropy()
elif name == "cce" or name == "categorycrossentropy":
return CategoricalCrossentropy()
elif name == "scce" or name == "sparsecategoricalcrossentropy":
return SparseCategoricalCrossentropy()
elif name == "mae" or name == "meanabsoluteerror":
return MeanAbsoluteError()
elif name == "kld" or name == "kullbackleiblerdivergence":
return KullbackLeiblerDivergence()
elif name == "crossentropywithlabelsmoothing" or name == "cels":
return CrossEntropyWithLabelSmoothing()
elif name == "Wasserstein" or name == "wasserstein" or name == "wass":
return Wasserstein()
elif name == "focalloss" or name == "focal" or name == "fl":
return FocalLoss()
elif name.startswith("huber") and len(name.split("_")) == 2:
delta = float(name.split("_")[-1])
return HuberLoss(delta)
else:
for subclass in LossFunction.__subclasses__():
if subclass.__name__.lower() == name:
return subclass()

raise ValueError(f"No loss function found for the name: {name}")

if name.startswith("huber") and len(original_name.split("_")) == 2:
try:
delta = float(original_name.split("_")[-1])
return Huber(delta=delta)
except ValueError:
pass

if name in aliases:
name = aliases[name]

for loss_class in LossFunction.__subclasses__():
if loss_class.__name__.lower() == name or loss_class.__name__ == name:
if loss_class.__name__ == "Huber":
return loss_class(delta=1.0)
elif loss_class.__name__ == "CrossEntropyWithLabelSmoothing":
return loss_class(label_smoothing=0.1)
elif loss_class.__name__ == "FocalLoss":
return loss_class(gamma=2.0, alpha=0.25)
else:
return loss_class()

raise ValueError(f"No loss function found for the name: {original_name}")


class MeanSquaredError(LossFunction):
Expand Down Expand Up @@ -142,7 +138,7 @@ def __str__(self):
return "MeanAbsoluteError"


class HuberLoss(LossFunction):
class Huber(LossFunction):
def __init__(self, delta: float = 1.0):
super().__init__()
self.delta = delta
Expand Down
18 changes: 8 additions & 10 deletions neuralnetlib/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ def get_config(self) -> dict:

@staticmethod
def from_config(config: dict):
if config['name'] == 'SGD':
return SGD.from_config(config)
elif config['name'] == 'Momentum':
return Momentum.from_config(config)
elif config['name'] == 'RMSprop':
return RMSprop.from_config(config)
elif config['name'] == 'Adam':
return Adam.from_config(config)
else:
raise ValueError(f"Unknown optimizer name: {config['name']}")
optimizer_name = config['name']

for optimizer_class in Optimizer.__subclasses__():
if optimizer_class.__name__ == optimizer_name:
constructor_params = {k: v for k, v in config.items() if k != 'name'}
return optimizer_class(**constructor_params)

raise ValueError(f"No optimizer found for the name: {optimizer_name}")

@staticmethod
def from_name(name: str) -> "Optimizer":
Expand Down

0 comments on commit e2678c4

Please sign in to comment.