Skip to content

Commit

Permalink
fix ego mi
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves committed Jan 30, 2024
1 parent 4e18506 commit 3e10ac0
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions smt/applications/tests/test_ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_ego_mixed_integer(self):
seed=42,
)
samp = MixedIntegerSamplingMethod(
LHS, ds, criterion="ese", random_state=ds.seed
LHS, design_space, criterion="ese", random_state=design_space.seed
)
xdoe = samp(n_doe)

Expand Down Expand Up @@ -423,7 +423,7 @@ def test_ego_mixed_integer_gower_distance(self):
seed=random_state,
)
samp = MixedIntegerSamplingMethod(
LHS, ds, criterion="ese", random_state=ds.seed
LHS, design_space, criterion="ese", random_state=design_space.seed
)
xdoe = samp(n_doe)

Expand Down Expand Up @@ -503,7 +503,7 @@ def f_hv(X):
# x7 is active when x0 >= 3
design_space.declare_decreed_var(decreed_var=7, meta_var=0, meta_value=3)

n_doe = 4
n_doe = 5

neutral_var_ds = DesignSpace(design_space.design_variables[1:])
sampling = MixedIntegerSamplingMethod(
Expand Down Expand Up @@ -539,7 +539,7 @@ def f_hv(X):
Xt = np.concatenate((xdoe1, xdoe2, xdoe3), axis=0)
# Yt = np.concatenate((ydoe1, ydoe2, ydoe3), axis=0)

n_iter = 6
n_iter = 9
criterion = "EI"

ego = EGO(
Expand Down Expand Up @@ -1011,16 +1011,9 @@ def f_obj(X):
n_start=15,
)
x_opt, y_opt, dnk, x_data, y_data = ego.optimize(fun=f_obj)
if ds.HAS_CONFIG_SPACE: # results differs wrt config_space impl
if platform.startswith("linux"): # results differs wrt platform
self.assertAlmostEqual(np.sum(y_data), 1.0355815090110578, delta=1e-12)
self.assertAlmostEqual(np.sum(x_data), 38.56885202767958, delta=1e-12)
else:
self.assertAlmostEqual(np.sum(y_data), 0.9606415626557894, delta=1e-12)
self.assertAlmostEqual(np.sum(x_data), 38.23494224077761, delta=1e-12)
else:
self.assertAlmostEqual(np.sum(y_data), 1.8911720770059735, delta=1e-12)
self.assertAlmostEqual(np.sum(x_data), 47.56885202767958, delta=1e-12)

self.assertAlmostEqual(np.sum(y_data), 6.846225752638086, delta=1e-10)
self.assertAlmostEqual(np.sum(x_data), 33.81192549170815, delta=1e-10)

def test_ego_gek(self):
ego, fun = self.initialize_ego_gek()
Expand Down

0 comments on commit 3e10ac0

Please sign in to comment.