diff --git a/optimizerapi/optimizer.py b/optimizerapi/optimizer.py index 17e28fe..18229e7 100644 --- a/optimizerapi/optimizer.py +++ b/optimizerapi/optimizer.py @@ -244,7 +244,11 @@ def process_result(result, optimizer, dimensions, cfg, extras, data, space): if "experimentSuggestionCount" in extras: experiment_suggestion_count = extras["experimentSuggestionCount"] - next_exp = optimizer.ask(n_points=experiment_suggestion_count) + + if "constraints" in cfg and len(cfg["constraints"]) > 0: + next_exp = optimizer.ask(n_points=experiment_suggestion_count, strategy="cl_min") + else: + next_exp = optimizer.ask(n_points=experiment_suggestion_count) if len(next_exp) > 0 and not any(isinstance(x, list) for x in next_exp): next_exp = [next_exp] result_details["next"] = round_to_length_scales(next_exp, optimizer.space) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index bea92dd..fd4fdbc 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,6 +1,8 @@ """ Test main optimizer module """ + +from unittest.mock import patch import copy import collections.abc from optimizerapi import optimizer @@ -37,6 +39,63 @@ samplePayload = {"data": sampleData, "optimizerConfig": sampleConfig} +brownie_with_constraints = { + "extras": { + "experimentSuggestionCount": 3, + "graphs": ["single"], + "includeModel": "false", + "objectivePars": "expected_minimum", + }, + "data": [], + "optimizerConfig": { + "baseEstimator": "GP", + "acqFunc": "EI", + "initialPoints": 3, + "kappa": 1.96, + "xi": 5, + "space": [ + {"type": "continuous", "name": "Cocoa", "from": 18, "to": 56}, + {"type": "continuous", "name": "Powdered sugar", "from": 79, "to": 237}, + {"type": "discrete", "name": "Egg whites", "from": 1, "to": 4}, + {"type": "discrete", "name": "Time", "from": 16, "to": 30}, + { + "type": "category", + "name": "Temperature", + "categories": ["160", "180", "200"], + }, + ], + "constraints": [{"type": "sum", "dimensions": [0, 1], "value": 200}], + }, +} +brownie_without_constraints = { + "extras": { + "experimentSuggestionCount": 3, + "graphs": ["single"], + "includeModel": "false", + "objectivePars": "expected_minimum", + }, + "data": [], + "optimizerConfig": { + "baseEstimator": "GP", + "acqFunc": "EI", + "initialPoints": 3, + "kappa": 1.96, + "xi": 5, + "space": [ + {"type": "continuous", "name": "Cocoa", "from": 18, "to": 56}, + {"type": "continuous", "name": "Powdered sugar", "from": 79, "to": 237}, + {"type": "discrete", "name": "Egg whites", "from": 1, "to": 4}, + {"type": "discrete", "name": "Time", "from": 16, "to": 30}, + { + "type": "category", + "name": "Temperature", + "categories": ["160", "180", "200"], + }, + ], + "constraints": [], + }, +} + def validateResult(result): assert "plots" in result @@ -152,3 +211,35 @@ def test_expected_minimum_contains_std_deviation(): assert "expected_minimum" in result["result"] expected_minimum = result["result"]["expected_minimum"] assert isinstance(expected_minimum[1], collections.abc.Sequence) + + +@patch("optimizerapi.optimizer.Optimizer") +def test_when_using_constraints_set_constraints_should_be_called(mock): + instance = mock.return_value + request = brownie_with_constraints + optimizer.run(body=request) + instance.set_constraints.assert_called_once() + + +@patch("optimizerapi.optimizer.Optimizer") +def test_when_not_using_constraints_set_constraints_should_not_be_called(mock): + instance = mock.return_value + request = brownie_without_constraints + optimizer.run(body=request) + instance.set_constraints.assert_not_called() + + +@patch("optimizerapi.optimizer.Optimizer") +def test_when_using_constraints_strategy_cl_min_should_be_used(mock): + instance = mock.return_value + request = brownie_with_constraints + optimizer.run(body=request) + instance.ask.assert_called_once_with(n_points=3, strategy="cl_min") + + +@patch("optimizerapi.optimizer.Optimizer") +def test_when_not_using_constraints_standard_strategy_should_be_used(mock): + instance = mock.return_value + request = brownie_without_constraints + optimizer.run(body=request) + instance.ask.assert_called_once_with(n_points=3)