Skip to content

Commit

Permalink
Add isinstance_af
Browse files Browse the repository at this point in the history
Summary: Creates a new helper method for checking both if a given AF is an instance of a class or if the given AF wraps a base AF that is an instance of a class

Differential Revision: D43127722

fbshipit-source-id: 0ec0131cf1c7512c10ab14e5d0a0a20cf3025688
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 8, 2023
1 parent c5ad87b commit 3a396b2
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 3 deletions.
17 changes: 15 additions & 2 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations

import math
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401
Expand All @@ -22,6 +22,7 @@
MCAcquisitionObjective,
PosteriorTransform,
)
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.models.fully_bayesian import MCMC_DIM
from botorch.models.model import Model
Expand Down Expand Up @@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None):
return -(lb.clamp_max(0.0))


def isinstance_af(
__obj: object,
__class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]],
) -> bool:
r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions."""
if isinstance(__obj, AbstractAcquisitionFunctionWrapper):
isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple)
else:
isinstance_base_af = False
return isinstance_base_af or isinstance(__obj, __class_or_tuple)


def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
r"""Determine whether a given acquisition function is non-negative.
Expand All @@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
>>> qEI = qExpectedImprovement(model, best_f=0.1)
>>> is_nonnegative(qEI) # returns True
"""
return isinstance(
return isinstance_af(
acq_function,
(
analytic.ExpectedImprovement,
Expand Down
61 changes: 60 additions & 1 deletion test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from unittest import mock

import torch
from botorch.acquisition import monte_carlo
from botorch.acquisition import analytic, monte_carlo, multi_objective
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.multi_objective import (
MCMultiOutputObjective,
monte_carlo as moo_monte_carlo,
Expand All @@ -18,10 +19,13 @@
MCAcquisitionObjective,
ScalarizedPosteriorTransform,
)
from botorch.acquisition.proximal import ProximalAcquisitionFunction
from botorch.acquisition.utils import (
expand_trace_observations,
get_acquisition_function,
get_infeasible_cost,
is_nonnegative,
isinstance_af,
project_to_sample_points,
project_to_target_fidelity,
prune_inferior_points,
Expand Down Expand Up @@ -606,6 +610,61 @@ def test_get_infeasible_cost(self):
self.assertAllClose(M4, torch.tensor([1.0], **tkwargs))


class TestIsNonnegative(BotorchTestCase):
def test_is_nonnegative(self):
nonneg_afs = (
analytic.ExpectedImprovement,
analytic.ConstrainedExpectedImprovement,
analytic.ProbabilityOfImprovement,
analytic.NoisyExpectedImprovement,
monte_carlo.qExpectedImprovement,
monte_carlo.qNoisyExpectedImprovement,
monte_carlo.qProbabilityOfImprovement,
multi_objective.analytic.ExpectedHypervolumeImprovement,
multi_objective.monte_carlo.qExpectedHypervolumeImprovement,
multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement,
)
mm = MockModel(
MockPosterior(
mean=torch.rand(1, 1, device=self.device),
variance=torch.ones(1, 1, device=self.device),
)
)
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
with mock.patch(
"botorch.acquisition.utils.isinstance_af", return_value=True
) as mock_isinstance_af:
self.assertTrue(is_nonnegative(acq_function=acq_func))
mock_isinstance_af.assert_called_once()
cargs, _ = mock_isinstance_af.call_args
self.assertIs(cargs[0], acq_func)
self.assertEqual(cargs[1], nonneg_afs)
acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0)
self.assertFalse(is_nonnegative(acq_function=acq_func))


class TestIsinstanceAf(BotorchTestCase):
def test_isinstance_af(self):
mm = MockModel(
MockPosterior(
mean=torch.rand(1, 1, device=self.device),
variance=torch.ones(1, 1, device=self.device),
)
)
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement))
self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound))
wrapped_af = FixedFeatureAcquisitionFunction(
acq_function=acq_func, d=2, columns=[1], values=[0.0]
)
# test base af class
self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement))
self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound))
# test wrapper class
self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction))
self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction))


class TestPruneInferiorPoints(BotorchTestCase):
def test_prune_inferior_points(self):
for dtype in (torch.float, torch.double):
Expand Down

0 comments on commit 3a396b2

Please sign in to comment.