Skip to content

Commit

Permalink
Add derivatives prediction methods to GPX (#543)
Browse files Browse the repository at this point in the history
* Implement derivatives predictions in GPX

* Test GPX derivatives predictions

* Require egobox 0.17

* Add predict gradients in GPX API

* Skip test if GPX not available

* Relax test tolerance to pass on macosx

* Use pytest -n auto

* Do not run test in parallel and add verbosity

* Pin pytest to 8.0.x (try to fix github actions hanging)

* Do not run slow test in parallel

* Upgrade to egobox 0.18

* Avoid warning about GPX until GPX is actually used

* Update tests when GPX not available

* Fix error message
  • Loading branch information
relf authored Apr 10, 2024
1 parent fb82f9d commit 161a22c
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ jobs:
pip install -e .
- name: Test with pytest
run: pytest --durations=0 -n 4 smt
run: pytest -v --durations=0 smt


2 changes: 1 addition & 1 deletion .github/workflows/tests_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: Test with pytest and coverage
run: |
RUN_SLOW_TESTS=1 pytest --cov=smt -n 4
RUN_SLOW_TESTS=1 pytest --cov=smt
- name: Coveralls
uses: AndreMiras/coveralls-python-action@develop
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests_minimal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Test with pytest
run: |
pip install pytest pytest-xdist
pytest -n 4 smt
pip install pytest==8.0.1 pytest-xdist
pytest -n auto smt
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ numpydoc
matplotlib
ConfigSpace ~= 0.6.1
jenn >= 1.0.2, <2.0
egobox ~= 0.16.0
egobox ~= 0.18.0
git+https://github.com/hwangjt/sphinx_auto_embed.git # for doc generation
14 changes: 7 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ numpy
scipy
scikit-learn
pyDOE3
numba # JIT compiler
matplotlib # used in examples and tests
pytest # tests runner
pytest-xdist # allows running parallel testing with pytest -n <num_workers>
pytest-cov # allows to get coverage report
ruff # format and lint code
numba # JIT compiler
matplotlib # used in examples and tests
pytest ~= 8.0.1 # tests runner
pytest-xdist # allows running parallel testing with pytest -n <num_workers>
pytest-cov # allows to get coverage report
ruff # format and lint code
ConfigSpace ~= 0.6.1
jenn >= 1.0.2, <2.0
egobox ~= 0.16.0
egobox ~= 0.18.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
"cs": [ # pip install smt[cs]
"ConfigSpace~=0.6.1",
],
"gpx": ["egobox~=0.16"], # pip install smt[gpx]
"gpx": ["egobox~=0.18"], # pip install smt[gpx]
},
python_requires=">=3.8",
zip_safe=False,
Expand Down
38 changes: 31 additions & 7 deletions smt/surrogate_models/gpx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import numpy as np

from smt.surrogate_models.surrogate_model import SurrogateModel
Expand All @@ -22,14 +20,19 @@
}
except ImportError:
GPX_AVAILABLE = False
warnings.warn("To use GPX you have to install dependencies: pip install smt['gpx']")


class GPX(SurrogateModel):
name = "GPX"

def _initialize(self):
super(GPX, self)._initialize()

if not GPX_AVAILABLE:
raise RuntimeError(
'GPX not available. Please install GPX dependencies with: pip install smt["gpx"]'
)

declare = self.options.declare

declare(
Expand Down Expand Up @@ -80,7 +83,9 @@ def _initialize(self):
)

supports = self.supports
supports["derivatives"] = True
supports["variances"] = True
supports["variance_derivatives"] = True

self._gpx = None

Expand All @@ -101,8 +106,27 @@ def _train(self):

self._gpx = egx.Gpx.builder(**config).fit(xt, yt)

def _predict_values(self, xt):
return self._gpx.predict_values(xt)
def _predict_values(self, x):
return self._gpx.predict(x)

def _predict_variances(self, x):
return self._gpx.predict_var(x)

def _predict_derivatives(self, x, kx):
return self._gpx.predict_gradients(x)[:, kx : kx + 1]

def _predict_variance_derivatives(self, x, kx):
return self._gpx.predict_var_gradients(x)[:, kx : kx + 1]

