From 360d6c78d704ca970b50738941b215ba45316a76 Mon Sep 17 00:00:00 2001 From: ANTOINE AVERLAND Date: Mon, 9 Dec 2024 14:21:29 +0100 Subject: [PATCH] better test's assertions --- smt/surrogate_models/genn.py | 2 +- smt/surrogate_models/krg_based.py | 8 ++--- smt/surrogate_models/ls.py | 5 ++- smt/surrogate_models/qp.py | 3 +- smt/tests/test_save_load.py | 59 +++++++++++++------------------ smt/utils/persistance.py | 13 +++---- 6 files changed, 36 insertions(+), 54 deletions(-) diff --git a/smt/surrogate_models/genn.py b/smt/surrogate_models/genn.py index 2e13b8d90..8c0320fd1 100644 --- a/smt/surrogate_models/genn.py +++ b/smt/surrogate_models/genn.py @@ -218,7 +218,7 @@ def save(self, filename): @staticmethod def load(filename): - return (persistance.load(filename)) + return persistance.load(filename) def _predict_values(self, x): return self.model.predict(x.T).T diff --git a/smt/surrogate_models/krg_based.py b/smt/surrogate_models/krg_based.py index a8bb891dc..1130cd917 100644 --- a/smt/surrogate_models/krg_based.py +++ b/smt/surrogate_models/krg_based.py @@ -284,7 +284,7 @@ def design_space(self) -> BaseDesignSpace: xt=xt, xlimits=ds_input, design_space=ds_input ) return self.options["design_space"] - + @design_space.setter def design_space(self, d_s_value): if not isinstance(d_s_value, BaseDesignSpace): @@ -1876,13 +1876,13 @@ def _predict_variances( # machine precision: force to zero! s2[s2 < 0.0] = 0.0 return s2 - + def save(self, filename): persistance.save(self, filename) @staticmethod def load(filename): - return (persistance.load(filename)) + return persistance.load(filename) def _predict_variance_derivatives(self, x, kx): """ @@ -2563,5 +2563,3 @@ def compute_n_param(design_space, cat_kernel, d, n_comp, mat_dim): if cat_kernel == MixIntKernelType.CONT_RELAX: n_param += int(n_values) return n_param - - diff --git a/smt/surrogate_models/ls.py b/smt/surrogate_models/ls.py index 6e58cbe86..99379d7d0 100644 --- a/smt/surrogate_models/ls.py +++ b/smt/surrogate_models/ls.py @@ -9,7 +9,6 @@ """ import numpy as np -import pickle from sklearn import linear_model from smt.surrogate_models.surrogate_model import SurrogateModel @@ -102,10 +101,10 @@ def _predict_derivatives(self, x, kx): n_eval, n_features_x = x.shape y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx] return y - + def save(self, filename): persistance.save(self, filename) @staticmethod def load(filename): - return (persistance.load(filename)) + return persistance.load(filename) diff --git a/smt/surrogate_models/qp.py b/smt/surrogate_models/qp.py index d56846942..67335ea26 100644 --- a/smt/surrogate_models/qp.py +++ b/smt/surrogate_models/qp.py @@ -10,7 +10,6 @@ import numpy as np import scipy -import pickle from smt.surrogate_models.surrogate_model import SurrogateModel from smt.utils import persistance @@ -171,4 +170,4 @@ def save(self, filename): @staticmethod def load(filename): - return (persistance.load(filename)) \ No newline at end of file + return persistance.load(filename) diff --git a/smt/tests/test_save_load.py b/smt/tests/test_save_load.py index 1e10d9a9e..08d6803de 100644 --- a/smt/tests/test_save_load.py +++ b/smt/tests/test_save_load.py @@ -4,13 +4,11 @@ from smt.problems import Sphere from smt.sampling_methods import LHS -from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN, QP +from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN class TestSaveLoad(unittest.TestCase): - def test_save_load_GEKPLS(self): - filename = "sm_save_test" fun = Sphere(ndim=2) @@ -22,51 +20,51 @@ def test_save_load_GEKPLS(self): yd = fun(xt, kx=i) yt = np.concatenate((yt, yd), axis=1) - sm = GEKPLS() + X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25) + Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25) + X, Y = np.meshgrid(X, Y) + Z1 = np.zeros((X.shape[0], X.shape[1])) + Z2 = np.zeros((X.shape[0], X.shape[1])) + + sm = GEKPLS(print_global=False) sm.set_training_values(xt, yt[:, 0]) for i in range(2): sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i) sm.train() - + for i in range(X.shape[0]): + for j in range(X.shape[1]): + Z1[i, j] = sm.predict_values( + np.hstack((X[i, j], Y[i, j])).reshape((1, 2)) + ).item() sm.save(filename) - self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") - - file_size = os.path.getsize(filename) - print(f"Taille du fichier : {file_size} octets") sm2 = GEKPLS.load(filename) self.assertIsNotNone(sm2) - X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25) - Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25) - X, Y = np.meshgrid(X, Y) - Z = np.zeros((X.shape[0], X.shape[1])) - for i in range(X.shape[0]): for j in range(X.shape[1]): - Z[i, j] = sm.predict_values( + Z2[i, j] = sm2.predict_values( np.hstack((X[i, j], Y[i, j])).reshape((1, 2)) ).item() - self.assertIsNotNone(Z) - print("Prédictions avec le modèle chargé :", Z) + np.testing.assert_allclose(Z1, Z2) os.remove(filename) - def test_save_krigs(self): - + def test_save_load_surrogates(self): krigs = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS] rng = np.random.RandomState(1) N_inducing = 30 + num = 100 xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0]) + x = np.linspace(0.0, 4.0, num).reshape(-1, 1) for surrogate in krigs: - filename = "sm_save_test" - sm = surrogate() + sm = surrogate(print_global=False) sm.set_training_values(xt, yt) if surrogate == SGP: @@ -74,25 +72,16 @@ def test_save_krigs(self): sm.set_inducing_inputs(Z=sm.Z) sm.train() - + y1 = sm.predict_values(x) sm.save(filename) - self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.") - - file_size = os.path.getsize(filename) - print(f"Taille du fichier : {file_size} octets") sm2 = surrogate.load(filename) - self.assertIsNotNone(sm2) - - num = 100 - x = np.linspace(0.0, 4.0, num).reshape(-1, 1) + y2 = sm2.predict_values(x) - y = sm2.predict_values(x) - - self.assertIsNotNone(y) - print("Prédictions avec le modèle chargé :", y) + np.testing.assert_allclose(y1, y2) os.remove(filename) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/smt/utils/persistance.py b/smt/utils/persistance.py index f50d9d4e2..69d1133e5 100644 --- a/smt/utils/persistance.py +++ b/smt/utils/persistance.py @@ -2,15 +2,12 @@ def save(self, filename): - try: - with open(filename, 'wb') as file: - pickle.dump(self, file) - print("model saved") - except: - print("Couldn't save the model") + with open(filename, "wb") as file: + pickle.dump(self, file) + -@staticmethod def load(filename): + sm = None with open(filename, "rb") as file: sm = pickle.load(file) - return sm \ No newline at end of file + return sm