forked from SMTorg/smt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor
smt.design_space
(SMTorg#665)
* smt use smt_design_space_ext design space impl if installed * other base classes belongs to smt * only need HAS_DESIGN_SPACE_EXT * Remove constants managed in __init__ * Cleanup tests
- Loading branch information
Showing
5 changed files
with
35 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,33 @@ | ||
import importlib | ||
|
||
spec = importlib.util.find_spec("smt_design_space") | ||
if spec: | ||
HAS_DESIGN_SPACE_EXT = True | ||
HAS_CONFIG_SPACE = True | ||
HAS_ADSG = True | ||
else: | ||
HAS_DESIGN_SPACE_EXT = False | ||
HAS_CONFIG_SPACE = False | ||
HAS_ADSG = False | ||
|
||
|
||
if HAS_DESIGN_SPACE_EXT: | ||
from smt_design_space.design_space import ( | ||
CategoricalVariable, | ||
from smt.design_space.design_space import ( | ||
CategoricalVariable, | ||
BaseDesignSpace, | ||
FloatVariable, | ||
IntegerVariable, | ||
OrdinalVariable, | ||
) | ||
|
||
try: | ||
from smt_design_space_ext import ( | ||
DesignSpace, | ||
BaseDesignSpace, | ||
FloatVariable, | ||
IntegerVariable, | ||
OrdinalVariable, | ||
ensure_design_space, | ||
) | ||
|
||
else: | ||
from smt.design_space.design_space import ( | ||
CategoricalVariable, | ||
HAS_DESIGN_SPACE_EXT = True | ||
except ImportError: | ||
from .design_space import ( | ||
DesignSpace, | ||
FloatVariable, | ||
IntegerVariable, | ||
OrdinalVariable, | ||
ensure_design_space, | ||
BaseDesignSpace, | ||
) | ||
|
||
if HAS_DESIGN_SPACE_EXT: | ||
from smt_design_space.design_space import DesignSpaceGraph | ||
else: | ||
|
||
class DesignSpaceGraph: | ||
pass | ||
|
||
HAS_DESIGN_SPACE_EXT = False | ||
|
||
__all__ = [ | ||
"HAS_DESIGN_SPACE_EXT", | ||
"HAS_CONFIG_SPACE", | ||
"HAS_ADSG", | ||
"BaseDesignSpace", | ||
"DesignSpace", | ||
"FloatVariable", | ||
"IntegerVariable", | ||
"OrdinalVariable", | ||
"CategoricalVariable", | ||
"DesignSpace", | ||
"ensure_design_space", | ||
"HAS_DESIGN_SPACE_EXT", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
Author: Jasper Bussemaker <[email protected]> | ||
""" | ||
|
||
import contextlib | ||
import itertools | ||
import unittest | ||
|
||
|
@@ -11,7 +10,6 @@ | |
from smt.sampling_methods import LHS | ||
|
||
|
||
import smt.design_space.design_space as ds | ||
from smt.design_space.design_space import ( | ||
BaseDesignSpace, | ||
CategoricalVariable, | ||
|
@@ -22,16 +20,6 @@ | |
) | ||
|
||
|
||
@contextlib.contextmanager | ||
def simulate_no_config_space(do_simulate=True): | ||
if ds.HAS_CONFIG_SPACE and do_simulate: | ||
ds.HAS_CONFIG_SPACE = False | ||
yield | ||
ds.HAS_CONFIG_SPACE = True | ||
else: | ||
yield | ||
|
||
|
||
class Test(unittest.TestCase): | ||
def test_design_variables(self): | ||
with self.assertRaises(ValueError): | ||
|
@@ -193,8 +181,6 @@ def test_base_design_space(self): | |
|
||
def test_create_design_space(self): | ||
DesignSpace([FloatVariable(0, 1)]) | ||
with simulate_no_config_space(): | ||
DesignSpace([FloatVariable(0, 1)]) | ||
|
||
def test_design_space(self): | ||
ds = DesignSpace( | ||
|
@@ -421,22 +407,20 @@ def test_design_space_hierarchical(self): | |
assert len(seen_is_acting) == 2 | ||
|
||
def test_check_conditionally_acting_2(self): | ||
for simulate_no_cs in [True, False]: | ||
with simulate_no_config_space(simulate_no_cs): | ||
ds = DesignSpace( | ||
[ | ||
CategoricalVariable(["A", "B", "C"]), # x0 | ||
CategoricalVariable(["E", "F"]), # x1 | ||
IntegerVariable(0, 1), # x2 | ||
FloatVariable(0, 1), # x3 | ||
], | ||
random_state=42, | ||
) | ||
ds.declare_decreed_var( | ||
decreed_var=0, meta_var=1, meta_value="E" | ||
) # Activate x3 if x0 == A | ||
ds = DesignSpace( | ||
[ | ||
CategoricalVariable(["A", "B", "C"]), # x0 | ||
CategoricalVariable(["E", "F"]), # x1 | ||
IntegerVariable(0, 1), # x2 | ||
FloatVariable(0, 1), # x3 | ||
], | ||
random_state=42, | ||
) | ||
ds.declare_decreed_var( | ||
decreed_var=0, meta_var=1, meta_value="E" | ||
) # Activate x3 if x0 == A | ||
|
||
ds.sample_valid_x(10, random_state=42) | ||
ds.sample_valid_x(10, random_state=42) | ||
|
||
|
||
if __name__ == "__main__": | ||
|