def predict_gradients(self, x):
"""Predict derivatives wrt to all x components (eg the gradient)
at n points given as [n, nx] matrix where nx is the dimension of x.
Returns all gradients at the given x points as a [n, nx] matrix
"""
return self._gpx.predict_gradients(x)

def _predict_variances(self, xt):
return self._gpx.predict_variances(xt)
def predict_variance_gradients(self, x):
"""Predict variance derivatives wrt to all x components (eg the variance gradient)
at n points given as [n, nx] matrix where nx is the dimension of x.
Returns all variance gradients at the given x points as a [n, nx] matrix"""
return self._gpx.predict_var_gradients(x)
61 changes: 51 additions & 10 deletions smt/surrogate_models/tests/test_gpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,76 @@

from smt.problems import Sphere
from smt.sampling_methods import LHS
from smt.surrogate_models import GPX
from smt.surrogate_models import GPX, KRG
from smt.surrogate_models.gpx import GPX_AVAILABLE


class TestGPX(unittest.TestCase):
@unittest.skipIf(not GPX_AVAILABLE, "GPX not available")
def test_gpx(self):
ndim = 2
num = 50
num = 20
problem = Sphere(ndim=ndim)
xlimits = problem.xlimits
sampling = LHS(xlimits=xlimits, criterion="ese")

xt = sampling(num)
yt = problem(xt)

sm = GPX()
sm.set_training_values(xt, yt)
sm.train()
gpx = GPX(print_global=False, seed=42)
gpx.set_training_values(xt, yt)
gpx.train()

xe = sampling(10)
ye = problem(xe)

ytest = sm.predict_values(xe)
e_error = np.linalg.norm(ytest - ye) / np.linalg.norm(ye)
self.assertLessEqual(e_error, 2e-2)
# Prediction should be pretty good
gpx_y = gpx.predict_values(xe)
e_error = np.linalg.norm(gpx_y - ye) / np.linalg.norm(ye)
self.assertLessEqual(e_error, 1e-3)

vars = sm.predict_variances(xt)
self.assertLessEqual(np.linalg.norm(vars), 1e-6)
gpx_var = gpx.predict_variances(xe)
self.assertLessEqual(np.linalg.norm(gpx_var), 1e-3)

@unittest.skipIf(not GPX_AVAILABLE, "GPX not available")
def test_gpx_vs_krg(self):
ndim = 3
num = 30
problem = Sphere(ndim=ndim)
xlimits = problem.xlimits
sampling = LHS(xlimits=xlimits, criterion="ese", random_state=42)

xt = sampling(num)
yt = problem(xt)

gpx = GPX(print_global=False, seed=42)
gpx.set_training_values(xt, yt)
gpx.train()

xe = sampling(10)

gpx_y = gpx.predict_values(xe)
gpx_var = gpx.predict_variances(xe)

# Compare against KRG
krg = KRG(print_global=False)
krg.set_training_values(xt, yt)
krg.train()

krg_y = krg.predict_values(xe)
np.testing.assert_allclose(gpx_y, krg_y, atol=1e-2)

krg_var = krg.predict_variances(xe)
np.testing.assert_allclose(gpx_var, krg_var, atol=1e-2)

for kx in range(ndim):
dy = gpx.predict_derivatives(xe, kx)
krg_dy = krg.predict_derivatives(xe, kx)
np.testing.assert_allclose(dy, krg_dy, atol=1e-3)

dvar = gpx.predict_variance_derivatives(xe, kx)
krg_dvar = krg.predict_variance_derivatives(xe, kx)
np.testing.assert_allclose(dvar, krg_dvar, atol=1e-3)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion smt/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def setUp(self):
sms = OrderedDict()
sms["LS"] = LS()
sms["QP"] = QP()
sms["GPX"] = GPX()
if GPX_AVAILABLE:
sms["GPX"] = GPX()
sms["KRG"] = KRG(theta0=[1e-2] * ndim)
sms["KPLS"] = KPLS(theta0=[1e-2] * ncomp, n_comp=ncomp)
sms["KPLSK"] = KPLSK(theta0=[1] * ncomp, n_comp=ncomp)
Expand Down

0 comments on commit 161a22c

Please sign in to comment.