Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
charlie-becker committed Feb 29, 2024
1 parent 3627d95 commit 3b360d3
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mlguess/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from collections import defaultdict
import logging


class BaseRegressor(object):
"""
A base class for regression models.
Expand Down Expand Up @@ -1390,7 +1389,7 @@ def __init__(
self.balanced_classes = balanced_classes
self.steps_per_epoch = steps_per_epoch
self.outputs = 4
self.current_epoch = keras.Variable(initializer=0, dtype='float32', trainable=False)
self.current_epoch = keras.Variable(initializer=20, dtype='float32', trainable=False)

"""
Create Keras neural network model and compile it.
Expand Down Expand Up @@ -1428,9 +1427,14 @@ def call(self, inputs):
for l in range(1, len(self.model_layers)):
layer_output = self.model_layers[l](layer_output)

self.current_epoch.assign_add(1)
return layer_output

# def fit(self, x, y, epochs):
#
# report_epoch_callback = ReportEpoch()
# self.fit(x, y, epochs=epochs)


def get_config(self):
base_config = super().get_config()
# parameter_config = {hp: getattr(self, hp) for hp in self.hyperparameters}
Expand Down Expand Up @@ -1584,7 +1588,7 @@ def __init__(
dropout_alpha=0.1,
batch_size=128,
epochs=2,
kernel_reg="l22345",
kernel_reg="l2",
l1_weight=0.01,
l2_weight=0.01,
sgd_momentum=0.9,
Expand Down

0 comments on commit 3b360d3

Please sign in to comment.