diff --git a/smt/surrogate_models/genn.py b/smt/surrogate_models/genn.py index adb6ebb37..2e13b8d90 100644 --- a/smt/surrogate_models/genn.py +++ b/smt/surrogate_models/genn.py @@ -12,6 +12,7 @@ from jenn.model import NeuralNet from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistance # The missing type SMTrainingPoints = Dict[Union[int, None], Dict[int, List[np.ndarray]]] @@ -212,6 +213,13 @@ def _train(self): ) self.model.fit(X, Y, J, **kwargs) + def save(self, filename): + persistance.save(self, filename) + + @staticmethod + def load(filename): + return (persistance.load(filename)) + def _predict_values(self, x): return self.model.predict(x.T).T diff --git a/smt/surrogate_models/gpx.py b/smt/surrogate_models/gpx.py index 6de583fbc..8a2993d9b 100644 --- a/smt/surrogate_models/gpx.py +++ b/smt/surrogate_models/gpx.py @@ -164,33 +164,14 @@ def predict_variance_gradients(self, x): Returns all variance gradients at the given x points as a [n, nx] matrix""" return self._gpx.predict_var_gradients(x) - def _save(self, filename): + def save(self, filename): """Save the trained model in the given filepath Arguments --------- filename (string): path to the json file """ - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - # self._gpx.save(filepath) - - - - # def _load(self, filename): - # if filename is None: - # return ("file is not found") - # else: - # with open(filename, "rb") as file: - # sm2 = pickle.load(file) - # return sm2 + self._gpx.save(filename) @staticmethod diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index 2270a68d9..af5374934 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -8,15 +8,14 @@ import warnings from copy import deepcopy from enum import Enum -from joblib import dump, load import numpy as np -import pickle from scipy import linalg, optimize from scipy.stats import multivariate_normal as m_norm from smt.sampling_methods import LHS from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistance from smt.kernels import ( SquarSinExp, @@ -78,7 +77,6 @@ class KrgBased(SurrogateModel): "act_exp": ActExp, } name = "KrigingBased" - filename = "kriging_model" def _initialize(self): super(KrgBased, self)._initialize() @@ -1872,24 +1870,12 @@ def _predict_variances( s2[s2 < 0.0] = 0.0 return s2 - def _save(self, filename=None): - if filename is None: - filename = self.filename + def save(self, filename): + persistance.save(self, filename) - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 + @staticmethod + def load(filename): + return (persistance.load(filename)) def _predict_variance_derivatives(self, x, kx): """ diff --git a/smt/surrogate_models/ls.py b/smt/surrogate_models/ls.py index bc6c42d14..6e58cbe86 100644 --- a/smt/surrogate_models/ls.py +++ b/smt/surrogate_models/ls.py @@ -13,6 +13,7 @@ from sklearn import linear_model from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistance from smt.utils.caching import cached_operation @@ -102,21 +103,9 @@ def _predict_derivatives(self, x, kx): y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx] return y - def _save(self, filename=None): - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 + def save(self, filename): + persistance.save(self, filename) + + @staticmethod + def load(filename): + return (persistance.load(filename)) diff --git a/smt/surrogate_models/qp.py b/smt/surrogate_models/qp.py index e47ca47be..d56846942 100644 --- a/smt/surrogate_models/qp.py +++ b/smt/surrogate_models/qp.py @@ -13,6 +13,7 @@ import pickle from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistance from smt.utils.caching import cached_operation from smt.utils.misc import standardization @@ -165,21 +166,9 @@ def _predict_values(self, x): y = y.reshape((x.shape[0], self.ny)) return y - def _save(self, filename=None): - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 \ No newline at end of file + def save(self, filename): + persistance.save(self, filename) + + @staticmethod + def load(filename): + return (persistance.load(filename)) \ No newline at end of file diff --git a/smt/surrogate_models/surrogate_model.py b/smt/surrogate_models/surrogate_model.py index 9f90ccc51..cecd57a65 100644 --- a/smt/surrogate_models/surrogate_model.py +++ b/smt/surrogate_models/surrogate_model.py @@ -571,8 +571,8 @@ def _check_xdim(self, x): This method is used as a guard in preamble of predict methods""" check_nx(self.nx, x) - def _save(self, filename): + def save(self, filename): """ Implemented by surrogate models to save the surrogate object in a file """ - pass + raise NotImplementedError("save() has to be implemented by the given surrogate") diff --git a/smt/tests/test_identification_attibuts_krig.py b/smt/tests/test_identification_attibuts_krig.py deleted file mode 100644 index a3cde0436..000000000 --- a/smt/tests/test_identification_attibuts_krig.py +++ /dev/null @@ -1,46 +0,0 @@ -import unittest -import numpy as np -from smt.surrogate_models import KRG - -class TestIdKrig(unittest.TestCase): - - def test_attrib_krig(self): - - xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) - yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) - - sm = KRG() - sm.set_training_values(xt, yt) - sm.train() - nx = sm.nx - nt = sm.nt - design_space = sm.design_space - X_offset = sm.X_offset - X_scale = sm.X_scale - X_norma = sm.X_norma - optimal_par = sm.optimal_par - y_mean = sm.y_mean - y_std = sm.y_std - ny = sm.ny - print(f"nx = ", nx) - print(f"nt = ", nt) - print(f"design space = ", design_space) - - self.assertIsNotNone(nx) - - sm2 = KRG() - sm2.nx = nx - sm2.nt = nt - sm2.design_space = design_space - sm2.X_offset = X_offset - sm2.X_scale = X_scale - sm2.X_norma = X_norma - sm2.optimal_par = optimal_par - sm2.y_mean = y_mean - sm2.y_std = y_std - sm2.ny = ny - sm2.predict_values(np.array([5, 1])) - print - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/smt/tests/test_save_load.py b/smt/tests/test_save_load.py index 19b26d757..1e10d9a9e 100644 --- a/smt/tests/test_save_load.py +++ b/smt/tests/test_save_load.py @@ -1,51 +1,98 @@ import os import unittest import numpy as np -from smt.surrogate_models import KRG -# from smt.surrogate_models import IDW //TODO comprendre pourquoi l"import de IDW génère une erreur -from smt.surrogate_models import LS -from smt.surrogate_models import KPLS -from smt.surrogate_models import GEKPLS -from smt.surrogate_models import GPX -from smt.surrogate_models import KPLSK -from smt.surrogate_models import MGP -from smt.surrogate_models import QP -# from smt.surrogate_models import RBF -# from smt.surrogate_models import RMTS -from smt.surrogate_models import SGP + +from smt.problems import Sphere +from smt.sampling_methods import LHS +from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN, QP class TestSaveLoad(unittest.TestCase): - def test_save_load(self): - - filename = "kriging_save_test" + def test_save_load_GEKPLS(self): - xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) - yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + filename = "sm_save_test" + fun = Sphere(ndim=2) + + sampling = LHS(xlimits=fun.xlimits, criterion="m") + xt = sampling(20) + yt = fun(xt) + + for i in range(2): + yd = fun(xt, kx=i) + yt = np.concatenate((yt, yd), axis=1) - sm = SGP() - sm.set_training_values(xt, yt) + sm = GEKPLS() + sm.set_training_values(xt, yt[:, 0]) + for i in range(2): + sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i) sm.train() - sm._save(filename) + sm.save(filename) self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") file_size = os.path.getsize(filename) print(f"Taille du fichier : {file_size} octets") - sm2 = sm._load(filename) + sm2 = GEKPLS.load(filename) self.assertIsNotNone(sm2) - num = 100 - x = np.linspace(0.0, 4.0, num).reshape(-1, 1) + X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25) + Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25) + X, Y = np.meshgrid(X, Y) + Z = np.zeros((X.shape[0], X.shape[1])) - y = sm2.predict_values(x) + for i in range(X.shape[0]): + for j in range(X.shape[1]): + Z[i, j] = sm.predict_values( + np.hstack((X[i, j], Y[i, j])).reshape((1, 2)) + ).item() - self.assertIsNotNone(y) - print("Prédictions avec le modèle chargé :", y) + self.assertIsNotNone(Z) + print("Prédictions avec le modèle chargé :", Z) os.remove(filename) + def test_save_krigs(self): + + krigs = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS] + rng = np.random.RandomState(1) + N_inducing = 30 + + xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + + for surrogate in krigs: + + filename = "sm_save_test" + + sm = surrogate() + sm.set_training_values(xt, yt) + + if surrogate == SGP: + sm.Z = 2 * rng.rand(N_inducing, 1) - 1 + sm.set_inducing_inputs(Z=sm.Z) + + sm.train() + + sm.save(filename) + self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") + + file_size = os.path.getsize(filename) + print(f"Taille du fichier : {file_size} octets") + + sm2 = surrogate.load(filename) + self.assertIsNotNone(sm2) + + num = 100 + x = np.linspace(0.0, 4.0, num).reshape(-1, 1) + + y = sm2.predict_values(x) + + self.assertIsNotNone(y) + print("Prédictions avec le modèle chargé :", y) + + os.remove(filename) + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/smt/utils/persistance.py b/smt/utils/persistance.py new file mode 100644 index 000000000..f50d9d4e2 --- /dev/null +++ b/smt/utils/persistance.py @@ -0,0 +1,16 @@ +import pickle + + +def save(self, filename): + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + +@staticmethod +def load(filename): + with open(filename, "rb") as file: + sm = pickle.load(file) + return sm \ No newline at end of file