diff --git a/smt/surrogate_models/genn.py b/smt/surrogate_models/genn.py index adb6ebb37..532c8b23f 100644 --- a/smt/surrogate_models/genn.py +++ b/smt/surrogate_models/genn.py @@ -12,6 +12,7 @@ from jenn.model import NeuralNet from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistence # The missing type SMTrainingPoints = Dict[Union[int, None], Dict[int, List[np.ndarray]]] @@ -212,6 +213,13 @@ def _train(self): ) self.model.fit(X, Y, J, **kwargs) + def save(self, filename): + persistence.save(self, filename) + + @staticmethod + def load(filename): + return persistence.load(filename) + def _predict_values(self, x): return self.model.predict(x.T).T diff --git a/smt/surrogate_models/gpx.py b/smt/surrogate_models/gpx.py index 9d1980506..1b5253b45 100644 --- a/smt/surrogate_models/gpx.py +++ b/smt/surrogate_models/gpx.py @@ -162,14 +162,14 @@ def predict_variance_gradients(self, x): Returns all variance gradients at the given x points as a [n, nx] matrix""" return self._gpx.predict_var_gradients(x) - def save(self, filepath): + def save(self, filename): """Save the trained model in the given filepath Arguments --------- filename (string): path to the json file """ - self._gpx.save(filepath) + self._gpx.save(filename) @staticmethod def load(filepath): diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index c17bcd856..aae65eee8 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -15,6 +15,7 @@ from smt.sampling_methods import LHS from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistence from smt.kernels import ( SquarSinExp, @@ -1870,6 +1871,13 @@ def _predict_variances( s2[s2 < 0.0] = 0.0 return s2 + def save(self, filename): + persistence.save(self, filename) + + @staticmethod + def load(filename): + return persistence.load(filename) + def _predict_variance_derivatives(self, x, kx): """ Provide the derivatives of the variance of the model at a set of points diff --git a/smt/surrogate_models/ls.py b/smt/surrogate_models/ls.py index f33135e50..82c4b6bbb 100644 --- a/smt/surrogate_models/ls.py +++ b/smt/surrogate_models/ls.py @@ -12,6 +12,7 @@ from sklearn import linear_model from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistence from smt.utils.caching import cached_operation @@ -99,3 +100,10 @@ def _predict_derivatives(self, x, kx): n_eval, n_features_x = x.shape y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx] return y + + def save(self, filename): + persistence.save(self, filename) + + @staticmethod + def load(filename): + return persistence.load(filename) diff --git a/smt/surrogate_models/qp.py b/smt/surrogate_models/qp.py index 4eb7fe1e4..badfb78ec 100644 --- a/smt/surrogate_models/qp.py +++ b/smt/surrogate_models/qp.py @@ -12,6 +12,7 @@ import scipy from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistence from smt.utils.caching import cached_operation from smt.utils.misc import standardization @@ -162,3 +163,10 @@ def _predict_values(self, x): y = (self.y_mean + self.y_std * y_).ravel() y = y.reshape((x.shape[0], self.ny)) return y + + def save(self, filename): + persistence.save(self, filename) + + @staticmethod + def load(filename): + return persistence.load(filename) diff --git a/smt/surrogate_models/surrogate_model.py b/smt/surrogate_models/surrogate_model.py index 8513296c9..cecd57a65 100644 --- a/smt/surrogate_models/surrogate_model.py +++ b/smt/surrogate_models/surrogate_model.py @@ -570,3 +570,9 @@ def _check_xdim(self, x): """Raise a ValueError if x dimension is not consistent with surrogate model training data dimension. This method is used as a guard in preamble of predict methods""" check_nx(self.nx, x) + + def save(self, filename): + """ + Implemented by surrogate models to save the surrogate object in a file + """ + raise NotImplementedError("save() has to be implemented by the given surrogate") diff --git a/smt/surrogate_models/tests/test_save_load.py b/smt/surrogate_models/tests/test_save_load.py new file mode 100644 index 000000000..81fe94ce9 --- /dev/null +++ b/smt/surrogate_models/tests/test_save_load.py @@ -0,0 +1,87 @@ +import os +import unittest +import numpy as np + +from smt.problems import Sphere +from smt.sampling_methods import LHS +from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN + + +class TestSaveLoad(unittest.TestCase): + def test_save_load_GEKPLS(self): + filename = "sm_save_test" + fun = Sphere(ndim=2) + + sampling = LHS(xlimits=fun.xlimits, criterion="m") + xt = sampling(20) + yt = fun(xt) + + for i in range(2): + yd = fun(xt, kx=i) + yt = np.concatenate((yt, yd), axis=1) + + X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25) + Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25) + X, Y = np.meshgrid(X, Y) + Z1 = np.zeros((X.shape[0], X.shape[1])) + Z2 = np.zeros((X.shape[0], X.shape[1])) + + sm = GEKPLS(print_global=False) + sm.set_training_values(xt, yt[:, 0]) + for i in range(2): + sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i) + sm.train() + for i in range(X.shape[0]): + for j in range(X.shape[1]): + Z1[i, j] = sm.predict_values( + np.hstack((X[i, j], Y[i, j])).reshape((1, 2)) + ).item() + sm.save(filename) + + sm2 = GEKPLS.load(filename) + self.assertIsNotNone(sm2) + + for i in range(X.shape[0]): + for j in range(X.shape[1]): + Z2[i, j] = sm2.predict_values( + np.hstack((X[i, j], Y[i, j])).reshape((1, 2)) + ).item() + + np.testing.assert_allclose(Z1, Z2) + + os.remove(filename) + + def test_save_load_surrogates(self): + surrogates = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS] + rng = np.random.RandomState(1) + N_inducing = 30 + num = 100 + + xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + x = np.linspace(0.0, 4.0, num).reshape(-1, 1) + + for surrogate in surrogates: + filename = "sm_save_test" + + sm = surrogate(print_global=False) + sm.set_training_values(xt, yt) + + if surrogate == SGP: + sm.Z = 2 * rng.rand(N_inducing, 1) - 1 + sm.set_inducing_inputs(Z=sm.Z) + + sm.train() + y1 = sm.predict_values(x) + sm.save(filename) + + sm2 = surrogate.load(filename) + y2 = sm2.predict_values(x) + + np.testing.assert_allclose(y1, y2) + + os.remove(filename) + + +if __name__ == "__main__": + unittest.main() diff --git a/smt/utils/persistence.py b/smt/utils/persistence.py new file mode 100644 index 000000000..69d1133e5 --- /dev/null +++ b/smt/utils/persistence.py @@ -0,0 +1,13 @@ +import pickle + + +def save(self, filename): + with open(filename, "wb") as file: + pickle.dump(self, file) + + +def load(filename): + sm = None + with open(filename, "rb") as file: + sm = pickle.load(file) + return sm