Skip to content

Commit

Permalink
Merge pull request #113 from automl/fix/search_space_test
Browse files Browse the repository at this point in the history
Fix/search space test
  • Loading branch information
TheEimer authored Feb 9, 2024
2 parents afbd479 + 193d512 commit bc6ef39
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
8 changes: 5 additions & 3 deletions carl/context/search_space_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ def search_space_to_config_space(
-------
ConfigurationSpace
"""
if type(search_space) == str:
if isinstance(search_space, str):
with open(search_space, "r") as f:
jason_string = f.read()
cs = csjson.read(jason_string)
elif type(search_space) == DictConfig:
elif isinstance(search_space, DictConfig):
# reorder hyperparameters as List[Dict]
hyperparameters = []
for name, cfg in search_space.hyperparameters.items():
Expand All @@ -130,8 +130,10 @@ def search_space_to_config_space(

jason_string = json.dumps(search_space, cls=JSONCfgEncoder)
cs = csjson.read(jason_string)
elif type(search_space) == ConfigurationSpace:
elif isinstance(search_space, ConfigurationSpace):
cs = search_space
elif isinstance(search_space, dict):
cs = csjson.read(json.dumps(search_space))
else:
raise ValueError(
f"search_space must be of type str or DictConfig. Got {type(search_space)}."
Expand Down
18 changes: 11 additions & 7 deletions test/test_search_space_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"lower": -512.0,
"upper": 512.0,
"default": -3.0,
"q": None,
},
{
"name": "x1",
Expand All @@ -29,6 +30,7 @@
"lower": -512.0,
"upper": 512.0,
"default": -4.0,
"q": None,
},
],
"conditions": [],
Expand All @@ -51,19 +53,14 @@ def setUp(self):
self.test_space = ConfigurationSpace(name="myspace", space=dict_space)
return super().setUp()

def test_init(self):
self.test_space = ConfigurationSpace(name="myspace", space=dict_space_2)

self.test_space = ConfigurationSpace(name="myspace", space=str_space)

def test_config_spaces(self):
def test_ss_as_cs(self):
try:
search_space_to_config_space(self.test_space)
except Exception as e:
print(f"Cannot encode search space -- {self.test_space}.")
raise e

def test_dict_configs(self):
def test_ss_as_dictconfig(self):
try:
dict_space = DictConfig({"hyperparameters": {}})

Expand All @@ -72,6 +69,13 @@ def test_dict_configs(self):
print(f"Cannot encode search space -- {dict_space}.")
raise e

def test_ss_as_dict(self):
try:
search_space_to_config_space(dict_space_2)
except Exception as e:
print(f"Cannot encode search space -- {dict_space_2}.")
raise e


if __name__ == "__main__":
unittest.main()

0 comments on commit bc6ef39

Please sign in to comment.