diff --git a/conftest.py b/conftest.py index 5c27d947c13..d71de30a616 100644 --- a/conftest.py +++ b/conftest.py @@ -15,6 +15,9 @@ import pytest # noqa: E402 from keras.src.backend import backend # noqa: E402 +from keras.src.saving.object_registration import ( # noqa: E402 + get_custom_objects, +) def pytest_configure(config): @@ -32,3 +35,9 @@ def pytest_collection_modifyitems(config, items): for item in items: if "requires_trainable_backend" in item.keywords: item.add_marker(requires_trainable_backend) + + +# Ensure each test is run in isolation regarding the custom objects dict +@pytest.fixture(autouse=True) +def reset_custom_objects_global_dictionary(request): + get_custom_objects().clear() diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 29d062bcd8a..cd06e154616 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1,5 +1,7 @@ import pickle +import sys +import cloudpickle import numpy as np import pytest from absl.testing import parameterized @@ -14,26 +16,6 @@ from keras.src.saving.object_registration import register_keras_serializable -@pytest.fixture -def my_custom_dense(): - @register_keras_serializable(package="MyLayers", name="CustomDense") - class CustomDense(layers.Layer): - def __init__(self, units, **kwargs): - super().__init__(**kwargs) - self.units = units - self.dense = layers.Dense(units) - - def call(self, x): - return self.dense(x) - - def get_config(self): - config = super().get_config() - config.update({"units": self.units}) - return config - - return CustomDense - - def _get_model(): input_a = Input(shape=(3,), batch_size=2, name="input_a") input_b = Input(shape=(3,), batch_size=2, name="input_b") @@ -89,11 +71,15 @@ def _get_model_multi_outputs_dict(): return model -def _get_model_custom_layer(): - x = Input(shape=(3,), name="input_a") - output_a = my_custom_dense()(10, name="output_a")(x) - model = Model(x, output_a) - return model +@pytest.fixture +def fake_main_module(request, monkeypatch): + original_main = sys.modules["__main__"] + + def restore_main_module(): + sys.modules["__main__"] = original_main + + request.addfinalizer(restore_main_module) + sys.modules["__main__"] = sys.modules[__name__] @pytest.mark.requires_trainable_backend @@ -155,7 +141,6 @@ def call(self, x): ("single_list_output_2", _get_model_single_output_list), ("single_list_output_3", _get_model_single_output_list), ("single_list_output_4", _get_model_single_output_list), - ("custom_layer", _get_model_custom_layer), ) def test_functional_pickling(self, model_fn): model = model_fn() @@ -170,6 +155,45 @@ def test_functional_pickling(self, model_fn): self.assertAllClose(np.array(pred_reloaded), np.array(pred)) + # Fake the __main__ module because cloudpickle only serializes + # functions & classes if they are defined in the __main__ module. + @pytest.mark.usefixtures("fake_main_module") + def test_functional_pickling_custom_layer(self): + @register_keras_serializable() + class CustomDense(layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.units = units + self.dense = layers.Dense(units) + + def call(self, x): + return self.dense(x) + + def get_config(self): + config = super().get_config() + config.update({"units": self.units}) + return config + + x = Input(shape=(3,), name="input_a") + output_a = CustomDense(10, name="output_a")(x) + model = Model(x, output_a) + + self.assertIsInstance(model, Functional) + model.compile() + x = np.random.rand(8, 3) + + dumped_pickle = cloudpickle.dumps(model) + + # Verify that we can load the dumped pickle even if the custom object + # is not available in the loading environment. + del CustomDense + reloaded_pickle = cloudpickle.loads(dumped_pickle) + + pred_reloaded = reloaded_pickle.predict(x) + pred = model.predict(x) + + self.assertAllClose(np.array(pred_reloaded), np.array(pred)) + @parameterized.named_parameters( ("single_output_1", _get_model_single_output, None), ("single_output_2", _get_model_single_output, "list"), diff --git a/keras/src/saving/keras_saveable.py b/keras/src/saving/keras_saveable.py index eb85231cf8a..e122a4fa8d3 100644 --- a/keras/src/saving/keras_saveable.py +++ b/keras/src/saving/keras_saveable.py @@ -1,5 +1,4 @@ import io -import pickle from keras.src.saving.object_registration import get_custom_objects @@ -17,28 +16,24 @@ def _obj_type(self): ) @classmethod - def _unpickle_model(cls, model_buf, *args): + def _unpickle_model(cls, data): import keras.src.saving.saving_lib as saving_lib # pickle is not safe regardless of what you do. - if len(args) == 0: - return saving_lib._load_model_from_fileobj( - model_buf, - custom_objects=None, - compile=True, - safe_mode=False, - ) + if "custom_objects_buf" in data.keys(): + import pickle + custom_objects = pickle.load(data["custom_objects_buf"]) else: - custom_objects_buf = args[0] - custom_objects = pickle.load(custom_objects_buf) - return saving_lib._load_model_from_fileobj( - model_buf, - custom_objects=custom_objects, - compile=True, - safe_mode=False, - ) + custom_objects = None + + return saving_lib._load_model_from_fileobj( + data["model_buf"], + custom_objects=custom_objects, + compile=True, + safe_mode=False, + ) def __reduce__(self): """__reduce__ is used to customize the behavior of `pickle.pickle()`. @@ -48,14 +43,23 @@ def __reduce__(self): keras saving library.""" import keras.src.saving.saving_lib as saving_lib + data = {} + model_buf = io.BytesIO() saving_lib._save_model_to_fileobj(self, model_buf, "h5") + data["model_buf"] = model_buf + + try: + import cloudpickle - custom_objects_buf = io.BytesIO() - pickle.dump(get_custom_objects(), custom_objects_buf) - custom_objects_buf.seek(0) + custom_objects_buf = io.BytesIO() + cloudpickle.dump(get_custom_objects(), custom_objects_buf) + custom_objects_buf.seek(0) + data["custom_objects_buf"] = custom_objects_buf + except ImportError: + pass return ( self._unpickle_model, - (model_buf, custom_objects_buf), + (data,), )