Skip to content

Commit

Permalink
Refactor smt.design_space (SMTorg#665)
Browse files Browse the repository at this point in the history
* smt use smt_design_space_ext design space impl if installed

* other base classes belongs to smt
* only need HAS_DESIGN_SPACE_EXT

* Remove constants managed in __init__

* Cleanup tests
  • Loading branch information
relf authored Oct 18, 2024
1 parent 38f6e37 commit 712b336
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 79 deletions.
4 changes: 2 additions & 2 deletions smt/applications/tests/test_ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def f_obj(X):
LHS, design_space, criterion="ese", random_state=random_state
)
Xt = sampling(n_doe)
if ds.HAS_CONFIG_SPACE: # results differs wrt config_space impl
if ds.HAS_DESIGN_SPACE_EXT: # results differs wrt config_space impl
self.assertAlmostEqual(np.sum(Xt), 24.811925491708156, delta=1e-4)
else:
self.assertAlmostEqual(np.sum(Xt), 28.568852027679586, delta=1e-4)
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def f_obj(X):
n_start=25,
)
x_opt, y_opt, dnk, x_data, y_data = ego.optimize(fun=f_obj)
if ds.HAS_CONFIG_SPACE: # results differs wrt config_space impl
if ds.HAS_DESIGN_SPACE_EXT: # results differs wrt config_space impl
self.assertAlmostEqual(np.sum(y_data), 8.846225704750577, delta=1e-4)
self.assertAlmostEqual(np.sum(x_data), 41.811925504901374, delta=1e-4)
else:
Expand Down
9 changes: 4 additions & 5 deletions smt/applications/tests/test_mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
except ImportError:
NO_MATPLOTLIB = True

import smt.design_space as ds
from smt.design_space import (
HAS_CONFIG_SPACE,
HAS_DESIGN_SPACE_EXT,
DesignSpace,
CategoricalVariable,
FloatVariable,
Expand Down Expand Up @@ -464,7 +463,7 @@ def test_examples(self):
self.run_mixed_gower_example()
self.run_mixed_homo_gaussian_example()
self.run_mixed_homo_hyp_example()
if ds.HAS_CONFIG_SPACE:
if HAS_DESIGN_SPACE_EXT:
self.run_mixed_cs_example()
self.run_hierarchical_design_space_example() # works only with config space impl

Expand Down Expand Up @@ -918,7 +917,7 @@ def run_hierarchical_design_space_example(self):
self._sm = sm # to be ignored: just used for automated test

@unittest.skipIf(
not HAS_CONFIG_SPACE, "Hierarchy ConfigSpace dependency not installed"
not HAS_DESIGN_SPACE_EXT, "Hierarchy ConfigSpace dependency not installed"
)
def test_hierarchical_design_space_example(self):
self.run_hierarchical_design_space_example()
Expand Down Expand Up @@ -951,7 +950,7 @@ def test_hierarchical_design_space_example(self):
)

@unittest.skipIf(
not HAS_CONFIG_SPACE, "Hierarchy ConfigSpace dependency not installed"
not HAS_DESIGN_SPACE_EXT, "Hierarchy ConfigSpace dependency not installed"
)
def test_hierarchical_design_space_example_all_categorical_decreed(self):
ds = DesignSpace(
Expand Down
54 changes: 16 additions & 38 deletions smt/design_space/__init__.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,33 @@
import importlib

spec = importlib.util.find_spec("smt_design_space")
if spec:
HAS_DESIGN_SPACE_EXT = True
HAS_CONFIG_SPACE = True
HAS_ADSG = True
else:
HAS_DESIGN_SPACE_EXT = False
HAS_CONFIG_SPACE = False
HAS_ADSG = False


if HAS_DESIGN_SPACE_EXT:
from smt_design_space.design_space import (
CategoricalVariable,
from smt.design_space.design_space import (
CategoricalVariable,
BaseDesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
)

try:
from smt_design_space_ext import (
DesignSpace,
BaseDesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
ensure_design_space,
)

else:
from smt.design_space.design_space import (
CategoricalVariable,
HAS_DESIGN_SPACE_EXT = True
except ImportError:
from .design_space import (
DesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
ensure_design_space,
BaseDesignSpace,
)

if HAS_DESIGN_SPACE_EXT:
from smt_design_space.design_space import DesignSpaceGraph
else:

class DesignSpaceGraph:
pass

HAS_DESIGN_SPACE_EXT = False

__all__ = [
"HAS_DESIGN_SPACE_EXT",
"HAS_CONFIG_SPACE",
"HAS_ADSG",
"BaseDesignSpace",
"DesignSpace",
"FloatVariable",
"IntegerVariable",
"OrdinalVariable",
"CategoricalVariable",
"DesignSpace",
"ensure_design_space",
"HAS_DESIGN_SPACE_EXT",
]
5 changes: 0 additions & 5 deletions smt/design_space/design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
from typing import List, Optional, Sequence, Tuple, Union


HAS_DESIGN_SPACE_EXT = False
HAS_CONFIG_SPACE = False
HAS_ADSG = False


class Configuration:
pass

Expand Down
42 changes: 13 additions & 29 deletions smt/design_space/tests/test_design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Author: Jasper Bussemaker <[email protected]>
"""

import contextlib
import itertools
import unittest

Expand All @@ -11,7 +10,6 @@
from smt.sampling_methods import LHS


import smt.design_space.design_space as ds
from smt.design_space.design_space import (
BaseDesignSpace,
CategoricalVariable,
Expand All @@ -22,16 +20,6 @@
)


@contextlib.contextmanager
def simulate_no_config_space(do_simulate=True):
if ds.HAS_CONFIG_SPACE and do_simulate:
ds.HAS_CONFIG_SPACE = False
yield
ds.HAS_CONFIG_SPACE = True
else:
yield


class Test(unittest.TestCase):
def test_design_variables(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -193,8 +181,6 @@ def test_base_design_space(self):

def test_create_design_space(self):
DesignSpace([FloatVariable(0, 1)])
with simulate_no_config_space():
DesignSpace([FloatVariable(0, 1)])

def test_design_space(self):
ds = DesignSpace(
Expand Down Expand Up @@ -421,22 +407,20 @@ def test_design_space_hierarchical(self):
assert len(seen_is_acting) == 2

def test_check_conditionally_acting_2(self):
for simulate_no_cs in [True, False]:
with simulate_no_config_space(simulate_no_cs):
ds = DesignSpace(
[
CategoricalVariable(["A", "B", "C"]), # x0
CategoricalVariable(["E", "F"]), # x1
IntegerVariable(0, 1), # x2
FloatVariable(0, 1), # x3
],
random_state=42,
)
ds.declare_decreed_var(
decreed_var=0, meta_var=1, meta_value="E"
) # Activate x3 if x0 == A
ds = DesignSpace(
[
CategoricalVariable(["A", "B", "C"]), # x0
CategoricalVariable(["E", "F"]), # x1
IntegerVariable(0, 1), # x2
FloatVariable(0, 1), # x3
],
random_state=42,
)
ds.declare_decreed_var(
decreed_var=0, meta_var=1, meta_value="E"
) # Activate x3 if x0 == A

ds.sample_valid_x(10, random_state=42)
ds.sample_valid_x(10, random_state=42)


if __name__ == "__main__":
Expand Down

0 comments on commit 712b336

Please sign in to comment.