Skip to content

Commit

Permalink
Merge branch 'master' into master-2
Browse files Browse the repository at this point in the history
  • Loading branch information
relf authored Mar 18, 2024
2 parents 5e51050 + 34df4e2 commit 9ed04a6
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion doc/_src_docs/surrogate_models/krg.rst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions smt/applications/tests/test_ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,8 +1009,8 @@ def f_obj(X):
)
x_opt, y_opt, dnk, x_data, y_data = ego.optimize(fun=f_obj)
if ds.HAS_CONFIG_SPACE: # results differs wrt config_space impl
self.assertAlmostEqual(np.sum(y_data), 6.768616104127338, delta=1e-6)
self.assertAlmostEqual(np.sum(x_data), 34.205904294464716, delta=1e-6)
self.assertAlmostEqual(np.sum(y_data), 5.4385331120184475, delta=1e-3)
self.assertAlmostEqual(np.sum(x_data), 39.711522540205394, delta=1e-3)
else:
self.assertAlmostEqual(np.sum(y_data), 1.8911720670620835, delta=1e-6)
self.assertAlmostEqual(np.sum(x_data), 47.56885202767958, delta=1e-6)
Expand Down
9 changes: 8 additions & 1 deletion smt/surrogate_models/krg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ def _initialize(self):
declare(
"corr",
"squar_exp",
values=("pow_exp", "abs_exp", "squar_exp", "matern52", "matern32"),
values=(
"pow_exp",
"abs_exp",
"squar_exp",
"squar_sin_exp",
"matern52",
"matern32",
),
desc="Correlation function type",
types=(str),
)
Expand Down
4 changes: 3 additions & 1 deletion smt/surrogate_models/krg_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def design_space(self) -> BaseDesignSpace:
xt = xt[0][0]

if self.options["design_space"] is None:
self.options["design_space"] = ensure_design_space(xt=xt)
self.options["design_space"] = ensure_design_space(
xt=xt, xlimits=self.options["xlimits"]
)

elif not isinstance(self.options["design_space"], BaseDesignSpace):
ds_input = self.options["design_space"]
Expand Down
2 changes: 1 addition & 1 deletion smt/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def run_test(self):
elif pname == "tanh" and sname in ["KPLS", "RMTB"]:
self.assertLessEqual(e_error, self.e_errors[sname] + 0.4)
elif pname == "exp" and sname in ["GENN"]:
self.assertLessEqual(e_error, 1e-1)
self.assertLessEqual(e_error, 1.5e-1)
elif pname == "exp" and sname in ["RMTB"]:
self.assertLessEqual(e_error, self.e_errors[sname] + 0.5)
else:
Expand Down
2 changes: 2 additions & 0 deletions smt/utils/design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,8 @@ def _normalize_x(self, x: np.ndarray, cs_normalize=True):
for i, dv in enumerate(self.design_variables):
if isinstance(dv, FloatVariable):
if cs_normalize:
dv.lower = min(np.min(x[:, i]), dv.lower)
dv.upper = max(np.max(x[:, i]), dv.upper)
x[:, i] = np.clip(
(x[:, i] - dv.lower) / (dv.upper - dv.lower + 1e-16), 0, 1
)
Expand Down

0 comments on commit 9ed04a6

Please sign in to comment.