diff --git a/optimizerapi/optimizer.py b/optimizerapi/optimizer.py index bf34811..fdd115c 100644 --- a/optimizerapi/optimizer.py +++ b/optimizerapi/optimizer.py @@ -194,7 +194,7 @@ def process_result(result, optimizer, dimensions, cfg, extras, data, space): experiment_suggestion_count = extras["experimentSuggestionCount"] next_exp = optimizer.ask(n_points=experiment_suggestion_count) - if not any(isinstance(x, list) for x in next_exp): next_exp = [next_exp] + if len(next_exp) > 0 and not any(isinstance(x, list) for x in next_exp): next_exp = [next_exp] result_details["next"] = round_to_length_scales(next_exp, optimizer.space) if len(data) >= cfg["initialPoints"]: diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 41cdfa3..9394f2c 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -44,7 +44,7 @@ def validateResult(result): assert "result" in result assert "models" in result["result"] assert "next" in result["result"] - assert len(result["result"]["next"]) == len(sampleConfig["space"]) + assert all(len(x) == len(sampleConfig["space"]) for x in result["result"]["next"]) assert "pickled" in result["result"] assert len(result["result"]["pickled"]) > 1 if len(result["result"]["models"]) > 0: