Skip to content

Commit

Permalink
better test's assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
Antoine-Averland committed Dec 9, 2024
1 parent cbdd50f commit 404b28d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 54 deletions.
2 changes: 1 addition & 1 deletion smt/surrogate_models/genn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def save(self, filename):

@staticmethod
def load(filename):
return (persistance.load(filename))
return persistance.load(filename)

def _predict_values(self, x):
return self.model.predict(x.T).T
Expand Down
8 changes: 3 additions & 5 deletions smt/surrogate_models/krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def design_space(self) -> BaseDesignSpace:
xt=xt, xlimits=ds_input, design_space=ds_input
)
return self.options["design_space"]

@design_space.setter
def design_space(self, d_s_value):
if not isinstance(d_s_value, BaseDesignSpace):
Expand Down Expand Up @@ -1869,13 +1869,13 @@ def _predict_variances(
# machine precision: force to zero!
s2[s2 < 0.0] = 0.0
return s2

def save(self, filename):
persistance.save(self, filename)

@staticmethod
def load(filename):
return (persistance.load(filename))
return persistance.load(filename)

def _predict_variance_derivatives(self, x, kx):
"""
Expand Down Expand Up @@ -2556,5 +2556,3 @@ def compute_n_param(design_space, cat_kernel, d, n_comp, mat_dim):
if cat_kernel == MixIntKernelType.CONT_RELAX:
n_param += int(n_values)
return n_param


5 changes: 2 additions & 3 deletions smt/surrogate_models/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import numpy as np
import pickle
from sklearn import linear_model

from smt.surrogate_models.surrogate_model import SurrogateModel
Expand Down Expand Up @@ -102,10 +101,10 @@ def _predict_derivatives(self, x, kx):
n_eval, n_features_x = x.shape
y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx]
return y

def save(self, filename):
persistance.save(self, filename)

@staticmethod
def load(filename):
return (persistance.load(filename))
return persistance.load(filename)
3 changes: 1 addition & 2 deletions smt/surrogate_models/qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import scipy
import pickle

from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistance
Expand Down Expand Up @@ -171,4 +170,4 @@ def save(self, filename):

@staticmethod
def load(filename):
return (persistance.load(filename))
return persistance.load(filename)
59 changes: 24 additions & 35 deletions smt/tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

from smt.problems import Sphere
from smt.sampling_methods import LHS
from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN, QP
from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN


class TestSaveLoad(unittest.TestCase):

def test_save_load_GEKPLS(self):

filename = "sm_save_test"
fun = Sphere(ndim=2)

Expand All @@ -22,77 +20,68 @@ def test_save_load_GEKPLS(self):
yd = fun(xt, kx=i)
yt = np.concatenate((yt, yd), axis=1)

sm = GEKPLS()
X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25)
Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25)
X, Y = np.meshgrid(X, Y)
Z1 = np.zeros((X.shape[0], X.shape[1]))
Z2 = np.zeros((X.shape[0], X.shape[1]))

sm = GEKPLS(print_global=False)
sm.set_training_values(xt, yt[:, 0])
for i in range(2):
sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm.train()

for i in range(X.shape[0]):
for j in range(X.shape[1]):
Z1[i, j] = sm.predict_values(
np.hstack((X[i, j], Y[i, j])).reshape((1, 2))
).item()
sm.save(filename)
self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.")

file_size = os.path.getsize(filename)
print(f"Taille du fichier : {file_size} octets")

sm2 = GEKPLS.load(filename)
self.assertIsNotNone(sm2)

X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25)
Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25)
X, Y = np.meshgrid(X, Y)
Z = np.zeros((X.shape[0], X.shape[1]))

for i in range(X.shape[0]):
for j in range(X.shape[1]):
Z[i, j] = sm.predict_values(
Z2[i, j] = sm2.predict_values(
np.hstack((X[i, j], Y[i, j])).reshape((1, 2))
).item()

self.assertIsNotNone(Z)
print("Prédictions avec le modèle chargé :", Z)
np.testing.assert_allclose(Z1, Z2)

os.remove(filename)

def test_save_krigs(self):

def test_save_load_surrogates(self):
krigs = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS]
rng = np.random.RandomState(1)
N_inducing = 30
num = 100

xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0])
x = np.linspace(0.0, 4.0, num).reshape(-1, 1)

for surrogate in krigs:

filename = "sm_save_test"

sm = surrogate()
sm = surrogate(print_global=False)
sm.set_training_values(xt, yt)

if surrogate == SGP:
sm.Z = 2 * rng.rand(N_inducing, 1) - 1
sm.set_inducing_inputs(Z=sm.Z)

sm.train()

y1 = sm.predict_values(x)
sm.save(filename)
self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.")

file_size = os.path.getsize(filename)
print(f"Taille du fichier : {file_size} octets")

sm2 = surrogate.load(filename)
self.assertIsNotNone(sm2)

num = 100
x = np.linspace(0.0, 4.0, num).reshape(-1, 1)
y2 = sm2.predict_values(x)

y = sm2.predict_values(x)

self.assertIsNotNone(y)
print("Prédictions avec le modèle chargé :", y)
np.testing.assert_allclose(y1, y2)

os.remove(filename)


if __name__ == "__main__":
unittest.main()
unittest.main()
13 changes: 5 additions & 8 deletions smt/utils/persistance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@


def save(self, filename):
try:
with open(filename, 'wb') as file:
pickle.dump(self, file)
print("model saved")
except:
print("Couldn't save the model")
with open(filename, "wb") as file:
pickle.dump(self, file)


@staticmethod
def load(filename):
sm = None
with open(filename, "rb") as file:
sm = pickle.load(file)
return sm
return sm

0 comments on commit 404b28d

Please sign in to comment.