Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplicate shared logic in prune_inferior_points(_multi_objective) #2629

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 25 additions & 53 deletions botorch/acquisition/multi_objective/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,19 @@

from __future__ import annotations

import math
import warnings
from collections.abc import Callable
from math import ceil
from typing import Any

import torch
from botorch.acquisition import monte_carlo # noqa F401
from botorch.acquisition.multi_objective.objective import (
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
)
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
from botorch.acquisition.utils import _prune_inferior_shared_processing
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.fully_bayesian import MCMC_DIM
from botorch.models.model import Model
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.multi_objective.box_decompositions.box_decomposition import (
BoxDecomposition,
Expand All @@ -39,9 +34,8 @@
DominatedPartitioning,
)
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.objective import compute_feasibility_indicator
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import is_ensemble
from pyre_extensions import assert_is_instance
from torch import Tensor


Expand Down Expand Up @@ -115,40 +109,14 @@ def prune_inferior_points_multi_objective(
with `N_nz` the number of points in `X` that have non-zero (empirical,
under `num_samples` samples) probability of being pareto optimal.
"""
if marginalize_dim is None and is_ensemble(model):
# TODO: Properly deal with marginalizing fully Bayesian models
marginalize_dim = MCMC_DIM

if X.ndim > 2:
# TODO: support batched inputs (req. dealing with ragged tensors)
raise UnsupportedError(
"Batched inputs `X` are currently unsupported by "
"prune_inferior_points_multi_objective"
)
if X.size(-2) == 0:
raise ValueError("X must have at least one point.")
if max_frac <= 0 or max_frac > 1.0:
raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}")
max_points = math.ceil(max_frac * X.size(-2))
with torch.no_grad():
posterior = model.posterior(X=X)
sampler = get_sampler(posterior, sample_shape=torch.Size([num_samples]))
samples = sampler(posterior)
if objective is None:
objective = IdentityMCMultiOutputObjective()
obj_vals = objective(samples, X=X)
if obj_vals.ndim > 3:
if obj_vals.ndim == 4 and marginalize_dim is not None:
obj_vals = obj_vals.mean(dim=marginalize_dim)
else:
# TODO: support batched inputs (req. dealing with ragged tensors)
raise UnsupportedError(
"Models with multiple batch dims are currently unsupported by"
" prune_inferior_points_multi_objective."
)
infeas = ~compute_feasibility_indicator(
max_points, obj_vals, infeas = _prune_inferior_shared_processing(
model=model,
X=X,
is_moo=True,
objective=objective,
constraints=constraints,
samples=samples,
num_samples=num_samples,
max_frac=max_frac,
marginalize_dim=marginalize_dim,
)
if infeas.any():
Expand All @@ -168,9 +136,9 @@ def prune_inferior_points_multi_objective(

def compute_sample_box_decomposition(
pareto_fronts: Tensor,
partitioning: BoxDecomposition = DominatedPartitioning,
partitioning: type[BoxDecomposition] = DominatedPartitioning,
maximize: bool = True,
num_constraints: int | None = 0,
num_constraints: int = 0,
) -> Tensor:
r"""Computes the box decomposition associated with some sampled optimal
objectives. This also supports the single-objective and constrained optimization
Expand All @@ -195,7 +163,10 @@ def compute_sample_box_decomposition(
the hyper-rectangles. The number `J` is the smallest number of boxes needed
to partition all the Pareto samples.
"""
tkwargs = {"dtype": pareto_fronts.dtype, "device": pareto_fronts.device}
tkwargs: dict[str, Any] = {
"dtype": pareto_fronts.dtype,
"device": pareto_fronts.device,
}
# We will later compute `norm.log_prob(NEG_INF)`, this is `-inf` if `NEG_INF` is
# too small.
NEG_INF = -1e10
Expand All @@ -214,16 +185,18 @@ def compute_sample_box_decomposition(

if M == 1:
# Only consider a Pareto front with one element.
extreme_values = weight * torch.max(weight * pareto_fronts, dim=-2).values
extreme_values = assert_is_instance(
weight * torch.max(weight * pareto_fronts, dim=-2).values, Tensor
)
ref_point = weight * ref_point.expand(extreme_values.shape)

if maximize:
hypercell_bounds = torch.stack(
[ref_point, extreme_values], axis=-2
[ref_point, extreme_values], dim=-2
).unsqueeze(-1)
else:
hypercell_bounds = torch.stack(
[extreme_values, ref_point], axis=-2
[extreme_values, ref_point], dim=-2
).unsqueeze(-1)
else:
bd_list = []
Expand All @@ -244,17 +217,15 @@ def compute_sample_box_decomposition(
# Add an extra box for the inequality constraint.
if K > 0:
# `num_pareto_samples x 2 x (J - 1) x K`
feasible_boxes = torch.zeros(
hypercell_bounds.shape[:-1] + torch.Size([K]), **tkwargs
)
feasible_boxes = torch.zeros(hypercell_bounds.shape[:-1] + (K,), **tkwargs)

feasible_boxes[..., 0, :, :] = NEG_INF
# `num_pareto_samples x 2 x (J - 1) x (M + K)`
hypercell_bounds = torch.cat([hypercell_bounds, feasible_boxes], dim=-1)

# `num_pareto_samples x 2 x 1 x (M + K)`
infeasible_box = torch.zeros(
hypercell_bounds.shape[:-2] + torch.Size([1, M + K]), **tkwargs
hypercell_bounds.shape[:-2] + (1, M + K), **tkwargs
)
infeasible_box[..., 1, :, M:] = -NEG_INF
infeasible_box[..., 0, :, 0:M] = NEG_INF
Expand Down Expand Up @@ -292,11 +263,12 @@ def random_search_optimizer(
- A `num_points x M`-dim Tensor containing the collection of optimal
objectives.
"""
tkwargs = {"dtype": bounds.dtype, "device": bounds.device}
tkwargs: dict[str, Any] = {"dtype": bounds.dtype, "device": bounds.device}
weight = 1.0 if maximize else -1.0
optimal_inputs = torch.tensor([], **tkwargs)
optimal_outputs = torch.tensor([], **tkwargs)
num_tries = 0
num_found = 0
ratio = 2
while ratio > 1 and num_tries < max_tries:
X = draw_sobol_samples(bounds=bounds, n=pop_size, q=1).squeeze(-2)
Expand Down
122 changes: 80 additions & 42 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
from botorch.acquisition.objective import (
IdentityMCObjective,
MCAcquisitionObjective,
PosteriorTransform,
ScalarizedPosteriorTransform,
Expand All @@ -34,6 +33,7 @@
from botorch.utils.sampling import optimize_posterior_samples
from botorch.utils.transforms import is_ensemble, normalize_indices
from gpytorch.models import GP
from pyre_extensions import none_throws
from torch import Tensor


Expand Down Expand Up @@ -244,6 +244,76 @@ def objective(Y: Tensor, X: Tensor | None = None):
return -(lb.clamp_max(0.0))


def _prune_inferior_shared_processing(
model: Model,
X: Tensor,
is_moo: bool,
objective: MCAcquisitionObjective | None = None,
posterior_transform: PosteriorTransform | None = None,
constraints: list[Callable[[Tensor], Tensor]] | None = None,
num_samples: int = 2048,
max_frac: float = 1.0,
sampler: MCSampler | None = None,
marginalize_dim: int | None = None,
) -> tuple[int, Tensor, Tensor]:
r"""Shared data processing for `prune_inferior_points` and
`prune_inferior_points_multi_objective`.

Returns:
- max_points: The maximum number of points to keep.
- obj_vals: The objective values of the points in `X`.
- infeas: A boolean tensor indicating feasibility of `X`.
"""
func_name = (
"prune_inferior_points_multi_objective" if is_moo else "prune_inferior_points"
)
if marginalize_dim is None and is_ensemble(model):
marginalize_dim = MCMC_DIM

if X.ndim > 2:
raise UnsupportedError(
f"Batched inputs `X` are currently unsupported by `{func_name}`"
)
if X.size(-2) == 0:
raise ValueError("X must have at least one point.")
if max_frac <= 0 or max_frac > 1.0:
raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}")
max_points = math.ceil(max_frac * X.size(-2))
with torch.no_grad():
posterior = model.posterior(X=X, posterior_transform=posterior_transform)
if sampler is None:
sampler = get_sampler(
posterior=posterior, sample_shape=torch.Size([num_samples])
)
samples = sampler(posterior)
if objective is not None:
obj_vals = objective(samples=samples, X=X)
elif is_moo:
obj_vals = samples
else:
obj_vals = samples.squeeze(-1)
if obj_vals.ndim > (2 + is_moo):
if obj_vals.ndim == (3 + is_moo) and marginalize_dim is not None:
if marginalize_dim < 0:
# Update `marginalize_dim` to be positive while accounting for
# removal of output dimension in SOO.
marginalize_dim = (not is_moo) + none_throws(
normalize_indices([marginalize_dim], d=obj_vals.ndim)
)[0]
obj_vals = obj_vals.mean(dim=marginalize_dim)
else:
raise UnsupportedError(
"Models with multiple batch dims are currently unsupported by "
f"`{func_name}`."
)
infeas = ~compute_feasibility_indicator(
constraints=constraints,
samples=samples,
marginalize_dim=marginalize_dim,
)
return max_points, obj_vals, infeas


def prune_inferior_points(
model: Model,
X: Tensor,
Expand Down Expand Up @@ -292,48 +362,16 @@ def prune_inferior_points(
with `N_nz` the number of points in `X` that have non-zero (empirical,
under `num_samples` samples) probability of being the best point.
"""
if marginalize_dim is None and is_ensemble(model):
# TODO: Properly deal with marginalizing fully Bayesian models
marginalize_dim = MCMC_DIM

if X.ndim > 2:
# TODO: support batched inputs (req. dealing with ragged tensors)
raise UnsupportedError(
"Batched inputs `X` are currently unsupported by prune_inferior_points"
)
if X.size(-2) == 0:
raise ValueError("X must have at least one point.")
if max_frac <= 0 or max_frac > 1.0:
raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}")
max_points = math.ceil(max_frac * X.size(-2))
with torch.no_grad():
posterior = model.posterior(X=X, posterior_transform=posterior_transform)
if sampler is None:
sampler = get_sampler(
posterior=posterior, sample_shape=torch.Size([num_samples])
)
samples = sampler(posterior)
if objective is None:
objective = IdentityMCObjective()
obj_vals = objective(samples, X=X)
if obj_vals.ndim > 2:
if obj_vals.ndim == 3 and marginalize_dim is not None:
if marginalize_dim < 0:
# we do this again in compute_feasibility_indicator, but that will
# have no effect since marginalize_dim will be non-negative
marginalize_dim = (
1 + normalize_indices([marginalize_dim], d=obj_vals.ndim)[0]
)
obj_vals = obj_vals.mean(dim=marginalize_dim)
else:
# TODO: support batched inputs (req. dealing with ragged tensors)
raise UnsupportedError(
"Models with multiple batch dims are currently unsupported by"
" prune_inferior_points."
)
infeas = ~compute_feasibility_indicator(
max_points, obj_vals, infeas = _prune_inferior_shared_processing(
model=model,
X=X,
is_moo=False,
objective=objective,
posterior_transform=posterior_transform,
constraints=constraints,
samples=samples,
num_samples=num_samples,
max_frac=max_frac,
sampler=sampler,
marginalize_dim=marginalize_dim,
)
if infeas.any():
Expand Down
15 changes: 1 addition & 14 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from botorch.acquisition import AcquisitionFunction
from botorch.exceptions.errors import OptimizationGradientError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.utils import (
_convert_nonlinear_inequality_constraints,
_remove_fixed_features_from_optimization,
)
from botorch.generation.utils import _remove_fixed_features_from_optimization
from botorch.logging import logger
from botorch.optim.parameter_constraints import (
_arrayify,
Expand Down Expand Up @@ -136,16 +133,6 @@ def gen_candidates_scipy(
else:
reduced_domain = None not in fixed_features.values()

if nonlinear_inequality_constraints:
if not isinstance(nonlinear_inequality_constraints, list):
raise ValueError(
"`nonlinear_inequality_constraints` must be a list of tuples, "
f"got {type(nonlinear_inequality_constraints)}."
)
nonlinear_inequality_constraints = _convert_nonlinear_inequality_constraints(
nonlinear_inequality_constraints
)

if reduced_domain:
_no_fixed_features = _remove_fixed_features_from_optimization(
fixed_features=fixed_features,
Expand Down
30 changes: 0 additions & 30 deletions botorch/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

from __future__ import annotations

import warnings
from collections.abc import Callable
from dataclasses import dataclass

import torch

from botorch.acquisition import AcquisitionFunction, FixedFeatureAcquisitionFunction
from botorch.optim.parameter_constraints import (
_generate_unfixed_lin_constraints,
Expand All @@ -20,34 +18,6 @@
from torch import Tensor


def _convert_nonlinear_inequality_constraints(
nonlinear_inequality_constraints: list[Callable | tuple[Callable, bool]],
) -> list[tuple[Callable, bool]]:
"""Convert legacy defintions of nonlinear inequality constraints into the new
format. Assumes intra-point constraints.
"""
nlcs = []
legacy = False
# return nonlinear_inequality_constraints
for nlc in nonlinear_inequality_constraints:
if callable(nlc):
# old style --> convert
nlcs.append((nlc, True))
legacy = True
else:
nlcs.append(nlc)
if legacy:
warnings.warn(
"The `nonlinear_inequality_constraints` argument is expected "
"take a list of tuples. Passing a list of callables "
"will result in an error in future versions.",
DeprecationWarning,
stacklevel=3,
)

return nlcs


def _flip_sub_unique(x: Tensor, k: int) -> Tensor:
"""Get the first k unique elements of a single-dimensional tensor, traversing the
tensor from the back.
Expand Down
Loading