Skip to content

Commit

Permalink
final implementation of an save/load option for the surrogates with p…
Browse files Browse the repository at this point in the history
…ickle
  • Loading branch information
Antoine-Averland committed Dec 9, 2024
1 parent 9614c8f commit 2936d69
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 151 deletions.
8 changes: 8 additions & 0 deletions smt/surrogate_models/genn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from jenn.model import NeuralNet

from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistance

# The missing type
SMTrainingPoints = Dict[Union[int, None], Dict[int, List[np.ndarray]]]
Expand Down Expand Up @@ -212,6 +213,13 @@ def _train(self):
)
self.model.fit(X, Y, J, **kwargs)

def save(self, filename):
persistance.save(self, filename)

@staticmethod
def load(filename):
return (persistance.load(filename))

def _predict_values(self, x):
return self.model.predict(x.T).T

Expand Down
23 changes: 2 additions & 21 deletions smt/surrogate_models/gpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,33 +164,14 @@ def predict_variance_gradients(self, x):
Returns all variance gradients at the given x points as a [n, nx] matrix"""
return self._gpx.predict_var_gradients(x)

def _save(self, filename):
def save(self, filename):
"""Save the trained model in the given filepath
Arguments
---------
filename (string): path to the json file
"""
if filename is None:
filename = self.filename

try:
with open(filename, 'wb') as file:
pickle.dump(self, file)
print("model saved")
except:
print("Couldn't save the model")
# self._gpx.save(filepath)



# def _load(self, filename):
# if filename is None:
# return ("file is not found")
# else:
# with open(filename, "rb") as file:
# sm2 = pickle.load(file)
# return sm2
self._gpx.save(filename)


@staticmethod
Expand Down
26 changes: 6 additions & 20 deletions smt/surrogate_models/krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import warnings
from copy import deepcopy
from enum import Enum
from joblib import dump, load

import numpy as np
import pickle
from scipy import linalg, optimize
from scipy.stats import multivariate_normal as m_norm

from smt.sampling_methods import LHS
from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistance

from smt.kernels import (
SquarSinExp,
Expand Down Expand Up @@ -78,7 +77,6 @@ class KrgBased(SurrogateModel):
"act_exp": ActExp,
}
name = "KrigingBased"
filename = "kriging_model"

def _initialize(self):
super(KrgBased, self)._initialize()
Expand Down Expand Up @@ -1879,24 +1877,12 @@ def _predict_variances(
s2[s2 < 0.0] = 0.0
return s2

def _save(self, filename=None):
if filename is None:
filename = self.filename
def save(self, filename):
persistance.save(self, filename)

try:
with open(filename, 'wb') as file:
pickle.dump(self, file)
print("model saved")
except:
print("Couldn't save the model")

def _load(self, filename):
if filename is None:
return ("file is not found")
else:
with open(filename, "rb") as file:
sm2 = pickle.load(file)
return sm2
@staticmethod
def load(filename):
return (persistance.load(filename))

def _predict_variance_derivatives(self, x, kx):
"""
Expand Down
25 changes: 7 additions & 18 deletions smt/surrogate_models/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn import linear_model

from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistance
from smt.utils.caching import cached_operation


Expand Down Expand Up @@ -102,21 +103,9 @@ def _predict_derivatives(self, x, kx):
y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx]
return y

def _save(self, filename=None):
if filename is None:
filename = self.filename

try:
with open(filename, 'wb') as file:
pickle.dump(self, file)
print("model saved")
except:
print("Couldn't save the model")

def _load(self, filename):
if filename is None:
return ("file is not found")
else:
with open(filename, "rb") as file:
sm2 = pickle.load(file)
return sm2
def save(self, filename):
persistance.save(self, filename)

@staticmethod
def load(filename):
return (persistance.load(filename))
25 changes: 7 additions & 18 deletions smt/surrogate_models/qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pickle

from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistance
from smt.utils.caching import cached_operation
from smt.utils.misc import standardization

Expand Down Expand Up @@ -165,21 +166,9 @@ def _predict_values(self, x):
y = y.reshape((x.shape[0], self.ny))
return y

def _save(self, filename=None):
if filename is None:
filename = self.filename

try:
with open(filename, 'wb') as file:
pickle.dump(self, file)
print("model saved")
except:
print("Couldn't save the model")

