Skip to content

Commit

Permalink
Custom objects support when pickling keras models
Browse files Browse the repository at this point in the history
  • Loading branch information
mthiboust committed Jun 18, 2024
1 parent f6cf6a0 commit f4865cc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 7 deletions.
29 changes: 29 additions & 0 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,27 @@
from keras.src.models.functional import Functional
from keras.src.models.model import Model
from keras.src.models.model import model_from_json
from keras.src.saving.object_registration import register_keras_serializable


@pytest.fixture
def my_custom_dense():
@register_keras_serializable(package="MyLayers", name="CustomDense")
class CustomDense(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.dense = layers.Dense(units)

def call(self, x):
return self.dense(x)

def get_config(self):
config = super().get_config()
config.update({"units": self.units})
return config

return CustomDense


def _get_model():
Expand Down Expand Up @@ -68,6 +89,13 @@ def _get_model_multi_outputs_dict():
return model


def _get_model_custom_layer():
x = Input(shape=(3,), name="input_a")
output_a = my_custom_dense()(10, name="output_a")(x)
model = Model(x, output_a)
return model


@pytest.mark.requires_trainable_backend
class ModelTest(testing.TestCase, parameterized.TestCase):
def test_functional_rerouting(self):
Expand Down Expand Up @@ -127,6 +155,7 @@ def call(self, x):
("single_list_output_2", _get_model_single_output_list),
("single_list_output_3", _get_model_single_output_list),
("single_list_output_4", _get_model_single_output_list),
("custom_layer", _get_model_custom_layer),
)
def test_functional_pickling(self, model_fn):
model = model_fn()
Expand Down
37 changes: 30 additions & 7 deletions keras/src/saving/keras_saveable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import io
import pickle

from keras.src.saving.object_registration import get_custom_objects


class KerasSaveable:
Expand All @@ -14,13 +17,28 @@ def _obj_type(self):
)

@classmethod
def _unpickle_model(cls, bytesio):
def _unpickle_model(cls, model_buf, *args):
import keras.src.saving.saving_lib as saving_lib

# pickle is not safe regardless of what you do.
return saving_lib._load_model_from_fileobj(
bytesio, custom_objects=None, compile=True, safe_mode=False
)

if len(args) == 0:
return saving_lib._load_model_from_fileobj(
model_buf,
custom_objects=None,
compile=True,
safe_mode=False,
)

else:
custom_objects_buf = args[0]
custom_objects = pickle.load(custom_objects_buf)
return saving_lib._load_model_from_fileobj(
model_buf,
custom_objects=custom_objects,
compile=True,
safe_mode=False,
)

def __reduce__(self):
"""__reduce__ is used to customize the behavior of `pickle.pickle()`.
Expand All @@ -30,9 +48,14 @@ def __reduce__(self):
keras saving library."""
import keras.src.saving.saving_lib as saving_lib

buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, buf, "h5")
model_buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, model_buf, "h5")

custom_objects_buf = io.BytesIO()
pickle.dump(get_custom_objects(), custom_objects_buf)
custom_objects_buf.seek(0)

return (
self._unpickle_model,
(buf,),
(model_buf, custom_objects_buf),
)

0 comments on commit f4865cc

Please sign in to comment.