Skip to content

Commit

Permalink
Merge branch 'master' into dev-castano
Browse files Browse the repository at this point in the history
  • Loading branch information
mcastanoUQ authored Dec 14, 2024
2 parents 27e50b8 + 5722846 commit f146c6b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
7 changes: 4 additions & 3 deletions smt/surrogate_models/tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
class TestSaveLoad(unittest.TestCase):
def test_save_load_GEKPLS(self):
filename = "sm_save_test"
fun = Sphere(ndim=2)
ndim = 2
fun = Sphere(ndim=ndim)

sampling = LHS(xlimits=fun.xlimits, criterion="m")
xt = sampling(20)
yt = fun(xt)

for i in range(2):
for i in range(ndim):
yd = fun(xt, kx=i)
yt = np.concatenate((yt, yd), axis=1)

Expand All @@ -43,7 +44,7 @@ def test_save_load_GEKPLS(self):

sm = GEKPLS(print_global=False)
sm.set_training_values(xt, yt[:, 0])
for i in range(2):
for i in range(ndim):
sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm.train()
for i in range(X.shape[0]):
Expand Down
12 changes: 9 additions & 3 deletions smt/utils/persistence.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pickle
import zlib


def save(self, filename):
serialized_data = pickle.dumps(self, protocol=5)
compressed_data = zlib.compress(serialized_data)
with open(filename, "wb") as file:
pickle.dump(self, file)
file.write(compressed_data)


def load(filename):
sm = None
with open(filename, "rb") as file:
sm = pickle.load(file)
compressed_data = file.read()

serialized_data = zlib.decompress(compressed_data)
sm = pickle.loads(serialized_data)

return sm

0 comments on commit f146c6b

Please sign in to comment.