Skip to content

Commit

Permalink
Add metric computation functions for surrogate performance evaluation (
Browse files Browse the repository at this point in the history
…#698)

* Add error calculation functions and tests

* rename of compute_rms_error() to compute_relative_error()

* changes for the validation of the pull request

* format tutorial with ruff

* rename functions
  • Loading branch information
Antoine-Averland authored Dec 19, 2024
1 parent 61bf8c1 commit 223cec8
Show file tree
Hide file tree
Showing 21 changed files with 230 additions and 148 deletions.
2 changes: 1 addition & 1 deletion smt/applications/tests/test_mfck_1fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from smt.problems import TensorProduct
from smt.sampling_methods import LHS

# from smt.utils.misc import compute_rms_error
# from smt.utils.misc import compute_relative_error
# from smt.utils.silence import Silence
from smt.utils.sm_test_case import SMTestCase
from smt.applications.mfck import MFCK
Expand Down
14 changes: 7 additions & 7 deletions smt/applications/tests/test_mfk.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from smt.applications.mfk import MFK, NestedLHS
from smt.problems import Sphere, TensorProduct
from smt.sampling_methods import LHS, FullFactorial
from smt.utils.misc import compute_rms_error
from smt.utils.misc import compute_relative_error
from smt.utils.silence import Silence
from smt.utils.sm_test_case import SMTestCase

Expand Down Expand Up @@ -84,8 +84,8 @@ def test_mfk(self):
with Silence():
sm.train()

t_error = compute_rms_error(sm)
e_error = compute_rms_error(sm, xe, ye)
t_error = compute_relative_error(sm)
e_error = compute_relative_error(sm, xe, ye)

self.assert_error(t_error, 0.0, 1)
self.assert_error(e_error, 0.0, 1)
Expand Down Expand Up @@ -123,10 +123,10 @@ def test_mfk_derivs(self):
with Silence():
sm.train()

t_error = compute_rms_error(sm)
e_error = compute_rms_error(sm, xe, ye)
e_error0 = compute_rms_error(sm, xe, dye[0], 0)
e_error1 = compute_rms_error(sm, xe, dye[1], 1)
t_error = compute_relative_error(sm)
e_error = compute_relative_error(sm, xe, ye)
e_error0 = compute_relative_error(sm, xe, dye[0], 0)
e_error1 = compute_relative_error(sm, xe, dye[1], 1)

if print_output:
print(
Expand Down
6 changes: 3 additions & 3 deletions smt/applications/tests/test_mfk_1fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from smt.applications.mfk import MFK
from smt.problems import TensorProduct
from smt.sampling_methods import LHS
from smt.utils.misc import compute_rms_error
from smt.utils.misc import compute_relative_error
from smt.utils.silence import Silence
from smt.utils.sm_test_case import SMTestCase

Expand Down Expand Up @@ -53,8 +53,8 @@ def test_mfk_1fidelity(self):
with Silence():
sm.train()

t_error = compute_rms_error(sm)
e_error = compute_rms_error(sm, xv, yv)
t_error = compute_relative_error(sm)
e_error = compute_relative_error(sm, xv, yv)

self.assert_error(t_error, 0.0, 3e-3)
self.assert_error(e_error, 0.0, 3e-3)
Expand Down
10 changes: 5 additions & 5 deletions smt/applications/tests/test_mfkpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from smt.applications.mfkpls import MFKPLS
from smt.problems import Sphere, TensorProduct
from smt.sampling_methods import LHS, FullFactorial
from smt.utils.misc import compute_rms_error
from smt.utils.misc import compute_relative_error
from smt.utils.silence import Silence
from smt.utils.sm_test_case import SMTestCase

Expand Down Expand Up @@ -79,8 +79,8 @@ def test_mfkpls(self):
with Silence():
sm.train()

t_error = compute_rms_error(sm)
e_error = compute_rms_error(sm, xe, ye)
t_error = compute_relative_error(sm)
e_error = compute_relative_error(sm, xe, ye)

self.assert_error(t_error, 0.0, 1.5)
self.assert_error(e_error, 0.0, 1.5)
Expand Down Expand Up @@ -127,8 +127,8 @@ def test_mfkpls_derivs(self):
with Silence():
sm.train()

e_error0 = compute_rms_error(sm, xe, dye[0], 0)
e_error1 = compute_rms_error(sm, xe, dye[1], 1)
e_error0 = compute_relative_error(sm, xe, dye[0], 0)
e_error1 = compute_relative_error(sm, xe, dye[1], 1)

self.assert_error(e_error0, 0.0, 1e-1)
self.assert_error(e_error1, 0.0, 1e-1)
Expand Down
14 changes: 7 additions & 7 deletions smt/applications/tests/test_mfkplsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from smt.applications.mfkplsk import MFKPLSK
from smt.problems import Sphere, TensorProduct
from smt.sampling_methods import LHS, FullFactorial
from smt.utils.misc import compute_rms_error
from smt.utils.misc import compute_relative_error
from smt.utils.silence import Silence
from smt.utils.sm_test_case import SMTestCase

Expand Down Expand Up @@ -79,8 +79,8 @@ def test_mfkplsk(self):
with Silence():
sm.train()

t_error = compute_rms_error(sm)
e_error = compute_rms_error(sm, xe, ye)
t_error = compute_relative_error(sm)
e_error = compute_relative_error(sm, xe, ye)

self.assert_error(t_error, 0.0, 1.5)
self.assert_error(e_error, 0.0, 1.5)
Expand Down Expand Up @@ -130,10 +130,10 @@ def test_mfkplsk_derivs(self):
with Silence():
sm.train()

_t_error = compute_rms_error(sm)
_e_error = compute_rms_error(sm, xe, ye)
e_error0 = compute_rms_error(sm, xe, dye[0], 0)
e_error1 = compute_rms_error(sm, xe, dye[1], 1)
_t_error = compute_relative_error(sm)
_e_error = compute_relative_error(sm, xe, ye)
e_error0 = compute_relative_error(sm, xe, dye[0], 0)
e_error1 = compute_relative_error(sm, xe, dye[1], 1)

self.assert_error(e_error0, 0.0, 1e-1)
self.assert_error(e_error1, 0.0, 1e-1)
Expand Down
12 changes: 6 additions & 6 deletions smt/applications/tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from smt.problems import Branin, LpNorm
from smt.sampling_methods import LHS, FullFactorial
from smt.surrogate_models import RMTB, RMTC
from smt.utils.misc import compute_rms_error
from smt.utils.misc import compute_relative_error
from smt.utils.sm_test_case import SMTestCase


Expand Down Expand Up @@ -64,7 +64,7 @@ def test_1d_50(self):
xe = np.random.sample(self.ne)
ye = self.function_test_1d(xe)

rms_error = compute_rms_error(moe, xe, ye)
rms_error = compute_relative_error(moe, xe, ye)
self.assert_error(rms_error, 0.0, 3e-1)

self.assertRaises(RuntimeError, lambda: moe.predict_variances(xe))
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_1d_50_var(self):
xe = np.random.sample(self.ne)
ye = self.function_test_1d(xe)

rms_error = compute_rms_error(moe, xe, ye)
rms_error = compute_relative_error(moe, xe, ye)
self.assert_error(rms_error, 0.0, 3e-1)

moe.predict_variances(xe)
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_1d_50_surrogate_model(self):
xe = np.random.sample(self.ne)
ye = self.function_test_1d(xe)

rms_error = compute_rms_error(moe, xe, ye)
rms_error = compute_relative_error(moe, xe, ye)
self.assert_error(rms_error, 0.0, 3e-1)

self.assertRaises(RuntimeError, lambda: moe.predict_variances(xe))
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_norm1_2d_200(self):
xe = sampling(self.ne)
ye = prob(xe)

rms_error = compute_rms_error(moe, xe, ye)
rms_error = compute_relative_error(moe, xe, ye)
self.assert_error(rms_error, 0.0, 1e-1)

if TestMOE.plot:
Expand Down Expand Up @@ -273,7 +273,7 @@ def test_branin_2d_200(self):
xe = sampling(self.ne)
ye = prob(xe)

rms_error = compute_rms_error(moe, xe, ye)
rms_error = compute_relative_error(moe, xe, ye)
self.assert_error(rms_error, 0.0, 1e-1)

if TestMOE.plot:
Expand Down
4 changes: 2 additions & 2 deletions smt/applications/tests/test_vfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from smt.problems import WaterFlow, WaterFlowLFidelity
from smt.sampling_methods import LHS
from smt.utils.misc import compute_rms_error
from smt.utils.misc import compute_relative_error
from smt.utils.silence import Silence
from smt.utils.sm_test_case import SMTestCase

Expand Down Expand Up @@ -111,7 +111,7 @@ def test_vfm(self):
)

# Prediction of the validation points
rms_error = compute_rms_error(vfm, xtest, ytest)
rms_error = compute_relative_error(vfm, xtest, ytest)
self.assert_error(rms_error, 0.0, 3e-1)

@staticmethod
Expand Down
Loading

0 comments on commit 223cec8

Please sign in to comment.