diff --git a/botorch/acquisition/fixed_feature.py b/botorch/acquisition/fixed_feature.py index 0f3b85faa7..763226799e 100644 --- a/botorch/acquisition/fixed_feature.py +++ b/botorch/acquisition/fixed_feature.py @@ -16,11 +16,11 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor -from torch.nn import Module -class FixedFeatureAcquisitionFunction(AcquisitionFunction): +class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AquisitionFunctions to fix a subset of features. Example: @@ -56,8 +56,7 @@ def __init__( combination of `Tensor`s and numbers which can be broadcasted to form a tensor with trailing dimension size of `d_f`. """ - Module.__init__(self) - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) dtype = torch.float device = torch.device("cpu") self.d = d @@ -126,24 +125,13 @@ def forward(self, X: Tensor): X_full = self._construct_X_full(X) return self.acq_func(X_full) - @property - def X_pending(self): - r"""Return the `X_pending` of the base acquisition function.""" - try: - return self.acq_func.X_pending - except (ValueError, AttributeError): - raise ValueError( - f"Base acquisition function {type(self.acq_func).__name__} " - "does not have an `X_pending` attribute." - ) - - @X_pending.setter - def X_pending(self, X_pending: Optional[Tensor]): + def set_X_pending(self, X_pending: Optional[Tensor]): r"""Sets the `X_pending` of the base acquisition function.""" if X_pending is not None: - self.acq_func.X_pending = self._construct_X_full(X_pending) + full_X_pending = self._construct_X_full(X_pending) else: - self.acq_func.X_pending = X_pending + full_X_pending = None + self.acq_func.set_X_pending(full_X_pending) def _construct_X_full(self, X: Tensor) -> Tensor: r"""Constructs the full input for the base acquisition function. diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py index b114362ea9..9ee8f1fee5 100644 --- a/botorch/acquisition/penalized.py +++ b/botorch/acquisition/penalized.py @@ -15,9 +15,8 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.acquisition.analytic import AnalyticAcquisitionFunction from botorch.acquisition.objective import GenericMCObjective -from botorch.exceptions import UnsupportedError +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor: return regularization_term -class PenalizedAcquisitionFunction(AcquisitionFunction): +class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper): r"""Single-outcome acquisition function regularized by the given penalty. The usage is similar to: @@ -161,29 +160,16 @@ def __init__( penalty_func: The regularization function. regularization_parameter: Regularization parameter used in optimization. """ - super().__init__(model=raw_acqf.model) - self.raw_acqf = raw_acqf + AcquisitionFunction.__init__(self, model=raw_acqf.model) + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf) self.penalty_func = penalty_func self.regularization_parameter = regularization_parameter def forward(self, X: Tensor) -> Tensor: - raw_value = self.raw_acqf(X=X) + raw_value = self.acq_func(X=X) penalty_term = self.penalty_func(X) return raw_value - self.regularization_parameter * penalty_term - @property - def X_pending(self) -> Optional[Tensor]: - return self.raw_acqf.X_pending - - def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: - if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): - self.raw_acqf.set_X_pending(X_pending=X_pending) - else: - raise UnsupportedError( - "The raw acquisition function is Analytic and does not account " - "for X_pending yet." - ) - def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor: r"""Computes the group lasso regularization function for the given point. diff --git a/botorch/acquisition/probabilistic_reparameterization.py b/botorch/acquisition/probabilistic_reparameterization.py new file mode 100644 index 0000000000..5c6428985e --- /dev/null +++ b/botorch/acquisition/probabilistic_reparameterization.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Probabilistic Reparameterization (with gradients) using Monte Carlo estimators. + +See [Daulton2022bopr]_ for details. +""" + +from contextlib import ExitStack +from typing import Dict, List, Optional + +import torch +from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper +from botorch.models.transforms.factory import ( + get_probabilistic_reparameterization_input_transform, +) + +from botorch.models.transforms.input import ( + ChainedInputTransform, + InputTransform, + OneHotToNumeric, +) +from torch import Tensor +from torch.autograd import Function +from torch.nn.functional import one_hot + + +class _MCProbabilisticReparameterization(Function): + r"""Evaluate the acquisition function via probabistic reparameterization. + + This uses a score function gradient estimator. See [Daulton2022bopr]_ for details. + """ + + @staticmethod + def forward( + ctx, + X: Tensor, + acq_function: AcquisitionFunction, + input_tf: InputTransform, + batch_limit: Optional[int], + integer_indices: Tensor, + cont_indices: Tensor, + categorical_indices: Tensor, + use_ma_baseline: bool, + one_hot_to_numeric: Optional[OneHotToNumeric], + ma_counter: Optional[Tensor], + ma_hidden: Optional[Tensor], + ma_decay: Optional[float], + ): + """Evaluate the expectation of the acquisition function under + probabilistic reparameterization. Compute this in chunks of size + batch_limit to enable scaling to large numbers of samples from the + proposal distribution. + """ + with ExitStack() as es: + if ctx.needs_input_grad[0]: + es.enter_context(torch.enable_grad()) + if cont_indices.shape[0] > 0: + # only require gradient for continuous parameters + ctx.cont_X = X[..., cont_indices].detach().requires_grad_(True) + cont_idx = 0 + cols = [] + for col in range(X.shape[-1]): + # cont_indices is sorted in ascending order + if ( + cont_idx < cont_indices.shape[0] + and col == cont_indices[cont_idx] + ): + cols.append(ctx.cont_X[..., cont_idx]) + cont_idx += 1 + else: + cols.append(X[..., col]) + X = torch.stack(cols, dim=-1) + else: + ctx.cont_X = None + ctx.discrete_indices = input_tf["round"].discrete_indices + ctx.cont_indices = cont_indices + ctx.categorical_indices = categorical_indices + ctx.ma_counter = ma_counter + ctx.ma_hidden = ma_hidden + ctx.X_shape = X.shape + tilde_x_samples = input_tf(X.unsqueeze(-3)) + # save the rounding component + + rounding_component = tilde_x_samples.clone() + if integer_indices.shape[0] > 0: + X_integer_params = X[..., integer_indices].unsqueeze(-3) + rounding_component[..., integer_indices] = ( + (tilde_x_samples[..., integer_indices] - X_integer_params > 0) + | (X_integer_params == 1) + ).to(tilde_x_samples) + if categorical_indices.shape[0] > 0: + rounding_component[..., categorical_indices] = tilde_x_samples[ + ..., categorical_indices + ] + ctx.rounding_component = rounding_component[..., ctx.discrete_indices] + ctx.tau = input_tf["round"].tau + if hasattr(input_tf["round"], "base_samples"): + ctx.base_samples = input_tf["round"].base_samples.detach() + # save the probabilities + if "unnormalize" in input_tf: + unnormalized_X = input_tf["unnormalize"](X) + else: + unnormalized_X = X + # this is only for the integer parameters + ctx.prob = input_tf["round"].get_rounding_prob(unnormalized_X) + + if categorical_indices.shape[0] > 0: + ctx.base_samples_categorical = input_tf[ + "round" + ].base_samples_categorical.clone() + # compute the acquisition function where inputs are rounded according to base_samples < prob + ctx.tilde_x_samples = tilde_x_samples + ctx.use_ma_baseline = use_ma_baseline + acq_values_list = [] + start_idx = 0 + if one_hot_to_numeric is not None: + tilde_x_samples = one_hot_to_numeric(tilde_x_samples) + + while start_idx < tilde_x_samples.shape[-3]: + end_idx = min(start_idx + batch_limit, tilde_x_samples.shape[-3]) + acq_values = acq_function(tilde_x_samples[..., start_idx:end_idx, :, :]) + acq_values_list.append(acq_values) + start_idx += batch_limit + acq_values = torch.cat(acq_values_list, dim=-1) + ctx.mean_acq_values = acq_values.mean( + dim=-1 + ) # average over samples from proposal distribution + ctx.acq_values = acq_values + # update moving average baseline + ctx.ma_hidden = ma_hidden.clone() + ctx.ma_counter = ctx.ma_counter.clone() + ctx.ma_decay = ma_decay + # update in place + ma_counter.add_(1) + ma_hidden.sub_((ma_hidden - acq_values.detach().mean()) * (1 - ma_decay)) + return ctx.mean_acq_values.detach() + + @staticmethod + def backward(ctx, grad_output): + """ + Compute the gradient of the expectation of the acquisition function + with respect to the parameters of the proposal distribution using + Monte Carlo. + """ + # this is overwriting the entire gradient w.r.t. x' + # x' has shape batch_shape x q x d + if ctx.needs_input_grad[0]: + acq_values = ctx.acq_values + mean_acq_values = ctx.mean_acq_values + cont_indices = ctx.cont_indices + discrete_indices = ctx.discrete_indices + rounding_component = ctx.rounding_component + # retrieve only the ordinal parameters + expanded_acq_values = acq_values.view(*acq_values.shape, 1, 1).expand( + acq_values.shape + rounding_component.shape[-2:] + ) + prob = ctx.prob.unsqueeze(-3) + if not ctx.use_ma_baseline: + sample_level = expanded_acq_values * (rounding_component - prob) + else: + # use reinforce with the moving average baseline + if ctx.ma_counter == 0: + baseline = 0.0 + else: + baseline = ctx.ma_hidden / ( + 1.0 - torch.pow(ctx.ma_decay, ctx.ma_counter) + ) + sample_level = (expanded_acq_values - baseline) * ( + rounding_component - prob + ) + + grads = (sample_level / ctx.tau).mean(dim=-3) + + new_grads = ( + grad_output.view( + *grad_output.shape, + *[1 for _ in range(grads.ndim - grad_output.ndim)], + ) + .expand(*grad_output.shape, *ctx.X_shape[-2:]) + .clone() + ) + # multiply upstream grad_output by new gradients + new_grads[..., discrete_indices] *= grads + # use autograd for gradients w.r.t. the continuous parameters + if ctx.cont_X is not None: + auto_grad = torch.autograd.grad( + # note: this multiplies the gradient of mean_acq_values w.r.t to input + # by grad_output + mean_acq_values, + ctx.cont_X, + grad_outputs=grad_output, + )[0] + # overwrite grad_output since the previous step already applied the chain rule + new_grads[..., cont_indices] = auto_grad + return ( + new_grads, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + return None, None, None, None, None, None, None, None, None, None, None, None + + +class AbstractProbabilisticReparameterization(AbstractAcquisitionFunctionWrapper): + r"""Acquisition Function Wrapper that leverages probabilistic reparameterization. + + The forward method is abstract and must be implemented. + + See [Daulton2022bopr]_ for details. + """ + + input_transform: ChainedInputTransform + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + **kwargs, + ) -> None: + r"""Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + """ + if categorical_features is None and integer_indices is None: + raise NotImplementedError( + "categorical_features or integer indices must be provided." + ) + super().__init__(acq_function=acq_function) + self.batch_limit = batch_limit + + if apply_numeric: + self.one_hot_to_numeric = OneHotToNumeric( + categorical_features=categorical_features, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + ) + self.one_hot_to_numeric.eval() + else: + self.one_hot_to_numeric = None + discrete_indices = [] + if integer_indices is not None: + self.register_buffer( + "integer_indices", + torch.tensor( + integer_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + discrete_indices.extend(integer_indices) + else: + self.register_buffer( + "integer_indices", + torch.tensor([], dtype=torch.long, device=one_hot_bounds.device), + ) + self.register_buffer( + "integer_bounds", + torch.tensor( + [], dtype=one_hot_bounds.dtype, device=one_hot_bounds.device + ), + ) + dim = one_hot_bounds.shape[1] + if categorical_features is not None and len(categorical_features) > 0: + categorical_indices = list(range(min(categorical_features.keys()), dim)) + discrete_indices.extend(categorical_indices) + self.register_buffer( + "categorical_indices", + torch.tensor( + categorical_indices, + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + self.categorical_features = categorical_features + else: + self.register_buffer( + "categorical_indices", + torch.tensor( + [], + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + + self.register_buffer( + "cont_indices", + torch.tensor( + sorted(set(range(dim)) - set(discrete_indices)), + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + self.model = acq_function.model # for sample_around_best heuristic + # moving average baseline + self.register_buffer( + "ma_counter", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + self.register_buffer( + "ma_hidden", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + self.register_buffer( + "ma_baseline", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + + def sample_candidates(self, X: Tensor) -> Tensor: + if "unnormalize" in self.input_transform: + unnormalized_X = self.input_transform["unnormalize"](X) + else: + unnormalized_X = X.clone() + prob = self.input_transform["round"].get_rounding_prob(X=unnormalized_X) + discrete_idx = 0 + for i in self.integer_indices: + p = prob[..., discrete_idx] + rounding_component = torch.distributions.Bernoulli(probs=p).sample() + unnormalized_X[..., i] = unnormalized_X[..., i].floor() + rounding_component + discrete_idx += 1 + if len(self.integer_indices) > 0: + unnormalized_X[..., self.integer_indices] = torch.minimum( + torch.maximum( + unnormalized_X[..., self.integer_indices], self.integer_bounds[0] + ), + self.integer_bounds[1], + ) + # this is the starting index for the categoricals in unnormalized_X + raw_idx = self.cont_indices.shape[0] + discrete_idx + if self.categorical_indices.shape[0] > 0: + for cardinality in self.categorical_features.values(): + discrete_end = discrete_idx + cardinality + p = prob[..., discrete_idx:discrete_end] + z = one_hot( + torch.distributions.Categorical(probs=p).sample(), + num_classes=cardinality, + ) + raw_end = raw_idx + cardinality + unnormalized_X[..., raw_idx:raw_end] = z + discrete_idx = discrete_end + raw_idx = raw_end + # normalize X + if "normalize" in self.input_transform: + return self.input_transform["normalize"](unnormalized_X) + return unnormalized_X + + +class AnalyticProbabilisticReparameterization(AbstractProbabilisticReparameterization): + """Analytic probabilistic reparameterization. + + Note: this is only reasonable from a computation perspective for relatively + small numbers of discrete options (probably less than a few thousand). + """ + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + tau: float = 0.1, + ) -> None: + """Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + tau: The temperature parameter used to determine the probabilities. + + """ + super().__init__( + acq_function=acq_function, + integer_indices=integer_indices, + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + batch_limit=batch_limit, + apply_numeric=apply_numeric, + ) + # create input transform + # need to compute cross product of discrete options and weights + self.input_transform = get_probabilistic_reparameterization_input_transform( + one_hot_bounds=one_hot_bounds, + use_analytic=True, + integer_indices=integer_indices, + categorical_features=categorical_features, + tau=tau, + ) + + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate PR.""" + X_discrete_all = self.input_transform(X.unsqueeze(-3)) + acq_values_list = [] + start_idx = 0 + if self.one_hot_to_numeric is not None: + X_discrete_all = self.one_hot_to_numeric(X_discrete_all) + if X.shape[-2] != 1: + raise NotImplementedError + + # save the probabilities + if "unnormalize" in self.input_transform: + unnormalized_X = self.input_transform["unnormalize"](X) + else: + unnormalized_X = X + # this is batch_shape x n_discrete (after squeezing) + probs = self.input_transform["round"].get_probs(X=unnormalized_X).squeeze(-1) + # TODO: filter discrete configs with zero probability + # this would require padding because there may be a different number in each batch. + while start_idx < X_discrete_all.shape[-3]: + end_idx = min(start_idx + self.batch_limit, X_discrete_all.shape[-3]) + acq_values = self.acq_func(X_discrete_all[..., start_idx:end_idx, :, :]) + acq_values_list.append(acq_values) + start_idx += self.batch_limit + # this is batch_shape x n_discrete + acq_values = torch.cat(acq_values_list, dim=-1) + # now weight the acquisition values by probabilities + return (acq_values * probs).sum(dim=-1) + + +class MCProbabilisticReparameterization(AbstractProbabilisticReparameterization): + r"""MC-based probabilistic reparameterization. + + See [Daulton2022bopr]_ for details. + """ + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + mc_samples: int = 128, + use_ma_baseline: bool = True, + tau: float = 0.1, + ma_decay: float = 0.7, + resample: bool = True, + ) -> None: + """Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + mc_samples: The number of MC samples for MC probabilistic + reparameterization. + use_ma_baseline: A boolean indicating whether to use a moving average + baseline for variance reduction. + tau: The temperature parameter used to determine the probabilities. + ma_decay: The decay parameter in the moving average baseline. + Default: 0.7 + resample: A boolean indicating whether to resample with MC + probabilistic reparameterization on each forward pass. + + """ + super().__init__( + acq_function=acq_function, + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + batch_limit=batch_limit, + apply_numeric=apply_numeric, + ) + if self.batch_limit is None: + self.batch_limit = mc_samples + self.use_ma_baseline = use_ma_baseline + self._pr_acq_function = _MCProbabilisticReparameterization() + # create input transform + self.input_transform = get_probabilistic_reparameterization_input_transform( + integer_indices=integer_indices, + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + mc_samples=mc_samples, + tau=tau, + resample=resample, + ) + self.ma_decay = ma_decay + + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate MC probabilistic reparameterization.""" + return self._pr_acq_function.apply( + X, + self.acq_func, + self.input_transform, + self.batch_limit, + self.integer_indices, + self.cont_indices, + self.categorical_indices, + self.use_ma_baseline, + self.one_hot_to_numeric, + self.ma_counter, + self.ma_hidden, + self.ma_decay, + ) diff --git a/botorch/acquisition/proximal.py b/botorch/acquisition/proximal.py index 9cd4aed7ad..b1d68edef1 100644 --- a/botorch/acquisition/proximal.py +++ b/botorch/acquisition/proximal.py @@ -15,6 +15,8 @@ import torch from botorch.acquisition import AcquisitionFunction + +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models import ModelListGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel @@ -25,7 +27,7 @@ from torch.nn import Module -class ProximalAcquisitionFunction(AcquisitionFunction): +class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AcquisitionFunctions to add proximal weighting of the acquisition function. The acquisition function is weighted via a squared exponential centered at the last training point, @@ -70,9 +72,7 @@ def __init__( beta: If not None, apply a softplus transform to the base acquisition function, allows negative base acquisition function values. """ - Module.__init__(self) - - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) model = self.acq_func.model if hasattr(acq_function, "X_pending"): @@ -80,7 +80,6 @@ def __init__( raise UnsupportedError( "Proximal acquisition function requires `X_pending` to be None." ) - self.X_pending = acq_function.X_pending self.register_buffer("proximal_weights", proximal_weights) self.register_buffer( @@ -91,6 +90,12 @@ def __init__( _validate_model(model, proximal_weights) + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + raise UnsupportedError( + "Proximal acquisition function does not support `X_pending`." + ) + @t_batch_mode_transform(expected_q=1, assert_output_shape=False) def forward(self, X: Tensor) -> Tensor: r"""Evaluate base acquisition function with proximal weighting. diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 486fdd0cff..ccbbf471b2 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -11,7 +11,7 @@ from __future__ import annotations import math -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401 @@ -22,6 +22,7 @@ MCAcquisitionObjective, PosteriorTransform, ) +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models.fully_bayesian import MCMC_DIM from botorch.models.model import Model @@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None): return -(lb.clamp_max(0.0)) +def isinstance_af( + __obj: object, + __class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]], +) -> bool: + r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions.""" + if isinstance(__obj, AbstractAcquisitionFunctionWrapper): + isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple) + else: + isinstance_base_af = False + return isinstance_base_af or isinstance(__obj, __class_or_tuple) + + def is_nonnegative(acq_function: AcquisitionFunction) -> bool: r"""Determine whether a given acquisition function is non-negative. @@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool: >>> qEI = qExpectedImprovement(model, best_f=0.1) >>> is_nonnegative(qEI) # returns True """ - return isinstance( + return isinstance_af( acq_function, ( analytic.ExpectedImprovement, diff --git a/botorch/acquisition/wrapper.py b/botorch/acquisition/wrapper.py new file mode 100644 index 0000000000..08dfbd2849 --- /dev/null +++ b/botorch/acquisition/wrapper.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +A wrapper classes around AcquisitionFunctions to modify inputs and outputs. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +from botorch.acquisition.acquisition import AcquisitionFunction +from torch import Tensor +from torch.nn import Module + + +class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC): + r"""Abstract acquisition wrapper.""" + + def __init__(self, acq_function: AcquisitionFunction) -> None: + Module.__init__(self) + self.acq_func = acq_function + + @property + def X_pending(self) -> Optional[Tensor]: + r"""Return the `X_pending` of the base acquisition function.""" + try: + return self.acq_func.X_pending + except (ValueError, AttributeError): + raise ValueError( + f"Base acquisition function {type(self.acq_func).__name__} " + "does not have an `X_pending` attribute." + ) + + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + self.acq_func.set_X_pending(X_pending) + + @abstractmethod + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate the wrapped acquisition function on the candidate set X. + + Args: + X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim + design points each. + + Returns: + A `(b)`-dim Tensor of acquisition function values at the given + design points `X`. + """ + pass # pragma: no cover diff --git a/botorch/models/transforms/factory.py b/botorch/models/transforms/factory.py index 847fdf1b7c..486dbc3125 100644 --- a/botorch/models/transforms/factory.py +++ b/botorch/models/transforms/factory.py @@ -10,7 +10,9 @@ from typing import Dict, List, Optional from botorch.models.transforms.input import ( + AnalyticProbabilisticReparameterizationInputTransform, ChainedInputTransform, + MCProbabilisticReparameterizationInputTransform, Normalize, OneHotToNumeric, Round, @@ -123,3 +125,83 @@ def get_rounding_input_transform( tf.to(dtype=one_hot_bounds.dtype, device=one_hot_bounds.device) tf.eval() return tf + + +def get_probabilistic_reparameterization_input_transform( + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + use_analytic: bool = False, + mc_samples: int = 128, + resample: bool = False, + tau: float = 0.1, +) -> ChainedInputTransform: + r"""Construct InputTransform for Probabilistic Reparameterization. + + Note: this is intended to be used only for acquisition optimization + in via the AnalyticProbabilisticReparameterization and + MCProbabilisticReparameterization classes. This is not intended to be + attached to a botorch Model. + + See [Daulton2022bopr]_ for details. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + use_analytic: A boolean indicating whether to use analytic + probabilistic reparameterization. + mc_samples: The number of MC samples for MC probabilistic + reparameterization. + resample: A boolean indicating whether to resample with MC + probabilistic reparameterization on each forward pass. + tau: The temperature parameter used to determine the probabilities. + + Returns: + The probabilistic reparameterization input transformation. + """ + tfs = OrderedDict() + if integer_indices is not None and len(integer_indices) > 0: + # unnormalize to integer space + tfs["unnormalize"] = Normalize( + d=one_hot_bounds.shape[1], + bounds=one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + reverse=True, + ) + if use_analytic: + tfs["round"] = AnalyticProbabilisticReparameterizationInputTransform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + tau=tau, + ) + else: + tfs["round"] = MCProbabilisticReparameterizationInputTransform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + resample=resample, + mc_samples=mc_samples, + tau=tau, + ) + if integer_indices is not None and len(integer_indices) > 0: + # normalize to unit cube + tfs["normalize"] = Normalize( + d=one_hot_bounds.shape[1], + bounds=one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + reverse=False, + ) + tf = ChainedInputTransform(**tfs) + tf.eval() + return tf diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 09310163b5..0bc649dedf 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -25,6 +25,7 @@ from botorch.models.transforms.utils import subset_transform from botorch.models.utils import fantasize from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE +from botorch.utils.sampling import draw_sobol_samples from gpytorch import Module as GPyTorchModule from gpytorch.constraints import GreaterThan from gpytorch.priors import Prior @@ -1503,3 +1504,574 @@ def equals(self, other: InputTransform) -> bool: and (self.transform_on_fantasize == other.transform_on_fantasize) and self.categorical_features == other.categorical_features ) + + +class AnalyticProbabilisticReparameterizationInputTransform(InputTransform, Module): + r"""An input transform to prepare inputs for analytic PR. + + See [Daulton2022bopr]_ for details. + + This will typically be used in conjunction with normalization as + follows: + + In eval() mode (i.e. after training), the inputs pass + would typically be normalized to the unit cube (e.g. during candidate + optimization). + 1. These are unnormalized back to the raw input space. + 2. The discrete values are created. + 3. All values are normalized to the unitcube. + + TODO: consolidate this with MCProbabilisticReparameterizationInputTransform. + + """ + + def __init__( + self, + one_hot_bounds: Tensor = None, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + transform_on_train: bool = False, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + tau: float = 0.1, + ) -> None: + r"""Initialize transform. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer inputs. + categorical_features: The indices and cardinality of + each categorical feature. The features are assumed + to be one-hot encoded. TODO: generalize to support + alternative representations. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + mc_samples: The number of MC samples. + resample: A boolean indicating whether to resample base samples + at each forward pass. + tau: The temperature parameter. + """ + super().__init__() + if integer_indices is None and categorical_features is None: + raise ValueError( + "integer_indices and/or categorical_features must be provided." + ) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + discrete_indices = [] + if integer_indices is not None and len(integer_indices) > 0: + self.register_buffer( + "integer_indices", + torch.tensor( + integer_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + discrete_indices += integer_indices + else: + self.integer_indices = None + self.categorical_features = categorical_features + if self.categorical_features is not None: + self.categorical_start_idx = min(self.categorical_features.keys()) + # check that the trailing dimensions are categoricals + end = self.categorical_start_idx + err_msg = ( + f"{self.__class__.__name__} requires that the categorical " + "parameters are the rightmost elements." + ) + for start, card in self.categorical_features.items(): + # the end of one one-hot representation should be followed + # by the start of the next + if end != start: + raise ValueError(err_msg) + end = start + card + if end != one_hot_bounds.shape[1]: + # check end + raise ValueError(err_msg) + categorical_starts = [] + categorical_ends = [] + if self.categorical_features is not None: + start = None + for i, n_categories in categorical_features.items(): + if start is None: + start = i + end = start + n_categories + categorical_starts.append(start) + categorical_ends.append(end) + discrete_indices += list(range(start, end)) + start = end + self.register_buffer( + "discrete_indices", + torch.tensor( + discrete_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_starts", + torch.tensor( + categorical_starts, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_ends", + torch.tensor( + categorical_ends, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.tau = tau + # create cartesian product of discrete options + discrete_options = [] + dim = one_hot_bounds.shape[1] + # get number of discrete parameters + num_discrete_params = 0 + if self.integer_indices is not None: + num_discrete_params += self.integer_indices.shape[0] + if self.categorical_features is not None: + num_discrete_params += len(self.categorical_features) + # add zeros for continuous params to simplify code + for _ in range(dim - len(discrete_indices)): + discrete_options.append( + torch.zeros( + 1, + dtype=torch.long, + device=one_hot_bounds.device, + ) + ) + if integer_indices is not None: + for i in range(self.integer_bounds.shape[-1]): + discrete_options.append( + torch.arange( + self.integer_bounds[0, i], + self.integer_bounds[1, i] + 1, + dtype=torch.long, + device=one_hot_bounds.device, + ) + ) + categorical_start_idx = len(discrete_options) + if categorical_features is not None: + for idx in sorted(categorical_features.keys()): + cardinality = categorical_features[idx] + discrete_options.append( + torch.arange( + cardinality, dtype=torch.long, device=one_hot_bounds.device + ) + ) + # categoricals are in numeric representation + all_discrete_options = torch.cartesian_prod(*discrete_options) + # one-hot encode the categoricals + if categorical_features is not None and len(categorical_features) > 0: + X_categ = torch.empty( + *all_discrete_options.shape[:-1], sum(categorical_features.values()) + ) + start = 0 + for i, (idx, cardinality) in enumerate( + sorted(categorical_features.items(), key=lambda kv: kv[0]) + ): + start = idx - categorical_start_idx + X_categ[..., start : start + cardinality] = one_hot( + all_discrete_options[..., i], + num_classes=cardinality, + ).to(X_categ) + all_discrete_options = torch.cat( + [all_discrete_options[..., : -len(categorical_features)], X_categ], + dim=-1, + ) + self.register_buffer("all_discrete_options", all_discrete_options) + + def get_rounding_prob(self, X: Tensor) -> Tensor: + # todo consolidate this the MCProbabilisticReparameterizationInputTransform + X_prob = X.detach().clone() + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X_prob[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + if self.tau is not None: + X_prob[..., self.integer_indices] = torch.sigmoid( + (X_int_abs - offset - 0.5) / self.tau + ) + else: + X_prob[..., self.integer_indices] = X_int_abs - offset + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X_prob[..., start:end] + if self.tau is not None: + X_prob[..., start:end] = torch.softmax( + (X_categ - 0.5) / self.tau, dim=-1 + ) + else: + X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) + return X_prob[..., self.discrete_indices] + + def get_probs(self, X: Tensor) -> Tensor: + """ + Args: + X: a `batch_shape x n x d`-dim tensor + + Returns: + A `batch_shape x n_discrete x n`-dim tensors of probabilities of each discrete config under X. + """ + # note this method should be differentiable + X_prob = torch.ones( + *X.shape[:-2], + self.all_discrete_options.shape[0], + X.shape[-2], + dtype=X.dtype, + device=X.device, + ) + # n_discrete x batch_shape x n x d + all_discrete_options = self.all_discrete_options.view( + *([1] * (X.ndim - 2)), self.all_discrete_options.shape[0], *X.shape[-2:] + ).expand(*X.shape[:-2], self.all_discrete_options.shape[0], *X.shape[-2:]) + X = X.unsqueeze(-3) + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + # note we don't actually need the sigmoid here + X_prob_int = torch.sigmoid((X_int_abs - offset - 0.5) / self.tau) + # X_prob_int = X_int_abs - offset + for int_idx, idx in enumerate(self.integer_indices): + offset_i = offset[..., int_idx] + all_discrete_i = all_discrete_options[..., idx] + diff = (offset_i + 1) - all_discrete_i + round_up_mask = diff == 0 + round_down_mask = diff == 1 + neither_mask = ~(round_up_mask | round_down_mask) + prob = X_prob_int[..., int_idx].expand(round_up_mask.shape) + # need to be careful with in-place ops here for autograd + X_prob[round_up_mask] = X_prob[round_up_mask] * prob[round_up_mask] + X_prob[round_down_mask] = X_prob[round_down_mask] * ( + 1 - prob[round_down_mask] + ) + X_prob[neither_mask] = X_prob[neither_mask] * 0 + + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X[..., start:end] + X_prob_c = torch.softmax((X_categ - 0.5) / self.tau, dim=-1).expand( + *X_categ.shape[:-3], all_discrete_options.shape[-3], *X_categ.shape[-2:] + ) + for i in range(X_prob_c.shape[-1]): + mask = all_discrete_options[..., start + i] == 1 + X_prob[mask] = X_prob[mask] * X_prob_c[..., i][mask] + + return X_prob + + def transform(self, X: Tensor) -> Tensor: + r"""Round the inputs. + + This is not sample-path differentiable. + + Args: + X: A `batch_shape x 1 x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n_discrete x n x d`-dim tensor of rounded inputs. + """ + n_discrete = self.discrete_indices.shape[0] + all_discrete_options = self.all_discrete_options.view( + *([1] * (X.ndim - 3)), self.all_discrete_options.shape[0], *X.shape[-2:] + ).expand(*X.shape[:-3], self.all_discrete_options.shape[0], *X.shape[-2:]) + if X.shape[-1] > n_discrete: + X = X.expand( + *X.shape[:-3], self.all_discrete_options.shape[0], *X.shape[-2:] + ) + return torch.cat( + [X[..., :-n_discrete], all_discrete_options[..., -n_discrete:]], dim=-1 + ) + return all_discrete_options + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + # TODO: update this + return super().equals(other=other) and torch.equal( + self.integer_indices, other.integer_indices + ) + + +class MCProbabilisticReparameterizationInputTransform(InputTransform, Module): + r"""An input transform to prepare inputs for analytic PR. + + See [Daulton2022bopr]_ for details. + + This will typically be used in conjunction with normalization as + follows: + + In eval() mode (i.e. after training), the inputs pass + would typically be normalized to the unit cube (e.g. during candidate + optimization). + 1. These are unnormalized back to the raw input space. + 2. The discrete ordinal valeus are sampled. + 3. All values are normalized to the unitcube. + """ + + def __init__( + self, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + transform_on_train: bool = False, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + mc_samples: int = 128, + resample: bool = False, + tau: float = 0.1, + ) -> None: + r"""Initialize transform. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer inputs. + categorical_features: The indices and cardinality of + each categorical feature. The features are assumed + to be one-hot encoded. TODO: generalize to support + alternative representations. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + mc_samples: The number of MC samples. + resample: A boolean indicating whether to resample base samples + at each forward pass. + tau: The temperature parameter. + """ + super().__init__() + if integer_indices is None and categorical_features is None: + raise ValueError( + "integer_indices and/or categorical_features must be provided." + ) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + discrete_indices = [] + if integer_indices is not None and len(integer_indices) > 0: + self.register_buffer( + "integer_indices", torch.tensor(integer_indices, dtype=torch.long) + ) + discrete_indices += integer_indices + else: + self.integer_indices = None + self.categorical_features = categorical_features + if self.categorical_features is not None: + self.categorical_start_idx = min(self.categorical_features.keys()) + # check that the trailing dimensions are categoricals + end = self.categorical_start_idx + err_msg = ( + f"{self.__class__.__name__} requires that the categorical " + "parameters are the rightmost elements." + ) + for start, card in self.categorical_features.items(): + # the end of one one-hot representation should be followed + # by the start of the next + if end != start: + raise ValueError(err_msg) + end = start + card + if end != one_hot_bounds.shape[1]: + # check end + raise ValueError(err_msg) + categorical_starts = [] + categorical_ends = [] + if self.categorical_features is not None: + start = None + for i, n_categories in categorical_features.items(): + if start is None: + start = i + end = start + n_categories + categorical_starts.append(start) + categorical_ends.append(end) + discrete_indices += list(range(start, end)) + start = end + self.register_buffer( + "discrete_indices", + torch.tensor( + discrete_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_starts", + torch.tensor( + categorical_starts, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_ends", + torch.tensor( + categorical_ends, dtype=torch.long, device=one_hot_bounds.device + ), + ) + if integer_indices is None: + self.register_buffer( + "integer_bounds", + torch.tensor([], dtype=torch.long, device=one_hot_bounds.device), + ) + else: + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + self.mc_samples = mc_samples + self.resample = resample + self.tau = tau + + def get_rounding_prob(self, X: Tensor) -> Tensor: + X_prob = X.detach().clone() + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X_prob[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + if self.tau is not None: + X_prob[..., self.integer_indices] = torch.sigmoid( + (X_int_abs - offset - 0.5) / self.tau + ) + else: + X_prob[..., self.integer_indices] = X_int_abs - offset + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X_prob[..., start:end] + if self.tau is not None: + X_prob[..., start:end] = torch.softmax( + (X_categ - 0.5) / self.tau, dim=-1 + ) + else: + X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) + return X_prob[..., self.discrete_indices] + + def transform(self, X: Tensor) -> Tensor: + r"""Round the inputs. + + This is not sample-path differentiable. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of rounded inputs. + """ + X_expanded = X.expand(*X.shape[:-3], self.mc_samples, *X.shape[-2:]).clone() + X_prob = self.get_rounding_prob(X=X) + if self.integer_indices is not None: + X_int = X[..., self.integer_indices].detach() + assert X.ndim > 1 + if X.ndim == 2: + X.unsqueeze(-1) + if ( + not hasattr(self, "base_samples") + or self.base_samples.shape[-2:] != X_int.shape[-2:] + or self.resample + ): + # construct sobol base samples + bounds = torch.zeros( + 2, X_int.shape[-1], dtype=X_int.dtype, device=X_int.device + ) + bounds[1] = 1 + self.register_buffer( + "base_samples", + draw_sobol_samples( + bounds=bounds, + n=self.mc_samples, + q=X_int.shape[-2], + seed=torch.randint(0, 100000, (1,)).item(), + ), + ) + X_int_abs = X_int.abs() + # perform exact rounding + is_negative = X_int < 0 + offset = X_int_abs.floor() + prob = X_prob[..., : self.integer_indices.shape[0]] + rounding_component = (prob >= self.base_samples).to( + dtype=X.dtype, + ) + X_abs_rounded = offset + rounding_component + X_int_new = (-1) ** is_negative.to(offset) * X_abs_rounded + # clamp to bounds + X_expanded[..., self.integer_indices] = torch.minimum( + torch.maximum(X_int_new, self.integer_bounds[0]), self.integer_bounds[1] + ) + + # sample for categoricals + if self.categorical_features is not None and len(self.categorical_features) > 0: + if ( + not hasattr(self, "base_samples_categorical") + or self.base_samples_categorical.shape[-2] != X.shape[-2] + or self.resample + ): + bounds = torch.zeros( + 2, len(self.categorical_features), dtype=X.dtype, device=X.device + ) + bounds[1] = 1 + self.register_buffer( + "base_samples_categorical", + draw_sobol_samples( + bounds=bounds, + n=self.mc_samples, + q=X.shape[-2], + seed=torch.randint(0, 100000, (1,)).item(), + ), + ) + + # sample from multinomial as argmin_c [sample_c * exp(-x_c)] + sample_d_start_idx = 0 + X_categ_prob = X_prob + if self.integer_indices is not None: + n_ints = self.integer_indices.shape[0] + if n_ints > 0: + X_categ_prob = X_prob[..., n_ints:] + + for i, cardinality in enumerate(self.categorical_features.values()): + sample_d_end_idx = sample_d_start_idx + cardinality + start = self.categorical_starts[i] + end = self.categorical_ends[i] + cum_prob = X_categ_prob[ + ..., sample_d_start_idx:sample_d_end_idx + ].cumsum(dim=-1) + categories = ( + ( + (cum_prob > self.base_samples_categorical[..., i : i + 1]) + .long() + .cumsum(dim=-1) + == 1 + ) + .long() + .argmax(dim=-1) + ) + # one-hot encode + X_expanded[..., start:end] = one_hot( + categories, num_classes=cardinality + ).to(X) + sample_d_start_idx = sample_d_end_idx + + return X_expanded + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return ( + super().equals(other=other) + and (self.resample == other.resample) + and torch.equal(self.base_samples, other.base_samples) + and torch.equal(self.integer_indices, other.integer_indices) + ) diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index 79f529826a..a5a429e46b 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -21,6 +21,11 @@ Analytic Acquisition Function API .. autoclass:: AnalyticAcquisitionFunction :members: +Acquisition Function Wrapper API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.wrapper + :members: + Cached Cholesky Acquisition Function API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.cached_cholesky @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions .. automodule:: botorch.acquisition.multi_objective.analytic :members: :exclude-members: MultiObjectiveAnalyticAcquisitionFunction - + Multi-Objective Joint Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.joint_entropy_search @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.multi_fidelity :members: - + Multi-Objective Predictive Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search @@ -175,6 +180,11 @@ Penalized Acquisition Function Wrapper .. automodule:: botorch.acquisition.penalized :members: +Probabilistic Reparameterization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.probabilistic_reparameterization + :members: + Proximal Acquisition Function Wrapper ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.proximal diff --git a/test/acquisition/test_fixed_feature.py b/test/acquisition/test_fixed_feature.py index 8dcc02f1df..b8f570e7e1 100644 --- a/test/acquisition/test_fixed_feature.py +++ b/test/acquisition/test_fixed_feature.py @@ -87,7 +87,7 @@ def test_fixed_features(self): qEI_ff.set_X_pending(X_pending[..., :-1]) self.assertAllClose(qEI.X_pending, X_pending) # test setting to None - qEI_ff.X_pending = None + qEI_ff.set_X_pending(None) self.assertIsNone(qEI_ff.X_pending) # test gradient diff --git a/test/acquisition/test_proximal.py b/test/acquisition/test_proximal.py index 795daa1b34..e17536ddd0 100644 --- a/test/acquisition/test_proximal.py +++ b/test/acquisition/test_proximal.py @@ -209,9 +209,15 @@ def test_proximal(self): # test for x_pending points pending_acq = DummyAcquisitionFunction(model) - pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype)) + X_pending = torch.rand(3, 3, device=self.device, dtype=dtype) + pending_acq.set_X_pending(X_pending) with self.assertRaises(UnsupportedError): ProximalAcquisitionFunction(pending_acq, proximal_weights) + # test setting pending points + pending_acq.set_X_pending(None) + af = ProximalAcquisitionFunction(pending_acq, proximal_weights) + with self.assertRaises(UnsupportedError): + af.set_X_pending(X_pending) # test model with multi-batch training inputs train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype) diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index d12b5f6da4..39b8017ea2 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -8,7 +8,8 @@ from unittest import mock import torch -from botorch.acquisition import monte_carlo +from botorch.acquisition import analytic, monte_carlo, multi_objective +from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.multi_objective import ( MCMultiOutputObjective, monte_carlo as moo_monte_carlo, @@ -18,10 +19,13 @@ MCAcquisitionObjective, ScalarizedPosteriorTransform, ) +from botorch.acquisition.proximal import ProximalAcquisitionFunction from botorch.acquisition.utils import ( expand_trace_observations, get_acquisition_function, get_infeasible_cost, + is_nonnegative, + isinstance_af, project_to_sample_points, project_to_target_fidelity, prune_inferior_points, @@ -606,6 +610,61 @@ def test_get_infeasible_cost(self): self.assertAllClose(M4, torch.tensor([1.0], **tkwargs)) +class TestIsNonnegative(BotorchTestCase): + def test_is_nonnegative(self): + nonneg_afs = ( + analytic.ExpectedImprovement, + analytic.ConstrainedExpectedImprovement, + analytic.ProbabilityOfImprovement, + analytic.NoisyExpectedImprovement, + monte_carlo.qExpectedImprovement, + monte_carlo.qNoisyExpectedImprovement, + monte_carlo.qProbabilityOfImprovement, + multi_objective.analytic.ExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, + ) + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + with mock.patch( + "botorch.acquisition.utils.isinstance_af", return_value=True + ) as mock_isinstance_af: + self.assertTrue(is_nonnegative(acq_function=acq_func)) + mock_isinstance_af.assert_called_once() + cargs, _ = mock_isinstance_af.call_args + self.assertIs(cargs[0], acq_func) + self.assertEqual(cargs[1], nonneg_afs) + acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0) + self.assertFalse(is_nonnegative(acq_function=acq_func)) + + +class TestIsinstanceAf(BotorchTestCase): + def test_isinstance_af(self): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound)) + wrapped_af = FixedFeatureAcquisitionFunction( + acq_function=acq_func, d=2, columns=[1], values=[0.0] + ) + # test base af class + self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound)) + # test wrapper class + self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction)) + self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction)) + + class TestPruneInferiorPoints(BotorchTestCase): def test_prune_inferior_points(self): for dtype in (torch.float, torch.double): diff --git a/test/acquisition/test_wrapper.py b/test/acquisition/test_wrapper.py new file mode 100644 index 0000000000..e35175fb9b --- /dev/null +++ b/test/acquisition/test_wrapper.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.acquisition.analytic import ExpectedImprovement +from botorch.acquisition.monte_carlo import qExpectedImprovement +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper +from botorch.exceptions.errors import UnsupportedError +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + + +class DummyWrapper(AbstractAcquisitionFunctionWrapper): + def forward(self, X): + return self.acq_func(X) + + +class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase): + def test_abstract_acquisition_function_wrapper(self): + for dtype in (torch.float, torch.double): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, dtype=dtype, device=self.device), + variance=torch.ones(1, 1, dtype=dtype, device=self.device), + ) + ) + acq_func = ExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIs(wrapped_af.acq_func, acq_func) + # test forward + X = torch.rand(1, 1, dtype=dtype, device=self.device) + with torch.no_grad(): + wrapped_val = wrapped_af(X) + af_val = acq_func(X) + self.assertEqual(wrapped_val.item(), af_val.item()) + + # test X_pending + with self.assertRaises(ValueError): + self.assertIsNone(wrapped_af.X_pending) + with self.assertRaises(UnsupportedError): + wrapped_af.set_X_pending(X) + acq_func = qExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIsNone(wrapped_af.X_pending) + wrapped_af.set_X_pending(X) + self.assertTrue(torch.equal(X, wrapped_af.X_pending)) + self.assertTrue(torch.equal(X, acq_func.X_pending)) + wrapped_af.set_X_pending(None) + self.assertIsNone(wrapped_af.X_pending) + self.assertIsNone(acq_func.X_pending)