diff --git a/smt/applications/moe.py b/smt/applications/moe.py index 65ebbe8d1..7057ae386 100644 --- a/smt/applications/moe.py +++ b/smt/applications/moe.py @@ -16,6 +16,7 @@ from smt.applications.application import SurrogateBasedApplication from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistence MOE_EXPERT_NAMES = [ "KRG", @@ -172,6 +173,31 @@ def enabled_experts(self): self._enabled_expert_types = self._get_enabled_expert_types() return list(self._enabled_expert_types.keys()) + def __getstate__(self): + state = self.__dict__.copy() + for expert in self._experts: + if expert == "IDW": + state["idwc"] = None + elif expert == "RMTC": + state["rmtsc"] = None + elif expert == "RMTB": + state["rmtsc"] = None + elif expert == "RBF": + state["rbfc"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + for expert in self._experts: + if expert == "IDW": + expert._setup() + elif expert == "RMTC": + expert._setup() + elif expert == "RMTB": + expert._setup() + elif expert == "RBF": + expert._setup() + def set_training_values(self, xt, yt, name=None): """ Set training data (values). @@ -760,3 +786,10 @@ def _proba_cluster(self, x, distribs=None): ) return probs + + def save(self, filename): + persistence.save(self, filename) + + @staticmethod + def load(filename): + return persistence.load(filename) diff --git a/smt/applications/tests/test_save_load_moe_mfk.py b/smt/applications/tests/test_save_load_moe_mfk.py new file mode 100644 index 000000000..43af239eb --- /dev/null +++ b/smt/applications/tests/test_save_load_moe_mfk.py @@ -0,0 +1,91 @@ +import os +import unittest +import numpy as np + +from smt.applications.mfk import MFK, NestedLHS +from smt.applications.mfkpls import MFKPLS +from smt.applications.mfkplsk import MFKPLSK +from smt.applications.moe import MOE +from smt.sampling_methods import FullFactorial + + +class TestSaveLoad(unittest.TestCase): + def function_test_1d(self, x): + x = np.reshape(x, (-1,)) + y = np.zeros(x.shape) + y[x < 0.4] = x[x < 0.4] ** 2 + y[(x >= 0.4) & (x < 0.8)] = 3 * x[(x >= 0.4) & (x < 0.8)] + 1 + y[x >= 0.8] = np.sin(10 * x[x >= 0.8]) + return y.reshape((-1, 1)) + + def test_save_load_moe(self): + filename = "moe_save_test" + nt = 35 + x = np.linspace(0, 1, 100) + + sampling = FullFactorial(xlimits=np.array([[0, 1]]), clip=True) + np.random.seed(0) + xt = sampling(nt) + yt = self.function_test_1d(xt) + + moe1 = MOE(n_clusters=1) + moe1.set_training_values(xt, yt) + moe1.train() + y_moe1 = moe1.predict_values(x) + + moe1.save(filename) + + moe2 = MOE.load(filename) + y_moe2 = moe2.predict_values(x) + + np.testing.assert_allclose(y_moe1, y_moe2) + + os.remove(filename) + + def lf_function(self, x): + return 0.5 * ((x * 6 - 2) ** 2) * np.sin((x * 6 - 2) * 2) + (x - 0.5) * 10.0 - 5 + + def hf_function(self, x): + return ((x * 6 - 2) ** 2) * np.sin((x * 6 - 2) * 2) + + def _setup_MFKs(self): + xlimits = np.array([[0.0, 1.0]]) + xdoes = NestedLHS(nlevel=2, xlimits=xlimits, random_state=0) + xt_c, xt_e = xdoes(7) + yt_e = self.hf_function(xt_e) + yt_c = self.lf_function(xt_c) + x = np.linspace(0, 1, 101, endpoint=True).reshape(-1, 1) + return (xt_c, xt_e, yt_c, yt_e, x) + + def test_save_load_mfk_mfkpls_mfkplsk(self): + filename = "MFKs_save_test" + MFKs = [MFK, MFKPLS, MFKPLSK] + xt_c, xt_e, yt_c, yt_e, x = self._setup_MFKs() + ncomp = 1 + for mfk in MFKs: + if mfk == MFK: + application = MFK(theta0=xt_e.shape[1] * [1.0], corr="squar_exp") + elif mfk == MFKPLS: + application = MFKPLS(n_comp=ncomp, theta0=ncomp * [1.0]) + else: + application = MFKPLSK(n_comp=ncomp, theta0=ncomp * [1.0]) + + application.set_training_values(xt_c, yt_c, name=0) + application.set_training_values(xt_e, yt_e) + + application.train() + application.save(filename) + + x = np.linspace(0, 1, 101, endpoint=True).reshape(-1, 1) + y1 = application.predict_values(x) + + mfk2 = MFK.load(filename) + y2 = mfk2.predict_values(x) + + np.testing.assert_allclose(y1, y2) + + os.remove(filename) + + +if __name__ == "__main__": + unittest.main() diff --git a/smt/surrogate_models/tests/test_save_load.py b/smt/surrogate_models/tests/test_save_load.py index f1a6d8e49..3785912ef 100644 --- a/smt/surrogate_models/tests/test_save_load.py +++ b/smt/surrogate_models/tests/test_save_load.py @@ -2,6 +2,7 @@ import unittest import numpy as np + from smt.problems import Sphere from smt.sampling_methods import LHS from smt.surrogate_models import ( @@ -24,13 +25,14 @@ class TestSaveLoad(unittest.TestCase): def test_save_load_GEKPLS(self): filename = "sm_save_test" - fun = Sphere(ndim=2) + ndim = 2 + fun = Sphere(ndim=ndim) sampling = LHS(xlimits=fun.xlimits, criterion="m") xt = sampling(20) yt = fun(xt) - for i in range(2): + for i in range(ndim): yd = fun(xt, kx=i) yt = np.concatenate((yt, yd), axis=1) @@ -42,7 +44,7 @@ def test_save_load_GEKPLS(self): sm = GEKPLS(print_global=False) sm.set_training_values(xt, yt[:, 0]) - for i in range(2): + for i in range(ndim): sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i) sm.train() for i in range(X.shape[0]): diff --git a/smt/utils/persistence.py b/smt/utils/persistence.py index 69d1133e5..78ef7c9af 100644 --- a/smt/utils/persistence.py +++ b/smt/utils/persistence.py @@ -1,13 +1,19 @@ import pickle +import zlib def save(self, filename): + serialized_data = pickle.dumps(self, protocol=5) + compressed_data = zlib.compress(serialized_data) with open(filename, "wb") as file: - pickle.dump(self, file) + file.write(compressed_data) def load(filename): - sm = None with open(filename, "rb") as file: - sm = pickle.load(file) + compressed_data = file.read() + + serialized_data = zlib.decompress(compressed_data) + sm = pickle.loads(serialized_data) + return sm