Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves committed Feb 29, 2024
1 parent c9324fb commit acd9a9f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 61 deletions.
53 changes: 32 additions & 21 deletions smt/utils/design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(self, design_variables: List[DesignVariable] = None):
self._is_cat_mask = None
self._is_conditionally_acting_mask = None
self.seed = None

@property
def design_variables(self) -> List[DesignVariable]:
if self._design_variables is None:
Expand Down Expand Up @@ -443,9 +443,9 @@ def fold_x(

i_x_unfold = 0
for i, dv in enumerate(self.design_variables):
if (isinstance(dv, CategoricalVariable) or isinstance(dv, OrdinalVariable)) and (
fold_mask is None or fold_mask[i]
):
if (
isinstance(dv, CategoricalVariable) or isinstance(dv, OrdinalVariable)
) and (fold_mask is None or fold_mask[i]):
n_dim_cat = dv.n_values

# Categorical values are folded by reversed one-hot encoding:
Expand Down Expand Up @@ -745,9 +745,9 @@ def _is_num(val):
cs_vars[name] = FixedIntegerParam(
name, lower=dv.lower, upper=dv.upper
)
listvalues= []
for i in range(dv.upper-dv.lower+1):
listvalues.append(str(i+dv.lower))
listvalues = []
for i in range(dv.upper - dv.lower + 1):
listvalues.append(str(i + dv.lower))
cs_vars2[name] = CategoricalHyperparameter(name, choices=listvalues)
self.isinteger = True
elif isinstance(dv, OrdinalVariable):
Expand Down Expand Up @@ -872,7 +872,7 @@ def add_value_constraint(

constraint_clause = ForbiddenAndConjunction(clause1, clause2)
self._cs.add_forbidden_clause(constraint_clause)

# Get parameters
param1 = self._get_param2(var1)
param2 = self._get_param2(var2)
Expand All @@ -887,14 +887,15 @@ def add_value_constraint(
else:
clause2 = ForbiddenEqualsClause(param2, str(value2))

constraint_clause = ForbiddenAndConjunction(clause1, clause2)
constraint_clause = ForbiddenAndConjunction(clause1, clause2)
self._cs2.add_forbidden_clause(constraint_clause)

def _get_param(self, idx):
try:
return self._cs.get_hyperparameter(f"x{idx}")
except KeyError:
raise KeyError(f"Variable not found: {idx}")

def _get_param2(self, idx):
try:
return self._cs2.get_hyperparameter(f"x{idx}")
Expand Down Expand Up @@ -1018,14 +1019,22 @@ def _get_correct_config(self, vector: np.ndarray) -> Configuration:
# to find out which parameters should be inactive
while True:
try:
if self.isinteger :
if self.isinteger:
vector2 = np.copy(vector)
self._cs_denormalize_x(np.atleast_2d(vector2))
indvec=0
for hp in self._cs2 :
if (str(self._cs.get_hyperparameter(hp)).split()[2]) == "UniformInteger," and (str(self._cs2.get_hyperparameter(hp)).split()[2][:3]) == "Cat" and not(np.isnan(vector2[indvec])):
vector2[indvec] = int(vector2[indvec])- int(str(self._cs2.get_hyperparameter(hp)).split()[4][1:-1])
indvec +=1
indvec = 0
for hp in self._cs2:
if (
(str(self._cs.get_hyperparameter(hp)).split()[2])
== "UniformInteger,"
and (str(self._cs2.get_hyperparameter(hp)).split()[2][:3])
== "Cat"
and not (np.isnan(vector2[indvec]))
):
vector2[indvec] = int(vector2[indvec]) - int(
str(self._cs2.get_hyperparameter(hp)).split()[4][1:-1]
)
indvec += 1
self._normalize_x_no_integer(np.atleast_2d(vector2))
config2 = Configuration(self._cs2, vector=vector2)
config2.is_valid_configuration()
Expand Down Expand Up @@ -1056,14 +1065,16 @@ def _get_correct_config(self, vector: np.ndarray) -> Configuration:
vector = config.get_array().copy()
indvec = 0
vector2 = np.copy(vector)
for hp in self._cs2 :
if (str(self._cs2.get_hyperparameter(hp)).split()[2][:3]) == "Cat" and not(np.isnan(vector2[indvec])):

for hp in self._cs2:
if (
str(self._cs2.get_hyperparameter(hp)).split()[2][:3]
) == "Cat" and not (np.isnan(vector2[indvec])):

vector2[indvec] = int(vector2[indvec])
indvec +=1
indvec += 1

config2 = Configuration(self._cs2, vector=vector2)
config3 = get_random_neighbor(config2, seed=self.seed)
config3 = get_random_neighbor(config2, seed=self.seed)
vector3 = config3.get_array().copy()
config4 = Configuration(self._cs, vector=vector3)
return config4
Expand Down
75 changes: 35 additions & 40 deletions smt/utils/test/test_design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_design_space(self):
ds = DesignSpace(
[
CategoricalVariable(["A", "B", "C"]),
CategoricalVariable(["E", "F"]), ### todo
CategoricalVariable(["E", "F"]), ### todo
IntegerVariable(-1, 2),
FloatVariable(0.5, 1.5),
],
Expand Down Expand Up @@ -564,74 +564,69 @@ def test_check_conditionally_acting_2(self):
def test_restrictive_value_constraint_ordinal(self):
ds = DesignSpace(
[
OrdinalVariable(["0","1","2"]),
OrdinalVariable(["0","1","2"]),
OrdinalVariable(["0", "1", "2"]),
OrdinalVariable(["0", "1", "2"]),
]
)
assert list(ds._cs.values())[0].default_value == "0"

ds.add_value_constraint(var1=0, value1="1", var2=1, value2="1")
ds.sample_valid_x(100, random_state=42)

x_cartesian = np.array(list(itertools.product([0,1,2], [0,1,2])))
x_cartesian = np.array(list(itertools.product([0, 1, 2], [0, 1, 2])))
x_cartesian2, _ = ds.correct_get_acting(x_cartesian)
np.testing.assert_array_equal(np.array([[0, 0],
[0, 1],
[0, 2],
[1, 0],
[0, 1],
[1, 2],
[2, 0],
[2, 1],
[2, 2]]), x_cartesian2)
np.testing.assert_array_equal(
np.array(
[[0, 0], [0, 1], [0, 2], [1, 0], [0, 1], [1, 2], [2, 0], [2, 1], [2, 2]]
),
x_cartesian2,
)

def test_restrictive_value_constraint_integer(self):
ds = DesignSpace(
[
IntegerVariable(0,2),
IntegerVariable(0,2),
IntegerVariable(0, 2),
IntegerVariable(0, 2),
]
)
assert list(ds._cs.values())[0].default_value == 1

ds.add_value_constraint(var1=0, value1=1, var2=1, value2=1)
ds.sample_valid_x(100, random_state=42)
x_cartesian = np.array(list(itertools.product([0,1,2], [0,1,2])))

x_cartesian = np.array(list(itertools.product([0, 1, 2], [0, 1, 2])))
ds.correct_get_acting(x_cartesian)
x_cartesian2, _= ds.correct_get_acting(x_cartesian)
x_cartesian2, _ = ds.correct_get_acting(x_cartesian)
print(x_cartesian2)
np.testing.assert_array_equal(np.array([[0, 0],
[0, 1],
[0, 2],
[1, 0],
[2, 0],
[1, 2],
[2, 0],
[2, 1],
[2, 2]]), x_cartesian2)
np.testing.assert_array_equal(
np.array(
[[0, 0], [0, 1], [0, 2], [1, 0], [2, 0], [1, 2], [2, 0], [2, 1], [2, 2]]
),
x_cartesian2,
)

def test_restrictive_value_constraint_categorical(self):
ds = DesignSpace(
[
CategoricalVariable(["a","b","c"]),
CategoricalVariable(["a","b","c"]),
CategoricalVariable(["a", "b", "c"]),
CategoricalVariable(["a", "b", "c"]),
]
)
assert list(ds._cs.values())[0].default_value == "a"

ds.add_value_constraint(var1=0, value1="b", var2=1, value2="b")
ds.sample_valid_x(100, random_state=42)

x_cartesian = np.array(list(itertools.product([0,1,2], [0,1,2])))
x_cartesian = np.array(list(itertools.product([0, 1, 2], [0, 1, 2])))
ds.correct_get_acting(x_cartesian)
x_cartesian2, _ = ds.correct_get_acting(x_cartesian)
np.testing.assert_array_equal(np.array([[0, 0],
[0, 1],
[0, 2],
[1, 0],
[0, 1],
[1, 2],
[2, 0],
[2, 1],
[2, 2]]), x_cartesian2)
np.testing.assert_array_equal(
np.array(
[[0, 0], [0, 1], [0, 2], [1, 0], [0, 1], [1, 2], [2, 0], [2, 1], [2, 2]]
),
x_cartesian2,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit acd9a9f

Please sign in to comment.