Skip to content

Commit

Permalink
Merge branch 'dev-castano' of https://github.com/mcastanoUQ/smt into …
Browse files Browse the repository at this point in the history
…dev-castano
  • Loading branch information
mcastanoUQ committed Dec 14, 2024
2 parents 2bc9fd2 + f146c6b commit 036345b
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 6 deletions.
33 changes: 33 additions & 0 deletions smt/applications/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from smt.applications.application import SurrogateBasedApplication
from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistence

MOE_EXPERT_NAMES = [
"KRG",
Expand Down Expand Up @@ -172,6 +173,31 @@ def enabled_experts(self):
self._enabled_expert_types = self._get_enabled_expert_types()
return list(self._enabled_expert_types.keys())

def __getstate__(self):
state = self.__dict__.copy()
for expert in self._experts:
if expert == "IDW":
state["idwc"] = None
elif expert == "RMTC":
state["rmtsc"] = None
elif expert == "RMTB":
state["rmtsc"] = None
elif expert == "RBF":
state["rbfc"] = None
return state

def __setstate__(self, state):
self.__dict__.update(state)
for expert in self._experts:
if expert == "IDW":
expert._setup()
elif expert == "RMTC":
expert._setup()
elif expert == "RMTB":
expert._setup()
elif expert == "RBF":
expert._setup()

def set_training_values(self, xt, yt, name=None):
"""
Set training data (values).
Expand Down Expand Up @@ -760,3 +786,10 @@ def _proba_cluster(self, x, distribs=None):
)

return probs

def save(self, filename):
persistence.save(self, filename)

@staticmethod
def load(filename):
return persistence.load(filename)
91 changes: 91 additions & 0 deletions smt/applications/tests/test_save_load_moe_mfk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import unittest
import numpy as np

from smt.applications.mfk import MFK, NestedLHS
from smt.applications.mfkpls import MFKPLS
from smt.applications.mfkplsk import MFKPLSK
from smt.applications.moe import MOE
from smt.sampling_methods import FullFactorial


class TestSaveLoad(unittest.TestCase):
def function_test_1d(self, x):
x = np.reshape(x, (-1,))
y = np.zeros(x.shape)
y[x < 0.4] = x[x < 0.4] ** 2
y[(x >= 0.4) & (x < 0.8)] = 3 * x[(x >= 0.4) & (x < 0.8)] + 1
y[x >= 0.8] = np.sin(10 * x[x >= 0.8])
return y.reshape((-1, 1))

def test_save_load_moe(self):
filename = "moe_save_test"
nt = 35
x = np.linspace(0, 1, 100)

sampling = FullFactorial(xlimits=np.array([[0, 1]]), clip=True)
np.random.seed(0)
xt = sampling(nt)
yt = self.function_test_1d(xt)

moe1 = MOE(n_clusters=1)
moe1.set_training_values(xt, yt)
moe1.train()
y_moe1 = moe1.predict_values(x)

moe1.save(filename)

moe2 = MOE.load(filename)
y_moe2 = moe2.predict_values(x)

np.testing.assert_allclose(y_moe1, y_moe2)

os.remove(filename)

def lf_function(self, x):
return 0.5 * ((x * 6 - 2) ** 2) * np.sin((x * 6 - 2) * 2) + (x - 0.5) * 10.0 - 5

def hf_function(self, x):
return ((x * 6 - 2) ** 2) * np.sin((x * 6 - 2) * 2)

def _setup_MFKs(self):
xlimits = np.array([[0.0, 1.0]])
xdoes = NestedLHS(nlevel=2, xlimits=xlimits, random_state=0)
xt_c, xt_e = xdoes(7)
yt_e = self.hf_function(xt_e)
yt_c = self.lf_function(xt_c)
x = np.linspace(0, 1, 101, endpoint=True).reshape(-1, 1)
return (xt_c, xt_e, yt_c, yt_e, x)

def test_save_load_mfk_mfkpls_mfkplsk(self):
filename = "MFKs_save_test"
MFKs = [MFK, MFKPLS, MFKPLSK]
xt_c, xt_e, yt_c, yt_e, x = self._setup_MFKs()
ncomp = 1
for mfk in MFKs:
if mfk == MFK:
application = MFK(theta0=xt_e.shape[1] * [1.0], corr="squar_exp")
elif mfk == MFKPLS:
application = MFKPLS(n_comp=ncomp, theta0=ncomp * [1.0])
else:
application = MFKPLSK(n_comp=ncomp, theta0=ncomp * [1.0])

application.set_training_values(xt_c, yt_c, name=0)
application.set_training_values(xt_e, yt_e)

application.train()
application.save(filename)

x = np.linspace(0, 1, 101, endpoint=True).reshape(-1, 1)
y1 = application.predict_values(x)

mfk2 = MFK.load(filename)
y2 = mfk2.predict_values(x)

np.testing.assert_allclose(y1, y2)

os.remove(filename)


if __name__ == "__main__":
unittest.main()
8 changes: 5 additions & 3 deletions smt/surrogate_models/tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
import numpy as np


from smt.problems import Sphere
from smt.sampling_methods import LHS
from smt.surrogate_models import (
Expand All @@ -24,13 +25,14 @@
class TestSaveLoad(unittest.TestCase):
def test_save_load_GEKPLS(self):
filename = "sm_save_test"
fun = Sphere(ndim=2)
ndim = 2
fun = Sphere(ndim=ndim)

sampling = LHS(xlimits=fun.xlimits, criterion="m")
xt = sampling(20)
yt = fun(xt)

for i in range(2):
for i in range(ndim):
yd = fun(xt, kx=i)
yt = np.concatenate((yt, yd), axis=1)

Expand All @@ -42,7 +44,7 @@ def test_save_load_GEKPLS(self):

sm = GEKPLS(print_global=False)
sm.set_training_values(xt, yt[:, 0])
for i in range(2):
for i in range(ndim):
sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm.train()
for i in range(X.shape[0]):
Expand Down
12 changes: 9 additions & 3 deletions smt/utils/persistence.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pickle
import zlib


def save(self, filename):
serialized_data = pickle.dumps(self, protocol=5)
compressed_data = zlib.compress(serialized_data)
with open(filename, "wb") as file:
pickle.dump(self, file)
file.write(compressed_data)


def load(filename):
sm = None
with open(filename, "rb") as file:
sm = pickle.load(file)
compressed_data = file.read()

serialized_data = zlib.decompress(compressed_data)
sm = pickle.loads(serialized_data)

return sm

0 comments on commit 036345b

Please sign in to comment.