Skip to content

Commit

Permalink
New save or load functions for the surrogate models using pickle (#689)
Browse files Browse the repository at this point in the history
* new save and load functions for the surrogate models

* new save functions for the surrogates and identification of the necessary attributes for predict_values()

* final implementation of an save/load option for the surrogates with pickle

* better test's assertions

* ruff check modifications

* changes to validate the pull request
  • Loading branch information
Antoine-Averland authored Dec 10, 2024
1 parent 5fa7fa9 commit 5bfec9a
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 2 deletions.
8 changes: 8 additions & 0 deletions smt/surrogate_models/genn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from jenn.model import NeuralNet

from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistence

# The missing type
SMTrainingPoints = Dict[Union[int, None], Dict[int, List[np.ndarray]]]
Expand Down Expand Up @@ -212,6 +213,13 @@ def _train(self):
)
self.model.fit(X, Y, J, **kwargs)

def save(self, filename):
persistence.save(self, filename)

@staticmethod
def load(filename):
return persistence.load(filename)

def _predict_values(self, x):
return self.model.predict(x.T).T

Expand Down
4 changes: 2 additions & 2 deletions smt/surrogate_models/gpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ def predict_variance_gradients(self, x):
Returns all variance gradients at the given x points as a [n, nx] matrix"""
return self._gpx.predict_var_gradients(x)

def save(self, filepath):
def save(self, filename):
"""Save the trained model in the given filepath
Arguments
---------
filename (string): path to the json file
"""
self._gpx.save(filepath)
self._gpx.save(filename)

@staticmethod
def load(filepath):
Expand Down
8 changes: 8 additions & 0 deletions smt/surrogate_models/krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from smt.sampling_methods import LHS
from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistence

from smt.kernels import (
SquarSinExp,
Expand Down Expand Up @@ -1870,6 +1871,13 @@ def _predict_variances(
s2[s2 < 0.0] = 0.0
return s2

def save(self, filename):
persistence.save(self, filename)

@staticmethod
def load(filename):
return persistence.load(filename)

def _predict_variance_derivatives(self, x, kx):
"""
Provide the derivatives of the variance of the model at a set of points
Expand Down
8 changes: 8 additions & 0 deletions smt/surrogate_models/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn import linear_model

from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistence
from smt.utils.caching import cached_operation


Expand Down Expand Up @@ -99,3 +100,10 @@ def _predict_derivatives(self, x, kx):
n_eval, n_features_x = x.shape
y = np.ones((n_eval, self.ny)) * self.mod.coef_[:, kx]
return y

def save(self, filename):
persistence.save(self, filename)

@staticmethod
def load(filename):
return persistence.load(filename)
8 changes: 8 additions & 0 deletions smt/surrogate_models/qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import scipy

from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistence
from smt.utils.caching import cached_operation
from smt.utils.misc import standardization

Expand Down Expand Up @@ -162,3 +163,10 @@ def _predict_values(self, x):
y = (self.y_mean + self.y_std * y_).ravel()
y = y.reshape((x.shape[0], self.ny))
return y

def save(self, filename):
persistence.save(self, filename)

@staticmethod
def load(filename):
return persistence.load(filename)
6 changes: 6 additions & 0 deletions smt/surrogate_models/surrogate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,9 @@ def _check_xdim(self, x):
"""Raise a ValueError if x dimension is not consistent with surrogate model training data dimension.
This method is used as a guard in preamble of predict methods"""
check_nx(self.nx, x)

def save(self, filename):
"""
Implemented by surrogate models to save the surrogate object in a file
"""
raise NotImplementedError("save() has to be implemented by the given surrogate")
87 changes: 87 additions & 0 deletions smt/surrogate_models/tests/test_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import unittest
import numpy as np

from smt.problems import Sphere
from smt.sampling_methods import LHS
from smt.surrogate_models import KRG, LS, KPLS, GEKPLS, KPLSK, MGP, QP, SGP, GENN


class TestSaveLoad(unittest.TestCase):
def test_save_load_GEKPLS(self):
filename = "sm_save_test"
fun = Sphere(ndim=2)

sampling = LHS(xlimits=fun.xlimits, criterion="m")
xt = sampling(20)
yt = fun(xt)

for i in range(2):
yd = fun(xt, kx=i)
yt = np.concatenate((yt, yd), axis=1)

X = np.arange(fun.xlimits[0, 0], fun.xlimits[0, 1], 0.25)
Y = np.arange(fun.xlimits[1, 0], fun.xlimits[1, 1], 0.25)
X, Y = np.meshgrid(X, Y)
Z1 = np.zeros((X.shape[0], X.shape[1]))
Z2 = np.zeros((X.shape[0], X.shape[1]))

sm = GEKPLS(print_global=False)
sm.set_training_values(xt, yt[:, 0])
for i in range(2):
sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm.train()
for i in range(X.shape[0]):
for j in range(X.shape[1]):
Z1[i, j] = sm.predict_values(
np.hstack((X[i, j], Y[i, j])).reshape((1, 2))
).item()
sm.save(filename)

sm2 = GEKPLS.load(filename)
self.assertIsNotNone(sm2)

for i in range(X.shape[0]):
for j in range(X.shape[1]):
Z2[i, j] = sm2.predict_values(
np.hstack((X[i, j], Y[i, j])).reshape((1, 2))
).item()

np.testing.assert_allclose(Z1, Z2)

os.remove(filename)

def test_save_load_surrogates(self):
surrogates = [KRG, KPLS, KPLSK, MGP, SGP, QP, GENN, LS]
rng = np.random.RandomState(1)
N_inducing = 30
num = 100

xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0])
x = np.linspace(0.0, 4.0, num).reshape(-1, 1)

for surrogate in surrogates:
filename = "sm_save_test"

sm = surrogate(print_global=False)
sm.set_training_values(xt, yt)

if surrogate == SGP:
sm.Z = 2 * rng.rand(N_inducing, 1) - 1
sm.set_inducing_inputs(Z=sm.Z)

sm.train()
y1 = sm.predict_values(x)
sm.save(filename)

sm2 = surrogate.load(filename)
y2 = sm2.predict_values(x)

np.testing.assert_allclose(y1, y2)

os.remove(filename)


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions smt/utils/persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pickle


def save(self, filename):
with open(filename, "wb") as file:
pickle.dump(self, file)


def load(filename):
sm = None
with open(filename, "rb") as file:
sm = pickle.load(file)
return sm

0 comments on commit 5bfec9a

Please sign in to comment.