From e30bd52eba18207931d59ea47b8d5a2bb1124085 Mon Sep 17 00:00:00 2001 From: ANTOINE AVERLAND Date: Tue, 3 Dec 2024 16:41:32 +0100 Subject: [PATCH 1/6] new save and load functions for the surrogate models --- smt/surrogate_models/krg_based.py | 24 +++++++++++++++ smt/surrogate_models/ls.py | 21 +++++++++++++ smt/surrogate_models/surrogate_model.py | 6 ++++ smt/tests/test_save_load.py | 39 +++++++++++++++++++++++++ 4 files changed, 90 insertions(+) create mode 100644 smt/tests/test_save_load.py diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index c17bcd856..18027888a 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -8,8 +8,10 @@ import warnings from copy import deepcopy from enum import Enum +from joblib import dump, load import numpy as np +import pickle from scipy import linalg, optimize from scipy.stats import multivariate_normal as m_norm @@ -76,6 +78,7 @@ class KrgBased(SurrogateModel): "act_exp": ActExp, } name = "KrigingBased" + filename = "kriging_model" def _initialize(self): super(KrgBased, self)._initialize() @@ -1869,6 +1872,25 @@ def _predict_variances( # machine precision: force to zero! s2[s2 < 0.0] = 0.0 return s2 + + def _save(self, filename=None): + if filename is None: + filename = self.filename + + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + + def _load(self, filename): + if filename is None: + return ("file is not found") + else: + with open(filename, "rb") as file: + sm2 = pickle.load(file) + return sm2 def _predict_variance_derivatives(self, x, kx): """ @@ -2549,3 +2571,5 @@ def compute_n_param(design_space, cat_kernel, d, n_comp, mat_dim): if cat_kernel == MixIntKernelType.CONT_RELAX: n_param += int(n_values) return n_param + + diff --git a/smt/surrogate_models/ls.py b/smt/surrogate_models/ls.py index f33135e50..bc6c42d14 100644 --- a/smt/surrogate_models/ls.py +++ b/smt/surrogate_models/ls.py @@ -9,6 +9,7 @@ """ import numpy as np +import pickle from sklearn import linear_model from smt.surrogate_models.surrogate_model import SurrogateModel @@ -23,6 +24,7 @@ class LS(SurrogateModel): """ name = "LS" + filename = "least_square" def _initialize(self): super(LS, self)._initialize() @@ -99,3 +101,22 @@ def _predict_derivatives(self, x, kx): n_eval, n_features_x = x.shape y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx] return y + + def _save(self, filename=None): + if filename is None: + filename = self.filename + + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + + def _load(self, filename): + if filename is None: + return ("file is not found") + else: + with open(filename, "rb") as file: + sm2 = pickle.load(file) + return sm2 diff --git a/smt/surrogate_models/surrogate_model.py b/smt/surrogate_models/surrogate_model.py index 8513296c9..2a7a708f9 100644 --- a/smt/surrogate_models/surrogate_model.py +++ b/smt/surrogate_models/surrogate_model.py @@ -570,3 +570,9 @@ def _check_xdim(self, x): """Raise a ValueError if x dimension is not consistent with surrogate model training data dimension. This method is used as a guard in preamble of predict methods""" check_nx(self.nx, x) + + def _save(filename): + """ + Implemented by surrogate models to save the surrogate object in a file + """ + pass diff --git a/smt/tests/test_save_load.py b/smt/tests/test_save_load.py new file mode 100644 index 000000000..b8a97f7f5 --- /dev/null +++ b/smt/tests/test_save_load.py @@ -0,0 +1,39 @@ +import os +import unittest +import numpy as np +from smt.surrogate_models import KRG + +class TestKRG(unittest.TestCase): + + def test_save_load(self): + + filename = "kriging_save_test" + + xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + + sm = KRG() + sm.set_training_values(xt, yt) + sm.train() + + sm._save(filename) + self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") + + file_size = os.path.getsize(filename) + print(f"Taille du fichier : {file_size} octets") + + sm2 = sm._load(filename) + self.assertIsNotNone(sm2) + + num = 100 + x = np.linspace(0.0, 4.0, num).reshape(-1, 1) + + y = sm2.predict_values(x) + + self.assertIsNotNone(y) + print("Prédictions avec le modèle chargé :", y) + + os.remove(filename) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 9614c8f0dc78759e4605abb77387e73d372c29a6 Mon Sep 17 00:00:00 2001 From: ANTOINE AVERLAND Date: Wed, 4 Dec 2024 17:18:17 +0100 Subject: [PATCH 2/6] new save functions for the surrogates and identification of the necessary attributes for predict_values() --- smt/surrogate_models/gpx.py | 26 ++++++++++- smt/surrogate_models/idw.py | 20 ++++++++ smt/surrogate_models/krg_based.py | 6 +++ smt/surrogate_models/qp.py | 21 +++++++++ smt/surrogate_models/rbf.py | 21 +++++++++ smt/surrogate_models/rmts.py | 21 +++++++++ smt/surrogate_models/surrogate_model.py | 2 +- .../test_identification_attibuts_krig.py | 46 +++++++++++++++++++ smt/tests/test_save_load.py | 18 ++++++-- 9 files changed, 175 insertions(+), 6 deletions(-) create mode 100644 smt/tests/test_identification_attibuts_krig.py diff --git a/smt/surrogate_models/gpx.py b/smt/surrogate_models/gpx.py index 9d1980506..6de583fbc 100644 --- a/smt/surrogate_models/gpx.py +++ b/smt/surrogate_models/gpx.py @@ -1,4 +1,5 @@ import numpy as np +import pickle from smt.surrogate_models.surrogate_model import SurrogateModel @@ -29,6 +30,7 @@ class GPX(SurrogateModel): name = "GPX" + filename = "gpx_save" def _initialize(self): super(GPX, self)._initialize() @@ -162,14 +164,34 @@ def predict_variance_gradients(self, x): Returns all variance gradients at the given x points as a [n, nx] matrix""" return self._gpx.predict_var_gradients(x) - def save(self, filepath): + def _save(self, filename): """Save the trained model in the given filepath Arguments --------- filename (string): path to the json file """ - self._gpx.save(filepath) + if filename is None: + filename = self.filename + + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + # self._gpx.save(filepath) + + + + # def _load(self, filename): + # if filename is None: + # return ("file is not found") + # else: + # with open(filename, "rb") as file: + # sm2 = pickle.load(file) + # return sm2 + @staticmethod def load(filepath): diff --git a/smt/surrogate_models/idw.py b/smt/surrogate_models/idw.py index aa08b3776..bb489aba7 100644 --- a/smt/surrogate_models/idw.py +++ b/smt/surrogate_models/idw.py @@ -5,6 +5,7 @@ This package is distributed under New BSD license. """ +import pickle import numpy as np from smt.surrogate_models.idwclib import PyIDW @@ -132,3 +133,22 @@ def _predict_output_derivatives(self, x): dy_dyt = {None: jac} return dy_dyt + + def _save(self, filename=None): + if filename is None: + filename = self.filename + + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + + def _load(self, filename): + if filename is None: + return ("file is not found") + else: + with open(filename, "rb") as file: + sm2 = pickle.load(file) + return sm2 diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index 18027888a..0a72da713 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -286,6 +286,12 @@ def design_space(self) -> BaseDesignSpace: xt=xt, xlimits=ds_input, design_space=ds_input ) return self.options["design_space"] + + @design_space.setter + def design_space(self, d_s_value): + if not isinstance(d_s_value, BaseDesignSpace): + raise TypeError("design_space must be of type BaseDesignSpace") + self.options["design_space"] = d_s_value @property def is_continuous(self) -> bool: diff --git a/smt/surrogate_models/qp.py b/smt/surrogate_models/qp.py index 4eb7fe1e4..e47ca47be 100644 --- a/smt/surrogate_models/qp.py +++ b/smt/surrogate_models/qp.py @@ -10,6 +10,7 @@ import numpy as np import scipy +import pickle from smt.surrogate_models.surrogate_model import SurrogateModel from smt.utils.caching import cached_operation @@ -22,6 +23,7 @@ class QP(SurrogateModel): """ name = "QP" + filename = "qp_save" def _initialize(self): super(QP, self)._initialize() @@ -162,3 +164,22 @@ def _predict_values(self, x): y = (self.y_mean + self.y_std * y_).ravel() y = y.reshape((x.shape[0], self.ny)) return y + + def _save(self, filename=None): + if filename is None: + filename = self.filename + + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + + def _load(self, filename): + if filename is None: + return ("file is not found") + else: + with open(filename, "rb") as file: + sm2 = pickle.load(file) + return sm2 \ No newline at end of file diff --git a/smt/surrogate_models/rbf.py b/smt/surrogate_models/rbf.py index 024299901..fbe5c213f 100644 --- a/smt/surrogate_models/rbf.py +++ b/smt/surrogate_models/rbf.py @@ -5,6 +5,7 @@ """ import numpy as np +import pickle from scipy.sparse import csc_matrix from smt.surrogate_models.rbfclib import PyRBF @@ -19,6 +20,7 @@ class RBF(SurrogateModel): """ name = "RBF" + filename = "rbf_save" def _initialize(self): super(RBF, self)._initialize() @@ -213,3 +215,22 @@ def _predict_output_derivatives(self, x): dy_dyt = (dytl_dyt.T.dot(dstates_dytl.T).dot(dy_dstates.T)).T dy_dyt = np.einsum("ij,k->ijk", dy_dyt, np.ones(ny)) return {None: dy_dyt} + + def _save(self, filename=None): + if filename is None: + filename = self.filename + + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + + def _load(self, filename): + if filename is None: + return ("file is not found") + else: + with open(filename, "rb") as file: + sm2 = pickle.load(file) + return sm2 diff --git a/smt/surrogate_models/rmts.py b/smt/surrogate_models/rmts.py index 9dc903c44..c0ebac2a8 100644 --- a/smt/surrogate_models/rmts.py +++ b/smt/surrogate_models/rmts.py @@ -7,6 +7,7 @@ from numbers import Integral import numpy as np +import pickle import scipy.sparse from smt.surrogate_models.surrogate_model import SurrogateModel @@ -21,6 +22,7 @@ class RMTS(SurrogateModel): """ name = "RMTS" + filename = "rmts_save" def _initialize(self): super(RMTS, self)._initialize() @@ -583,3 +585,22 @@ def _predict_output_derivatives(self, x): dy_dyt[kx - 1] = np.einsum("ij,jkl->ikl", dy_dw, dw_dyt) return dy_dyt + + def _save(self, filename=None): + if filename is None: + filename = self.filename + + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + + def _load(self, filename): + if filename is None: + return ("file is not found") + else: + with open(filename, "rb") as file: + sm2 = pickle.load(file) + return sm2 diff --git a/smt/surrogate_models/surrogate_model.py b/smt/surrogate_models/surrogate_model.py index 2a7a708f9..9f90ccc51 100644 --- a/smt/surrogate_models/surrogate_model.py +++ b/smt/surrogate_models/surrogate_model.py @@ -571,7 +571,7 @@ def _check_xdim(self, x): This method is used as a guard in preamble of predict methods""" check_nx(self.nx, x) - def _save(filename): + def _save(self, filename): """ Implemented by surrogate models to save the surrogate object in a file """ diff --git a/smt/tests/test_identification_attibuts_krig.py b/smt/tests/test_identification_attibuts_krig.py new file mode 100644 index 000000000..a3cde0436 --- /dev/null +++ b/smt/tests/test_identification_attibuts_krig.py @@ -0,0 +1,46 @@ +import unittest +import numpy as np +from smt.surrogate_models import KRG + +class TestIdKrig(unittest.TestCase): + + def test_attrib_krig(self): + + xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + + sm = KRG() + sm.set_training_values(xt, yt) + sm.train() + nx = sm.nx + nt = sm.nt + design_space = sm.design_space + X_offset = sm.X_offset + X_scale = sm.X_scale + X_norma = sm.X_norma + optimal_par = sm.optimal_par + y_mean = sm.y_mean + y_std = sm.y_std + ny = sm.ny + print(f"nx = ", nx) + print(f"nt = ", nt) + print(f"design space = ", design_space) + + self.assertIsNotNone(nx) + + sm2 = KRG() + sm2.nx = nx + sm2.nt = nt + sm2.design_space = design_space + sm2.X_offset = X_offset + sm2.X_scale = X_scale + sm2.X_norma = X_norma + sm2.optimal_par = optimal_par + sm2.y_mean = y_mean + sm2.y_std = y_std + sm2.ny = ny + sm2.predict_values(np.array([5, 1])) + print + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/smt/tests/test_save_load.py b/smt/tests/test_save_load.py index b8a97f7f5..19b26d757 100644 --- a/smt/tests/test_save_load.py +++ b/smt/tests/test_save_load.py @@ -2,8 +2,20 @@ import unittest import numpy as np from smt.surrogate_models import KRG - -class TestKRG(unittest.TestCase): +# from smt.surrogate_models import IDW //TODO comprendre pourquoi l"import de IDW génère une erreur +from smt.surrogate_models import LS +from smt.surrogate_models import KPLS +from smt.surrogate_models import GEKPLS +from smt.surrogate_models import GPX +from smt.surrogate_models import KPLSK +from smt.surrogate_models import MGP +from smt.surrogate_models import QP +# from smt.surrogate_models import RBF +# from smt.surrogate_models import RMTS +from smt.surrogate_models import SGP + + +class TestSaveLoad(unittest.TestCase): def test_save_load(self): @@ -12,7 +24,7 @@ def test_save_load(self): xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) - sm = KRG() + sm = SGP() sm.set_training_values(xt, yt) sm.train() From 2936d695d210533bfa48af63fa0ee22be537d4c4 Mon Sep 17 00:00:00 2001 From: ANTOINE AVERLAND Date: Mon, 9 Dec 2024 10:01:35 +0100 Subject: [PATCH 3/6] final implementation of an save/load option for the surrogates with pickle --- smt/surrogate_models/genn.py | 8 ++ smt/surrogate_models/gpx.py | 23 +---- smt/surrogate_models/krg_based.py | 26 ++--- smt/surrogate_models/ls.py | 25 ++--- smt/surrogate_models/qp.py | 25 ++--- smt/surrogate_models/surrogate_model.py | 4 +- .../test_identification_attibuts_krig.py | 46 --------- smt/tests/test_save_load.py | 99 ++++++++++++++----- smt/utils/persistance.py | 16 +++ 9 files changed, 121 insertions(+), 151 deletions(-) delete mode 100644 smt/tests/test_identification_attibuts_krig.py create mode 100644 smt/utils/persistance.py diff --git a/smt/surrogate_models/genn.py b/smt/surrogate_models/genn.py index adb6ebb37..2e13b8d90 100644 --- a/smt/surrogate_models/genn.py +++ b/smt/surrogate_models/genn.py @@ -12,6 +12,7 @@ from jenn.model import NeuralNet from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistance # The missing type SMTrainingPoints = Dict[Union[int, None], Dict[int, List[np.ndarray]]] @@ -212,6 +213,13 @@ def _train(self): ) self.model.fit(X, Y, J, **kwargs) + def save(self, filename): + persistance.save(self, filename) + + @staticmethod + def load(filename): + return (persistance.load(filename)) + def _predict_values(self, x): return self.model.predict(x.T).T diff --git a/smt/surrogate_models/gpx.py b/smt/surrogate_models/gpx.py index 6de583fbc..8a2993d9b 100644 --- a/smt/surrogate_models/gpx.py +++ b/smt/surrogate_models/gpx.py @@ -164,33 +164,14 @@ def predict_variance_gradients(self, x): Returns all variance gradients at the given x points as a [n, nx] matrix""" return self._gpx.predict_var_gradients(x) - def _save(self, filename): + def save(self, filename): """Save the trained model in the given filepath Arguments --------- filename (string): path to the json file """ - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - # self._gpx.save(filepath) - - - - # def _load(self, filename): - # if filename is None: - # return ("file is not found") - # else: - # with open(filename, "rb") as file: - # sm2 = pickle.load(file) - # return sm2 + self._gpx.save(filename) @staticmethod diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index 0a72da713..a8bb891dc 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -8,15 +8,14 @@ import warnings from copy import deepcopy from enum import Enum -from joblib import dump, load import numpy as np -import pickle from scipy import linalg, optimize from scipy.stats import multivariate_normal as m_norm from smt.sampling_methods import LHS from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistance from smt.kernels import ( SquarSinExp, @@ -78,7 +77,6 @@ class KrgBased(SurrogateModel): "act_exp": ActExp, } name = "KrigingBased" - filename = "kriging_model" def _initialize(self): super(KrgBased, self)._initialize() @@ -1879,24 +1877,12 @@ def _predict_variances( s2[s2 < 0.0] = 0.0 return s2 - def _save(self, filename=None): - if filename is None: - filename = self.filename + def save(self, filename): + persistance.save(self, filename) - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 + @staticmethod + def load(filename): + return (persistance.load(filename)) def _predict_variance_derivatives(self, x, kx): """ diff --git a/smt/surrogate_models/ls.py b/smt/surrogate_models/ls.py index bc6c42d14..6e58cbe86 100644 --- a/smt/surrogate_models/ls.py +++ b/smt/surrogate_models/ls.py @@ -13,6 +13,7 @@ from sklearn import linear_model from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistance from smt.utils.caching import cached_operation @@ -102,21 +103,9 @@ def _predict_derivatives(self, x, kx): y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx] return y - def _save(self, filename=None): - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 + def save(self, filename): + persistance.save(self, filename) + + @staticmethod + def load(filename): + return (persistance.load(filename)) diff --git a/smt/surrogate_models/qp.py b/smt/surrogate_models/qp.py index e47ca47be..d56846942 100644 --- a/smt/surrogate_models/qp.py +++ b/smt/surrogate_models/qp.py @@ -13,6 +13,7 @@ import pickle from smt.surrogate_models.surrogate_model import SurrogateModel +from smt.utils import persistance from smt.utils.caching import cached_operation from smt.utils.misc import standardization @@ -165,21 +166,9 @@ def _predict_values(self, x): y = y.reshape((x.shape[0], self.ny)) return y - def _save(self, filename=None): - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 \ No newline at end of file + def save(self, filename): + persistance.save(self, filename) + + @staticmethod + def load(filename): + return (persistance.load(filename)) \ No newline at end of file diff --git a/smt/surrogate_models/surrogate_model.py b/smt/surrogate_models/surrogate_model.py index 9f90ccc51..cecd57a65 100644 --- a/smt/surrogate_models/surrogate_model.py +++ b/smt/surrogate_models/surrogate_model.py @@ -571,8 +571,8 @@ def _check_xdim(self, x): This method is used as a guard in preamble of predict methods""" check_nx(self.nx, x) - def _save(self, filename): + def save(self, filename): """ Implemented by surrogate models to save the surrogate object in a file """ - pass + raise NotImplementedError("save() has to be implemented by the given surrogate") diff --git a/smt/tests/test_identification_attibuts_krig.py b/smt/tests/test_identification_attibuts_krig.py deleted file mode 100644 index a3cde0436..000000000 --- a/smt/tests/test_identification_attibuts_krig.py +++ /dev/null @@ -1,46 +0,0 @@ -import unittest -import numpy as np -from smt.surrogate_models import KRG - -class TestIdKrig(unittest.TestCase): - - def test_attrib_krig(self): - - xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) - yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) - - sm = KRG() - sm.set_training_values(xt, yt) - sm.train() - nx = sm.nx - nt = sm.nt - design_space = sm.design_space - X_offset = sm.X_offset - X_scale = sm.X_scale - X_norma = sm.X_norma - optimal_par = sm.optimal_par - y_mean = sm.y_mean - y_std = sm.y_std - ny = sm.ny - print(f"nx = ", nx) - print(f"nt = ", nt) - print(f"design space = ", design_space) - - self.assertIsNotNone(nx) - - sm2 = KRG() - sm2.nx = nx - sm2.nt = nt - sm2.design_space = design_space - sm2.X_offset = X_offset - sm2.X_scale = X_scale - sm2.X_norma = X_norma - sm2.optimal_par = optimal_par - sm2.y_mean = y_mean - sm2.y_std = y_std - sm2.ny = ny - sm2.predict_values(np.array([5, 1])) - print - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/smt/tests/test_save_load.py b/smt/tests/test_save_load.py index 19b26d757..1e10d9a9e 100644 --- a/smt/tests/test_save_load.py +++ b/smt/tests/test_save_load.py @@ -1,51 +1,98 @@ import os import unittest import numpy as np -from smt.surrogate_models import KRG -# from smt.surrogate_models import IDW //TODO comprendre pourquoi l"import de IDW génère une erreur -from smt.surrogate_models import LS -from smt.surrogate_models import KPLS -from smt.surrogate_models import GEKPLS -from smt.surrogate_models import GPX -from smt.surrogate_models import KPLSK -from smt.surrogate_models import MGP -from smt.surrogate_models import QP -# from smt.surrogate_models import RBF -# from smt.surrogate_models import RMTS -from smt.surrogate_models import SGP + +from smt.problems import Sphere +from smt.sampling_methods import LHS +from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN, QP class TestSaveLoad(unittest.TestCase): - def test_save_load(self): - - filename = "kriging_save_test" + def test_save_load_GEKPLS(self): - xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) - yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + filename = "sm_save_test" + fun = Sphere(ndim=2) + + sampling = LHS(xlimits=fun.xlimits, criterion="m") + xt = sampling(20) + yt = fun(xt) + + for i in range(2): + yd = fun(xt, kx=i) + yt = np.concatenate((yt, yd), axis=1) - sm = SGP() - sm.set_training_values(xt, yt) + sm = GEKPLS() + sm.set_training_values(xt, yt[:, 0]) + for i in range(2): + sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i) sm.train() - sm._save(filename) + sm.save(filename) self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") file_size = os.path.getsize(filename) print(f"Taille du fichier : {file_size} octets") - sm2 = sm._load(filename) + sm2 = GEKPLS.load(filename) self.assertIsNotNone(sm2) - num = 100 - x = np.linspace(0.0, 4.0, num).reshape(-1, 1) + X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25) + Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25) + X, Y = np.meshgrid(X, Y) + Z = np.zeros((X.shape[0], X.shape[1])) - y = sm2.predict_values(x) + for i in range(X.shape[0]): + for j in range(X.shape[1]): + Z[i, j] = sm.predict_values( + np.hstack((X[i, j], Y[i, j])).reshape((1, 2)) + ).item() - self.assertIsNotNone(y) - print("Prédictions avec le modèle chargé :", y) + self.assertIsNotNone(Z) + print("Prédictions avec le modèle chargé :", Z) os.remove(filename) + def test_save_krigs(self): + + krigs = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS] + rng = np.random.RandomState(1) + N_inducing = 30 + + xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + + for surrogate in krigs: + + filename = "sm_save_test" + + sm = surrogate() + sm.set_training_values(xt, yt) + + if surrogate == SGP: + sm.Z = 2 * rng.rand(N_inducing, 1) - 1 + sm.set_inducing_inputs(Z=sm.Z) + + sm.train() + + sm.save(filename) + self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") + + file_size = os.path.getsize(filename) + print(f"Taille du fichier : {file_size} octets") + + sm2 = surrogate.load(filename) + self.assertIsNotNone(sm2) + + num = 100 + x = np.linspace(0.0, 4.0, num).reshape(-1, 1) + + y = sm2.predict_values(x) + + self.assertIsNotNone(y) + print("Prédictions avec le modèle chargé :", y) + + os.remove(filename) + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/smt/utils/persistance.py b/smt/utils/persistance.py new file mode 100644 index 000000000..f50d9d4e2 --- /dev/null +++ b/smt/utils/persistance.py @@ -0,0 +1,16 @@ +import pickle + + +def save(self, filename): + try: + with open(filename, 'wb') as file: + pickle.dump(self, file) + print("model saved") + except: + print("Couldn't save the model") + +@staticmethod +def load(filename): + with open(filename, "rb") as file: + sm = pickle.load(file) + return sm \ No newline at end of file From 360d6c78d704ca970b50738941b215ba45316a76 Mon Sep 17 00:00:00 2001 From: ANTOINE AVERLAND Date: Mon, 9 Dec 2024 14:21:29 +0100 Subject: [PATCH 4/6] better test's assertions --- smt/surrogate_models/genn.py | 2 +- smt/surrogate_models/krg_based.py | 8 ++--- smt/surrogate_models/ls.py | 5 ++- smt/surrogate_models/qp.py | 3 +- smt/tests/test_save_load.py | 59 +++++++++++++------------------ smt/utils/persistance.py | 13 +++---- 6 files changed, 36 insertions(+), 54 deletions(-) diff --git a/smt/surrogate_models/genn.py b/smt/surrogate_models/genn.py index 2e13b8d90..8c0320fd1 100644 --- a/smt/surrogate_models/genn.py +++ b/smt/surrogate_models/genn.py @@ -218,7 +218,7 @@ def save(self, filename): @staticmethod def load(filename): - return (persistance.load(filename)) + return persistance.load(filename) def _predict_values(self, x): return self.model.predict(x.T).T diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index a8bb891dc..1130cd917 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -284,7 +284,7 @@ def design_space(self) -> BaseDesignSpace: xt=xt, xlimits=ds_input, design_space=ds_input ) return self.options["design_space"] - + @design_space.setter def design_space(self, d_s_value): if not isinstance(d_s_value, BaseDesignSpace): @@ -1876,13 +1876,13 @@ def _predict_variances( # machine precision: force to zero! s2[s2 < 0.0] = 0.0 return s2 - + def save(self, filename): persistance.save(self, filename) @staticmethod def load(filename): - return (persistance.load(filename)) + return persistance.load(filename) def _predict_variance_derivatives(self, x, kx): """ @@ -2563,5 +2563,3 @@ def compute_n_param(design_space, cat_kernel, d, n_comp, mat_dim): if cat_kernel == MixIntKernelType.CONT_RELAX: n_param += int(n_values) return n_param - - diff --git a/smt/surrogate_models/ls.py b/smt/surrogate_models/ls.py index 6e58cbe86..99379d7d0 100644 --- a/smt/surrogate_models/ls.py +++ b/smt/surrogate_models/ls.py @@ -9,7 +9,6 @@ """ import numpy as np -import pickle from sklearn import linear_model from smt.surrogate_models.surrogate_model import SurrogateModel @@ -102,10 +101,10 @@ def _predict_derivatives(self, x, kx): n_eval, n_features_x = x.shape y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx] return y - + def save(self, filename): persistance.save(self, filename) @staticmethod def load(filename): - return (persistance.load(filename)) + return persistance.load(filename) diff --git a/smt/surrogate_models/qp.py b/smt/surrogate_models/qp.py index d56846942..67335ea26 100644 --- a/smt/surrogate_models/qp.py +++ b/smt/surrogate_models/qp.py @@ -10,7 +10,6 @@ import numpy as np import scipy -import pickle from smt.surrogate_models.surrogate_model import SurrogateModel from smt.utils import persistance @@ -171,4 +170,4 @@ def save(self, filename): @staticmethod def load(filename): - return (persistance.load(filename)) \ No newline at end of file + return persistance.load(filename) diff --git a/smt/tests/test_save_load.py b/smt/tests/test_save_load.py index 1e10d9a9e..08d6803de 100644 --- a/smt/tests/test_save_load.py +++ b/smt/tests/test_save_load.py @@ -4,13 +4,11 @@ from smt.problems import Sphere from smt.sampling_methods import LHS -from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN, QP +from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN class TestSaveLoad(unittest.TestCase): - def test_save_load_GEKPLS(self): - filename = "sm_save_test" fun = Sphere(ndim=2) @@ -22,51 +20,51 @@ def test_save_load_GEKPLS(self): yd = fun(xt, kx=i) yt = np.concatenate((yt, yd), axis=1) - sm = GEKPLS() + X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25) + Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25) + X, Y = np.meshgrid(X, Y) + Z1 = np.zeros((X.shape[0], X.shape[1])) + Z2 = np.zeros((X.shape[0], X.shape[1])) + + sm = GEKPLS(print_global=False) sm.set_training_values(xt, yt[:, 0]) for i in range(2): sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i) sm.train() - + for i in range(X.shape[0]): + for j in range(X.shape[1]): + Z1[i, j] = sm.predict_values( + np.hstack((X[i, j], Y[i, j])).reshape((1, 2)) + ).item() sm.save(filename) - self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") - - file_size = os.path.getsize(filename) - print(f"Taille du fichier : {file_size} octets") sm2 = GEKPLS.load(filename) self.assertIsNotNone(sm2) - X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25) - Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25) - X, Y = np.meshgrid(X, Y) - Z = np.zeros((X.shape[0], X.shape[1])) - for i in range(X.shape[0]): for j in range(X.shape[1]): - Z[i, j] = sm.predict_values( + Z2[i, j] = sm2.predict_values( np.hstack((X[i, j], Y[i, j])).reshape((1, 2)) ).item() - self.assertIsNotNone(Z) - print("Prédictions avec le modèle chargé :", Z) + np.testing.assert_allclose(Z1, Z2) os.remove(filename) - def test_save_krigs(self): - + def test_save_load_surrogates(self): krigs = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS] rng = np.random.RandomState(1) N_inducing = 30 + num = 100 xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + x = np.linspace(0.0, 4.0, num).reshape(-1, 1) for surrogate in krigs: - filename = "sm_save_test" - sm = surrogate() + sm = surrogate(print_global=False) sm.set_training_values(xt, yt) if surrogate == SGP: @@ -74,25 +72,16 @@ def test_save_krigs(self): sm.set_inducing_inputs(Z=sm.Z) sm.train() - + y1 = sm.predict_values(x) sm.save(filename) - self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") - - file_size = os.path.getsize(filename) - print(f"Taille du fichier : {file_size} octets") sm2 = surrogate.load(filename) - self.assertIsNotNone(sm2) - - num = 100 - x = np.linspace(0.0, 4.0, num).reshape(-1, 1) + y2 = sm2.predict_values(x) - y = sm2.predict_values(x) - - self.assertIsNotNone(y) - print("Prédictions avec le modèle chargé :", y) + np.testing.assert_allclose(y1, y2) os.remove(filename) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/smt/utils/persistance.py b/smt/utils/persistance.py index f50d9d4e2..69d1133e5 100644 --- a/smt/utils/persistance.py +++ b/smt/utils/persistance.py @@ -2,15 +2,12 @@ def save(self, filename): - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") + with open(filename, "wb") as file: + pickle.dump(self, file) + -@staticmethod def load(filename): + sm = None with open(filename, "rb") as file: sm = pickle.load(file) - return sm \ No newline at end of file + return sm From c324dca05e386c34e808ca84991a52ce97e7c363 Mon Sep 17 00:00:00 2001 From: ANTOINE AVERLAND Date: Mon, 9 Dec 2024 14:29:06 +0100 Subject: [PATCH 5/6] ruff check modifications --- smt/surrogate_models/gpx.py | 2 -- smt/surrogate_models/idw.py | 20 -------------------- smt/surrogate_models/rbf.py | 20 -------------------- smt/surrogate_models/rmts.py | 20 -------------------- 4 files changed, 62 deletions(-) diff --git a/smt/surrogate_models/gpx.py b/smt/surrogate_models/gpx.py index 8a2993d9b..844a1b5e7 100644 --- a/smt/surrogate_models/gpx.py +++ b/smt/surrogate_models/gpx.py @@ -1,5 +1,4 @@ import numpy as np -import pickle from smt.surrogate_models.surrogate_model import SurrogateModel @@ -173,7 +172,6 @@ def save(self, filename): """ self._gpx.save(filename) - @staticmethod def load(filepath): """Load the model from a previously saved GPX model. diff --git a/smt/surrogate_models/idw.py b/smt/surrogate_models/idw.py index bb489aba7..aa08b3776 100644 --- a/smt/surrogate_models/idw.py +++ b/smt/surrogate_models/idw.py @@ -5,7 +5,6 @@ This package is distributed under New BSD license. """ -import pickle import numpy as np from smt.surrogate_models.idwclib import PyIDW @@ -133,22 +132,3 @@ def _predict_output_derivatives(self, x): dy_dyt = {None: jac} return dy_dyt - - def _save(self, filename=None): - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 diff --git a/smt/surrogate_models/rbf.py b/smt/surrogate_models/rbf.py index fbe5c213f..90912f282 100644 --- a/smt/surrogate_models/rbf.py +++ b/smt/surrogate_models/rbf.py @@ -5,7 +5,6 @@ """ import numpy as np -import pickle from scipy.sparse import csc_matrix from smt.surrogate_models.rbfclib import PyRBF @@ -215,22 +214,3 @@ def _predict_output_derivatives(self, x): dy_dyt = (dytl_dyt.T.dot(dstates_dytl.T).dot(dy_dstates.T)).T dy_dyt = np.einsum("ij,k->ijk", dy_dyt, np.ones(ny)) return {None: dy_dyt} - - def _save(self, filename=None): - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 diff --git a/smt/surrogate_models/rmts.py b/smt/surrogate_models/rmts.py index c0ebac2a8..a672e4e3f 100644 --- a/smt/surrogate_models/rmts.py +++ b/smt/surrogate_models/rmts.py @@ -7,7 +7,6 @@ from numbers import Integral import numpy as np -import pickle import scipy.sparse from smt.surrogate_models.surrogate_model import SurrogateModel @@ -585,22 +584,3 @@ def _predict_output_derivatives(self, x): dy_dyt[kx - 1] = np.einsum("ij,jkl->ikl", dy_dw, dw_dyt) return dy_dyt - - def _save(self, filename=None): - if filename is None: - filename = self.filename - - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") - - def _load(self, filename): - if filename is None: - return ("file is not found") - else: - with open(filename, "rb") as file: - sm2 = pickle.load(file) - return sm2 From fce8cce1211a6da022cd4065a2db4a11856ca74c Mon Sep 17 00:00:00 2001 From: ANTOINE AVERLAND Date: Mon, 9 Dec 2024 16:57:49 +0100 Subject: [PATCH 6/6] changes to validate the pull request --- smt/surrogate_models/genn.py | 6 +++--- smt/surrogate_models/gpx.py | 1 - smt/surrogate_models/krg_based.py | 12 +++--------- smt/surrogate_models/ls.py | 7 +++---- smt/surrogate_models/qp.py | 7 +++---- smt/surrogate_models/rbf.py | 1 - smt/surrogate_models/rmts.py | 1 - smt/{ => surrogate_models}/tests/test_save_load.py | 4 ++-- smt/utils/{persistance.py => persistence.py} | 0 9 files changed, 14 insertions(+), 25 deletions(-) rename smt/{ => surrogate_models}/tests/test_save_load.py (96%) rename smt/utils/{persistance.py => persistence.py} (100%) diff --git a/smt/surrogate_models/genn.py b/smt/surrogate_models/genn.py index 8c0320fd1..532c8b23f 100644 --- a/smt/surrogate_models/genn.py +++ b/smt/surrogate_models/genn.py @@ -12,7 +12,7 @@ from jenn.model import NeuralNet from smt.surrogate_models.surrogate_model import SurrogateModel -from smt.utils import persistance +from smt.utils import persistence # The missing type SMTrainingPoints = Dict[Union[int, None], Dict[int, List[np.ndarray]]] @@ -214,11 +214,11 @@ def _train(self): self.model.fit(X, Y, J, **kwargs) def save(self, filename): - persistance.save(self, filename) + persistence.save(self, filename) @staticmethod def load(filename): - return persistance.load(filename) + return persistence.load(filename) def _predict_values(self, x): return self.model.predict(x.T).T diff --git a/smt/surrogate_models/gpx.py b/smt/surrogate_models/gpx.py index 844a1b5e7..1b5253b45 100644 --- a/smt/surrogate_models/gpx.py +++ b/smt/surrogate_models/gpx.py @@ -29,7 +29,6 @@ class GPX(SurrogateModel): name = "GPX" - filename = "gpx_save" def _initialize(self): super(GPX, self)._initialize() diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index 1130cd917..aae65eee8 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -15,7 +15,7 @@ from smt.sampling_methods import LHS from smt.surrogate_models.surrogate_model import SurrogateModel -from smt.utils import persistance +from smt.utils import persistence from smt.kernels import ( SquarSinExp, @@ -285,12 +285,6 @@ def design_space(self) -> BaseDesignSpace: ) return self.options["design_space"] - @design_space.setter - def design_space(self, d_s_value): - if not isinstance(d_s_value, BaseDesignSpace): - raise TypeError("design_space must be of type BaseDesignSpace") - self.options["design_space"] = d_s_value - @property def is_continuous(self) -> bool: return self.design_space.is_all_cont @@ -1878,11 +1872,11 @@ def _predict_variances( return s2 def save(self, filename): - persistance.save(self, filename) + persistence.save(self, filename) @staticmethod def load(filename): - return persistance.load(filename) + return persistence.load(filename) def _predict_variance_derivatives(self, x, kx): """ diff --git a/smt/surrogate_models/ls.py b/smt/surrogate_models/ls.py index 99379d7d0..82c4b6bbb 100644 --- a/smt/surrogate_models/ls.py +++ b/smt/surrogate_models/ls.py @@ -12,7 +12,7 @@ from sklearn import linear_model from smt.surrogate_models.surrogate_model import SurrogateModel -from smt.utils import persistance +from smt.utils import persistence from smt.utils.caching import cached_operation @@ -24,7 +24,6 @@ class LS(SurrogateModel): """ name = "LS" - filename = "least_square" def _initialize(self): super(LS, self)._initialize() @@ -103,8 +102,8 @@ def _predict_derivatives(self, x, kx): return y def save(self, filename): - persistance.save(self, filename) + persistence.save(self, filename) @staticmethod def load(filename): - return persistance.load(filename) + return persistence.load(filename) diff --git a/smt/surrogate_models/qp.py b/smt/surrogate_models/qp.py index 67335ea26..badfb78ec 100644 --- a/smt/surrogate_models/qp.py +++ b/smt/surrogate_models/qp.py @@ -12,7 +12,7 @@ import scipy from smt.surrogate_models.surrogate_model import SurrogateModel -from smt.utils import persistance +from smt.utils import persistence from smt.utils.caching import cached_operation from smt.utils.misc import standardization @@ -23,7 +23,6 @@ class QP(SurrogateModel): """ name = "QP" - filename = "qp_save" def _initialize(self): super(QP, self)._initialize() @@ -166,8 +165,8 @@ def _predict_values(self, x): return y def save(self, filename): - persistance.save(self, filename) + persistence.save(self, filename) @staticmethod def load(filename): - return persistance.load(filename) + return persistence.load(filename) diff --git a/smt/surrogate_models/rbf.py b/smt/surrogate_models/rbf.py index 90912f282..024299901 100644 --- a/smt/surrogate_models/rbf.py +++ b/smt/surrogate_models/rbf.py @@ -19,7 +19,6 @@ class RBF(SurrogateModel): """ name = "RBF" - filename = "rbf_save" def _initialize(self): super(RBF, self)._initialize() diff --git a/smt/surrogate_models/rmts.py b/smt/surrogate_models/rmts.py index a672e4e3f..9dc903c44 100644 --- a/smt/surrogate_models/rmts.py +++ b/smt/surrogate_models/rmts.py @@ -21,7 +21,6 @@ class RMTS(SurrogateModel): """ name = "RMTS" - filename = "rmts_save" def _initialize(self): super(RMTS, self)._initialize() diff --git a/smt/tests/test_save_load.py b/smt/surrogate_models/tests/test_save_load.py similarity index 96% rename from smt/tests/test_save_load.py rename to smt/surrogate_models/tests/test_save_load.py index 08d6803de..81fe94ce9 100644 --- a/smt/tests/test_save_load.py +++ b/smt/surrogate_models/tests/test_save_load.py @@ -52,7 +52,7 @@ def test_save_load_GEKPLS(self): os.remove(filename) def test_save_load_surrogates(self): - krigs = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS] + surrogates = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS] rng = np.random.RandomState(1) N_inducing = 30 num = 100 @@ -61,7 +61,7 @@ def test_save_load_surrogates(self): yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) x = np.linspace(0.0, 4.0, num).reshape(-1, 1) - for surrogate in krigs: + for surrogate in surrogates: filename = "sm_save_test" sm = surrogate(print_global=False) diff --git a/smt/utils/persistance.py b/smt/utils/persistence.py similarity index 100% rename from smt/utils/persistance.py rename to smt/utils/persistence.py