Skip to content

Commit

Permalink
Fix bug caught in test
Browse files Browse the repository at this point in the history
  • Loading branch information
charlie-becker committed Feb 29, 2024
1 parent 2031415 commit 6e0d85a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 4 additions & 4 deletions mlguess/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class EvidentialRegressorDNN(keras.models.Model):
def __init__(self, hidden_layers=2, hidden_neurons=64, activation="relu", optimizer="adam", loss_weights=None,
use_noise=False, noise_sd=0.01, lr=0.00001, use_dropout=False, dropout_alpha=0.1, batch_size=128,
epochs=2, kernel_reg=None, l1_weight=0.01, l2_weight=0.01, sgd_momentum=0.9, adam_beta_1=0.9,
adam_beta_2=0.999, epsilon=1e-7, verbose=1, training_var=None, **kwargs):
adam_beta_2=0.999, epsilon=1e-7, verbose=1, training_var=None, n_output_tasks=1, **kwargs):

super().__init__(**kwargs)
self.hidden_layers = hidden_layers
Expand All @@ -197,16 +197,16 @@ def __init__(self, hidden_layers=2, hidden_neurons=64, activation="relu", optimi
self.optimizer_obj = None
self.training_var = training_var
self.epsilon = epsilon
self.n_output_tasks = 1
self.n_output_tasks = n_output_tasks
self.N_OUTPUT_PARAMS = 4
self.hyperparameters = ["hidden_layers", "hidden_neurons", "activation", "training_var",
"optimizer", "sgd_momentum", "adam_beta_1", "adam_beta_2", "epsilon",
"loss_weights", "lr", "kernel_reg", "l1_weight", "l2_weight",
"batch_size", "use_noise", "noise_sd", "use_dropout", "dropout_alpha", "epochs",
"verbose", "n_output_tasks", "epsilon"]

if self.activation == "leaky":
self.activation = LeakyReLU()
# if self.activation == "leaky":
# self.activation = LeakyReLU()
if self.kernel_reg == "l1":
self.kernel_reg = L1(self.l1_weight)
elif self.kernel_reg == "l2":
Expand Down
3 changes: 1 addition & 2 deletions mlguess/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def test_evidential_regression_model(self):
assert p_with_uncertainty.shape[-1] == 3
assert p_without_uncertainty.shape[-1] == 4
model.save("test_evi_regression.keras")
# load_model("test_evi_regression.keras")

load_model("test_evi_regression.keras")

if __name__ == "__main__":
unittest.main()

0 comments on commit 6e0d85a

Please sign in to comment.