diff --git a/tests/integration/test_ask_tell_optimization.py b/tests/integration/test_ask_tell_optimization.py index e471bd695d..3be8bc28d7 100644 --- a/tests/integration/test_ask_tell_optimization.py +++ b/tests/integration/test_ask_tell_optimization.py @@ -16,29 +16,31 @@ import copy import pickle import tempfile -from typing import Callable +from typing import Callable, Tuple, Union import numpy.testing as npt import pytest import tensorflow as tf from tests.util.misc import random_seed -from trieste.acquisition import LocalPenalization +from trieste.acquisition import LocalPenalization, ParallelContinuousThompsonSampling from trieste.acquisition.rule import ( AcquisitionRule, AsynchronousGreedy, AsynchronousRuleState, BatchTrustRegionBox, EfficientGlobalOptimization, + SingleObjectiveTrustRegionBox, TREGOBox, ) +from trieste.acquisition.utils import copy_to_local_models from trieste.ask_tell_optimization import AskTellOptimizer from trieste.bayesian_optimizer import OptimizationResult, Record from trieste.logging import set_step_number, tensorboard_writer from trieste.models import TrainableProbabilisticModel from trieste.models.gpflow import GaussianProcessRegression, build_gpr from trieste.objectives import ScaledBranin, SimpleQuadratic -from trieste.objectives.utils import mk_observer +from trieste.objectives.utils import mk_batch_observer, mk_observer from trieste.observer import OBJECTIVE from trieste.space import Box, SearchSpace from trieste.types import State, TensorType @@ -59,7 +61,10 @@ id="EfficientGlobalOptimization/reload_state", ), pytest.param( - 15, False, lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)), id="TREGO" + 15, + False, + lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)), + id="TREGO", ), pytest.param( 16, @@ -67,6 +72,33 @@ lambda: BatchTrustRegionBox(TREGOBox(ScaledBranin.search_space)), id="TREGO/reload_state", ), + pytest.param( + 10, + False, + lambda: BatchTrustRegionBox( + [SingleObjectiveTrustRegionBox(ScaledBranin.search_space) for _ in range(3)], + EfficientGlobalOptimization( + ParallelContinuousThompsonSampling(), + num_query_points=3, + ), + ), + id="BatchTrustRegionBox", + ), + pytest.param( + 10, + False, + ( + lambda: BatchTrustRegionBox( + [SingleObjectiveTrustRegionBox(ScaledBranin.search_space) for _ in range(3)], + EfficientGlobalOptimization( + ParallelContinuousThompsonSampling(), + num_query_points=2, + ), + ), + 3, + ), + id="BatchTrustRegionBox/LocalModels", + ), pytest.param( 10, False, @@ -92,23 +124,26 @@ ) -@random_seed -@pytest.mark.slow # to run this, add --runslow yes to the pytest command -@pytest.mark.parametrize(*OPTIMIZER_PARAMS) -def test_ask_tell_optimizer_finds_minima_of_the_scaled_branin_function( - num_steps: int, - reload_state: bool, - acquisition_rule_fn: Callable[ - [], AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModel] - ] - | Callable[ +AcquisitionRuleFunction = Union[ + Callable[[], AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModel]], + Callable[ [], AcquisitionRule[ - State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State], + State[TensorType, Union[AsynchronousRuleState, BatchTrustRegionBox.State]], Box, TrainableProbabilisticModel, ], ], +] + + +@random_seed +@pytest.mark.slow # to run this, add --runslow yes to the pytest command +@pytest.mark.parametrize(*OPTIMIZER_PARAMS) +def test_ask_tell_optimizer_finds_minima_of_the_scaled_branin_function( + num_steps: int, + reload_state: bool, + acquisition_rule_fn: AcquisitionRuleFunction | Tuple[AcquisitionRuleFunction, int], ) -> None: _test_ask_tell_optimization_finds_minima(True, num_steps, reload_state, acquisition_rule_fn) @@ -118,17 +153,7 @@ def test_ask_tell_optimizer_finds_minima_of_the_scaled_branin_function( def test_ask_tell_optimizer_finds_minima_of_simple_quadratic( num_steps: int, reload_state: bool, - acquisition_rule_fn: Callable[ - [], AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModel] - ] - | Callable[ - [], - AcquisitionRule[ - State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State], - Box, - TrainableProbabilisticModel, - ], - ], + acquisition_rule_fn: AcquisitionRuleFunction | Tuple[AcquisitionRuleFunction, int], ) -> None: # for speed reasons we sometimes test with a simple quadratic defined on the same search space # branin; currently assume that every rule should be able to solve this in 5 steps @@ -141,17 +166,7 @@ def _test_ask_tell_optimization_finds_minima( optimize_branin: bool, num_steps: int, reload_state: bool, - acquisition_rule_fn: Callable[ - [], AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModel] - ] - | Callable[ - [], - AcquisitionRule[ - State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State], - Box, - TrainableProbabilisticModel, - ], - ], + acquisition_rule_fn: AcquisitionRuleFunction | Tuple[AcquisitionRuleFunction, int], ) -> None: # For the case when optimization state is saved and reload on each iteration # we need to use new acquisition function object to imitate real life usage @@ -160,17 +175,27 @@ def _test_ask_tell_optimization_finds_minima( search_space = ScaledBranin.search_space initial_query_points = search_space.sample(5) observer = mk_observer(ScaledBranin.objective if optimize_branin else SimpleQuadratic.objective) + batch_observer = mk_batch_observer(observer) initial_data = observer(initial_query_points) + if isinstance(acquisition_rule_fn, tuple): + acquisition_rule_fn, num_models = acquisition_rule_fn + else: + num_models = 1 + model = GaussianProcessRegression( build_gpr(initial_data, search_space, likelihood_variance=1e-7) ) + models = copy_to_local_models(model, num_models) if num_models > 1 else {OBJECTIVE: model} + initial_dataset = {OBJECTIVE: initial_data} with tempfile.TemporaryDirectory() as tmpdirname: summary_writer = tf.summary.create_file_writer(tmpdirname) with tensorboard_writer(summary_writer): set_step_number(0) - ask_tell = AskTellOptimizer(search_space, initial_data, model, acquisition_rule_fn()) + ask_tell = AskTellOptimizer( + search_space, initial_dataset, models, acquisition_rule_fn() + ) for i in range(1, num_steps + 1): # two scenarios are tested here, depending on `reload_state` parameter @@ -185,7 +210,11 @@ def _test_ask_tell_optimization_finds_minima( ] = ask_tell.to_record() written_state = pickle.dumps(state) - new_data_point = observer(new_point) + # If query points are rank 3, then use a batched observer. + if tf.rank(new_point) == 3: + new_data_point = batch_observer(new_point) + else: + new_data_point = observer(new_point) if reload_state: state = pickle.loads(written_state) diff --git a/tests/integration/test_bayesian_optimization.py b/tests/integration/test_bayesian_optimization.py index 0b6adac15a..6cc0cc27fa 100644 --- a/tests/integration/test_bayesian_optimization.py +++ b/tests/integration/test_bayesian_optimization.py @@ -1,4 +1,4 @@ -# Copyright 2021 The Trieste Contributors +# Copyright 2021 The Trieste Contrib_fnutors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ import tempfile from functools import partial from pathlib import Path -from typing import Any, List, Mapping, Optional, Tuple, Type, cast +from typing import Any, List, Mapping, Optional, Tuple, Type, Union, cast import dill import gpflow @@ -58,6 +58,7 @@ TREGOBox, ) from trieste.acquisition.sampler import ThompsonSamplerFromTrajectory +from trieste.acquisition.utils import copy_to_local_models from trieste.bayesian_optimizer import ( BayesianOptimizer, FrozenRecord, @@ -227,6 +228,23 @@ def GPR_OPTIMIZER_PARAMS() -> Tuple[str, List[ParameterSet]]: ), id="BatchTrustRegionBox", ), + pytest.param( + 10, + ( + BatchTrustRegionBox( + [ + SingleObjectiveTrustRegionBox(ScaledBranin.search_space) + for _ in range(3) + ], + EfficientGlobalOptimization( + ParallelContinuousThompsonSampling(), + num_query_points=2, + ), + ), + 3, + ), + id="BatchTrustRegionBox/LocalModels", + ), pytest.param(15, DiscreteThompsonSampling(500, 5), id="DiscreteThompsonSampling"), pytest.param( 15, @@ -262,20 +280,29 @@ def GPR_OPTIMIZER_PARAMS() -> Tuple[str, List[ParameterSet]]: ) +AcquisitionRuleType = Union[ + AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModelType], + AcquisitionRule[ + State[TensorType, Union[AsynchronousRuleState, BatchTrustRegion.State]], + Box, + TrainableProbabilisticModelType, + ], +] + + @random_seed @pytest.mark.slow # to run this, add --runslow yes to the pytest command @pytest.mark.parametrize(*GPR_OPTIMIZER_PARAMS()) def test_bayesian_optimizer_with_gpr_finds_minima_of_scaled_branin( num_steps: int, - acquisition_rule: AcquisitionRule[TensorType, SearchSpace, GaussianProcessRegression] - | AcquisitionRule[ - State[TensorType, AsynchronousRuleState | BatchTrustRegion.State], - Box, - GaussianProcessRegression, - ], + acquisition_rule: AcquisitionRuleType[GaussianProcessRegression] + | Tuple[AcquisitionRuleType[GaussianProcessRegression], int], ) -> None: _test_optimizer_finds_minimum( - GaussianProcessRegression, num_steps, acquisition_rule, optimize_branin=True + GaussianProcessRegression, + num_steps, + acquisition_rule, + optimize_branin=True, ) @@ -283,12 +310,8 @@ def test_bayesian_optimizer_with_gpr_finds_minima_of_scaled_branin( @pytest.mark.parametrize(*GPR_OPTIMIZER_PARAMS()) def test_bayesian_optimizer_with_gpr_finds_minima_of_simple_quadratic( num_steps: int, - acquisition_rule: AcquisitionRule[TensorType, SearchSpace, GaussianProcessRegression] - | AcquisitionRule[ - State[TensorType, AsynchronousRuleState | BatchTrustRegion.State], - Box, - GaussianProcessRegression, - ], + acquisition_rule: AcquisitionRuleType[GaussianProcessRegression] + | Tuple[AcquisitionRuleType[GaussianProcessRegression], int], ) -> None: # for speed reasons we sometimes test with a simple quadratic defined on the same search space # branin; currently assume that every rule should be able to solve this in 6 steps @@ -556,12 +579,8 @@ def test_bayesian_optimizer_with_PCTS_and_deep_ensemble_finds_minima_of_simple_q def _test_optimizer_finds_minimum( model_type: Type[TrainableProbabilisticModelType], num_steps: Optional[int], - acquisition_rule: AcquisitionRule[TensorType, SearchSpace, TrainableProbabilisticModelType] - | AcquisitionRule[ - State[TensorType, AsynchronousRuleState | BatchTrustRegion.State], - Box, - TrainableProbabilisticModelType, - ], + acquisition_rule: AcquisitionRuleType[TrainableProbabilisticModelType] + | Tuple[AcquisitionRuleType[TrainableProbabilisticModelType], int], optimize_branin: bool = False, model_args: Optional[Mapping[str, Any]] = None, check_regret: bool = False, @@ -590,6 +609,11 @@ def _test_optimizer_finds_minimum( observer = mk_observer(ScaledBranin.objective if optimize_branin else SimpleQuadratic.objective) initial_data = observer(initial_query_points) + if isinstance(acquisition_rule, tuple): + acquisition_rule, num_models = acquisition_rule + else: + num_models = 1 + model: TrainableProbabilisticModel # (really TPMType, but that's too complicated for mypy) if model_type is GaussianProcessRegression: @@ -647,13 +671,17 @@ def _test_optimizer_finds_minimum( else: raise ValueError(f"Unsupported model_type '{model_type}'") + model = cast(TrainableProbabilisticModelType, model) + models = copy_to_local_models(model, num_models) if num_models > 1 else {OBJECTIVE: model} + dataset = {OBJECTIVE: initial_data} + with tempfile.TemporaryDirectory() as tmpdirname: summary_writer = tf.summary.create_file_writer(tmpdirname) with tensorboard_writer(summary_writer): result = BayesianOptimizer(observer, search_space).optimize( num_steps or 2, - initial_data, - cast(TrainableProbabilisticModelType, model), + dataset, + models, acquisition_rule, track_state=True, track_path=Path(tmpdirname) / "history", diff --git a/tests/unit/acquisition/test_rule.py b/tests/unit/acquisition/test_rule.py index 3bee1fd575..b184d31306 100644 --- a/tests/unit/acquisition/test_rule.py +++ b/tests/unit/acquisition/test_rule.py @@ -22,8 +22,9 @@ import numpy.testing as npt import pytest import tensorflow as tf +import tensorflow_probability as tfp -from tests.util.misc import empty_dataset, quadratic, random_seed +from tests.util.misc import empty_dataset, mk_dataset, quadratic, random_seed from tests.util.models.gpflow.models import ( GaussianProcess, QuadraticMeanAndRBFKernel, @@ -32,6 +33,7 @@ from trieste.acquisition import ( AcquisitionFunction, AcquisitionFunctionBuilder, + MultipleOptimismNegativeLowerConfidenceBound, NegativeLowerConfidenceBound, ParallelContinuousThompsonSampling, SingleModelAcquisitionBuilder, @@ -59,12 +61,15 @@ ThompsonSampler, ThompsonSamplerFromTrajectory, ) +from trieste.acquisition.utils import copy_to_local_models from trieste.data import Dataset from trieste.models import ProbabilisticModel from trieste.models.interfaces import TrainableSupportsGetKernel +from trieste.objectives.utils import mk_batch_observer from trieste.observer import OBJECTIVE from trieste.space import Box, SearchSpace, TaggedMultiSearchSpace from trieste.types import State, Tag, TensorType +from trieste.utils.misc import LocalizedTag, get_value_for_tag def _line_search_maximize( @@ -540,16 +545,23 @@ def test_async_keeps_track_of_pending_points( npt.assert_allclose(state.pending_points, tf.concat([point2, point3], axis=0)) -@pytest.mark.parametrize("datasets", [{}, {"foo": empty_dataset([1], [1])}]) +@pytest.mark.parametrize( + "datasets", + [ + {}, + {"foo": empty_dataset([1], [1])}, + {OBJECTIVE: empty_dataset([1], [1]), "foo": empty_dataset([1], [1])}, + ], +) @pytest.mark.parametrize( "models", [{}, {"foo": QuadraticMeanAndRBFKernel()}, {OBJECTIVE: QuadraticMeanAndRBFKernel()}] ) def test_trego_raises_for_missing_datasets_key( - datasets: dict[Tag, Dataset], models: dict[Tag, ProbabilisticModel] + datasets: Mapping[Tag, Dataset], models: dict[Tag, ProbabilisticModel] ) -> None: search_space = Box([-1], [1]) rule = BatchTrustRegionBox(TREGOBox(search_space)) # type: ignore[var-annotated] - with pytest.raises(ValueError, match="tag 'OBJECTIVE' not found"): + with pytest.raises(ValueError, match="a single OBJECTIVE dataset must be provided"): rule.acquire(search_space, models, datasets=datasets)(None) @@ -589,7 +601,7 @@ def test_trego_for_default_state( assert state is not None subspace = state.acquisition_space.get_subspace("0") assert isinstance(subspace, TREGOBox) - npt.assert_array_almost_equal(query_point, expected_query_point, 5) + npt.assert_array_almost_equal(query_point, [expected_query_point], 5) npt.assert_array_almost_equal(subspace.lower, lower_bound) npt.assert_array_almost_equal(subspace.upper, upper_bound) npt.assert_array_almost_equal(subspace._y_min, [0.012]) @@ -650,7 +662,7 @@ def test_trego_successful_global_to_global_trust_region_unchanged( assert isinstance(current_subspace, TREGOBox) npt.assert_array_almost_equal(current_subspace._eps, eps) assert current_subspace._is_global - npt.assert_array_almost_equal(query_point, expected_query_point, 5) + npt.assert_array_almost_equal(query_point, [expected_query_point], 5) npt.assert_array_almost_equal(current_subspace.lower, lower_bound) npt.assert_array_almost_equal(current_subspace.upper, upper_bound) @@ -692,7 +704,7 @@ def test_trego_for_unsuccessful_global_to_local_trust_region_unchanged( assert not current_subspace._is_global npt.assert_array_less(lower_bound, current_subspace.lower) npt.assert_array_less(current_subspace.upper, upper_bound) - assert query_point[0] in current_state.acquisition_space + assert query_point[0][0] in current_state.acquisition_space @pytest.mark.parametrize( @@ -773,6 +785,27 @@ def test_trego_for_unsuccessful_local_to_global_trust_region_reduced( npt.assert_array_almost_equal(current_subspace.upper, upper_bound) +def test_trego_always_uses_global_dataset() -> None: + search_space = Box([0.0, 0.0], [1.0, 1.0]) + dataset = Dataset( + tf.constant([[0.1, 0.2], [-0.1, -0.2], [1.1, 2.3]]), tf.constant([[0.4], [0.5], [0.6]]) + ) + tr = BatchTrustRegionBox(TREGOBox(search_space)) # type: ignore[var-annotated] + new_data = Dataset( + tf.constant([[0.5, -0.2], [0.7, 0.2], [1.1, 0.3], [0.5, 0.5]]), + tf.constant([[0.7], [0.8], [0.9], [1.0]]), + ) + updated_datasets = tr.filter_datasets({LocalizedTag(OBJECTIVE, 0): dataset + new_data}) + + # Both the local and global datasets should match. + assert updated_datasets.keys() == {OBJECTIVE, LocalizedTag(OBJECTIVE, 0)} + # Updated dataset should contain all the points, including ones outside the search space. + exp_dataset = dataset + new_data + for key in updated_datasets.keys(): + npt.assert_array_equal(exp_dataset.query_points, updated_datasets[key].query_points) + npt.assert_array_equal(exp_dataset.observations, updated_datasets[key].observations) + + def test_trego_state_deepcopy() -> None: dataset = Dataset(tf.constant([[0.1, 0.2], [-0.1, -0.2]]), tf.constant([[0.4], [0.5]])) search_space = Box(tf.constant([1.2]), tf.constant([3.4])) @@ -1169,16 +1202,25 @@ def test_turbo_state_deepcopy() -> None: npt.assert_allclose(tr_state_copy.y_min, tr_state.y_min) -# get_local_min raises if dataset is None. -def test_trust_region_box_get_local_min_raises_if_dataset_is_none() -> None: +@pytest.mark.parametrize( + "datasets", + [ + {}, + {"foo": empty_dataset([1], [1])}, + {OBJECTIVE: empty_dataset([1], [1]), "foo": empty_dataset([1], [1])}, + ], +) +def test_trust_region_box_get_dataset_min_raises_if_dataset_is_faulty( + datasets: Mapping[Tag, Dataset] +) -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) trb = SingleObjectiveTrustRegionBox(search_space) - with pytest.raises(ValueError, match="dataset must be provided"): - trb.get_local_min(None) + with pytest.raises(ValueError, match="a single OBJECTIVE dataset must be provided"): + trb.get_dataset_min(datasets) -# get_local_min picks the minimum x and y values from the dataset. -def test_trust_region_box_get_local_min() -> None: +# get_dataset_min picks the minimum x and y values from the dataset. +def test_trust_region_box_get_dataset_min() -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) dataset = Dataset( tf.constant([[0.1, 0.1], [0.5, 0.5], [0.3, 0.4], [0.8, 0.8], [0.4, 0.4]], dtype=tf.float64), @@ -1187,21 +1229,21 @@ def test_trust_region_box_get_local_min() -> None: trb = SingleObjectiveTrustRegionBox(search_space) trb._lower = tf.constant([0.2, 0.2], dtype=tf.float64) trb._upper = tf.constant([0.7, 0.7], dtype=tf.float64) - x_min, y_min = trb.get_local_min(dataset) + x_min, y_min = trb.get_dataset_min({OBJECTIVE: dataset}) npt.assert_array_equal(x_min, tf.constant([0.3, 0.4], dtype=tf.float64)) npt.assert_array_equal(y_min, tf.constant([0.2], dtype=tf.float64)) -# get_local_min returns first x value and inf y value when points in dataset are outside the +# get_dataset_min returns first x value and inf y value when points in dataset are outside the # search space. -def test_trust_region_box_get_local_min_outside_search_space() -> None: +def test_trust_region_box_get_dataset_min_outside_search_space() -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) dataset = Dataset( tf.constant([[1.2, 1.3], [-0.4, -0.5]], dtype=tf.float64), tf.constant([[0.7], [0.9]], dtype=tf.float64), ) trb = SingleObjectiveTrustRegionBox(search_space) - x_min, y_min = trb.get_local_min(dataset) + x_min, y_min = trb.get_dataset_min({OBJECTIVE: dataset}) npt.assert_array_equal(x_min, tf.constant([1.2, 1.3], dtype=tf.float64)) npt.assert_array_equal(y_min, tf.constant([np.inf], dtype=tf.float64)) @@ -1391,7 +1433,7 @@ def test_multi_trust_region_box_acquire_no_state() -> None: assert isinstance(state.acquisition_space, TaggedMultiSearchSpace) assert len(state.acquisition_space.subspace_tags) == 2 - for index, (tag, point) in enumerate(zip(state.acquisition_space.subspace_tags, points)): + for index, (tag, point) in enumerate(zip(state.acquisition_space.subspace_tags, points[0])): subspace = state.acquisition_space.get_subspace(tag) assert subspace == subspaces[index] assert isinstance(subspace, SingleObjectiveTrustRegionBox) @@ -1447,9 +1489,11 @@ def __init__( beta: float = 0.7, kappa: float = 1e-4, min_eps: float = 1e-2, + init_eps: float = 0.07, ): super().__init__(global_search_space, beta, kappa, min_eps) self._location = fixed_location + self._init_eps_val = init_eps @property def location(self) -> TensorType: @@ -1460,7 +1504,7 @@ def location(self, location: TensorType) -> None: ... def _init_eps(self) -> None: - self.eps = tf.constant(0.07, dtype=tf.float64) + self.eps = tf.constant(self._init_eps_val, dtype=tf.float64) # Start with a defined state and dataset. Acquire should return an updated state. @@ -1500,13 +1544,13 @@ def test_multi_trust_region_box_acquire_with_state() -> None: next_state, points = state_func(state) assert next_state is not None - assert len(points) == 3 + assert points.shape == [1, 3, 2] # The regions correspond to first, third and first points in the dataset. # First two regions should be updated. # The third region should be initialized and not updated, as it is too close to the first # subspace. for point, subspace, exp_obs, exp_eps in zip( - points, + points[0], subspaces, [dataset.observations[0], dataset.observations[2], dataset.observations[0]], [0.1, 0.1, 0.07], # First two regions updated, third region initialized. @@ -1517,13 +1561,212 @@ def test_multi_trust_region_box_acquire_with_state() -> None: npt.assert_allclose(subspace.eps, exp_eps) +# Test case with multiple local models and multiple regions for batch trust regions. +# It checks that the correct model is passed to each region, and that the correct dataset is +# passed to each instance of the base rule (note: the base rule is deep-copied for each region). +# This is done by mapping each region to a model. For each region the model has a local quadratic +# shape with the minimum at the center of the region. The overal model is creating by creating +# a product of all regions using that model. The end expected result is that each region should find +# its center after optimization. If the wrong model is being used by a region, then instead it would +# find one of its boundaries. +# Note that the implementation of this test is more general than strictly required. It can support +# fewer models than regions (as long as the number of regions is a multiple of the number of +# models). However, currently trieste only supports either a global model or a one to one mapping +# between models and regions. +@pytest.mark.parametrize("use_global_model", [True, False]) +@pytest.mark.parametrize("use_global_dataset", [True, False]) +@pytest.mark.parametrize("num_regions", [2, 4]) +@pytest.mark.parametrize("num_query_points_per_region", [1, 2]) +def test_multi_trust_region_box_with_multiple_models_and_regions( + use_global_model: bool, + use_global_dataset: bool, + num_regions: int, + num_query_points_per_region: int, +) -> None: + search_space = Box([0.0, 0.0], [6.0, 6.0]) + base_shift = tf.constant([2.0, 2.0], dtype=tf.float64) # Common base shift for all regions. + eps = 0.9 + subspaces = [ + TestTrustRegionBox(base_shift + i, search_space, init_eps=eps) for i in range(num_regions) + ] + + # Define the models and acquisition functions for each region + noise_variance = tf.constant(1e-6, dtype=tf.float64) + kernel_variance = tf.constant(1e-3, dtype=tf.float64) + + global_dataset = Dataset( + tf.constant([[0.0, 0.0]], dtype=tf.float64), + tf.constant([[1.0]], dtype=tf.float64), + ) + init_datasets = {OBJECTIVE: global_dataset} + models = {} + r = range(1) if use_global_model else range(num_regions) + for i in r: + if use_global_model: + tag = OBJECTIVE + num_models = 1 + else: + tag = LocalizedTag(OBJECTIVE, i) + num_models = num_regions + + num_regions_per_model = num_regions // num_models + query_points = tf.stack([base_shift + j for j in range(i, num_regions, num_models)]) + observations = tf.constant([0.0] * num_regions_per_model, dtype=tf.float64)[:, None] + + if not use_global_dataset: + init_datasets[tag] = Dataset(query_points, observations) + + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(kernel_variance) + + # Overall mean function is a product of local mean functions. + def mean_function(x: TensorType, i: int = i) -> TensorType: + return tf.reduce_prod( + tf.stack( + [ + quadratic(x - tf.cast(base_shift + j, dtype=x.dtype)) + for j in range(i, num_regions, num_models) + ] + ), + axis=0, + ) + + models[tag] = GaussianProcess([mean_function], [kernel], noise_variance) + models[tag]._exp_dataset = ( # type: ignore[attr-defined] + global_dataset if use_global_dataset else init_datasets[tag] + ) + + if use_global_model: + # Global model; acquire in parallel. + num_query_points = num_regions * num_query_points_per_region + else: + # Local models; acquire sequentially. + num_query_points = num_query_points_per_region + + class TestMultipleOptimismNegativeLowerConfidenceBound( + MultipleOptimismNegativeLowerConfidenceBound + ): + # Override the prepare_acquisition_function method to check that the dataset is correct. + def prepare_acquisition_function( + self, + model: ProbabilisticModel, + dataset: Optional[Dataset] = None, + ) -> AcquisitionFunction: + assert dataset is model._exp_dataset # type: ignore[attr-defined] + return super().prepare_acquisition_function(model, dataset) + + base_rule = EfficientGlobalOptimization( # type: ignore[var-annotated] + builder=TestMultipleOptimismNegativeLowerConfidenceBound(search_space), + num_query_points=num_query_points, + ) + + mtb = BatchTrustRegionBox(subspaces, base_rule) + _, points = mtb.acquire(search_space, models, init_datasets)(None) + + npt.assert_array_equal(points.shape, [num_query_points_per_region, num_regions, 2]) + + # Each region should find the minimum of its local model, which will be the center of + # the region. + exp_points = tf.stack([base_shift + i for i in range(num_regions)]) + exp_points = tf.tile(exp_points[None, :, :], [num_query_points_per_region, 1, 1]) + npt.assert_allclose(points, exp_points) + + +# This test ensures that the datasets for each region are updated correctly. The datasets should +# contain filtered data, i.e. only points in the respective regions. +@pytest.mark.parametrize( + "datasets, exp_num_init_points", + [ + ({OBJECTIVE: mk_dataset([[0.0], [1.0], [2.0]], [[1.0], [1.0], [1.0]])}, 1), + ( + { + OBJECTIVE: mk_dataset( + [[0.0], [1.0], [0.3], [2.0], [0.7], [1.7]], + [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0]], + ) + }, + 2, + ), + ( + { + OBJECTIVE: mk_dataset([[-1.0]], [[-1.0]]), # Should be ignored. + LocalizedTag(OBJECTIVE, 0): mk_dataset([[0.0]], [[1.0]]), + LocalizedTag(OBJECTIVE, 1): mk_dataset([[1.0]], [[1.0]]), + LocalizedTag(OBJECTIVE, 2): mk_dataset([[2.0]], [[1.0]]), + }, + 1, + ), + ( + { + OBJECTIVE: mk_dataset([[-1.0]], [[-1.0]]), # Should be ignored. + LocalizedTag(OBJECTIVE, 0): mk_dataset([[0.0], [1.0]], [[1.0], [1.0]]), + LocalizedTag(OBJECTIVE, 1): mk_dataset([[2.0], [1.0]], [[1.0], [1.0]]), + LocalizedTag(OBJECTIVE, 2): mk_dataset([[2.0], [3.0]], [[1.0], [1.0]]), + }, + 1, + ), + ], +) +@pytest.mark.parametrize("num_query_points_per_region", [1, 2]) +def test_multi_trust_region_box_updated_datasets_are_in_regions( + datasets: Mapping[Tag, Dataset], exp_num_init_points: int, num_query_points_per_region: int +) -> None: + num_local_models = 3 + search_space = Box([0.0], [3.0]) + # Non-overlapping regions. + subspaces = [ + TestTrustRegionBox(tf.constant([i], dtype=tf.float64), search_space, init_eps=0.4) + for i in range(num_local_models) + ] + models = copy_to_local_models(QuadraticMeanAndRBFKernel(), num_local_models) + base_rule = EfficientGlobalOptimization( # type: ignore[var-annotated] + builder=MultipleOptimismNegativeLowerConfidenceBound(search_space), + num_query_points=num_query_points_per_region, + ) + rule = BatchTrustRegionBox(subspaces, base_rule) + _, points = rule.acquire(search_space, models, datasets)(None) + observer = mk_batch_observer(quadratic) + new_data = observer(points) + assert not isinstance(new_data, Dataset) + + updated_datasets = {} + for tag in new_data: + _, dataset = get_value_for_tag(datasets, *[tag, LocalizedTag.from_tag(tag).global_tag]) + assert dataset is not None + updated_datasets[tag] = dataset + new_data[tag] + datasets = rule.filter_datasets(updated_datasets) + + # Check local datasets. + for i, subspace in enumerate(subspaces): + assert ( + datasets[LocalizedTag(OBJECTIVE, i)].query_points.shape[0] + == exp_num_init_points + num_query_points_per_region + ) + assert np.all(subspace.contains(datasets[LocalizedTag(OBJECTIVE, i)].query_points)) + + # Check global dataset. + assert datasets[OBJECTIVE].query_points.shape[0] == num_local_models * ( + exp_num_init_points + num_query_points_per_region + ) + # Each point should be in at least one region. + for point in datasets[OBJECTIVE].query_points: + assert any(subspace.contains(point) for subspace in subspaces) + # Global dataset should be the concatenation of all local datasets. + exp_query_points = tf.concat( + [datasets[LocalizedTag(OBJECTIVE, i)].query_points for i in range(num_local_models)], axis=0 + ) + npt.assert_array_almost_equal(datasets[OBJECTIVE].query_points, exp_query_points) + + def test_multi_trust_region_box_state_deepcopy() -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) dataset = Dataset( tf.constant([[0.25, 0.25], [0.5, 0.5], [0.75, 0.75]], dtype=tf.float64), tf.constant([[1.0], [1.0], [1.0]], dtype=tf.float64), ) - subspaces = [SingleObjectiveTrustRegionBox(search_space, 0.07, 1e-5, 1e-3) for _ in range(3)] + subspaces = [ + SingleObjectiveTrustRegionBox(search_space, beta=0.07, kappa=1e-5, min_eps=1e-3) + for _ in range(3) + ] for _subspace in subspaces: _subspace.initialize(datasets={OBJECTIVE: dataset}) state = BatchTrustRegionBox.State(acquisition_space=TaggedMultiSearchSpace(subspaces)) diff --git a/tests/unit/acquisition/test_utils.py b/tests/unit/acquisition/test_utils.py index 7975cab3a0..9e578587ca 100644 --- a/tests/unit/acquisition/test_utils.py +++ b/tests/unit/acquisition/test_utils.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any +from typing import Any, Optional from unittest.mock import MagicMock import numpy as np @@ -22,6 +22,7 @@ from trieste.acquisition import AcquisitionFunction from trieste.acquisition.utils import ( + copy_to_local_models, get_local_dataset, get_unique_points_mask, select_nth_output, @@ -29,6 +30,8 @@ ) from trieste.data import Dataset from trieste.space import Box, SearchSpaceType +from trieste.types import Tag +from trieste.utils.misc import LocalizedTag @pytest.mark.parametrize( @@ -100,6 +103,18 @@ def test_get_local_dataset_works() -> None: assert tf.shape(get_local_dataset(search_space_2, combined).query_points)[0] == 20 +@pytest.mark.parametrize("num_local_models", [1, 3]) +@pytest.mark.parametrize("key", [None, "a"]) +def test_copy_to_local_models(num_local_models: int, key: Optional[Tag]) -> None: + global_model = MagicMock() + local_models = copy_to_local_models(global_model, num_local_models=num_local_models, key=key) + assert len(local_models) == num_local_models + for i, (k, m) in enumerate(local_models.items()): + assert k == LocalizedTag(key, i) + assert isinstance(m, MagicMock) + assert m is not global_model + + @pytest.mark.parametrize( "points, tolerance, expected_mask", [ diff --git a/tests/unit/objectives/test_utils.py b/tests/unit/objectives/test_utils.py index 52cbbdb827..d1832b9443 100644 --- a/tests/unit/objectives/test_utils.py +++ b/tests/unit/objectives/test_utils.py @@ -11,10 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Mapping, Set, Union + import numpy.testing as npt +import pytest import tensorflow as tf -from trieste.objectives.utils import mk_multi_observer, mk_observer +from trieste.data import Dataset +from trieste.objectives.utils import mk_batch_observer, mk_multi_observer, mk_observer +from trieste.observer import Observer +from trieste.types import Tag, TensorType +from trieste.utils.misc import LocalizedTag def test_mk_observer() -> None: @@ -49,3 +56,49 @@ def test_mk_multi_observer() -> None: npt.assert_array_equal(ys["foo"].observations, x_ + 1) npt.assert_array_equal(ys["bar"].query_points, x_) npt.assert_array_equal(ys["bar"].observations, x_ - 1) + + +@pytest.mark.parametrize( + "input_objective, exp_o_call", + [ + (lambda x: x, {"baz": lambda x: x}), + (lambda x: Dataset(x, x), {"baz": lambda x: x}), + ( + mk_multi_observer(foo=lambda x: x + 1, bar=lambda x: x - 1), + {"foo": lambda x: x + 1, "bar": lambda x: x - 1}, + ), + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 3]) +@pytest.mark.parametrize("num_query_points_per_batch", [1, 2]) +def test_mk_batch_observer( + input_objective: Union[Callable[[TensorType], TensorType], Observer], + exp_o_call: Mapping[Tag, Callable[[TensorType], TensorType]], + batch_size: int, + num_query_points_per_batch: int, +) -> None: + x_ = tf.reshape( + tf.constant(range(batch_size * num_query_points_per_batch), tf.float64), + (num_query_points_per_batch, batch_size, 1), + ) + ys = mk_batch_observer(input_objective, "baz")(x_) + + assert isinstance(ys, dict) + + # Check keys. + exp_keys: Set[Union[Tag, LocalizedTag]] = set() + for key in exp_o_call: + exp_keys.update({LocalizedTag(key, i) for i in range(batch_size)}) + exp_keys.add(key) + assert ys.keys() == exp_keys + + # Check datasets. + for key, call in exp_o_call.items(): + # Get expected observations. + exp_o = call(x_) + + npt.assert_array_equal(ys[key].query_points, tf.reshape(x_, [-1, 1])) + npt.assert_array_equal(ys[key].observations, tf.reshape(exp_o, [-1, 1])) + for i in range(batch_size): + npt.assert_array_equal(ys[LocalizedTag(key, i)].query_points, x_[:, i]) + npt.assert_array_equal(ys[LocalizedTag(key, i)].observations, exp_o[:, i]) diff --git a/tests/unit/test_ask_tell_optimization.py b/tests/unit/test_ask_tell_optimization.py index 7c638f2baf..973f9ce352 100644 --- a/tests/unit/test_ask_tell_optimization.py +++ b/tests/unit/test_ask_tell_optimization.py @@ -15,19 +15,28 @@ from typing import Mapping, Optional +import numpy.testing as npt import pytest import tensorflow as tf from tests.util.misc import FixedAcquisitionRule, assert_datasets_allclose, mk_dataset -from tests.util.models.gpflow.models import GaussianProcess, PseudoTrainableProbModel, rbf +from tests.util.models.gpflow.models import ( + GaussianProcess, + PseudoTrainableProbModel, + QuadraticMeanAndRBFKernel, + rbf, +) from trieste.acquisition.rule import AcquisitionRule +from trieste.acquisition.utils import copy_to_local_models from trieste.ask_tell_optimization import AskTellOptimizer from trieste.bayesian_optimizer import OptimizationResult, Record from trieste.data import Dataset from trieste.models.interfaces import ProbabilisticModel, TrainableProbabilisticModel +from trieste.objectives.utils import mk_batch_observer from trieste.observer import OBJECTIVE from trieste.space import Box from trieste.types import State, Tag, TensorType +from trieste.utils.misc import LocalizedTag # tags TAG1: Tag = "1" @@ -427,3 +436,85 @@ def __deepcopy__(self, memo: dict[int, object]) -> _UncopyableModel: with pytest.raises(NotImplementedError): ask_tell.to_result() assert ask_tell.to_result(copy=False).final_result.is_ok + + +class DatasetChecker(QuadraticMeanAndRBFKernel, PseudoTrainableProbModel): + def __init__( + self, + use_global_model: bool, + use_global_init_dataset: bool, + init_data: Mapping[Tag, Dataset], + query_points: TensorType, + ) -> None: + super().__init__() + self.update_count = 0 + self._tag = OBJECTIVE + self.use_global_model = use_global_model + self.use_global_init_dataset = use_global_init_dataset + self.init_data = init_data + self.query_points = query_points + + def update(self, dataset: Dataset) -> None: + if self.use_global_model: + exp_init_qps = self.init_data[OBJECTIVE].query_points + else: + if self.use_global_init_dataset: + exp_init_qps = self.init_data[OBJECTIVE].query_points + else: + exp_init_qps = self.init_data[self._tag].query_points + + if self.update_count == 0: + # Initial model training. + exp_qps = exp_init_qps + else: + # Subsequent model training. + if self.use_global_model: + exp_qps = tf.concat([exp_init_qps, tf.reshape(self.query_points, [-1, 1])], 0) + else: + index = LocalizedTag.from_tag(self._tag).local_index + exp_qps = tf.concat([exp_init_qps, self.query_points[:, index]], 0) + + npt.assert_array_equal(exp_qps, dataset.query_points) + self.update_count += 1 + + +# Check that the correct dataset is routed to the model. +# Note: this test is almost identical to the one in test_bayesian_optimizer.py. +@pytest.mark.parametrize("use_global_model", [True, False]) +@pytest.mark.parametrize("use_global_init_dataset", [True, False]) +@pytest.mark.parametrize("num_query_points_per_batch", [1, 2]) +def test_ask_tell_optimizer_creates_correct_datasets_for_rank3_points( + use_global_model: bool, use_global_init_dataset: bool, num_query_points_per_batch: int +) -> None: + batch_size = 4 + if use_global_init_dataset: + init_data = {OBJECTIVE: mk_dataset([[0.5], [1.5]], [[0.25], [0.35]])} + else: + init_data = { + LocalizedTag(OBJECTIVE, i): mk_dataset([[0.5 + i], [1.5 + i]], [[0.25], [0.35]]) + for i in range(batch_size) + } + init_data[OBJECTIVE] = mk_dataset([[0.5], [1.5]], [[0.25], [0.35]]) + + query_points = tf.reshape( + tf.constant(range(batch_size * num_query_points_per_batch), tf.float64), + (num_query_points_per_batch, batch_size, 1), + ) + + search_space = Box([-1], [1]) + + model = DatasetChecker(use_global_model, use_global_init_dataset, init_data, query_points) + if use_global_model: + models = {OBJECTIVE: model} + else: + models = copy_to_local_models(model, batch_size) # type: ignore[assignment] + for tag, model in models.items(): + model._tag = tag + + observer = mk_batch_observer(lambda x: Dataset(x, x)) + rule = FixedAcquisitionRule(query_points) + ask_tell = AskTellOptimizer(search_space, init_data, models, rule) + + points = ask_tell.ask() + new_data = observer(points) + ask_tell.tell(new_data) diff --git a/tests/unit/test_bayesian_optimizer.py b/tests/unit/test_bayesian_optimizer.py index 99017dff90..cca9d3b952 100644 --- a/tests/unit/test_bayesian_optimizer.py +++ b/tests/unit/test_bayesian_optimizer.py @@ -23,6 +23,7 @@ import tensorflow as tf from check_shapes import inherit_check_shapes +from tests.unit.test_ask_tell_optimization import DatasetChecker from tests.util.misc import ( FixedAcquisitionRule, assert_datasets_allclose, @@ -38,6 +39,7 @@ rbf, ) from trieste.acquisition.rule import AcquisitionRule +from trieste.acquisition.utils import copy_to_local_models from trieste.bayesian_optimizer import BayesianOptimizer, FrozenRecord, OptimizationResult, Record from trieste.data import Dataset from trieste.models import ProbabilisticModel, TrainableProbabilisticModel @@ -45,6 +47,7 @@ from trieste.space import Box, SearchSpace from trieste.types import State, Tag, TensorType from trieste.utils import Err, Ok +from trieste.utils.misc import LocalizedTag # tags FOO: Tag = "foo" @@ -236,6 +239,44 @@ def __call__(self, x: tf.Tensor) -> Dataset: assert observer.call_count == steps +# Check that the correct dataset is routed to the model. +# Note: this test is almost identical to the one in test_ask_tell_optimization.py. +@pytest.mark.parametrize("use_global_model", [True, False]) +@pytest.mark.parametrize("use_global_init_dataset", [True, False]) +@pytest.mark.parametrize("num_query_points_per_batch", [1, 2]) +def test_bayesian_optimizer_creates_correct_datasets_for_rank3_points( + use_global_model: bool, use_global_init_dataset: bool, num_query_points_per_batch: int +) -> None: + batch_size = 4 + if use_global_init_dataset: + init_data = {OBJECTIVE: mk_dataset([[0.5], [1.5]], [[0.25], [0.35]])} + else: + init_data = { + LocalizedTag(OBJECTIVE, i): mk_dataset([[0.5 + i], [1.5 + i]], [[0.25], [0.35]]) + for i in range(batch_size) + } + init_data[OBJECTIVE] = mk_dataset([[0.5], [1.5]], [[0.25], [0.35]]) + + query_points = tf.reshape( + tf.constant(range(batch_size * num_query_points_per_batch), tf.float64), + (num_query_points_per_batch, batch_size, 1), + ) + + search_space = Box([-1], [1]) + + model = DatasetChecker(use_global_model, use_global_init_dataset, init_data, query_points) + if use_global_model: + models = {OBJECTIVE: model} + else: + models = copy_to_local_models(model, batch_size) # type: ignore[assignment] + for tag, model in models.items(): + model._tag = tag + + optimizer = BayesianOptimizer(lambda x: Dataset(x, x), search_space) + rule = FixedAcquisitionRule(query_points) + optimizer.optimize(1, init_data, models, rule).final_result.unwrap() + + @pytest.mark.parametrize("mode", ["early", "fail", "full"]) def test_bayesian_optimizer_continue_optimization(mode: str) -> None: class _CountingObserver: diff --git a/tests/unit/utils/test_misc.py b/tests/unit/utils/test_misc.py index ac6dd642eb..528f7aabf6 100644 --- a/tests/unit/utils/test_misc.py +++ b/tests/unit/utils/test_misc.py @@ -14,7 +14,7 @@ from __future__ import annotations from time import sleep -from typing import Any +from typing import Any, Optional, Union import numpy as np import numpy.testing as npt @@ -23,9 +23,10 @@ from tests.util.misc import TF_DEBUGGING_ERROR_TYPES, ShapeLike, various_shapes from trieste.observer import OBJECTIVE -from trieste.types import TensorType +from trieste.types import Tag, TensorType from trieste.utils.misc import ( Err, + LocalizedTag, Ok, Timer, flatten_leading_dims, @@ -97,20 +98,56 @@ def test_err() -> None: def test_get_value_for_tag_returns_none_if_mapping_is_none() -> None: - assert get_value_for_tag(None) is None + assert get_value_for_tag(None) == (None, None) def test_get_value_for_tag_raises_if_tag_not_in_mapping() -> None: - with pytest.raises(ValueError, match="tag 'baz' not found in mapping"): + with pytest.raises(ValueError, match="none of the tags '.'baz',.' found in mapping"): get_value_for_tag({"foo": "bar"}, "baz") def test_get_value_for_tag_returns_value_for_default_tag() -> None: - assert get_value_for_tag({"foo": "bar", OBJECTIVE: "baz"}) == "baz" + assert get_value_for_tag({"foo": "bar", OBJECTIVE: "baz"}) == (OBJECTIVE, "baz") def test_get_value_for_tag_returns_value_for_specified_tag() -> None: - assert get_value_for_tag({"foo": "bar", OBJECTIVE: "baz"}, "foo") == "bar" + assert get_value_for_tag({"foo": "bar", OBJECTIVE: "baz"}, "foo") == ("foo", "bar") + + +def test_get_value_for_tag_returns_first_matching_tag() -> None: + assert get_value_for_tag( + {"foo": "bar", OBJECTIVE: "baz", "qux": "quux", "bar": "baz"}, *["far", "qux", "foo"] + ) == ("qux", "quux") + + +@pytest.mark.parametrize("tag_name", ["test_tag_1", "test_tag_2"]) +@pytest.mark.parametrize("tag_index", [0, 2, None]) +def test_localized_tag_creation(tag_name: str, tag_index: Optional[int]) -> None: + tag = LocalizedTag(tag_name, tag_index) + is_local = True if tag_index is not None else False + # Ensure a duplicate tag is equal. + tag2 = LocalizedTag(tag_name, tag_index) + + assert tag.is_local == is_local + assert tag.global_tag == tag_name + assert tag.local_index == tag_index + assert tag == tag2 + assert hash(tag) == hash(tag2) + assert repr(tag) == f"LocalizedTag(global_tag='{tag_name}', local_index={tag_index})" + + +@pytest.mark.parametrize( + "tag, exp_tag", + [ + ("test_tag_1", LocalizedTag("test_tag_1", None)), + (LocalizedTag("test_tag_1", 3), LocalizedTag("test_tag_1", 3)), + (LocalizedTag("test_tag", None), LocalizedTag("test_tag", None)), + ], +) +def test_localized_tag_from_tag(tag: Union[Tag, LocalizedTag], exp_tag: LocalizedTag) -> None: + ltag = LocalizedTag.from_tag(tag) + assert ltag.global_tag == exp_tag.global_tag + assert ltag.local_index == exp_tag.local_index def test_Timer() -> None: diff --git a/trieste/acquisition/rule.py b/trieste/acquisition/rule.py index 50acaa21b2..dc172ab2a6 100644 --- a/trieste/acquisition/rule.py +++ b/trieste/acquisition/rule.py @@ -20,9 +20,22 @@ import copy import math from abc import ABC, abstractmethod +from collections import Counter from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, Sequence, Tuple, TypeVar, Union, cast, overload +from typing import ( + Any, + Callable, + Generic, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, + cast, + overload, +) import numpy as np from check_shapes import check_shapes, inherit_check_shapes @@ -50,7 +63,7 @@ from ..observer import OBJECTIVE from ..space import Box, SearchSpace, TaggedMultiSearchSpace from ..types import State, Tag, TensorType -from ..utils.misc import get_value_for_tag +from ..utils.misc import LocalizedTag from .function import ( BatchMonteCarloExpectedImprovement, ExpectedImprovement, @@ -81,6 +94,9 @@ SearchSpaceType = TypeVar("SearchSpaceType", bound=SearchSpace, contravariant=True) """ Contravariant type variable bound to :class:`~trieste.space.SearchSpace`. """ +T = TypeVar("T") +""" Unbound type variable. """ + class AcquisitionRule(ABC, Generic[ResultType, SearchSpaceType, ProbabilisticModelType]): """ @@ -148,6 +164,17 @@ def acquire_single( datasets=None if dataset is None else {OBJECTIVE: dataset}, ) + def filter_datasets(self, datasets: Mapping[Tag, Dataset]) -> Mapping[Tag, Dataset]: + """ + Filter the post-acquisition datasets before they are used for model training. For example, + this can be used to remove points from the datasets that are no longer in the search space. + + :param datasets: The datasets to filter. + :return: The filtered datasets. + """ + # No filtering by default. + return datasets + class EfficientGlobalOptimization( AcquisitionRule[TensorType, SearchSpaceType, ProbabilisticModelType] @@ -944,6 +971,14 @@ def acquire( class UpdatableTrustRegion(SearchSpace): """A search space that can be updated.""" + def __init__(self, region_index: Optional[int] = None) -> None: + """ + :param region_index: The index of the region in a multi-region search space. This is used to + identify the local models and datasets to use for acquisition. If `None`, the + global models and datasets are used. + """ + self.region_index = region_index + @abstractmethod def initialize( self, @@ -972,6 +1007,78 @@ def update( """ ... + def _get_tags(self, tags: Set[Tag]) -> Tuple[Set[Tag], Set[Tag]]: + # Separate tags into local (matching index) and global tags (without matching + # local tag). + local_gtags = set() # Set of global part of all local tags. + global_tags = set() # Set of all global tags. + for tag in tags: + ltag = LocalizedTag.from_tag(tag) + if not ltag.is_local: + global_tags.add(tag) + elif ltag.local_index == self.region_index: + local_gtags.add(ltag.global_tag) + + # Only keep global tags that don't have a matching local tag. + global_tags -= local_gtags + + return local_gtags, global_tags + + def select_in_region(self, mapping: Optional[Mapping[Tag, T]]) -> Optional[Mapping[Tag, T]]: + """ + Select items belonging to this region for acquisition. + + :param mapping: The mapping of items for each tag. + :return: The items belonging to this region (or `None` if there aren't any). + """ + if mapping is None: + _mapping = {} + elif self.region_index is None: + # If no index, then return the global items. + _mapping = { + tag: item + for tag, item in mapping.items() + if not LocalizedTag.from_tag(tag).is_local + } + else: + # Prefer matching local item for each tag, otherwise select the global item. + local_gtags, global_tags = self._get_tags(set(mapping)) + + _mapping = {} + for tag in local_gtags: + ltag = LocalizedTag(tag, self.region_index) + _mapping[ltag] = mapping[ltag] + for tag in global_tags: + _mapping[tag] = mapping[tag] + + return _mapping if _mapping else None + + def get_datasets_filter_mask( + self, datasets: Optional[Mapping[Tag, Dataset]] + ) -> Optional[Mapping[Tag, tf.Tensor]]: + """ + Return a boolean mask that can be used to filter out points from the datasets that + belong to this region. + + :param datasets: The dataset for each tag. + :return: A mapping for each tag belonging to this region, to a boolean mask that can be + used to filter out points from the datasets. A value of `True` indicates that the + corresponding point should be kept. + """ + # Only select the region datasets for filtering. Don't directly filter the global dataset. + assert ( + self.region_index is not None + ), "the region_index should be set for filtering local datasets" + if datasets is None: + return None + else: + # Only keep points that are in the region. + return { + tag: self.contains(dataset.query_points) + for tag, dataset in datasets.items() + if LocalizedTag.from_tag(tag).local_index == self.region_index + } + UpdatableTrustRegionType = TypeVar("UpdatableTrustRegionType", bound=UpdatableTrustRegion) """ A type variable bound to :class:`UpdatableTrustRegion`. """ @@ -1024,9 +1131,16 @@ def __init__( if not isinstance(init_subspaces, Sequence): init_subspaces = [init_subspaces] self._init_subspaces = tuple(init_subspaces) + for index, subspace in enumerate(self._init_subspaces): + subspace.region_index = index # Override the index. self._tags = tuple([str(index) for index in range(len(init_subspaces))]) self._rule = rule + # The rules for each subspace. These are only used when we have local models to run the + # base rule sequentially for each subspace. Theses are set in `acquire`. + self._rules: Optional[ + Sequence[AcquisitionRule[TensorType, SearchSpace, ProbabilisticModelType]] + ] = None def __repr__(self) -> str: """""" @@ -1055,10 +1169,38 @@ def acquire( points from the previous acquisition state. """ + # Subspaces should be set by the time we call `acquire`. + assert self._tags is not None + assert self._init_subspaces is not None + + num_local_models = Counter( + LocalizedTag.from_tag(tag).global_tag + for tag in models + if LocalizedTag.from_tag(tag).is_local + ) + num_local_models_vals = set(num_local_models.values()) + assert ( + len(num_local_models_vals) <= 1 + ), f"The number of local models should be the same for all tags, got {num_local_models}" + _num_local_models = sum(num_local_models_vals) + + num_subspaces = len(self._tags) + assert _num_local_models in [0, num_subspaces], ( + f"When using local models, the number of subspaces {num_subspaces} should be equal to " + f"the number of local models {_num_local_models}" + ) + + # If we have local models, run the (deepcopied) base rule sequentially for each subspace. + # Otherwise, run the base rule as is, once with all models and datasets. + # Note: this should only trigger on the first call to `acquire`, as after that we will + # have a list of rules in `self._rules`. + if _num_local_models > 0 and self._rules is None: + self._rules = [copy.deepcopy(self._rule) for _ in range(num_subspaces)] + def state_func( state: BatchTrustRegion.State | None, ) -> Tuple[BatchTrustRegion.State | None, TensorType]: - # Subspaces should be set by the time we call `acquire`. + # Check again to keep mypy happy. assert self._tags is not None assert self._init_subspaces is not None @@ -1092,9 +1234,32 @@ def state_func( acquisition_space = state.acquisition_space state_ = BatchTrustRegion.State(acquisition_space) - points = self._rule.acquire(acquisition_space, models, datasets=datasets) - return state_, points + # If the base rule is a sequence, run it sequentially for each subspace. + # See earlier comments. + if self._rules is not None: + _points = [] + for subspace, rule in zip(subspaces, self._rules): + _models = subspace.select_in_region(models) + _datasets = subspace.select_in_region(datasets) + assert _models is not None + # Remap all local tags to global ones. One reason is that single model + # acquisition builders expect OBJECTIVE to exist. + _models = { + LocalizedTag.from_tag(tag).global_tag: model + for tag, model in _models.items() + } + if _datasets is not None: + _datasets = { + LocalizedTag.from_tag(tag).global_tag: dataset + for tag, dataset in _datasets.items() + } + _points.append(rule.acquire(subspace, _models, _datasets)) + points = tf.stack(_points, axis=1) + else: + points = self._rule.acquire(acquisition_space, models, datasets) + + return state_, tf.reshape(points, [-1, len(subspaces), points.shape[-1]]) return state_func @@ -1143,6 +1308,53 @@ def get_initialize_subspaces_mask( """ ... + def filter_datasets(self, datasets: Mapping[Tag, Dataset]) -> Mapping[Tag, Dataset]: + # Filter out points that are not in any of the subspaces. This is done by creating a mask + # for each local dataset that is True for points that are in any subspace. + used_masks = { + tag: tf.zeros(dataset.query_points.shape[:-1], dtype=tf.bool) + for tag, dataset in datasets.items() + if LocalizedTag.from_tag(tag).is_local + } + + # Global datasets to re-generate. + global_tags = {LocalizedTag.from_tag(tag).global_tag for tag in used_masks} + + # Using init_subspaces here relies on the users not creating new subspaces after + # initialization. This is a reasonable assumption for now. + assert self._init_subspaces is not None + for subspace in self._init_subspaces: + in_region_masks = subspace.get_datasets_filter_mask(datasets) + if in_region_masks is not None: + for tag, in_region in in_region_masks.items(): + ltag = LocalizedTag.from_tag(tag) + assert ltag.is_local, f"can only filter local tags, got {tag}" + used_masks[tag] = tf.logical_or(used_masks[tag], in_region) + + filtered_datasets = {} + for tag, used_mask in used_masks.items(): + filtered_datasets[tag] = Dataset( + tf.boolean_mask(datasets[tag].query_points, used_mask), + tf.boolean_mask(datasets[tag].observations, used_mask), + ) + + # Include global datasets. + for gtag in global_tags: + # Create global dataset from local datasets. This is done by concatenating the local + # datasets. + local_datasets = [ + value + for tag, value in filtered_datasets.items() + if LocalizedTag.from_tag(tag).global_tag == gtag + ] + # Note there is no ordering assumption for the local datasets. They are simply + # concatenated and information about which local dataset they came from is lost. + qps = tf.concat([dataset.query_points for dataset in local_datasets], axis=0) + obs = tf.concat([dataset.observations for dataset in local_datasets], axis=0) + filtered_datasets[gtag] = Dataset(qps, obs) + + return filtered_datasets + class SingleObjectiveTrustRegionBox(Box, UpdatableTrustRegion): """An updatable box search space for use with trust region acquisition rules.""" @@ -1153,6 +1365,7 @@ def __init__( beta: float = 0.7, kappa: float = 1e-4, min_eps: float = 1e-2, + region_index: Optional[int] = None, ): """ Calculates the bounds of the box from the location/centre and global bounds. @@ -1163,6 +1376,9 @@ def __init__( considered a success. :param min_eps: The minimal size of the search space. If the size of the search space is smaller than this, the search space is reinitialized. + :param region_index: The index of the region in a multi-region search space. This is used to + identify the local models and datasets to use for acquisition. If `None`, the + global models and datasets are used. """ self._global_search_space = global_search_space @@ -1171,6 +1387,7 @@ def __init__( self._min_eps = min_eps super().__init__(global_search_space.lower, global_search_space.upper) + super(Box, self).__init__(region_index) @property def global_search_space(self) -> SearchSpace: @@ -1199,13 +1416,13 @@ def initialize( Initialize the box by sampling a location from the global search space and setting the bounds. """ - dataset = get_value_for_tag(datasets) + datasets = self.select_in_region(datasets) self.location = tf.squeeze(self.global_search_space.sample(1), axis=0) self._step_is_success = False self._init_eps() self._update_bounds() - _, self._y_min = self.get_local_min(dataset) + _, self._y_min = self.get_dataset_min(datasets) def update( self, @@ -1223,13 +1440,13 @@ def update( ``1 / beta``. Conversely, if it was unsuccessful, the size is reduced by the factor ``beta``. """ - dataset = get_value_for_tag(datasets) + datasets = self.select_in_region(datasets) if tf.reduce_any(self.eps < self._min_eps): self.initialize(models, datasets) return - x_min, y_min = self.get_local_min(dataset) + x_min, y_min = self.get_dataset_min(datasets) self.location = x_min tr_volume = tf.reduce_prod(self.upper - self.lower) @@ -1242,10 +1459,17 @@ def update( "return[0]: [D]", "return[1]: []", ) - def get_local_min(self, dataset: Optional[Dataset]) -> Tuple[TensorType, TensorType]: - """Calculate the local minimum of the box using the given dataset.""" - if dataset is None: - raise ValueError("""dataset must be provided""") + def get_dataset_min( + self, datasets: Optional[Mapping[Tag, Dataset]] + ) -> Tuple[TensorType, TensorType]: + """Calculate the minimum of the box using the given dataset.""" + if ( + datasets is None + or len(datasets) != 1 + or LocalizedTag.from_tag(next(iter(datasets))).global_tag != OBJECTIVE + ): + raise ValueError("""a single OBJECTIVE dataset must be provided""") + dataset = next(iter(datasets.values())) in_tr = self.contains(dataset.query_points) in_tr_obs = tf.where( @@ -1286,6 +1510,8 @@ def acquire( self._init_subspaces = tuple( [SingleObjectiveTrustRegionBox(search_space) for _ in range(num_query_points)] ) + for index, subspace in enumerate(self._init_subspaces): + subspace.region_index = index # Override the index. self._tags = tuple([str(index) for index in range(len(self._init_subspaces))]) # Ensure passed in global search space is always the same as the search space passed to @@ -1339,8 +1565,9 @@ def __init__( beta: float = 0.7, kappa: float = 1e-4, min_eps: float = 1e-2, + region_index: Optional[int] = None, ): - super().__init__(global_search_space, beta, kappa, min_eps) + super().__init__(global_search_space, beta, kappa, min_eps, region_index) self._is_global = False self._initialized = False @@ -1382,10 +1609,35 @@ def initialize( super().initialize(models, datasets) + def get_datasets_filter_mask( + self, datasets: Optional[Mapping[Tag, Dataset]] + ) -> Optional[Mapping[Tag, tf.Tensor]]: + # Only select the region datasets for filtering. Don't directly filter the global dataset. + assert ( + self.region_index is not None + ), "the region_index should be set for filtering local datasets" + if datasets is None: + return None + else: + # Don't filter out any points from the dataset. Always keep the entire dataset. + return { + tag: tf.ones(tf.shape(dataset.query_points)[:-1], dtype=tf.bool) + for tag, dataset in datasets.items() + if LocalizedTag.from_tag(tag).local_index == self.region_index + } + @inherit_check_shapes - def get_local_min(self, dataset: Optional[Dataset]) -> Tuple[TensorType, TensorType]: - if dataset is None: - raise ValueError("""dataset must be provided""") + def get_dataset_min( + self, datasets: Optional[Mapping[Tag, Dataset]] + ) -> Tuple[TensorType, TensorType]: + """Calculate the minimum of the box using the given dataset.""" + if ( + datasets is None + or len(datasets) != 1 + or LocalizedTag.from_tag(next(iter(datasets))).global_tag != OBJECTIVE + ): + raise ValueError("""a single OBJECTIVE dataset must be provided""") + dataset = next(iter(datasets.values())) # Always return the global minimum. ix = tf.argmin(dataset.observations) diff --git a/trieste/acquisition/utils.py b/trieste/acquisition/utils.py index 8afe5d07eb..8a6780de49 100644 --- a/trieste/acquisition/utils.py +++ b/trieste/acquisition/utils.py @@ -11,15 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import functools -from typing import Tuple, Union +from typing import Mapping, Tuple, Union import tensorflow as tf from check_shapes import check_shapes from ..data import Dataset +from ..models import ProbabilisticModelType +from ..observer import OBJECTIVE from ..space import SearchSpaceType -from ..types import TensorType +from ..types import Tag, TensorType +from ..utils.misc import LocalizedTag from .interface import AcquisitionFunction from .optimizer import AcquisitionOptimizer @@ -139,6 +143,22 @@ def get_local_dataset(local_space: SearchSpaceType, dataset: Dataset) -> Dataset return local_dataset +def copy_to_local_models( + global_model: ProbabilisticModelType, + num_local_models: int, + key: Tag = OBJECTIVE, +) -> Mapping[Tag, ProbabilisticModelType]: + """ + Helper method to copy a global model to local models. + + :param global_model: The global model. + :param num_local_models: The number of local models to create. + :param key: The tag prefix for the local models. + :return: A mapping of the local models. + """ + return {LocalizedTag(key, i): copy.deepcopy(global_model) for i in range(num_local_models)} + + @check_shapes( "points: [n_points, ...]", "return: [n_points]", diff --git a/trieste/ask_tell_optimization.py b/trieste/ask_tell_optimization.py index b6530023b2..27a36c38d5 100644 --- a/trieste/ask_tell_optimization.py +++ b/trieste/ask_tell_optimization.py @@ -49,6 +49,7 @@ from .space import SearchSpace from .types import State, Tag, TensorType from .utils import Ok, Timer +from .utils.misc import LocalizedTag, get_value_for_tag, ignoring_local_tags StateType = TypeVar("StateType") """ Unbound type variable. """ @@ -188,18 +189,28 @@ def __init__( if not datasets or not models: raise ValueError("dicts of datasets and models must be populated.") + # Copy the dataset so we don't change the one provided by the user. + datasets = deepcopy(datasets) + if isinstance(datasets, Dataset): datasets = {OBJECTIVE: datasets} - models = {OBJECTIVE: models} # type: ignore[dict-item] + if not isinstance(models, Mapping): + models = {OBJECTIVE: models} + + self._filtered_datasets = datasets # reassure the type checker that everything is tagged datasets = cast(Dict[Tag, Dataset], datasets) models = cast(Dict[Tag, TrainableProbabilisticModelType], models) - if datasets.keys() != models.keys(): + # Get set of dataset and model keys, ignoring any local tag index. That is, only the + # global tag part is considered. + datasets_keys = {LocalizedTag.from_tag(tag).global_tag for tag in datasets.keys()} + models_keys = {LocalizedTag.from_tag(tag).global_tag for tag in models.keys()} + if datasets_keys != models_keys: raise ValueError( - f"datasets and models should contain the same keys. Got {datasets.keys()} and" - f" {models.keys()} respectively." + f"datasets and models should contain the same keys. Got {datasets_keys} and" + f" {models_keys} respectively." ) self._datasets = datasets @@ -233,7 +244,10 @@ def __init__( if fit_model: with Timer() as initial_model_fitting_timer: for tag, model in self._models.items(): - dataset = datasets[tag] + # Prefer local dataset if available. + tags = [tag, LocalizedTag.from_tag(tag).global_tag] + _, dataset = get_value_for_tag(datasets, *tags) + assert dataset is not None model.update(dataset) optimize_model_and_save_result(model, dataset) @@ -258,10 +272,12 @@ def datasets(self) -> Mapping[Tag, Dataset]: @property def dataset(self) -> Dataset: """The current dataset when there is just one dataset.""" - if len(self.datasets) == 1: - return next(iter(self.datasets.values())) + # Ignore local datasets. + datasets: Mapping[Tag, Dataset] = ignoring_local_tags(self.datasets) + if len(datasets) == 1: + return next(iter(datasets.values())) else: - raise ValueError(f"Expected a single dataset, found {len(self.datasets)}") + raise ValueError(f"Expected a single dataset, found {len(datasets)}") @property def models(self) -> Mapping[Tag, TrainableProbabilisticModelType]: @@ -281,10 +297,12 @@ def models(self, models: Mapping[Tag, TrainableProbabilisticModelType]) -> None: @property def model(self) -> TrainableProbabilisticModel: """The current model when there is just one model.""" - if len(self.models) == 1: - return next(iter(self.models.values())) + # Ignore local models. + models: Mapping[Tag, TrainableProbabilisticModel] = ignoring_local_tags(self.models) + if len(models) == 1: + return next(iter(models.values())) else: - raise ValueError(f"Expected a single model, found {len(self.models)}") + raise ValueError(f"Expected a single model, found {len(models)}") @model.setter def model(self, model: TrainableProbabilisticModelType) -> None: @@ -392,7 +410,7 @@ def ask(self) -> TensorType: with Timer() as query_point_generation_timer: points_or_stateful = self._acquisition_rule.acquire( - self._search_space, self._models, datasets=self._datasets + self._search_space, self._models, datasets=self._filtered_datasets ) if callable(points_or_stateful): @@ -423,18 +441,48 @@ def tell(self, new_data: Mapping[Tag, Dataset] | Dataset) -> None: if isinstance(new_data, Dataset): new_data = {OBJECTIVE: new_data} - if self._datasets.keys() != new_data.keys(): + # The datasets must have the same keys as the existing datasets. Only exception is if + # the existing datasets are all global, in which case the dataset will be appropriately + # updated below for the next iteration. + datasets_indices = {LocalizedTag.from_tag(tag).local_index for tag in self._datasets.keys()} + if self._datasets.keys() != new_data.keys() and datasets_indices != {None}: raise ValueError( f"new_data keys {new_data.keys()} doesn't " f"match dataset keys {self._datasets.keys()}" ) - for tag in self._datasets: - self._datasets[tag] += new_data[tag] + # In order to support local datasets, account for the case where there may be an initial + # dataset that is not tagged per region. In this case, only the global dataset will exist + # in datasets. We want to copy this initial dataset to all the regions. + # If a tag from tagged_output does not exist in datasets, then add it to + # datasets by copying the data from datasets with the same global tag. Otherwise keep the + # existing data from datasets. + # + # Note: this replication of initial data can potentially cause an issue when a global model + # is being used with local datasets, as the points may be repeated. This will only be an + # issue if two regions overlap and both contain that initial data-point -- as filtering + # (in BatchTrustRegion) would otherwise remove duplicates. The main way to avoid the issue + # in this scenario is to provide local initial datasets, instead of a global initial + # dataset. + sorted_tags = sorted( # We need to process the local tags first, then the global tags. + new_data, key=lambda tag: not LocalizedTag.from_tag(tag).is_local + ) + for tag in sorted_tags: + new_dataset = new_data[tag] + if tag in self._datasets: + self._datasets[tag] += new_dataset + else: + global_tag = LocalizedTag.from_tag(tag).global_tag + if global_tag not in self._datasets: + raise ValueError(f"global tag '{global_tag}' not found in dataset") + self._datasets[tag] = self._datasets[global_tag] + new_dataset + self._filtered_datasets = self._acquisition_rule.filter_datasets(self._datasets) with Timer() as model_fitting_timer: for tag, model in self._models.items(): - dataset = self._datasets[tag] + # Always use the matching dataset to the model. If the model is + # local, then the dataset should be too by this stage. + dataset = self._filtered_datasets[tag] model.update(dataset) optimize_model_and_save_result(model, dataset) diff --git a/trieste/bayesian_optimizer.py b/trieste/bayesian_optimizer.py index be37acf834..a2e6b844b2 100644 --- a/trieste/bayesian_optimizer.py +++ b/trieste/bayesian_optimizer.py @@ -58,10 +58,12 @@ from .acquisition.rule import TURBO, AcquisitionRule, EfficientGlobalOptimization from .data import Dataset from .models import SupportsCovarianceWithTopFidelity, TrainableProbabilisticModel +from .objectives.utils import mk_batch_observer from .observer import OBJECTIVE, Observer from .space import SearchSpace from .types import State, Tag, TensorType from .utils import Err, Ok, Result, Timer +from .utils.misc import LocalizedTag, get_value_for_tag, ignoring_local_tags StateType = TypeVar("StateType") """ Unbound type variable. """ @@ -97,18 +99,22 @@ class Record(Generic[StateType]): @property def dataset(self) -> Dataset: """The dataset when there is just one dataset.""" - if len(self.datasets) == 1: - return next(iter(self.datasets.values())) + # Ignore local datasets. + datasets: Mapping[Tag, Dataset] = ignoring_local_tags(self.datasets) + if len(datasets) == 1: + return next(iter(datasets.values())) else: - raise ValueError(f"Expected a single dataset, found {len(self.datasets)}") + raise ValueError(f"Expected a single dataset, found {len(datasets)}") @property def model(self) -> TrainableProbabilisticModel: """The model when there is just one dataset.""" - if len(self.models) == 1: - return next(iter(self.models.values())) + # Ignore local models. + models: Mapping[Tag, TrainableProbabilisticModel] = ignoring_local_tags(self.models) + if len(models) == 1: + return next(iter(models.values())) else: - raise ValueError(f"Expected a single model, found {len(self.models)}") + raise ValueError(f"Expected a single model, found {len(models)}") def save(self, path: Path | str) -> FrozenRecord[StateType]: """Save the record to disk. Will overwrite any existing file at the same path.""" @@ -227,6 +233,8 @@ def try_get_final_dataset(self) -> Dataset: :raise ValueError: If the optimization was not a single dataset run. """ datasets = self.try_get_final_datasets() + # Ignore local datasets. + datasets = ignoring_local_tags(datasets) if len(datasets) == 1: return next(iter(datasets.values())) else: @@ -270,6 +278,8 @@ def try_get_final_model(self) -> TrainableProbabilisticModel: :raise ValueError: If the optimization was not a single model run. """ models = self.try_get_final_models() + # Ignore local models. + models = ignoring_local_tags(models) if len(models) == 1: return next(iter(models.values())) else: @@ -626,10 +636,15 @@ def optimize( - ``datasets`` or ``models`` are empty - the default `acquisition_rule` is used and the tags are not `OBJECTIVE`. """ + # Copy the dataset so we don't change the one provided by the user. + datasets = copy.deepcopy(datasets) + if isinstance(datasets, Dataset): datasets = {OBJECTIVE: datasets} - models = {OBJECTIVE: models} # type: ignore[dict-item] + if not isinstance(models, Mapping): + models = {OBJECTIVE: models} + filtered_datasets = datasets # reassure the type checker that everything is tagged datasets = cast(Dict[Tag, Dataset], datasets) models = cast(Dict[Tag, TrainableProbabilisticModelType], models) @@ -637,10 +652,14 @@ def optimize( if num_steps < 0: raise ValueError(f"num_steps must be at least 0, got {num_steps}") - if datasets.keys() != models.keys(): + # Get set of dataset and model keys, ignoring any local tag index. That is, only the + # global tag part is considered. + datasets_keys = {LocalizedTag.from_tag(tag).global_tag for tag in datasets.keys()} + models_keys = {LocalizedTag.from_tag(tag).global_tag for tag in models.keys()} + if datasets_keys != models_keys: raise ValueError( - f"datasets and models should contain the same keys. Got {datasets.keys()} and" - f" {models.keys()} respectively." + f"datasets and models should contain the same keys. Got {datasets_keys} and" + f" {models_keys} respectively." ) if not datasets: @@ -718,7 +737,10 @@ def optimize( if step == 1 and fit_model and fit_initial_model: with Timer() as initial_model_fitting_timer: for tag, model in models.items(): - dataset = datasets[tag] + # Prefer local dataset if available. + tags = [tag, LocalizedTag.from_tag(tag).global_tag] + _, dataset = get_value_for_tag(datasets, *tags) + assert dataset is not None model.update(dataset) optimize_model_and_save_result(model, dataset) if summary_writer: @@ -732,14 +754,18 @@ def optimize( with Timer() as total_step_wallclock_timer: with Timer() as query_point_generation_timer: points_or_stateful = acquisition_rule.acquire( - self._search_space, models, datasets=datasets + self._search_space, models, datasets=filtered_datasets ) if callable(points_or_stateful): acquisition_state, query_points = points_or_stateful(acquisition_state) else: query_points = points_or_stateful - observer_output = self._observer(query_points) + observer = self._observer + # If query_points are rank 3, then use a batched observer. + if tf.rank(query_points) == 3: + observer = mk_batch_observer(observer) + observer_output = observer(query_points) tagged_output = ( observer_output @@ -747,11 +773,28 @@ def optimize( else {OBJECTIVE: observer_output} ) - datasets = {tag: datasets[tag] + tagged_output[tag] for tag in tagged_output} + # See explanation in ask_tell_optimization.tell(). + # We need to process the local tags first, then the global tags. + sorted_tags = sorted( + tagged_output, key=lambda tag: not LocalizedTag.from_tag(tag).is_local + ) + for tag in sorted_tags: + new_dataset = tagged_output[tag] + if tag in datasets: + datasets[tag] += new_dataset + else: + global_tag = LocalizedTag.from_tag(tag).global_tag + if global_tag not in datasets: + raise ValueError(f"global tag '{global_tag}' not found in dataset") + datasets[tag] = datasets[global_tag] + new_dataset + filtered_datasets = acquisition_rule.filter_datasets(datasets) + with Timer() as model_fitting_timer: if fit_model: for tag, model in models.items(): - dataset = datasets[tag] + # Always use the matching dataset to the model. If the model is + # local, then the dataset should be too by this stage. + dataset = filtered_datasets[tag] model.update(dataset) optimize_model_and_save_result(model, dataset) @@ -882,7 +925,11 @@ def write_summary_initial_model_fit( """Write TensorBoard summary for the model fitting to the initial data.""" for tag, model in models.items(): with tf.name_scope(f"{tag}.model"): - model.log(datasets[tag]) + # Prefer local dataset if available. + tags = [tag, LocalizedTag.from_tag(tag).global_tag] + _, dataset = get_value_for_tag(datasets, *tags) + assert dataset is not None + model.log(dataset) logging.scalar( "wallclock/model_fitting", model_fitting_timer.time, @@ -929,7 +976,7 @@ def write_summary_observations( observation_plot_dfs: MutableMapping[Tag, pd.DataFrame], ) -> None: """Write TensorBoard summary for the current step observations.""" - for tag in datasets: + for tag in models: with tf.name_scope(f"{tag}.model"): models[tag].log(datasets[tag]) diff --git a/trieste/experimental/plotting/plotting.py b/trieste/experimental/plotting/plotting.py index 8faf81e865..92183e503a 100644 --- a/trieste/experimental/plotting/plotting.py +++ b/trieste/experimental/plotting/plotting.py @@ -33,9 +33,11 @@ from trieste.acquisition import AcquisitionFunction from trieste.acquisition.multi_objective.dominance import non_dominated from trieste.bayesian_optimizer import FrozenRecord, Record, StateType +from trieste.observer import OBJECTIVE from trieste.space import TaggedMultiSearchSpace from trieste.types import TensorType from trieste.utils import to_numpy +from trieste.utils.misc import LocalizedTag def create_grid( @@ -234,7 +236,7 @@ def batched_func(x: TensorType) -> TensorType: def format_point_markers( num_pts: int, - num_init: Optional[int] = None, + num_init: Optional[Union[int, TensorType]] = None, idx_best: Optional[TensorType] = None, mask_fail: Optional[TensorType] = None, m_init: str = "x", @@ -247,7 +249,7 @@ def format_point_markers( Prepares point marker styles according to some BO factors. :param num_pts: total number of BO points - :param num_init: initial number of BO points + :param num_init: initial number of BO points; can also be a mask :param idx_best: index of the best BO point(s) :param mask_fail: Bool vector, True if the corresponding observation violates the constraint(s) :param m_init: marker for the initial BO points @@ -262,7 +264,10 @@ def format_point_markers( col_pts = np.repeat(c_pass, num_pts) col_pts = col_pts.astype(" 1: + # Expect there to be an objective dataset for each subspace. + datasets = [history.datasets[LocalizedTag(OBJECTIVE, i)] for i in range(len(spaces))] + + _new_points_mask = [ + np.zeros(dataset.query_points.shape[0], dtype=bool) for dataset in datasets + ] + # Last point in each dataset is the new point. + for mask in _new_points_mask: + mask[-1] = True + # Concatenate the masks. + new_points_mask = np.concatenate(_new_points_mask) + + if num_init is not None: + _num_init_mask = [ + np.zeros(dataset.query_points.shape[0], dtype=bool) for dataset in datasets + ] + # First num_init points in each dataset are the init points. + for mask in _num_init_mask: + mask[:num_init] = True + # Concatenate the masks. + num_init = np.concatenate(_num_init_mask) + + # Get the overall query points. + query_points = np.concatenate([dataset.query_points for dataset in datasets]) + else: + query_points = history.dataset.query_points # All query points. + new_points_mask = np.zeros(query_points.shape[0], dtype=bool) + new_points_mask[-num_query_points:] = True # Plot trust regions. colors = [rgb2hex(color) for color in cm.rainbow(np.linspace(0, 1, num_query_points))] diff --git a/trieste/objectives/utils.py b/trieste/objectives/utils.py index 31e70a52eb..b074738ed1 100644 --- a/trieste/objectives/utils.py +++ b/trieste/objectives/utils.py @@ -20,11 +20,14 @@ from __future__ import annotations from collections.abc import Callable -from typing import Optional, overload +from typing import Mapping, Optional, Union, overload + +from check_shapes import check_shapes from ..data import Dataset -from ..observer import MultiObserver, Observer, SingleObserver +from ..observer import OBJECTIVE, MultiObserver, Observer, SingleObserver from ..types import Tag, TensorType +from ..utils.misc import LocalizedTag, flatten_leading_dims @overload @@ -57,3 +60,49 @@ def mk_multi_observer(**kwargs: Callable[[TensorType], TensorType]) -> MultiObse :return: An multi-observer returning the data from ``kwargs``. """ return lambda qp: {key: Dataset(qp, objective(qp)) for key, objective in kwargs.items()} + + +def mk_batch_observer( + objective_or_observer: Union[Callable[[TensorType], TensorType], Observer], + default_key: Tag = OBJECTIVE, +) -> Observer: + """ + Create an observer that returns the data from ``objective`` or an existing ``observer`` + separately for each query point in a batch. + + :param objective_or_observer: An objective or an existing observer. + :param default_key: The default key to use if ``objective_or_observer`` is an objective or + does not return a mapping. + :return: A multi-observer across the batch dimension of query points, returning the data from + ``objective_or_observer``. + """ + + @check_shapes("qps: [n_points, batch_size, n_dims]") + def _observer(qps: TensorType) -> Mapping[Tag, Dataset]: + # Call objective with rank 2 query points by flattening batch dimension. + # Some objectives might only expect rank 2 query points, so this is safer. + batch_size = qps.shape[1] + flat_qps, unflatten = flatten_leading_dims(qps) + obs_or_dataset = objective_or_observer(flat_qps) + + if not isinstance(obs_or_dataset, (Mapping, Dataset)): + # Just a single observation, so wrap in a dataset. + obs_or_dataset = Dataset(flat_qps, obs_or_dataset) + + if isinstance(obs_or_dataset, Dataset): + # Convert to a mapping with a default key. + obs_or_dataset = {default_key: obs_or_dataset} + + datasets = {} + for key, dataset in obs_or_dataset.items(): + # Include overall dataset and per batch dataset. + flat_obs = dataset.observations + qps = unflatten(flat_qps) + obs = unflatten(flat_obs) + datasets[key] = dataset + for i in range(batch_size): + datasets[LocalizedTag(key, i)] = Dataset(qps[:, i], obs[:, i]) + + return datasets + + return _observer diff --git a/trieste/utils/misc.py b/trieste/utils/misc.py index c29b3c4e51..34f76ae8da 100644 --- a/trieste/utils/misc.py +++ b/trieste/utils/misc.py @@ -14,9 +14,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from time import perf_counter from types import TracebackType -from typing import Any, Callable, Generic, Mapping, NoReturn, Optional, Tuple, Type, TypeVar +from typing import Any, Callable, Generic, Mapping, NoReturn, Optional, Tuple, Type, TypeVar, Union import numpy as np import tensorflow as tf @@ -220,21 +221,67 @@ def map_values(f: Callable[[U], V], mapping: Mapping[K, U]) -> Mapping[K, V]: """ An unbound type variable. """ -def get_value_for_tag(mapping: Optional[Mapping[Tag, T]], tag: Tag = OBJECTIVE) -> Optional[T]: - """Return the value of a tag in a mapping. +def get_value_for_tag( + mapping: Optional[Mapping[Tag, T]], *tags: Tag +) -> Tuple[Optional[Tag], Optional[T]]: + """Return the value from a mapping for the first tag found from a sequence of tags. :param mapping: A mapping from tags to values. - :param tag: A tag. - :return: The value of the tag in the mapping, or None if the mapping is None. - :raises ValueError: If the tag is not in the mapping and the mapping is not None. + :param tags: A tag or a sequence of tags. Sequence is searched in order. If no tags are + provided, the default tag OBJECTIVE is used. + :return: The chosen tag and value of the tag in the mapping, or None for each if the mapping is + None. + :raises ValueError: If none of the tags are in the mapping and the mapping is not None. """ + if not tags: + tags = (OBJECTIVE,) + if mapping is None: - return None - elif tag in mapping.keys(): - return mapping[tag] + return None, None else: - raise ValueError(f"tag '{tag}' not found in mapping") + matched_tag = next((tag for tag in tags if tag in mapping), None) + if matched_tag is None: + raise ValueError(f"none of the tags '{tags}' found in mapping") + return matched_tag, mapping[matched_tag] + + +@dataclass(frozen=True) +class LocalizedTag: + """Manage a tag for a local model or dataset. These have a global tag and a local index.""" + + global_tag: Tag + """ The global portion of the tag. """ + + local_index: Optional[int] + """ The local index of the tag. """ + + def __post_init__(self) -> None: + if self.local_index is not None and self.local_index < 0: + raise ValueError(f"local index must be non-negative, got {self.local_index}") + + @property + def is_local(self) -> bool: + """Return True if the tag is a local tag.""" + return self.local_index is not None + + @staticmethod + def from_tag(tag: Union[Tag, LocalizedTag]) -> LocalizedTag: + """Return a LocalizedTag from a given tag.""" + if isinstance(tag, LocalizedTag): + return tag + else: + return LocalizedTag(tag, None) + + +def ignoring_local_tags(mapping: Mapping[Tag, T]) -> Mapping[Tag, T]: + """ + Filter out local tags from a mapping, returning a new mapping with only global tags. + + :param mapping: A mapping from tags to values. + :return: A new mapping with only global tags. + """ + return {k: v for k, v in mapping.items() if not LocalizedTag.from_tag(k).is_local} class Timer: