Skip to content

Commit

Permalink
Add model constructor feedforward_symmetric
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Parmann authored and epa095 committed Feb 25, 2019
1 parent d89fbff commit 9001ad5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
29 changes: 29 additions & 0 deletions gordo_components/model/factories/feedforward_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,32 @@ def feedforward_model(

model.compile(optimizer="adam", loss="mean_squared_error", metrics=["accuracy"])
return model


@register_model_builder(type="KerasAutoEncoder")
@register_model_builder(type="KerasBaseEstimator")
def feedforward_symmetric(
n_features: int, dims: List[int], funcs: List[str], **kwargs
) -> keras.models.Sequential:
"""
Builds a symmetrical feedforward model
Parameters:
----------
n_features: int
Number of input and output neurons
dim: List[int]
Number of neurons per layers for the encoder, reversed for the decoder.
Must have len > 0
funcs: List[str]
Activation functions for the internal layers
Returns:
-------
keras.models.Sequential
"""
if len(dims) == 0:
raise ValueError("Parameter dims must have len > 0")
return feedforward_model(n_features, dims, dims[::-1], funcs, funcs[::-1], **kwargs)

2 changes: 1 addition & 1 deletion gordo_components/serializer/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def dumps(model: GordoBase) -> bytes:
>>> from gordo_components.model.models import KerasAutoEncoder
>>> from gordo_components import serializer
>>> model = KerasAutoEncoder('feedforward_model')
>>> model = KerasAutoEncoder('feedforward_symmetric')
>>> serialized = serializer.dumps(model)
>>> assert isinstance(serialized, bytes)
>>> model_clone = serializer.loads(serialized)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_feedforward_autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-

import unittest
from unittest import mock

from gordo_components.model.factories.feedforward_autoencoder import (
feedforward_symmetric,
)


def feedforward_model_mocker(n_features: int, enc_dim, dec_dim, enc_func, dec_func):
return n_features, enc_dim, dec_dim, enc_func, dec_func


class FeedForwardAutoEncoderTestCase(unittest.TestCase):
@mock.patch(
"gordo_components.model.factories.feedforward_autoencoder.feedforward_model",
side_effect=feedforward_model_mocker,
)
def test_feedforward_symmetric_basic(self, _):
"""
Test that feedforward_symmetric calls feedforward_model correctly
"""
n_features, enc_dim, dec_dim, enc_func, dec_func = feedforward_symmetric(
5, [4, 3, 2, 1], ["relu", "relu", "tanh", "tanh"]
)
self.assertEqual(n_features, 5)
self.assertEqual(enc_dim, [4, 3, 2, 1])
self.assertEqual(dec_dim, [1, 2, 3, 4])
self.assertEqual(enc_func, ["relu", "relu", "tanh", "tanh"])
self.assertEqual(dec_func, ["tanh", "tanh", "relu", "relu"])

def test_feedforward_symmetric_checks_dims(self):
"""
Test that feedforward_symmetric validates parameter requirements
"""
with self.assertRaises(ValueError):
feedforward_symmetric(4, [], [])

0 comments on commit 9001ad5

Please sign in to comment.