From 5d3760633cae6d76b10df57e9f8478557e4e1946 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 21 Nov 2024 07:56:01 -0800 Subject: [PATCH] Add support for continuous relaxation within optimize_acqf_mixed_alternating (#2635) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2635 `optimize_acqf_mixed_alternating` utilizes local search to optimize discrete dimensions. This works well when there are a small number of values for the discrete dimensions but it does not scale well as the number of values increases. To address this, we have been transforming the high-cardinality dimensions in Ax and only passing in the low-cardinality dimensions as part of `discrete_dims`. This diff adds support for using continuous relaxation for discrete dimensions that have more than `max_discrete_values` (configurable via `options`). Also updates the optimizer to fall back to `optimize_acqf` if there are no discrete dimensions left. This is more user friendly than erroring out (particularly when used through Ax). Reviewed By: Balandat Differential Revision: D66239005 fbshipit-source-id: 0878115eb08ea75acb34ad8e891cf88393d4e36c --- botorch/optim/optimize_mixed.py | 58 +++++++++++++++++++- test/optim/test_optimize_mixed.py | 91 +++++++++++++++++++++++++++++-- test/test_utils/test_mock.py | 2 +- 3 files changed, 143 insertions(+), 8 deletions(-) diff --git a/botorch/optim/optimize_mixed.py b/botorch/optim/optimize_mixed.py index ac153b9e23..105952b3b9 100644 --- a/botorch/optim/optimize_mixed.py +++ b/botorch/optim/optimize_mixed.py @@ -39,6 +39,10 @@ MAX_ITER_ALTER = 64 # Maximum number of alternating iterations. MAX_ITER_DISCRETE = 4 # Maximum number of discrete iterations. MAX_ITER_CONT = 8 # Maximum number of continuous iterations. +# Maximum number of discrete values for a discrete dimension. +# If there are more values for a dimension, we will use continuous +# relaxation to optimize it. +MAX_DISCRETE_VALUES = 20 # Maximum number of iterations for optimizing the continuous relaxation # during initialization MAX_ITER_INIT = 100 @@ -52,6 +56,7 @@ "maxiter_discrete", "maxiter_continuous", "maxiter_init", + "max_discrete_values", "num_spray_points", "std_cont_perturbation", "batch_limit", @@ -60,6 +65,40 @@ SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"} +def _setup_continuous_relaxation( + discrete_dims: list[int], + bounds: Tensor, + max_discrete_values: int, + post_processing_func: Callable[[Tensor], Tensor] | None, +) -> tuple[list[int], Callable[[Tensor], Tensor] | None]: + r"""Update `discrete_dims` and `post_processing_func` to use + continuous relaxation for discrete dimensions that have more than + `max_discrete_values` values. These dimensions are removed from + `discrete_dims` and `post_processing_func` is updated to round + them to the nearest integer. + """ + discrete_dims_t = torch.tensor(discrete_dims, dtype=torch.long) + num_discrete_values = ( + bounds[1, discrete_dims_t] - bounds[0, discrete_dims_t] + ).cpu() + dims_to_relax = discrete_dims_t[num_discrete_values > max_discrete_values] + if dims_to_relax.numel() == 0: + # No dimension needs continuous relaxation. + return discrete_dims, post_processing_func + # Remove relaxed dims from `discrete_dims`. + discrete_dims = list(set(discrete_dims).difference(dims_to_relax.tolist())) + + def new_post_processing_func(X: Tensor) -> Tensor: + r"""Round the relaxed dimensions to the nearest integer and apply the original + `post_processing_func`.""" + X[..., dims_to_relax] = X[..., dims_to_relax].round() + if post_processing_func is not None: + X = post_processing_func(X) + return X + + return discrete_dims, new_post_processing_func + + def _filter_infeasible( X: Tensor, inequality_constraints: list[tuple[Tensor, Tensor, float]] | None ) -> Tensor: @@ -532,6 +571,9 @@ def optimize_acqf_mixed_alternating( iterations. NOTE: This method assumes that all discrete variables are integer valued. + The discrete dimensions that have more than + `options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will + be optimized using continuous relaxation. # TODO: Support categorical variables. @@ -549,6 +591,9 @@ def optimize_acqf_mixed_alternating( Defaults to 4. - "maxiter_continuous": Maximum number of iterations in each continuous step. Defaults to 8. + - "max_discrete_values": Maximum number of values for a discrete dimension + to be optimized using discrete step / local search. The discrete dimensions + with more values will be optimized using continuous relaxation. - "num_spray_points": Number of spray points (around `X_baseline`) to add to the points generated by the initialization strategy. Defaults to 20 if all discrete variables are binary and to 0 otherwise. @@ -598,6 +643,17 @@ def optimize_acqf_mixed_alternating( f"Received an unsupported option {unsupported_keys}. {SUPPORTED_OPTIONS=}." ) + # Update discrete dims and post processing functions to account for any + # dimensions that should be using continuous relaxation. + discrete_dims, post_processing_func = _setup_continuous_relaxation( + discrete_dims=discrete_dims, + bounds=bounds, + max_discrete_values=assert_is_instance( + options.get("max_discrete_values", MAX_DISCRETE_VALUES), int + ), + post_processing_func=post_processing_func, + ) + opt_inputs = OptimizeAcqfInputs( acq_function=acq_function, bounds=bounds, @@ -623,7 +679,7 @@ def optimize_acqf_mixed_alternating( # Remove fixed features from dims, so they don't get optimized. discrete_dims = [dim for dim in discrete_dims if dim not in fixed_features] if len(discrete_dims) == 0: - raise ValueError("There must be at least one discrete parameter.") + return _optimize_acqf(opt_inputs=opt_inputs) if not ( isinstance(discrete_dims, list) and len(set(discrete_dims)) == len(discrete_dims) diff --git a/test/optim/test_optimize_mixed.py b/test/optim/test_optimize_mixed.py index 1ab5fce7ea..f358f0a537 100644 --- a/test/optim/test_optimize_mixed.py +++ b/test/optim/test_optimize_mixed.py @@ -19,12 +19,14 @@ from botorch.models.gp_regression import SingleTaskGP from botorch.optim.optimize import _optimize_acqf, OptimizeAcqfInputs from botorch.optim.optimize_mixed import ( + _setup_continuous_relaxation, complement_indices, continuous_step, discrete_step, generate_starting_points, get_nearest_neighbors, get_spray_points, + MAX_DISCRETE_VALUES, optimize_acqf_mixed_alternating, sample_feasible_points, ) @@ -544,11 +546,10 @@ def test_optimize_acqf_mixed_binary_only(self) -> None: self.assertEqual(candidates.shape[-1], dim) c_binary = candidates[:, binary_dims + [2]] self.assertTrue(((c_binary == 0) | (c_binary == 1)).all()) - # Only continuous parameters will raise an error. - with self.assertRaisesRegex( - ValueError, - "There must be at least one discrete parameter", - ): + # Only continuous parameters should fallback to optimize_acqf. + with mock.patch( + f"{OPT_MODULE}._optimize_acqf", wraps=_optimize_acqf + ) as wrapped_optimize: optimize_acqf_mixed_alternating( acq_function=acqf, bounds=bounds, @@ -556,8 +557,18 @@ def test_optimize_acqf_mixed_binary_only(self) -> None: options=options, q=1, raw_samples=20, - num_restarts=20, + num_restarts=2, + ) + wrapped_optimize.assert_called_once_with( + opt_inputs=_make_opt_inputs( + acq_function=acqf, + bounds=bounds, + options=options, + q=1, + raw_samples=20, + num_restarts=2, ) + ) # Only discrete works fine. candidates, _ = optimize_acqf_mixed_alternating( acq_function=acqf, @@ -720,3 +731,71 @@ def test_optimize_acqf_mixed_integer(self) -> None: wrapped_sample_feasible.assert_called_once() # Should request 4 candidates, since all 4 are infeasible. self.assertEqual(wrapped_sample_feasible.call_args.kwargs["num_points"], 4) + + def test_optimize_acqf_mixed_continuous_relaxation(self) -> None: + # Testing with integer variables. + train_X, train_Y, binary_dims, cont_dims = self._get_data() + # Update the data to introduce integer dimensions. + binary_dims = [0] + integer_dims = [3, 4] + discrete_dims = binary_dims + integer_dims + bounds = torch.tensor( + [[0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 40.0, 15.0]], + dtype=torch.double, + device=self.device, + ) + # Update the model to have a different optimizer. + root = torch.tensor([0.0, 0.0, 0.0, 25.0, 10.0], device=self.device) + model = QuadraticDeterministicModel(root) + acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X) + + for max_discrete_values, post_processing_func in ( + (None, None), + (5, lambda X: X + 10), + ): + options = { + "batch_limit": 5, + "init_batch_limit": 20, + "maxiter_alternating": 1, + } + if max_discrete_values is not None: + options["max_discrete_values"] = max_discrete_values + with mock.patch( + f"{OPT_MODULE}._setup_continuous_relaxation", + wraps=_setup_continuous_relaxation, + ) as wrapped_setup, mock.patch( + f"{OPT_MODULE}.discrete_step", wraps=discrete_step + ) as wrapped_discrete: + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=discrete_dims, + q=3, + raw_samples=32, + num_restarts=4, + options=options, + post_processing_func=post_processing_func, + ) + wrapped_setup.assert_called_once_with( + discrete_dims=discrete_dims, + bounds=bounds, + max_discrete_values=max_discrete_values or MAX_DISCRETE_VALUES, + post_processing_func=post_processing_func, + ) + discrete_call_args = wrapped_discrete.call_args.kwargs + expected_dims = [0, 4] if max_discrete_values is None else [0] + self.assertAllClose( + discrete_call_args["discrete_dims"], + torch.tensor(expected_dims, device=self.device), + ) + # Check that dim 3 is rounded. + X = torch.ones(1, 5, device=self.device) * 0.6 + X_expected = X.clone() + X_expected[0, 3] = 1.0 + if max_discrete_values is not None: + X_expected[0, 4] = 1.0 + if post_processing_func is not None: + X_expected = post_processing_func(X_expected) + self.assertAllClose( + discrete_call_args["opt_inputs"].post_processing_func(X), X_expected + ) diff --git a/test/test_utils/test_mock.py b/test/test_utils/test_mock.py index 19dc68eee6..43867bbeea 100644 --- a/test/test_utils/test_mock.py +++ b/test/test_utils/test_mock.py @@ -98,7 +98,7 @@ def test_mock_optimize_mixed_alternating(self) -> None: ) as mock_neighbors: optimize_acqf_mixed_alternating( acq_function=SinAcqusitionFunction(), - bounds=torch.tensor([[-2.0, 0.0], [2.0, 200.0]]), + bounds=torch.tensor([[-2.0, 0.0], [2.0, 20.0]]), discrete_dims=[1], num_restarts=1, )