From 5722846cbdea68230a636a99e876bc0b397a181c Mon Sep 17 00:00:00 2001 From: Antoine-Averland Date: Thu, 12 Dec 2024 15:21:02 +0100 Subject: [PATCH] Add compression of data at the save to reduce the size of the file (#694) --- smt/surrogate_models/tests/test_save_load.py | 7 ++++--- smt/utils/persistence.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/smt/surrogate_models/tests/test_save_load.py b/smt/surrogate_models/tests/test_save_load.py index 00419ce88..3785912ef 100644 --- a/smt/surrogate_models/tests/test_save_load.py +++ b/smt/surrogate_models/tests/test_save_load.py @@ -25,13 +25,14 @@ class TestSaveLoad(unittest.TestCase): def test_save_load_GEKPLS(self): filename = "sm_save_test" - fun = Sphere(ndim=2) + ndim = 2 + fun = Sphere(ndim=ndim) sampling = LHS(xlimits=fun.xlimits, criterion="m") xt = sampling(20) yt = fun(xt) - for i in range(2): + for i in range(ndim): yd = fun(xt, kx=i) yt = np.concatenate((yt, yd), axis=1) @@ -43,7 +44,7 @@ def test_save_load_GEKPLS(self): sm = GEKPLS(print_global=False) sm.set_training_values(xt, yt[:, 0]) - for i in range(2): + for i in range(ndim): sm.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i) sm.train() for i in range(X.shape[0]): diff --git a/smt/utils/persistence.py b/smt/utils/persistence.py index 69d1133e5..78ef7c9af 100644 --- a/smt/utils/persistence.py +++ b/smt/utils/persistence.py @@ -1,13 +1,19 @@ import pickle +import zlib def save(self, filename): + serialized_data = pickle.dumps(self, protocol=5) + compressed_data = zlib.compress(serialized_data) with open(filename, "wb") as file: - pickle.dump(self, file) + file.write(compressed_data) def load(filename): - sm = None with open(filename, "rb") as file: - sm = pickle.load(file) + compressed_data = file.read() + + serialized_data = zlib.decompress(compressed_data) + sm = pickle.loads(serialized_data) + return sm