Skip to content

Commit

Permalink
add keras3 regression test
Browse files Browse the repository at this point in the history
  • Loading branch information
charlie-becker committed Feb 28, 2024
1 parent 751ee85 commit 3627d95
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions mlguess/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from mlguess.keras.models import BaseRegressor as RegressorDNN
from mlguess.keras.models import GaussianRegressorDNN
from mlguess.keras.models import EvidentialRegressorDNN
from mlguess.keras.models import CategoricalDNN_keras3
from mlguess.keras.losses import DirichletEvidentialLoss, EvidentialCatLoss
from mlguess.keras.models import CategoricalDNN_keras3, EvidentialRegressorDNN_keras3
from mlguess.keras.losses import DirichletEvidentialLoss, EvidentialCatLoss, EvidentialRegLoss
from keras.models import load_model

class TestModels(unittest.TestCase):
Expand Down Expand Up @@ -100,5 +100,18 @@ def test_evi_cat(self):
model.save("test_model2.keras")
load_model("test_model2.keras")

def test_evi_reg(self):

x_train = np.random.random(size=(10000, 10)).astype('float32')
y_train = np.random.random(size=(10000, 1)).astype('float32')
model = EvidentialRegressorDNN_keras3(hidden_layers=2)
model.compile(loss=EvidentialRegLoss(0.01), optimizer="adam")
model.fit(x_train, y_train)
model.save("test_model3.keras")
load_model("test_model3.keras")




if __name__ == "__main__":
unittest.main()

0 comments on commit 3627d95

Please sign in to comment.