Skip to content

Commit

Permalink
fix(optimizers-losses-activations): from_config
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 3, 2024
1 parent d3dfa1d commit 90a9194
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 80 deletions.
27 changes: 7 additions & 20 deletions neuralnetlib/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,13 @@ def get_config(self) -> dict:
@staticmethod
def from_config(config: dict):
name = config['name']
if name == 'Sigmoid':
return Sigmoid()
elif name == 'ReLU':
return ReLU()
elif name == 'Tanh':
return Tanh()
elif name == 'Softmax':
return Softmax()
elif name == 'Linear':
return Linear()
elif name == 'LeakyReLU':
return LeakyReLU(alpha=config['alpha'])
elif name == 'ELU':
return ELU()
elif name == 'SELU':
return SELU(alpha=config['alpha'], scale=config['scale'])
elif name == 'GELU':
return GELU()
else:
raise ValueError(f'Unknown activation function: {name}')

for activation_class in ActivationFunction.__subclasses__():
if activation_class.__name__ == name:
constructor_params = {k: v for k, v in config.items() if k != 'name'}
return activation_class(**constructor_params)

raise ValueError(f'Unknown activation function: {name}')


class Sigmoid(ActivationFunction):
Expand Down
92 changes: 44 additions & 48 deletions neuralnetlib/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,53 @@ def get_config(self) -> dict:

@staticmethod
def from_config(config: dict) -> 'LossFunction':
if config['name'] == 'MeanSquaredError':
return MeanSquaredError()
elif config['name'] == 'BinaryCrossentropy':
return BinaryCrossentropy()
elif config['name'] == 'CategoricalCrossentropy':
return CategoricalCrossentropy()
elif config['name'] == 'MeanAbsoluteError':
return MeanAbsoluteError()
elif config['name'] == 'HuberLoss':
return HuberLoss(config['delta'])
elif config['name'] == 'KullbackLeiblerDivergence':
return KullbackLeiblerDivergence()
elif config['name'] == 'CrossEntropyWithLabelSmoothing':
return CrossEntropyWithLabelSmoothing(config['label_smoothing'])
elif config['name'] == 'Wasserstein':
return Wasserstein()
elif config['name'] == 'FocalLoss':
return FocalLoss(config['gamma'], config['alpha'])
else:
raise ValueError(f'Unknown loss function: {config["name"]}')
loss_name = config['name']

for loss_class in LossFunction.__subclasses__():
if loss_class.__name__ == loss_name:
constructor_params = {k: v for k, v in config.items() if k != 'name'}
return loss_class(**constructor_params)

@staticmethod
def from_name(name: str) -> "LossFunction":
aliases = {
"mse": "MeanSquaredError",
"bce": "BinaryCrossentropy",
"cce": "CategoricalCrossentropy",
"scce": "SparseCategoricalCrossentropy",
"mae": "MeanAbsoluteError",
"kld": "KullbackLeiblerDivergence",
"cels": "CrossEntropyWithLabelSmoothing",
"wass": "Wasserstein",
"focal": "FocalLoss",
"fl": "FocalLoss"
}

original_name = name
name = name.lower().replace("_", "")
if name == "mse" or name == "meansquarederror":
return MeanSquaredError()
elif name == "bce" or name == "binarycrossentropy":
return BinaryCrossentropy()
elif name == "cce" or name == "categorycrossentropy":
return CategoricalCrossentropy()
elif name == "scce" or name == "sparsecategoricalcrossentropy":
return SparseCategoricalCrossentropy()
elif name == "mae" or name == "meanabsoluteerror":
return MeanAbsoluteError()
elif name == "kld" or name == "kullbackleiblerdivergence":
return KullbackLeiblerDivergence()
elif name == "crossentropywithlabelsmoothing" or name == "cels":
return CrossEntropyWithLabelSmoothing()
elif name == "Wasserstein" or name == "wasserstein" or name == "wass":
return Wasserstein()
elif name == "focalloss" or name == "focal" or name == "fl":
return FocalLoss()
elif name.startswith("huber") and len(name.split("_")) == 2:
delta = float(name.split("_")[-1])
return HuberLoss(delta)
else:
for subclass in LossFunction.__subclasses__():
if subclass.__name__.lower() == name:
return subclass()

raise ValueError(f"No loss function found for the name: {name}")

if name.startswith("huber") and len(original_name.split("_")) == 2:
try:
delta = float(original_name.split("_")[-1])
return Huber(delta=delta)
except ValueError:
pass

if name in aliases:
name = aliases[name]

for loss_class in LossFunction.__subclasses__():
if loss_class.__name__.lower() == name or loss_class.__name__ == name:
if loss_class.__name__ == "Huber":
return loss_class(delta=1.0)
elif loss_class.__name__ == "CrossEntropyWithLabelSmoothing":
return loss_class(label_smoothing=0.1)
elif loss_class.__name__ == "FocalLoss":
return loss_class(gamma=2.0, alpha=0.25)
else:
return loss_class()

raise ValueError(f"No loss function found for the name: {original_name}")


class MeanSquaredError(LossFunction):
Expand Down Expand Up @@ -142,7 +138,7 @@ def __str__(self):
return "MeanAbsoluteError"


class HuberLoss(LossFunction):
class Huber(LossFunction):
def __init__(self, delta: float = 1.0):
super().__init__()
self.delta = delta
Expand Down
18 changes: 8 additions & 10 deletions neuralnetlib/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ def get_config(self) -> dict:

@staticmethod
def from_config(config: dict):
if config['name'] == 'SGD':
return SGD.from_config(config)
elif config['name'] == 'Momentum':
return Momentum.from_config(config)
elif config['name'] == 'RMSprop':
return RMSprop.from_config(config)
elif config['name'] == 'Adam':
return Adam.from_config(config)
else:
raise ValueError(f"Unknown optimizer name: {config['name']}")
optimizer_name = config['name']

for optimizer_class in Optimizer.__subclasses__():
if optimizer_class.__name__ == optimizer_name:
constructor_params = {k: v for k, v in config.items() if k != 'name'}
return optimizer_class(**constructor_params)

raise ValueError(f"No optimizer found for the name: {optimizer_name}")

@staticmethod
def from_name(name: str) -> "Optimizer":
Expand Down
4 changes: 2 additions & 2 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from neuralnetlib.losses import MeanSquaredError, BinaryCrossentropy, CategoricalCrossentropy, MeanAbsoluteError, \
HuberLoss
Huber


class TestLossFunctions(unittest.TestCase):
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_mean_absolute_error(self):
self.assertAlmostEqual(calculated_loss, expected_loss)

def test_huber_loss(self):
huber = HuberLoss(delta=1.0)
huber = Huber(delta=1.0)
y_true = np.array([1, 2, 3])
y_pred = np.array([1, 2, 4])
error = y_true - y_pred
Expand Down

0 comments on commit 90a9194

Please sign in to comment.