def _load(self, filename):
if filename is None:
return ("file is not found")
else:
with open(filename, "rb") as file:
sm2 = pickle.load(file)
return sm2
def save(self, filename):
persistance.save(self, filename)

@staticmethod
def load(filename):
return (persistance.load(filename))
4 changes: 2 additions & 2 deletions smt/surrogate_models/surrogate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ def _check_xdim(self, x):
This method is used as a guard in preamble of predict methods"""
check_nx(self.nx, x)

def _save(self, filename):
def save(self, filename):
"""
Implemented by surrogate models to save the surrogate object in a file
"""
pass
raise NotImplementedError("save() has to be implemented by the given surrogate")
46 changes: 0 additions & 46 deletions smt/tests/test_identification_attibuts_krig.py

This file was deleted.

99 changes: 73 additions & 26 deletions smt/tests/test_save_load.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,98 @@
import os
import unittest
import numpy as np
from smt.surrogate_models import KRG
# from smt.surrogate_models import IDW //TODO comprendre pourquoi l"import de IDW génère une erreur
from smt.surrogate_models import LS
from smt.surrogate_models import KPLS
from smt.surrogate_models import GEKPLS
from smt.surrogate_models import GPX
from smt.surrogate_models import KPLSK
from smt.surrogate_models import MGP
from smt.surrogate_models import QP
# from smt.surrogate_models import RBF
# from smt.surrogate_models import RMTS
from smt.surrogate_models import SGP

from smt.problems import Sphere
from smt.sampling_methods import LHS
from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN, QP


class TestSaveLoad(unittest.TestCase):

def test_save_load(self):

filename = "kriging_save_test"
def test_save_load_GEKPLS(self):

xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0])
filename = "sm_save_test"
fun = Sphere(ndim=2)

sampling = LHS(xlimits=fun.xlimits, criterion="m")
xt = sampling(20)
yt = fun(xt)

for i in range(2):
yd = fun(xt, kx=i)
yt = np.concatenate((yt, yd), axis=1)

sm = SGP()
sm.set_training_values(xt, yt)
sm = GEKPLS()
sm.set_training_values(xt, yt[:, 0])
for i in range(2):
sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm.train()

sm._save(filename)
sm.save(filename)
self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.")

file_size = os.path.getsize(filename)
print(f"Taille du fichier : {file_size} octets")

sm2 = sm._load(filename)
sm2 = GEKPLS.load(filename)
self.assertIsNotNone(sm2)

num = 100
x = np.linspace(0.0, 4.0, num).reshape(-1, 1)
X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25)
Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25)
X, Y = np.meshgrid(X, Y)
Z = np.zeros((X.shape[0], X.shape[1]))

y = sm2.predict_values(x)
for i in range(X.shape[0]):
for j in range(X.shape[1]):
Z[i, j] = sm.predict_values(
np.hstack((X[i, j], Y[i, j])).reshape((1, 2))
).item()

self.assertIsNotNone(y)
print("Prédictions avec le modèle chargé :", y)
self.assertIsNotNone(Z)
print("Prédictions avec le modèle chargé :", Z)

os.remove(filename)

def test_save_krigs(self):

krigs = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS]
rng = np.random.RandomState(1)
N_inducing = 30

xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0])

for surrogate in krigs:

filename = "sm_save_test"

sm = surrogate()
sm.set_training_values(xt, yt)

if surrogate == SGP:
sm.Z = 2 * rng.rand(N_inducing, 1) - 1
sm.set_inducing_inputs(Z=sm.Z)

sm.train()

sm.save(filename)
self.assertTrue(os.path.exists(filename), f"Le fichier {filename} n'a pas été créé.")

file_size = os.path.getsize(filename)
print(f"Taille du fichier : {file_size} octets")

sm2 = surrogate.load(filename)
self.assertIsNotNone(sm2)

num = 100
x = np.linspace(0.0, 4.0, num).reshape(-1, 1)

y = sm2.predict_values(x)

self.assertIsNotNone(y)
print("Prédictions avec le modèle chargé :", y)

os.remove(filename)

if __name__ == "__main__":
unittest.main()
16 changes: 16 additions & 0 deletions smt/utils/persistance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pickle


def save(self, filename):
try:
with open(filename, 'wb') as file:
pickle.dump(self, file)
print("model saved")
except:
print("Couldn't save the model")

@staticmethod
def load(filename):
with open(filename, "rb") as file:
sm = pickle.load(file)
return sm

0 comments on commit 2936d69

Please sign in to comment.