Skip to content

Commit

Permalink
Add a warning on kriging-based surrogates being trained with multiple…
Browse files Browse the repository at this point in the history
… outputs (#686)

* Add a warnings when multiple outputs are used to train a kriging based surrogate

* Test warning presence

* Adjust message
  • Loading branch information
relf authored Dec 5, 2024
1 parent 9972225 commit 7da5e1c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
9 changes: 9 additions & 0 deletions smt/surrogate_models/krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ def set_training_values(
Matrix specifying which of the design variables is acting in a hierarchical design space
"""
super().set_training_values(xt, yt, name=name)
if self.ny > 1:
warnings.warn(
"Kriging-based surrogate is not intended to handle multiple "
f"training output data (yt dim should be 1, got {self.ny}). "
"The quality of the resulting surrogate might not be as good as "
"if each training output is used separately to build a dedicated surrogate. "
"This warning might become a hard error in future SMT versions."
)
if is_acting is not None:
self.is_acting_points[name] = is_acting

Expand Down Expand Up @@ -397,6 +405,7 @@ def _new_train(self):
# Sampling points X and y
X = self.training_points[None][0][0]
y = self.training_points[None][0][1]

# Get is_acting status from design space model if needed (might correct training points)
is_acting = self.is_acting_points.get(None)
if is_acting is None and not self.is_continuous:
Expand Down
5 changes: 5 additions & 0 deletions smt/surrogate_models/tests/test_krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def test_theta0_erroneous_init(self):
krg.set_training_values(np.array([[1, 2, 3]]), np.array([[1]])) # erroneous
self.assertRaises(ValueError, krg._check_param)

def test_multiple_training_outputs_warning(self):
krg = KrgBased()
with self.assertWarns(UserWarning):
krg.set_training_values(np.array([[1, 2, 3]]), np.array([[1, 1]]))

def test_less_almost_squar_exp(self):
nobs = 50 # number of obsertvations
np.random.seed(0) # a seed for reproducibility
Expand Down

0 comments on commit 7da5e1c

Please sign in to comment.