diff --git a/carl/context/search_space_encoding.py b/carl/context/search_space_encoding.py index 893a71d4..ae955d3a 100644 --- a/carl/context/search_space_encoding.py +++ b/carl/context/search_space_encoding.py @@ -106,11 +106,11 @@ def search_space_to_config_space( ------- ConfigurationSpace """ - if type(search_space) == str: + if isinstance(search_space, str): with open(search_space, "r") as f: jason_string = f.read() cs = csjson.read(jason_string) - elif type(search_space) == DictConfig: + elif isinstance(search_space, DictConfig): # reorder hyperparameters as List[Dict] hyperparameters = [] for name, cfg in search_space.hyperparameters.items(): @@ -130,8 +130,10 @@ def search_space_to_config_space( jason_string = json.dumps(search_space, cls=JSONCfgEncoder) cs = csjson.read(jason_string) - elif type(search_space) == ConfigurationSpace: + elif isinstance(search_space, ConfigurationSpace): cs = search_space + elif isinstance(search_space, dict): + cs = csjson.read(json.dumps(search_space)) else: raise ValueError( f"search_space must be of type str or DictConfig. Got {type(search_space)}." diff --git a/test/test_search_space_encoding.py b/test/test_search_space_encoding.py index 42a0489d..2c1a4e84 100644 --- a/test/test_search_space_encoding.py +++ b/test/test_search_space_encoding.py @@ -21,6 +21,7 @@ "lower": -512.0, "upper": 512.0, "default": -3.0, + "q": None, }, { "name": "x1", @@ -29,6 +30,7 @@ "lower": -512.0, "upper": 512.0, "default": -4.0, + "q": None, }, ], "conditions": [], @@ -51,19 +53,14 @@ def setUp(self): self.test_space = ConfigurationSpace(name="myspace", space=dict_space) return super().setUp() - def test_init(self): - self.test_space = ConfigurationSpace(name="myspace", space=dict_space_2) - - self.test_space = ConfigurationSpace(name="myspace", space=str_space) - - def test_config_spaces(self): + def test_ss_as_cs(self): try: search_space_to_config_space(self.test_space) except Exception as e: print(f"Cannot encode search space -- {self.test_space}.") raise e - def test_dict_configs(self): + def test_ss_as_dictconfig(self): try: dict_space = DictConfig({"hyperparameters": {}}) @@ -72,6 +69,13 @@ def test_dict_configs(self): print(f"Cannot encode search space -- {dict_space}.") raise e + def test_ss_as_dict(self): + try: + search_space_to_config_space(dict_space_2) + except Exception as e: + print(f"Cannot encode search space -- {dict_space_2}.") + raise e + if __name__ == "__main__": unittest.main()