diff --git a/botorch/acquisition/multi_objective/utils.py b/botorch/acquisition/multi_objective/utils.py index cca67ad55e..ca45c869e6 100644 --- a/botorch/acquisition/multi_objective/utils.py +++ b/botorch/acquisition/multi_objective/utils.py @@ -10,7 +10,6 @@ from __future__ import annotations -import math import warnings from collections.abc import Callable from math import ceil @@ -18,16 +17,12 @@ import torch from botorch.acquisition import monte_carlo # noqa F401 -from botorch.acquisition.multi_objective.objective import ( - IdentityMCMultiOutputObjective, - MCMultiOutputObjective, -) +from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective +from botorch.acquisition.utils import _prune_inferior_shared_processing from botorch.exceptions.errors import UnsupportedError from botorch.exceptions.warnings import BotorchWarning from botorch.models.deterministic import GenericDeterministicModel -from botorch.models.fully_bayesian import MCMC_DIM from botorch.models.model import Model -from botorch.sampling.get_sampler import get_sampler from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model from botorch.utils.multi_objective.box_decompositions.box_decomposition import ( BoxDecomposition, @@ -39,9 +34,8 @@ DominatedPartitioning, ) from botorch.utils.multi_objective.pareto import is_non_dominated -from botorch.utils.objective import compute_feasibility_indicator from botorch.utils.sampling import draw_sobol_samples -from botorch.utils.transforms import is_ensemble +from pyre_extensions import assert_is_instance from torch import Tensor @@ -115,40 +109,14 @@ def prune_inferior_points_multi_objective( with `N_nz` the number of points in `X` that have non-zero (empirical, under `num_samples` samples) probability of being pareto optimal. """ - if marginalize_dim is None and is_ensemble(model): - # TODO: Properly deal with marginalizing fully Bayesian models - marginalize_dim = MCMC_DIM - - if X.ndim > 2: - # TODO: support batched inputs (req. dealing with ragged tensors) - raise UnsupportedError( - "Batched inputs `X` are currently unsupported by " - "prune_inferior_points_multi_objective" - ) - if X.size(-2) == 0: - raise ValueError("X must have at least one point.") - if max_frac <= 0 or max_frac > 1.0: - raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}") - max_points = math.ceil(max_frac * X.size(-2)) - with torch.no_grad(): - posterior = model.posterior(X=X) - sampler = get_sampler(posterior, sample_shape=torch.Size([num_samples])) - samples = sampler(posterior) - if objective is None: - objective = IdentityMCMultiOutputObjective() - obj_vals = objective(samples, X=X) - if obj_vals.ndim > 3: - if obj_vals.ndim == 4 and marginalize_dim is not None: - obj_vals = obj_vals.mean(dim=marginalize_dim) - else: - # TODO: support batched inputs (req. dealing with ragged tensors) - raise UnsupportedError( - "Models with multiple batch dims are currently unsupported by" - " prune_inferior_points_multi_objective." - ) - infeas = ~compute_feasibility_indicator( + max_points, obj_vals, infeas = _prune_inferior_shared_processing( + model=model, + X=X, + is_moo=True, + objective=objective, constraints=constraints, - samples=samples, + num_samples=num_samples, + max_frac=max_frac, marginalize_dim=marginalize_dim, ) if infeas.any(): @@ -168,9 +136,9 @@ def prune_inferior_points_multi_objective( def compute_sample_box_decomposition( pareto_fronts: Tensor, - partitioning: BoxDecomposition = DominatedPartitioning, + partitioning: type[BoxDecomposition] = DominatedPartitioning, maximize: bool = True, - num_constraints: int | None = 0, + num_constraints: int = 0, ) -> Tensor: r"""Computes the box decomposition associated with some sampled optimal objectives. This also supports the single-objective and constrained optimization @@ -195,7 +163,10 @@ def compute_sample_box_decomposition( the hyper-rectangles. The number `J` is the smallest number of boxes needed to partition all the Pareto samples. """ - tkwargs = {"dtype": pareto_fronts.dtype, "device": pareto_fronts.device} + tkwargs: dict[str, Any] = { + "dtype": pareto_fronts.dtype, + "device": pareto_fronts.device, + } # We will later compute `norm.log_prob(NEG_INF)`, this is `-inf` if `NEG_INF` is # too small. NEG_INF = -1e10 @@ -214,16 +185,18 @@ def compute_sample_box_decomposition( if M == 1: # Only consider a Pareto front with one element. - extreme_values = weight * torch.max(weight * pareto_fronts, dim=-2).values + extreme_values = assert_is_instance( + weight * torch.max(weight * pareto_fronts, dim=-2).values, Tensor + ) ref_point = weight * ref_point.expand(extreme_values.shape) if maximize: hypercell_bounds = torch.stack( - [ref_point, extreme_values], axis=-2 + [ref_point, extreme_values], dim=-2 ).unsqueeze(-1) else: hypercell_bounds = torch.stack( - [extreme_values, ref_point], axis=-2 + [extreme_values, ref_point], dim=-2 ).unsqueeze(-1) else: bd_list = [] @@ -244,9 +217,7 @@ def compute_sample_box_decomposition( # Add an extra box for the inequality constraint. if K > 0: # `num_pareto_samples x 2 x (J - 1) x K` - feasible_boxes = torch.zeros( - hypercell_bounds.shape[:-1] + torch.Size([K]), **tkwargs - ) + feasible_boxes = torch.zeros(hypercell_bounds.shape[:-1] + (K,), **tkwargs) feasible_boxes[..., 0, :, :] = NEG_INF # `num_pareto_samples x 2 x (J - 1) x (M + K)` @@ -254,7 +225,7 @@ def compute_sample_box_decomposition( # `num_pareto_samples x 2 x 1 x (M + K)` infeasible_box = torch.zeros( - hypercell_bounds.shape[:-2] + torch.Size([1, M + K]), **tkwargs + hypercell_bounds.shape[:-2] + (1, M + K), **tkwargs ) infeasible_box[..., 1, :, M:] = -NEG_INF infeasible_box[..., 0, :, 0:M] = NEG_INF @@ -292,11 +263,12 @@ def random_search_optimizer( - A `num_points x M`-dim Tensor containing the collection of optimal objectives. """ - tkwargs = {"dtype": bounds.dtype, "device": bounds.device} + tkwargs: dict[str, Any] = {"dtype": bounds.dtype, "device": bounds.device} weight = 1.0 if maximize else -1.0 optimal_inputs = torch.tensor([], **tkwargs) optimal_outputs = torch.tensor([], **tkwargs) num_tries = 0 + num_found = 0 ratio = 2 while ratio > 1 and num_tries < max_tries: X = draw_sobol_samples(bounds=bounds, n=pop_size, q=1).squeeze(-2) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index d486629b76..a930488680 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -15,7 +15,6 @@ import torch from botorch.acquisition.objective import ( - IdentityMCObjective, MCAcquisitionObjective, PosteriorTransform, ScalarizedPosteriorTransform, @@ -34,6 +33,7 @@ from botorch.utils.sampling import optimize_posterior_samples from botorch.utils.transforms import is_ensemble, normalize_indices from gpytorch.models import GP +from pyre_extensions import none_throws from torch import Tensor @@ -244,6 +244,76 @@ def objective(Y: Tensor, X: Tensor | None = None): return -(lb.clamp_max(0.0)) +def _prune_inferior_shared_processing( + model: Model, + X: Tensor, + is_moo: bool, + objective: MCAcquisitionObjective | None = None, + posterior_transform: PosteriorTransform | None = None, + constraints: list[Callable[[Tensor], Tensor]] | None = None, + num_samples: int = 2048, + max_frac: float = 1.0, + sampler: MCSampler | None = None, + marginalize_dim: int | None = None, +) -> tuple[int, Tensor, Tensor]: + r"""Shared data processing for `prune_inferior_points` and + `prune_inferior_points_multi_objective`. + + Returns: + - max_points: The maximum number of points to keep. + - obj_vals: The objective values of the points in `X`. + - infeas: A boolean tensor indicating feasibility of `X`. + """ + func_name = ( + "prune_inferior_points_multi_objective" if is_moo else "prune_inferior_points" + ) + if marginalize_dim is None and is_ensemble(model): + marginalize_dim = MCMC_DIM + + if X.ndim > 2: + raise UnsupportedError( + f"Batched inputs `X` are currently unsupported by `{func_name}`" + ) + if X.size(-2) == 0: + raise ValueError("X must have at least one point.") + if max_frac <= 0 or max_frac > 1.0: + raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}") + max_points = math.ceil(max_frac * X.size(-2)) + with torch.no_grad(): + posterior = model.posterior(X=X, posterior_transform=posterior_transform) + if sampler is None: + sampler = get_sampler( + posterior=posterior, sample_shape=torch.Size([num_samples]) + ) + samples = sampler(posterior) + if objective is not None: + obj_vals = objective(samples=samples, X=X) + elif is_moo: + obj_vals = samples + else: + obj_vals = samples.squeeze(-1) + if obj_vals.ndim > (2 + is_moo): + if obj_vals.ndim == (3 + is_moo) and marginalize_dim is not None: + if marginalize_dim < 0: + # Update `marginalize_dim` to be positive while accounting for + # removal of output dimension in SOO. + marginalize_dim = (not is_moo) + none_throws( + normalize_indices([marginalize_dim], d=obj_vals.ndim) + )[0] + obj_vals = obj_vals.mean(dim=marginalize_dim) + else: + raise UnsupportedError( + "Models with multiple batch dims are currently unsupported by " + f"`{func_name}`." + ) + infeas = ~compute_feasibility_indicator( + constraints=constraints, + samples=samples, + marginalize_dim=marginalize_dim, + ) + return max_points, obj_vals, infeas + + def prune_inferior_points( model: Model, X: Tensor, @@ -292,48 +362,16 @@ def prune_inferior_points( with `N_nz` the number of points in `X` that have non-zero (empirical, under `num_samples` samples) probability of being the best point. """ - if marginalize_dim is None and is_ensemble(model): - # TODO: Properly deal with marginalizing fully Bayesian models - marginalize_dim = MCMC_DIM - - if X.ndim > 2: - # TODO: support batched inputs (req. dealing with ragged tensors) - raise UnsupportedError( - "Batched inputs `X` are currently unsupported by prune_inferior_points" - ) - if X.size(-2) == 0: - raise ValueError("X must have at least one point.") - if max_frac <= 0 or max_frac > 1.0: - raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}") - max_points = math.ceil(max_frac * X.size(-2)) - with torch.no_grad(): - posterior = model.posterior(X=X, posterior_transform=posterior_transform) - if sampler is None: - sampler = get_sampler( - posterior=posterior, sample_shape=torch.Size([num_samples]) - ) - samples = sampler(posterior) - if objective is None: - objective = IdentityMCObjective() - obj_vals = objective(samples, X=X) - if obj_vals.ndim > 2: - if obj_vals.ndim == 3 and marginalize_dim is not None: - if marginalize_dim < 0: - # we do this again in compute_feasibility_indicator, but that will - # have no effect since marginalize_dim will be non-negative - marginalize_dim = ( - 1 + normalize_indices([marginalize_dim], d=obj_vals.ndim)[0] - ) - obj_vals = obj_vals.mean(dim=marginalize_dim) - else: - # TODO: support batched inputs (req. dealing with ragged tensors) - raise UnsupportedError( - "Models with multiple batch dims are currently unsupported by" - " prune_inferior_points." - ) - infeas = ~compute_feasibility_indicator( + max_points, obj_vals, infeas = _prune_inferior_shared_processing( + model=model, + X=X, + is_moo=False, + objective=objective, + posterior_transform=posterior_transform, constraints=constraints, - samples=samples, + num_samples=num_samples, + max_frac=max_frac, + sampler=sampler, marginalize_dim=marginalize_dim, ) if infeas.any(): diff --git a/botorch/generation/gen.py b/botorch/generation/gen.py index 6e6e047f05..0d70681fd5 100644 --- a/botorch/generation/gen.py +++ b/botorch/generation/gen.py @@ -22,10 +22,7 @@ from botorch.acquisition import AcquisitionFunction from botorch.exceptions.errors import OptimizationGradientError from botorch.exceptions.warnings import OptimizationWarning -from botorch.generation.utils import ( - _convert_nonlinear_inequality_constraints, - _remove_fixed_features_from_optimization, -) +from botorch.generation.utils import _remove_fixed_features_from_optimization from botorch.logging import logger from botorch.optim.parameter_constraints import ( _arrayify, @@ -136,16 +133,6 @@ def gen_candidates_scipy( else: reduced_domain = None not in fixed_features.values() - if nonlinear_inequality_constraints: - if not isinstance(nonlinear_inequality_constraints, list): - raise ValueError( - "`nonlinear_inequality_constraints` must be a list of tuples, " - f"got {type(nonlinear_inequality_constraints)}." - ) - nonlinear_inequality_constraints = _convert_nonlinear_inequality_constraints( - nonlinear_inequality_constraints - ) - if reduced_domain: _no_fixed_features = _remove_fixed_features_from_optimization( fixed_features=fixed_features, diff --git a/botorch/generation/utils.py b/botorch/generation/utils.py index 44bba19c24..a6cbaa67ef 100644 --- a/botorch/generation/utils.py +++ b/botorch/generation/utils.py @@ -6,12 +6,10 @@ from __future__ import annotations -import warnings from collections.abc import Callable from dataclasses import dataclass import torch - from botorch.acquisition import AcquisitionFunction, FixedFeatureAcquisitionFunction from botorch.optim.parameter_constraints import ( _generate_unfixed_lin_constraints, @@ -20,34 +18,6 @@ from torch import Tensor -def _convert_nonlinear_inequality_constraints( - nonlinear_inequality_constraints: list[Callable | tuple[Callable, bool]], -) -> list[tuple[Callable, bool]]: - """Convert legacy defintions of nonlinear inequality constraints into the new - format. Assumes intra-point constraints. - """ - nlcs = [] - legacy = False - # return nonlinear_inequality_constraints - for nlc in nonlinear_inequality_constraints: - if callable(nlc): - # old style --> convert - nlcs.append((nlc, True)) - legacy = True - else: - nlcs.append(nlc) - if legacy: - warnings.warn( - "The `nonlinear_inequality_constraints` argument is expected " - "take a list of tuples. Passing a list of callables " - "will result in an error in future versions.", - DeprecationWarning, - stacklevel=3, - ) - - return nlcs - - def _flip_sub_unique(x: Tensor, k: int) -> Tensor: """Get the first k unique elements of a single-dimensional tensor, traversing the tensor from the back. diff --git a/botorch/optim/optimize_mixed.py b/botorch/optim/optimize_mixed.py index ac153b9e23..105952b3b9 100644 --- a/botorch/optim/optimize_mixed.py +++ b/botorch/optim/optimize_mixed.py @@ -39,6 +39,10 @@ MAX_ITER_ALTER = 64 # Maximum number of alternating iterations. MAX_ITER_DISCRETE = 4 # Maximum number of discrete iterations. MAX_ITER_CONT = 8 # Maximum number of continuous iterations. +# Maximum number of discrete values for a discrete dimension. +# If there are more values for a dimension, we will use continuous +# relaxation to optimize it. +MAX_DISCRETE_VALUES = 20 # Maximum number of iterations for optimizing the continuous relaxation # during initialization MAX_ITER_INIT = 100 @@ -52,6 +56,7 @@ "maxiter_discrete", "maxiter_continuous", "maxiter_init", + "max_discrete_values", "num_spray_points", "std_cont_perturbation", "batch_limit", @@ -60,6 +65,40 @@ SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"} +def _setup_continuous_relaxation( + discrete_dims: list[int], + bounds: Tensor, + max_discrete_values: int, + post_processing_func: Callable[[Tensor], Tensor] | None, +) -> tuple[list[int], Callable[[Tensor], Tensor] | None]: + r"""Update `discrete_dims` and `post_processing_func` to use + continuous relaxation for discrete dimensions that have more than + `max_discrete_values` values. These dimensions are removed from + `discrete_dims` and `post_processing_func` is updated to round + them to the nearest integer. + """ + discrete_dims_t = torch.tensor(discrete_dims, dtype=torch.long) + num_discrete_values = ( + bounds[1, discrete_dims_t] - bounds[0, discrete_dims_t] + ).cpu() + dims_to_relax = discrete_dims_t[num_discrete_values > max_discrete_values] + if dims_to_relax.numel() == 0: + # No dimension needs continuous relaxation. + return discrete_dims, post_processing_func + # Remove relaxed dims from `discrete_dims`. + discrete_dims = list(set(discrete_dims).difference(dims_to_relax.tolist())) + + def new_post_processing_func(X: Tensor) -> Tensor: + r"""Round the relaxed dimensions to the nearest integer and apply the original + `post_processing_func`.""" + X[..., dims_to_relax] = X[..., dims_to_relax].round() + if post_processing_func is not None: + X = post_processing_func(X) + return X + + return discrete_dims, new_post_processing_func + + def _filter_infeasible( X: Tensor, inequality_constraints: list[tuple[Tensor, Tensor, float]] | None ) -> Tensor: @@ -532,6 +571,9 @@ def optimize_acqf_mixed_alternating( iterations. NOTE: This method assumes that all discrete variables are integer valued. + The discrete dimensions that have more than + `options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will + be optimized using continuous relaxation. # TODO: Support categorical variables. @@ -549,6 +591,9 @@ def optimize_acqf_mixed_alternating( Defaults to 4. - "maxiter_continuous": Maximum number of iterations in each continuous step. Defaults to 8. + - "max_discrete_values": Maximum number of values for a discrete dimension + to be optimized using discrete step / local search. The discrete dimensions + with more values will be optimized using continuous relaxation. - "num_spray_points": Number of spray points (around `X_baseline`) to add to the points generated by the initialization strategy. Defaults to 20 if all discrete variables are binary and to 0 otherwise. @@ -598,6 +643,17 @@ def optimize_acqf_mixed_alternating( f"Received an unsupported option {unsupported_keys}. {SUPPORTED_OPTIONS=}." ) + # Update discrete dims and post processing functions to account for any + # dimensions that should be using continuous relaxation. + discrete_dims, post_processing_func = _setup_continuous_relaxation( + discrete_dims=discrete_dims, + bounds=bounds, + max_discrete_values=assert_is_instance( + options.get("max_discrete_values", MAX_DISCRETE_VALUES), int + ), + post_processing_func=post_processing_func, + ) + opt_inputs = OptimizeAcqfInputs( acq_function=acq_function, bounds=bounds, @@ -623,7 +679,7 @@ def optimize_acqf_mixed_alternating( # Remove fixed features from dims, so they don't get optimized. discrete_dims = [dim for dim in discrete_dims if dim not in fixed_features] if len(discrete_dims) == 0: - raise ValueError("There must be at least one discrete parameter.") + return _optimize_acqf(opt_inputs=opt_inputs) if not ( isinstance(discrete_dims, list) and len(set(discrete_dims)) == len(discrete_dims) diff --git a/botorch/utils/containers.py b/botorch/utils/containers.py index f4e4c01e80..8cd2aabe76 100644 --- a/botorch/utils/containers.py +++ b/botorch/utils/containers.py @@ -8,10 +8,14 @@ from __future__ import annotations +import dataclasses + from abc import ABC, abstractmethod from dataclasses import dataclass, fields from typing import Any +import torch + from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor @@ -102,6 +106,9 @@ def _validate(self) -> None: f"`event shape` {self.event_shape}." ) + def clone(self) -> DenseContainer: + return dataclasses.replace(self) + @dataclass(eq=False) class SliceContainer(BotorchContainer): @@ -149,3 +156,10 @@ def _validate(self) -> None: f"Shapes of `values` {values.shape} and `indices` " f"{indices.shape} incompatible with `event_shape` {event_shape}." ) + + def clone(self) -> SliceContainer: + return type(self)( + values=self.values.clone(), + indices=self.indices.clone(), + event_shape=torch.Size(self.event_shape), + ) diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index f11f5c80e7..62d99d82fe 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -8,7 +8,8 @@ from __future__ import annotations -import warnings +import copy + from typing import Any import torch @@ -71,6 +72,7 @@ def __init__( self._Yvar = Yvar self.feature_names = feature_names self.outcome_names = outcome_names + self.validate_init = validate_init if validate_init: self._validate() @@ -148,37 +150,50 @@ def __eq__(self, other: Any) -> bool: and self.outcome_names == other.outcome_names ) + def clone( + self, deepcopy: bool = False, mask: Tensor | None = None + ) -> SupervisedDataset: + """Return a copy of the dataset. -class FixedNoiseDataset(SupervisedDataset): - r"""A SupervisedDataset with an additional field `Yvar` that stipulates - observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`. - - NOTE: This is deprecated. Use `SupervisedDataset` instead. - Will be removed in a future release (~v0.11). - """ + Args: + deepcopy: If True, perform a deep copy. Otherwise, use the same + tensors/lists. + mask: A `n`-dim boolean mask indicating which rows to keep. This is used + along the -2 dimension. - def __init__( - self, - X: BotorchContainer | Tensor, - Y: BotorchContainer | Tensor, - Yvar: BotorchContainer | Tensor, - feature_names: list[str], - outcome_names: list[str], - validate_init: bool = True, - ) -> None: - r"""Initialize a `FixedNoiseDataset` -- deprecated!""" - warnings.warn( - "`FixedNoiseDataset` is deprecated. Use `SupervisedDataset` instead.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__( - X=X, - Y=Y, + Returns: + The new dataset. + """ + new_X = self._X + new_Y = self._Y + new_Yvar = self._Yvar + feature_names = self.feature_names + outcome_names = self.outcome_names + if mask is not None: + if any(isinstance(x, BotorchContainer) for x in [new_X, new_Y, new_Yvar]): + raise NotImplementedError( + "Masking is not supported for BotorchContainers." + ) + new_X = new_X[..., mask, :] + new_Y = new_Y[..., mask, :] + if new_Yvar is not None: + new_Yvar = new_Yvar[..., mask, :] + if deepcopy: + new_X = new_X.clone() + new_Y = new_Y.clone() + new_Yvar = new_Yvar.clone() if new_Yvar is not None else None + feature_names = copy.copy(self.feature_names) + outcome_names = copy.copy(self.outcome_names) + kwargs = {} + if new_Yvar is not None: + kwargs = {"Yvar": new_Yvar} + return type(self)( + X=new_X, + Y=new_Y, feature_names=feature_names, outcome_names=outcome_names, - Yvar=Yvar, - validate_init=validate_init, + validate_init=self.validate_init, + **kwargs, ) @@ -373,7 +388,7 @@ def from_joint_dataset( outcome_names=[outcome_name], ) datasets.append(new_dataset) - # Return the new + # Return the new dataset return cls( datasets=datasets, target_outcome_name=outcome_names_per_task.get( @@ -500,6 +515,37 @@ def __eq__(self, other: Any) -> bool: and self.task_feature_index == other.task_feature_index ) + def clone( + self, deepcopy: bool = False, mask: Tensor | None = None + ) -> MultiTaskDataset: + """Return a copy of the dataset. + + Args: + deepcopy: If True, perform a deep copy. Otherwise, use the same + tensors/lists/datasets. + mask: A `n`-dim boolean mask indicating which rows to keep from the target + dataset. This is used along the -2 dimension. + + Returns: + The new dataset. + """ + datasets = list(self.datasets.values()) + if mask is not None or deepcopy: + new_datasets = [] + for outcome, ds in self.datasets.items(): + new_datasets.append( + ds.clone( + deepcopy=deepcopy, + mask=mask if outcome == self.target_outcome_name else None, + ) + ) + datasets = new_datasets + return MultiTaskDataset( + datasets=datasets, + target_outcome_name=self.target_outcome_name, + task_feature_index=self.task_feature_index, + ) + class ContextualDataset(SupervisedDataset): """This is a contextual dataset that is constructed from either a single @@ -661,3 +707,33 @@ def _validate_decompositions(self) -> None: raise InputDataError( f"{outcome} is missing in metric_decomposition." ) + + def clone( + self, deepcopy: bool = False, mask: Tensor | None = None + ) -> ContextualDataset: + """Return a copy of the dataset. + + Args: + deepcopy: If True, perform a deep copy. Otherwise, use the same + tensors/lists/datasets. + mask: A `n`-dim boolean mask indicating which rows to keep. This is used + along the -2 dimension. `n` here corresponds to the number of rows in + an individual dataset. + + Returns: + The new dataset. + """ + datasets = list(self.datasets.values()) + if mask is not None or deepcopy: + datasets = [ds.clone(deepcopy=deepcopy, mask=mask) for ds in datasets] + if deepcopy: + parameter_decomposition = copy.deepcopy(self.parameter_decomposition) + metric_decomposition = copy.deepcopy(self.metric_decomposition) + else: + parameter_decomposition = self.parameter_decomposition + metric_decomposition = self.metric_decomposition + return ContextualDataset( + datasets=datasets, + parameter_decomposition=parameter_decomposition, + metric_decomposition=metric_decomposition, + ) diff --git a/botorch/utils/testing.py b/botorch/utils/testing.py index 5f5b6a6df1..3c468f9e6b 100644 --- a/botorch/utils/testing.py +++ b/botorch/utils/testing.py @@ -17,7 +17,11 @@ import torch from botorch.acquisition.objective import PosteriorTransform -from botorch.exceptions.warnings import BotorchTensorDimensionWarning, InputDataWarning +from botorch.exceptions.warnings import ( + BotorchTensorDimensionWarning, + InputDataWarning, + NumericsWarning, +) from botorch.models.model import FantasizeMixin, Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.posterior import Posterior @@ -68,6 +72,16 @@ def setUp(self, suppress_input_warnings: bool = True) -> None: message=r"Data \(input features\) is not", category=InputDataWarning, ) + warnings.filterwarnings( + "ignore", + message="has known numerical issues", + category=NumericsWarning, + ) + warnings.filterwarnings( + "ignore", + message="Model converter code is deprecated", + category=DeprecationWarning, + ) def assertAllClose( self, diff --git a/test/generation/test_utils.py b/test/generation/test_utils.py index c5fd0e7c83..6c6176a2b9 100644 --- a/test/generation/test_utils.py +++ b/test/generation/test_utils.py @@ -12,7 +12,6 @@ from botorch.acquisition import FixedFeatureAcquisitionFunction from botorch.generation.utils import ( - _convert_nonlinear_inequality_constraints, _flip_sub_unique, _remove_fixed_features_from_optimization, ) @@ -20,27 +19,6 @@ class TestGenerationUtils(BotorchTestCase): - def test_convert_nonlinear_inequality_constraints(self): - def nlc(x): - return x[..., 2] - - def nlc2(x): - return x[..., 3] - - nlcs = [nlc] - with self.assertWarns(DeprecationWarning): - new_nlcs = _convert_nonlinear_inequality_constraints(nlcs) - self.assertEqual(new_nlcs, [(nlc, True)]) - - nlcs = [(nlc, False)] - new_nlcs = _convert_nonlinear_inequality_constraints(nlcs) - self.assertEqual(new_nlcs, [(nlc, False)]) - - nlcs = [(nlc, False), nlc2] - with self.assertWarns(DeprecationWarning): - new_nlcs = _convert_nonlinear_inequality_constraints(nlcs) - self.assertEqual(new_nlcs, [(nlc, False), (nlc2, True)]) - def test_flip_sub_unique(self): for dtype in (torch.float, torch.double): tkwargs = {"device": self.device, "dtype": dtype} diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index fad021c61b..331b86be55 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -953,20 +953,6 @@ def nlc4(x): ) self.assertEqual(candidates.size(), torch.Size([1, 3])) - # Constraints must be passed in as lists - with self.assertRaisesRegex( - ValueError, - "`nonlinear_inequality_constraints` must be a list of tuples, " - "got .", - ): - optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=1, - nonlinear_inequality_constraints=nlc1, - num_restarts=num_restarts, - batch_initial_conditions=batch_initial_conditions, - ) # batch_initial_conditions must be feasible with self.assertRaisesRegex( ValueError, diff --git a/test/optim/test_optimize_mixed.py b/test/optim/test_optimize_mixed.py index 1ab5fce7ea..f358f0a537 100644 --- a/test/optim/test_optimize_mixed.py +++ b/test/optim/test_optimize_mixed.py @@ -19,12 +19,14 @@ from botorch.models.gp_regression import SingleTaskGP from botorch.optim.optimize import _optimize_acqf, OptimizeAcqfInputs from botorch.optim.optimize_mixed import ( + _setup_continuous_relaxation, complement_indices, continuous_step, discrete_step, generate_starting_points, get_nearest_neighbors, get_spray_points, + MAX_DISCRETE_VALUES, optimize_acqf_mixed_alternating, sample_feasible_points, ) @@ -544,11 +546,10 @@ def test_optimize_acqf_mixed_binary_only(self) -> None: self.assertEqual(candidates.shape[-1], dim) c_binary = candidates[:, binary_dims + [2]] self.assertTrue(((c_binary == 0) | (c_binary == 1)).all()) - # Only continuous parameters will raise an error. - with self.assertRaisesRegex( - ValueError, - "There must be at least one discrete parameter", - ): + # Only continuous parameters should fallback to optimize_acqf. + with mock.patch( + f"{OPT_MODULE}._optimize_acqf", wraps=_optimize_acqf + ) as wrapped_optimize: optimize_acqf_mixed_alternating( acq_function=acqf, bounds=bounds, @@ -556,8 +557,18 @@ def test_optimize_acqf_mixed_binary_only(self) -> None: options=options, q=1, raw_samples=20, - num_restarts=20, + num_restarts=2, + ) + wrapped_optimize.assert_called_once_with( + opt_inputs=_make_opt_inputs( + acq_function=acqf, + bounds=bounds, + options=options, + q=1, + raw_samples=20, + num_restarts=2, ) + ) # Only discrete works fine. candidates, _ = optimize_acqf_mixed_alternating( acq_function=acqf, @@ -720,3 +731,71 @@ def test_optimize_acqf_mixed_integer(self) -> None: wrapped_sample_feasible.assert_called_once() # Should request 4 candidates, since all 4 are infeasible. self.assertEqual(wrapped_sample_feasible.call_args.kwargs["num_points"], 4) + + def test_optimize_acqf_mixed_continuous_relaxation(self) -> None: + # Testing with integer variables. + train_X, train_Y, binary_dims, cont_dims = self._get_data() + # Update the data to introduce integer dimensions. + binary_dims = [0] + integer_dims = [3, 4] + discrete_dims = binary_dims + integer_dims + bounds = torch.tensor( + [[0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 40.0, 15.0]], + dtype=torch.double, + device=self.device, + ) + # Update the model to have a different optimizer. + root = torch.tensor([0.0, 0.0, 0.0, 25.0, 10.0], device=self.device) + model = QuadraticDeterministicModel(root) + acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X) + + for max_discrete_values, post_processing_func in ( + (None, None), + (5, lambda X: X + 10), + ): + options = { + "batch_limit": 5, + "init_batch_limit": 20, + "maxiter_alternating": 1, + } + if max_discrete_values is not None: + options["max_discrete_values"] = max_discrete_values + with mock.patch( + f"{OPT_MODULE}._setup_continuous_relaxation", + wraps=_setup_continuous_relaxation, + ) as wrapped_setup, mock.patch( + f"{OPT_MODULE}.discrete_step", wraps=discrete_step + ) as wrapped_discrete: + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=discrete_dims, + q=3, + raw_samples=32, + num_restarts=4, + options=options, + post_processing_func=post_processing_func, + ) + wrapped_setup.assert_called_once_with( + discrete_dims=discrete_dims, + bounds=bounds, + max_discrete_values=max_discrete_values or MAX_DISCRETE_VALUES, + post_processing_func=post_processing_func, + ) + discrete_call_args = wrapped_discrete.call_args.kwargs + expected_dims = [0, 4] if max_discrete_values is None else [0] + self.assertAllClose( + discrete_call_args["discrete_dims"], + torch.tensor(expected_dims, device=self.device), + ) + # Check that dim 3 is rounded. + X = torch.ones(1, 5, device=self.device) * 0.6 + X_expected = X.clone() + X_expected[0, 3] = 1.0 + if max_discrete_values is not None: + X_expected[0, 4] = 1.0 + if post_processing_func is not None: + X_expected = post_processing_func(X_expected) + self.assertAllClose( + discrete_call_args["opt_inputs"].post_processing_func(X), X_expected + ) diff --git a/test/test_utils/test_mock.py b/test/test_utils/test_mock.py index 19dc68eee6..43867bbeea 100644 --- a/test/test_utils/test_mock.py +++ b/test/test_utils/test_mock.py @@ -98,7 +98,7 @@ def test_mock_optimize_mixed_alternating(self) -> None: ) as mock_neighbors: optimize_acqf_mixed_alternating( acq_function=SinAcqusitionFunction(), - bounds=torch.tensor([[-2.0, 0.0], [2.0, 200.0]]), + bounds=torch.tensor([[-2.0, 0.0], [2.0, 20.0]]), discrete_dims=[1], num_restarts=1, ) diff --git a/test/utils/test_containers.py b/test/utils/test_containers.py index 47ebb3d1e0..c57c7eb745 100644 --- a/test/utils/test_containers.py +++ b/test/utils/test_containers.py @@ -84,6 +84,9 @@ def test_dense(self): # Test `__call__` self.assertTrue(X().equal(values)) + # Test `clone` + self.assertEqual(X.clone(), X) + def test_slice(self): for arity in (2, 4): for vals in ( diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index 22d8c24a50..65f570ce4f 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -5,12 +5,13 @@ # LICENSE file in the root directory of this source tree. +from itertools import product + import torch from botorch.exceptions.errors import InputDataError, UnsupportedError from botorch.utils.containers import DenseContainer, SliceContainer from botorch.utils.datasets import ( ContextualDataset, - FixedNoiseDataset, MultiTaskDataset, RankingDataset, SupervisedDataset, @@ -40,6 +41,60 @@ def make_dataset( ) +def make_contextual_dataset( + has_yvar: bool = False, contextual_outcome: bool = False +) -> tuple[ContextualDataset, list[SupervisedDataset]]: + num_contexts = 3 + feature_names = [f"x_c{i}" for i in range(num_contexts)] + parameter_decomposition = { + "context_2": ["x_c2"], + "context_1": ["x_c1"], + "context_0": ["x_c0"], + } + context_buckets = list(parameter_decomposition.keys()) + if contextual_outcome: + context_outcome_list = [f"y:context_{i}" for i in range(num_contexts)] + metric_decomposition = {f"{c}": [f"y:{c}"] for c in context_buckets} + + dataset_list2 = [ + make_dataset( + d=1 * num_contexts, + has_yvar=has_yvar, + feature_names=feature_names, + outcome_names=[context_outcome_list[0]], + ) + ] + for mname in context_outcome_list[1:]: + dataset_list2.append( + SupervisedDataset( + X=dataset_list2[0].X, + Y=rand(dataset_list2[0].Y.size()), + Yvar=rand(dataset_list2[0].Yvar.size()) if has_yvar else None, + feature_names=feature_names, + outcome_names=[mname], + ) + ) + context_dt = ContextualDataset( + datasets=dataset_list2, + parameter_decomposition=parameter_decomposition, + metric_decomposition=metric_decomposition, + ) + return context_dt, dataset_list2 + dataset_list1 = [ + make_dataset( + d=num_contexts, + has_yvar=has_yvar, + feature_names=feature_names, + outcome_names=["y"], + ) + ] + context_dt = ContextualDataset( + datasets=dataset_list1, + parameter_decomposition=parameter_decomposition, + ) + return context_dt, dataset_list1 + + class TestDatasets(BotorchTestCase): def test_supervised(self): # Generate some data @@ -122,6 +177,70 @@ def test_supervised(self): self.assertNotEqual(dataset, dataset2) self.assertNotEqual(dataset2, dataset) + def test_clone(self, supervised: bool = True) -> None: + has_yvar_options = [False] + if supervised: + has_yvar_options.append(True) + for has_yvar in has_yvar_options: + if supervised: + dataset = make_dataset(has_yvar=has_yvar) + else: + X_val = rand(16, 2) + X_idx = stack([randperm(len(X_val))[:3] for _ in range(1)]) + X = SliceContainer( + X_val, X_idx, event_shape=Size([3 * X_val.shape[-1]]) + ) + dataset = RankingDataset( + X=X, + Y=tensor([[0, 1, 1]]), + feature_names=["x1", "x2"], + outcome_names=["ranking indices"], + ) + + for use_deepcopy in [False, True]: + dataset2 = dataset.clone(deepcopy=use_deepcopy) + self.assertEqual(dataset, dataset2) + self.assertTrue(torch.equal(dataset.X, dataset2.X)) + self.assertTrue(torch.equal(dataset.Y, dataset2.Y)) + if has_yvar: + self.assertTrue(torch.equal(dataset.Yvar, dataset2.Yvar)) + else: + self.assertIsNone(dataset2.Yvar) + self.assertEqual(dataset.feature_names, dataset2.feature_names) + self.assertEqual(dataset.outcome_names, dataset2.outcome_names) + if use_deepcopy: + self.assertIsNot(dataset.X, dataset2.X) + self.assertIsNot(dataset.Y, dataset2.Y) + if has_yvar: + self.assertIsNot(dataset.Yvar, dataset2.Yvar) + self.assertIsNot(dataset.feature_names, dataset2.feature_names) + self.assertIsNot(dataset.outcome_names, dataset2.outcome_names) + else: + self.assertIs(dataset._X, dataset2._X) + self.assertIs(dataset._Y, dataset2._Y) + self.assertIs(dataset._Yvar, dataset2._Yvar) + self.assertIs(dataset.feature_names, dataset2.feature_names) + self.assertIs(dataset.outcome_names, dataset2.outcome_names) + # test with mask + mask = torch.tensor([0, 1, 1], dtype=torch.bool) + if supervised: + dataset2 = dataset.clone(deepcopy=use_deepcopy, mask=mask) + self.assertTrue(torch.equal(dataset.X[1:], dataset2.X)) + self.assertTrue(torch.equal(dataset.Y[1:], dataset2.Y)) + if has_yvar: + self.assertTrue(torch.equal(dataset.Yvar[1:], dataset2.Yvar)) + else: + self.assertIsNone(dataset2.Yvar) + else: + with self.assertRaisesRegex( + NotImplementedError, + "Masking is not supported for BotorchContainers.", + ): + dataset.clone(deepcopy=use_deepcopy, mask=mask) + + def test_clone_ranking(self) -> None: + self.test_clone(supervised=False) + def test_fixedNoise(self): # Generate some data X = rand(3, 2) @@ -129,7 +248,7 @@ def test_fixedNoise(self): Yvar = rand(3, 1) feature_names = ["x1", "x2"] outcome_names = ["y"] - dataset = FixedNoiseDataset( + dataset = SupervisedDataset( X=X, Y=Y, Yvar=Yvar, @@ -142,17 +261,6 @@ def test_fixedNoise(self): self.assertEqual(dataset.feature_names, feature_names) self.assertEqual(dataset.outcome_names, outcome_names) - with self.assertRaisesRegex( - ValueError, "`Y` and `Yvar`" - ), self.assertWarnsRegex(DeprecationWarning, "SupervisedDataset"): - FixedNoiseDataset( - X=X, - Y=Y, - Yvar=Yvar.squeeze(), - feature_names=feature_names, - outcome_names=outcome_names, - ) - def test_ranking(self): # Test `_validate` X_val = rand(16, 2) @@ -365,6 +473,52 @@ def test_multi_task(self): MultiTaskDataset(datasets=[dataset_1, dataset_5], target_outcome_name="z"), ) + def test_clone_multitask(self) -> None: + for has_yvar in [False, True]: + dataset_1 = make_dataset(outcome_names=["y"], has_yvar=has_yvar) + dataset_2 = make_dataset(outcome_names=["z"], has_yvar=has_yvar) + mt_dataset = MultiTaskDataset( + datasets=[dataset_1, dataset_2], + target_outcome_name="z", + ) + for use_deepcopy in [False, True]: + mt_dataset2 = mt_dataset.clone(deepcopy=use_deepcopy) + self.assertEqual(mt_dataset, mt_dataset2) + self.assertTrue(torch.equal(mt_dataset.X, mt_dataset2.X)) + self.assertTrue(torch.equal(mt_dataset.Y, mt_dataset2.Y)) + if has_yvar: + self.assertTrue(torch.equal(mt_dataset.Yvar, mt_dataset2.Yvar)) + else: + self.assertIsNone(mt_dataset2.Yvar) + self.assertEqual(mt_dataset.feature_names, mt_dataset2.feature_names) + self.assertEqual(mt_dataset.outcome_names, mt_dataset2.outcome_names) + if use_deepcopy: + for ds, ds2 in zip( + mt_dataset.datasets.values(), mt_dataset2.datasets.values() + ): + self.assertIsNot(ds, ds2) + else: + for ds, ds2 in zip( + mt_dataset.datasets.values(), mt_dataset2.datasets.values() + ): + self.assertIs(ds, ds2) + # test with mask + mask = torch.tensor([0, 1, 1], dtype=torch.bool) + mt_dataset2 = mt_dataset.clone(deepcopy=use_deepcopy, mask=mask) + # mask should only apply to target dataset. + # All non-target datasets should be included. + full_mask = torch.tensor([1, 1, 1, 0, 1, 1], dtype=torch.bool) + self.assertTrue(torch.equal(mt_dataset.X[full_mask], mt_dataset2.X)) + self.assertTrue(torch.equal(mt_dataset.Y[full_mask], mt_dataset2.Y)) + if has_yvar: + self.assertTrue( + torch.equal(mt_dataset.Yvar[full_mask], mt_dataset2.Yvar) + ) + else: + self.assertIsNone(mt_dataset2.Yvar) + self.assertEqual(mt_dataset.feature_names, mt_dataset2.feature_names) + self.assertEqual(mt_dataset.outcome_names, mt_dataset2.outcome_names) + def test_contextual_datasets(self): num_contexts = 3 feature_names = [f"x_c{i}" for i in range(num_contexts)] @@ -378,17 +532,8 @@ def test_contextual_datasets(self): metric_decomposition = {f"{c}": [f"y:{c}"] for c in context_buckets} # test construction of agg outcome - dataset_list1 = [ - make_dataset( - d=1 * num_contexts, - has_yvar=True, - feature_names=feature_names, - outcome_names=["y"], - ) - ] - context_dt = ContextualDataset( - datasets=dataset_list1, - parameter_decomposition=parameter_decomposition, + context_dt, dataset_list1 = make_contextual_dataset( + has_yvar=True, contextual_outcome=False ) self.assertEqual(len(context_dt.datasets), len(dataset_list1)) self.assertListEqual(context_dt.context_buckets, context_buckets) @@ -400,28 +545,8 @@ def test_contextual_datasets(self): self.assertIs(context_dt.Yvar, dataset_list1[0].Yvar) # test construction of context outcome - dataset_list2 = [ - make_dataset( - d=1 * num_contexts, - has_yvar=True, - feature_names=feature_names, - outcome_names=[context_outcome_list[0]], - ) - ] - for m in context_outcome_list[1:]: - dataset_list2.append( - SupervisedDataset( - X=dataset_list2[0].X, - Y=rand(dataset_list2[0].Y.size()), - Yvar=rand(dataset_list2[0].Yvar.size()), - feature_names=feature_names, - outcome_names=[m], - ) - ) - context_dt = ContextualDataset( - datasets=dataset_list2, - parameter_decomposition=parameter_decomposition, - metric_decomposition=metric_decomposition, + context_dt, dataset_list2 = make_contextual_dataset( + has_yvar=True, contextual_outcome=True ) self.assertEqual(len(context_dt.datasets), len(dataset_list2)) # Ordering should match datasets, not parameter_decomposition @@ -438,30 +563,10 @@ def test_contextual_datasets(self): self.assertIs(context_dt.datasets[dt.outcome_names[0]], dt) # Test handling None Yvar - dataset_list3 = [ - make_dataset( - d=1 * num_contexts, - has_yvar=False, - feature_names=feature_names, - outcome_names=[context_outcome_list[0]], - ) - ] - for m in context_outcome_list[1:]: - dataset_list3.append( - SupervisedDataset( - X=dataset_list3[0].X, - Y=rand(dataset_list3[0].Y.size()), - Yvar=None, - feature_names=feature_names, - outcome_names=[m], - ) - ) - context_dt3 = ContextualDataset( - datasets=dataset_list3, - parameter_decomposition=parameter_decomposition, - metric_decomposition=metric_decomposition, + context_dt, dataset_list3 = make_contextual_dataset( + has_yvar=False, contextual_outcome=True ) - self.assertIsNone(context_dt3.Yvar) + self.assertIsNone(context_dt.Yvar) # test dataset validation wrong_metric_decomposition1 = { @@ -557,3 +662,54 @@ def test_contextual_datasets(self): parameter_decomposition=parameter_decomposition, metric_decomposition=wrong_metric_decomposition, ) + + def test_clone_contextual_dataset(self): + for has_yvar, contextual_outcome in product((False, True), (False, True)): + context_dt, _ = make_contextual_dataset( + has_yvar=has_yvar, contextual_outcome=contextual_outcome + ) + for use_deepcopy in [False, True]: + context_dt2 = context_dt.clone(deepcopy=use_deepcopy) + self.assertEqual(context_dt, context_dt2) + self.assertTrue(torch.equal(context_dt.X, context_dt2.X)) + self.assertTrue(torch.equal(context_dt.Y, context_dt2.Y)) + if has_yvar: + self.assertTrue(torch.equal(context_dt.Yvar, context_dt2.Yvar)) + else: + self.assertIsNone(context_dt.Yvar) + self.assertEqual(context_dt.feature_names, context_dt2.feature_names) + self.assertEqual(context_dt.outcome_names, context_dt2.outcome_names) + if use_deepcopy: + for ds, ds2 in zip( + context_dt.datasets.values(), context_dt2.datasets.values() + ): + self.assertIsNot(ds, ds2) + else: + for ds, ds2 in zip( + context_dt.datasets.values(), context_dt2.datasets.values() + ): + self.assertIs(ds, ds2) + # test with mask + mask = torch.tensor([0, 1, 1], dtype=torch.bool) + context_dt2 = context_dt.clone(deepcopy=use_deepcopy, mask=mask) + self.assertTrue(torch.equal(context_dt.X[mask], context_dt2.X)) + self.assertTrue(torch.equal(context_dt.Y[mask], context_dt2.Y)) + if has_yvar: + self.assertTrue( + torch.equal(context_dt.Yvar[mask], context_dt2.Yvar) + ) + else: + self.assertIsNone(context_dt2.Yvar) + self.assertEqual(context_dt.feature_names, context_dt2.feature_names) + self.assertEqual(context_dt.outcome_names, context_dt2.outcome_names) + self.assertEqual( + context_dt.parameter_decomposition, + context_dt2.parameter_decomposition, + ) + if contextual_outcome: + self.assertEqual( + context_dt.metric_decomposition, + context_dt2.metric_decomposition, + ) + else: + self.assertIsNone(context_dt2.metric_decomposition)