diff --git a/gordo_components/model/factories/feedforward_autoencoder.py b/gordo_components/model/factories/feedforward_autoencoder.py index 660535b79..8d384b840 100644 --- a/gordo_components/model/factories/feedforward_autoencoder.py +++ b/gordo_components/model/factories/feedforward_autoencoder.py @@ -70,3 +70,32 @@ def feedforward_model( model.compile(optimizer="adam", loss="mean_squared_error", metrics=["accuracy"]) return model + + +@register_model_builder(type="KerasAutoEncoder") +@register_model_builder(type="KerasBaseEstimator") +def feedforward_symmetric( + n_features: int, dims: List[int], funcs: List[str], **kwargs +) -> keras.models.Sequential: + """ + Builds a symmetrical feedforward model + + Parameters: + ---------- + n_features: int + Number of input and output neurons + dim: List[int] + Number of neurons per layers for the encoder, reversed for the decoder. + Must have len > 0 + funcs: List[str] + Activation functions for the internal layers + + Returns: + ------- + keras.models.Sequential + + """ + if len(dims) == 0: + raise ValueError("Parameter dims must have len > 0") + return feedforward_model(n_features, dims, dims[::-1], funcs, funcs[::-1], **kwargs) + diff --git a/gordo_components/serializer/serializer.py b/gordo_components/serializer/serializer.py index fa07a0315..f383bd213 100644 --- a/gordo_components/serializer/serializer.py +++ b/gordo_components/serializer/serializer.py @@ -35,7 +35,7 @@ def dumps(model: GordoBase) -> bytes: >>> from gordo_components.model.models import KerasAutoEncoder >>> from gordo_components import serializer - >>> model = KerasAutoEncoder('feedforward_model') + >>> model = KerasAutoEncoder('feedforward_symmetric') >>> serialized = serializer.dumps(model) >>> assert isinstance(serialized, bytes) >>> model_clone = serializer.loads(serialized) diff --git a/tests/test_feedforward_autoencoder.py b/tests/test_feedforward_autoencoder.py new file mode 100644 index 000000000..5a8f6fc4d --- /dev/null +++ b/tests/test_feedforward_autoencoder.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +import unittest +from unittest import mock + +from gordo_components.model.factories.feedforward_autoencoder import ( + feedforward_symmetric, +) + + +def feedforward_model_mocker(n_features: int, enc_dim, dec_dim, enc_func, dec_func): + return n_features, enc_dim, dec_dim, enc_func, dec_func + + +class FeedForwardAutoEncoderTestCase(unittest.TestCase): + @mock.patch( + "gordo_components.model.factories.feedforward_autoencoder.feedforward_model", + side_effect=feedforward_model_mocker, + ) + def test_feedforward_symmetric_basic(self, _): + """ + Test that feedforward_symmetric calls feedforward_model correctly + """ + n_features, enc_dim, dec_dim, enc_func, dec_func = feedforward_symmetric( + 5, [4, 3, 2, 1], ["relu", "relu", "tanh", "tanh"] + ) + self.assertEqual(n_features, 5) + self.assertEqual(enc_dim, [4, 3, 2, 1]) + self.assertEqual(dec_dim, [1, 2, 3, 4]) + self.assertEqual(enc_func, ["relu", "relu", "tanh", "tanh"]) + self.assertEqual(dec_func, ["tanh", "tanh", "relu", "relu"]) + + def test_feedforward_symmetric_checks_dims(self): + """ + Test that feedforward_symmetric validates parameter requirements + """ + with self.assertRaises(ValueError): + feedforward_symmetric(4, [], [])