Skip to content

Commit

Permalink
fix(dropout/batchnorm): save&load
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent 47535a8 commit 200a5e3
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions neuralnetlib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,20 +318,23 @@ def get_config(self) -> dict:
'name': self.__class__.__name__,
'rate': self.rate,
'adaptive': self.adaptive,
'min_rate': self.dropout_impl.min_rate if self.adaptive else 0.1,
'max_rate': self.dropout_impl.max_rate if self.adaptive else 0.9,
'temperature': self.dropout_impl.temperature if self.adaptive else 1.0,
'random_state': self.random_state
}

if self.adaptive:
config.update(self.dropout_impl.get_config())

return config

@staticmethod
def from_config(config: dict):
adaptive = config.pop('adaptive', False)
if adaptive:
return Dropout(adaptive=True, **config)
return Dropout(**config)
return Dropout(
rate=config['rate'],
adaptive=config['adaptive'],
min_rate=config['min_rate'],
max_rate=config['max_rate'],
temperature=config['temperature'],
random_state=config['random_state']
)


class Conv2D(Layer):
Expand Down Expand Up @@ -1275,15 +1278,26 @@ def get_config(self) -> dict:
'gamma': self.gamma.tolist() if self.gamma is not None else None,
'beta': self.beta.tolist() if self.beta is not None else None,
'momentum': self.momentum,
'epsilon': self.epsilon
'epsilon': self.epsilon,
'running_mean': self.running_mean.tolist() if self.running_mean is not None else None,
'running_var': self.running_var.tolist() if self.running_var is not None else None,
'input_shape': self.gamma.shape if self.gamma is not None else None
}

@staticmethod
def from_config(config: dict):
layer = BatchNormalization(config['momentum'], config['epsilon'])

if config['gamma'] is not None:
layer.gamma = np.array(config['gamma'])
layer.beta = np.array(config['beta'])

layer.running_mean = np.array(config['running_mean'])
layer.running_var = np.array(config['running_var'])

layer.d_gamma = np.zeros_like(layer.gamma)
layer.d_beta = np.zeros_like(layer.beta)

return layer


Expand Down

0 comments on commit 200a5e3

Please sign in to comment.