diff --git a/gordo/machine/model/models.py b/gordo/machine/model/models.py index d68303d1f..41b491047 100644 --- a/gordo/machine/model/models.py +++ b/gordo/machine/model/models.py @@ -315,6 +315,10 @@ def get_params(self, **params): params = super().get_params(**params) params.update({"kind": self.kind}) params.update(self.kwargs) + if self.kwargs.get("callbacks") is not None and any( + isinstance(callback, dict) for callback in self.kwargs["callbacks"] + ): + params["callbacks"] = serializer.build_callbacks(self.kwargs["callbacks"]) return params def _prepare_model(self): diff --git a/gordo/serializer/__init__.py b/gordo/serializer/__init__.py index 6749d1f4b..fb5305d26 100644 --- a/gordo/serializer/__init__.py +++ b/gordo/serializer/__init__.py @@ -1,4 +1,8 @@ -from .from_definition import from_definition, load_params_from_definition +from .from_definition import ( + from_definition, + load_params_from_definition, + build_callbacks, +) from .into_definition import into_definition, load_definition_from_params from .serializer import ( dump, diff --git a/gordo/serializer/from_definition.py b/gordo/serializer/from_definition.py index 2db482d2d..8a0610c61 100644 --- a/gordo/serializer/from_definition.py +++ b/gordo/serializer/from_definition.py @@ -248,28 +248,6 @@ def _build_step( ) -def _build_callbacks(definitions: list): - """ - Parameters - ---------- - definitions - List of callbacks definitions - - Examples - -------- - >>> callbacks=_build_callbacks([{'tensorflow.keras.callbacks.EarlyStopping': {'monitor': 'val_loss,', 'patience': 10}}]) - >>> type(callbacks[0]) - - - Returns - ------- - """ - callbacks = [] - for callback in definitions: - callbacks.append(_build_step(callback)) - return callbacks - - def _load_param_classes(params: dict): """ Inspect the params' values and determine if any can be loaded as a class. @@ -350,7 +328,7 @@ def _load_param_classes(params: dict): kwargs = _load_param_classes(sub_params) params[key] = create_instance(Model, **kwargs) # type: ignore elif key == "callbacks" and isinstance(value, list): - params[key] = _build_callbacks(value) + params[key] = build_callbacks(value) return params @@ -367,3 +345,25 @@ def load_params_from_definition(definition: dict) -> dict: "Expected definition to be a dict," f"found: {type(definition)}" ) return _load_param_classes(definition) + + +def build_callbacks(definitions: list): + """ + Parameters + ---------- + definitions + List of callbacks definitions + + Examples + -------- + >>> callbacks=build_callbacks([{'tensorflow.keras.callbacks.EarlyStopping': {'monitor': 'val_loss,', 'patience': 10}}]) + >>> type(callbacks[0]) + + + Returns + ------- + """ + callbacks = [] + for callback in definitions: + callbacks.append(_build_step(callback)) + return callbacks diff --git a/tests/gordo/machine/model/test_model.py b/tests/gordo/machine/model/test_model.py index 2bcb96abe..82cc01555 100644 --- a/tests/gordo/machine/model/test_model.py +++ b/tests/gordo/machine/model/test_model.py @@ -357,6 +357,8 @@ def test_keras_autoencoder_fits_callbacks(): assert isinstance(first_callback, EarlyStopping) assert first_callback.monitor == "val_loss" assert first_callback.patience == 10 + X, y = np.random.rand(10, 10), np.random.rand(10, 10) + model.fit(X, y) def test_parse_module_path():