From 161a22c4b225e14cea8d745f9ac293e1be364cfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Lafage?= Date: Wed, 10 Apr 2024 11:02:43 +0200 Subject: [PATCH] Add derivatives prediction methods to GPX (#543) * Implement derivatives predictions in GPX * Test GPX derivatives predictions * Require egobox 0.17 * Add predict gradients in GPX API * Skip test if GPX not available * Relax test tolerance to pass on macosx * Use pytest -n auto * Do not run test in parallel and add verbosity * Pin pytest to 8.0.x (try to fix github actions hanging) * Do not run slow test in parallel * Upgrade to egobox 0.18 * Avoid warning about GPX until GPX is actually used * Update tests when GPX not available * Fix error message --- .github/workflows/tests.yml | 2 +- .github/workflows/tests_coverage.yml | 2 +- .github/workflows/tests_minimal.yml | 4 +- doc/requirements.txt | 2 +- requirements.txt | 14 +++--- setup.py | 2 +- smt/surrogate_models/gpx.py | 38 +++++++++++++--- smt/surrogate_models/tests/test_gpx.py | 61 +++++++++++++++++++++----- smt/tests/test_all.py | 3 +- 9 files changed, 97 insertions(+), 31 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a44926597..dce35fc5c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -37,6 +37,6 @@ jobs: pip install -e . - name: Test with pytest - run: pytest --durations=0 -n 4 smt + run: pytest -v --durations=0 smt diff --git a/.github/workflows/tests_coverage.yml b/.github/workflows/tests_coverage.yml index 6f2ee3116..94b13f875 100644 --- a/.github/workflows/tests_coverage.yml +++ b/.github/workflows/tests_coverage.yml @@ -38,7 +38,7 @@ jobs: - name: Test with pytest and coverage run: | - RUN_SLOW_TESTS=1 pytest --cov=smt -n 4 + RUN_SLOW_TESTS=1 pytest --cov=smt - name: Coveralls uses: AndreMiras/coveralls-python-action@develop diff --git a/.github/workflows/tests_minimal.yml b/.github/workflows/tests_minimal.yml index e3d37742e..b1849bcb6 100644 --- a/.github/workflows/tests_minimal.yml +++ b/.github/workflows/tests_minimal.yml @@ -24,7 +24,7 @@ jobs: - name: Test with pytest run: | - pip install pytest pytest-xdist - pytest -n 4 smt + pip install pytest==8.0.1 pytest-xdist + pytest -n auto smt diff --git a/doc/requirements.txt b/doc/requirements.txt index 48bcc78c8..575e80d2b 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -7,5 +7,5 @@ numpydoc matplotlib ConfigSpace ~= 0.6.1 jenn >= 1.0.2, <2.0 -egobox ~= 0.16.0 +egobox ~= 0.18.0 git+https://github.com/hwangjt/sphinx_auto_embed.git # for doc generation \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 29de020d3..a7c37c751 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,12 +3,12 @@ numpy scipy scikit-learn pyDOE3 -numba # JIT compiler -matplotlib # used in examples and tests -pytest # tests runner -pytest-xdist # allows running parallel testing with pytest -n -pytest-cov # allows to get coverage report -ruff # format and lint code +numba # JIT compiler +matplotlib # used in examples and tests +pytest ~= 8.0.1 # tests runner +pytest-xdist # allows running parallel testing with pytest -n +pytest-cov # allows to get coverage report +ruff # format and lint code ConfigSpace ~= 0.6.1 jenn >= 1.0.2, <2.0 -egobox ~= 0.16.0 +egobox ~= 0.18.0 diff --git a/setup.py b/setup.py index 516dae916..83548ec3c 100644 --- a/setup.py +++ b/setup.py @@ -120,7 +120,7 @@ "cs": [ # pip install smt[cs] "ConfigSpace~=0.6.1", ], - "gpx": ["egobox~=0.16"], # pip install smt[gpx] + "gpx": ["egobox~=0.18"], # pip install smt[gpx] }, python_requires=">=3.8", zip_safe=False, diff --git a/smt/surrogate_models/gpx.py b/smt/surrogate_models/gpx.py index 29aa012c0..5b662f853 100644 --- a/smt/surrogate_models/gpx.py +++ b/smt/surrogate_models/gpx.py @@ -1,5 +1,3 @@ -import warnings - import numpy as np from smt.surrogate_models.surrogate_model import SurrogateModel @@ -22,7 +20,6 @@ } except ImportError: GPX_AVAILABLE = False - warnings.warn("To use GPX you have to install dependencies: pip install smt['gpx']") class GPX(SurrogateModel): @@ -30,6 +27,12 @@ class GPX(SurrogateModel): def _initialize(self): super(GPX, self)._initialize() + + if not GPX_AVAILABLE: + raise RuntimeError( + 'GPX not available. Please install GPX dependencies with: pip install smt["gpx"]' + ) + declare = self.options.declare declare( @@ -80,7 +83,9 @@ def _initialize(self): ) supports = self.supports + supports["derivatives"] = True supports["variances"] = True + supports["variance_derivatives"] = True self._gpx = None @@ -101,8 +106,27 @@ def _train(self): self._gpx = egx.Gpx.builder(**config).fit(xt, yt) - def _predict_values(self, xt): - return self._gpx.predict_values(xt) + def _predict_values(self, x): + return self._gpx.predict(x) + + def _predict_variances(self, x): + return self._gpx.predict_var(x) + + def _predict_derivatives(self, x, kx): + return self._gpx.predict_gradients(x)[:, kx : kx + 1] + + def _predict_variance_derivatives(self, x, kx): + return self._gpx.predict_var_gradients(x)[:, kx : kx + 1] + + def predict_gradients(self, x): + """Predict derivatives wrt to all x components (eg the gradient) + at n points given as [n, nx] matrix where nx is the dimension of x. + Returns all gradients at the given x points as a [n, nx] matrix + """ + return self._gpx.predict_gradients(x) - def _predict_variances(self, xt): - return self._gpx.predict_variances(xt) + def predict_variance_gradients(self, x): + """Predict variance derivatives wrt to all x components (eg the variance gradient) + at n points given as [n, nx] matrix where nx is the dimension of x. + Returns all variance gradients at the given x points as a [n, nx] matrix""" + return self._gpx.predict_var_gradients(x) diff --git a/smt/surrogate_models/tests/test_gpx.py b/smt/surrogate_models/tests/test_gpx.py index 8e449cb5a..641af2008 100644 --- a/smt/surrogate_models/tests/test_gpx.py +++ b/smt/surrogate_models/tests/test_gpx.py @@ -4,7 +4,7 @@ from smt.problems import Sphere from smt.sampling_methods import LHS -from smt.surrogate_models import GPX +from smt.surrogate_models import GPX, KRG from smt.surrogate_models.gpx import GPX_AVAILABLE @@ -12,7 +12,7 @@ class TestGPX(unittest.TestCase): @unittest.skipIf(not GPX_AVAILABLE, "GPX not available") def test_gpx(self): ndim = 2 - num = 50 + num = 20 problem = Sphere(ndim=ndim) xlimits = problem.xlimits sampling = LHS(xlimits=xlimits, criterion="ese") @@ -20,19 +20,60 @@ def test_gpx(self): xt = sampling(num) yt = problem(xt) - sm = GPX() - sm.set_training_values(xt, yt) - sm.train() + gpx = GPX(print_global=False, seed=42) + gpx.set_training_values(xt, yt) + gpx.train() xe = sampling(10) ye = problem(xe) - ytest = sm.predict_values(xe) - e_error = np.linalg.norm(ytest - ye) / np.linalg.norm(ye) - self.assertLessEqual(e_error, 2e-2) + # Prediction should be pretty good + gpx_y = gpx.predict_values(xe) + e_error = np.linalg.norm(gpx_y - ye) / np.linalg.norm(ye) + self.assertLessEqual(e_error, 1e-3) - vars = sm.predict_variances(xt) - self.assertLessEqual(np.linalg.norm(vars), 1e-6) + gpx_var = gpx.predict_variances(xe) + self.assertLessEqual(np.linalg.norm(gpx_var), 1e-3) + + @unittest.skipIf(not GPX_AVAILABLE, "GPX not available") + def test_gpx_vs_krg(self): + ndim = 3 + num = 30 + problem = Sphere(ndim=ndim) + xlimits = problem.xlimits + sampling = LHS(xlimits=xlimits, criterion="ese", random_state=42) + + xt = sampling(num) + yt = problem(xt) + + gpx = GPX(print_global=False, seed=42) + gpx.set_training_values(xt, yt) + gpx.train() + + xe = sampling(10) + + gpx_y = gpx.predict_values(xe) + gpx_var = gpx.predict_variances(xe) + + # Compare against KRG + krg = KRG(print_global=False) + krg.set_training_values(xt, yt) + krg.train() + + krg_y = krg.predict_values(xe) + np.testing.assert_allclose(gpx_y, krg_y, atol=1e-2) + + krg_var = krg.predict_variances(xe) + np.testing.assert_allclose(gpx_var, krg_var, atol=1e-2) + + for kx in range(ndim): + dy = gpx.predict_derivatives(xe, kx) + krg_dy = krg.predict_derivatives(xe, kx) + np.testing.assert_allclose(dy, krg_dy, atol=1e-3) + + dvar = gpx.predict_variance_derivatives(xe, kx) + krg_dvar = krg.predict_variance_derivatives(xe, kx) + np.testing.assert_allclose(dvar, krg_dvar, atol=1e-3) if __name__ == "__main__": diff --git a/smt/tests/test_all.py b/smt/tests/test_all.py index 6f4f421a6..4e7492578 100644 --- a/smt/tests/test_all.py +++ b/smt/tests/test_all.py @@ -55,7 +55,8 @@ def setUp(self): sms = OrderedDict() sms["LS"] = LS() sms["QP"] = QP() - sms["GPX"] = GPX() + if GPX_AVAILABLE: + sms["GPX"] = GPX() sms["KRG"] = KRG(theta0=[1e-2] * ndim) sms["KPLS"] = KPLS(theta0=[1e-2] * ncomp, n_comp=ncomp) sms["KPLSK"] = KPLSK(theta0=[1] * ncomp, n_comp=ncomp)