From 82f929d5f31c4ca0e9e986236fd79a108f025f54 Mon Sep 17 00:00:00 2001 From: Khurram Ghani <113982802+khurram-ghani@users.noreply.github.com> Date: Fri, 12 Jan 2024 11:55:08 +0000 Subject: [PATCH] Filter-out local datasets when calling base-rule (#805) * Fix batch observer return type to by MultiObserver * Only pass global datasets to EGO parallel rule * Add De Morgan's to rules expression + comments --- .../integration/test_ask_tell_optimization.py | 9 ++++-- tests/unit/acquisition/test_rule.py | 28 +++++++++++++++++++ trieste/acquisition/rule.py | 18 ++++++++++-- trieste/objectives/utils.py | 2 +- 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_ask_tell_optimization.py b/tests/integration/test_ask_tell_optimization.py index 3be8bc28d7..6e8fc22c4e 100644 --- a/tests/integration/test_ask_tell_optimization.py +++ b/tests/integration/test_ask_tell_optimization.py @@ -16,7 +16,7 @@ import copy import pickle import tempfile -from typing import Callable, Tuple, Union +from typing import Callable, Mapping, Tuple, Union import numpy.testing as npt import pytest @@ -36,6 +36,7 @@ from trieste.acquisition.utils import copy_to_local_models from trieste.ask_tell_optimization import AskTellOptimizer from trieste.bayesian_optimizer import OptimizationResult, Record +from trieste.data import Dataset from trieste.logging import set_step_number, tensorboard_writer from trieste.models import TrainableProbabilisticModel from trieste.models.gpflow import GaussianProcessRegression, build_gpr @@ -43,7 +44,7 @@ from trieste.objectives.utils import mk_batch_observer, mk_observer from trieste.observer import OBJECTIVE from trieste.space import Box, SearchSpace -from trieste.types import State, TensorType +from trieste.types import State, Tag, TensorType # Optimizer parameters for testing against the branin function. # We use a copy of these for a quicker test against a simple quadratic function @@ -212,7 +213,9 @@ def _test_ask_tell_optimization_finds_minima( # If query points are rank 3, then use a batched observer. if tf.rank(new_point) == 3: - new_data_point = batch_observer(new_point) + new_data_point: Union[Mapping[Tag, Dataset], Dataset] = batch_observer( + new_point + ) else: new_data_point = observer(new_point) diff --git a/tests/unit/acquisition/test_rule.py b/tests/unit/acquisition/test_rule.py index 1287c70393..1cab4f6a75 100644 --- a/tests/unit/acquisition/test_rule.py +++ b/tests/unit/acquisition/test_rule.py @@ -16,6 +16,7 @@ import copy from collections.abc import Mapping from typing import Callable, Optional +from unittest.mock import ANY, MagicMock import gpflow import numpy as np @@ -1798,6 +1799,33 @@ def test_multi_trust_region_box_updated_datasets_are_in_regions( ) +def test_multi_trust_region_box_acquire_filters() -> None: + # Create some dummy models and datasets + models: Mapping[Tag, ANY] = {"global_tag": MagicMock()} + datasets: Mapping[Tag, ANY] = { + LocalizedTag("tag1", 1): MagicMock(), + LocalizedTag("tag1", 2): MagicMock(), + LocalizedTag("tag2", 1): MagicMock(), + LocalizedTag("tag2", 2): MagicMock(), + "global_tag": MagicMock(), + } + + search_space = Box([0.0], [1.0]) + mock_base_rule = MagicMock(spec=EfficientGlobalOptimization) + mock_base_rule.acquire.return_value = tf.constant([[[0.0], [0.0]]], dtype=tf.float64) + + # Create a BatchTrustRegionBox instance with the mock base_rule. + subspaces = [SingleObjectiveTrustRegionBox(search_space) for _ in range(2)] + rule: BatchTrustRegionBox[ProbabilisticModel] = BatchTrustRegionBox(subspaces, mock_base_rule) + + rule.acquire(search_space, models, datasets)(None) + + # Only the global tags should be passed to the base_rule acquire call. + mock_base_rule.acquire.assert_called_once_with( + ANY, models, {"global_tag": datasets["global_tag"]} + ) + + def test_multi_trust_region_box_state_deepcopy() -> None: search_space = Box([0.0, 0.0], [1.0, 1.0]) dataset = Dataset( diff --git a/trieste/acquisition/rule.py b/trieste/acquisition/rule.py index 1ed6353303..352f446736 100644 --- a/trieste/acquisition/rule.py +++ b/trieste/acquisition/rule.py @@ -1234,8 +1234,8 @@ def acquire( # Otherwise, run the base rule as is (i.e as a batch), once with all models and datasets. # Note: this should only trigger on the first call to `acquire`, as after that we will # have a list of rules in `self._rules`. - if self._rules is None and ( - _num_local_models > 0 or not isinstance(self._rule, EfficientGlobalOptimization) + if self._rules is None and not ( + _num_local_models == 0 and isinstance(self._rule, EfficientGlobalOptimization) ): self._rules = [copy.deepcopy(self._rule) for _ in range(num_subspaces)] @@ -1282,7 +1282,19 @@ def state_func( _points.append(rule.acquire(subspace, _models, _datasets)) points = tf.stack(_points, axis=1) else: - points = self._rule.acquire(acquisition_space, models, datasets) + # Filter out local datasets as this is a rule (currently only EGO) with normal + # acquisition functions that don't expect local datasets. + # Note: no need to filter out local models, as setups with local models + # are handled above (i.e. we run the base rule sequentially for each subspace). + if datasets is not None: + _datasets = { + tag: dataset + for tag, dataset in datasets.items() + if not LocalizedTag.from_tag(tag).is_local + } + else: + _datasets = None + points = self._rule.acquire(acquisition_space, models, _datasets) # We may modify the regions in filter_datasets later, so return a copy. state_ = BatchTrustRegion.State(copy.deepcopy(acquisition_space)) diff --git a/trieste/objectives/utils.py b/trieste/objectives/utils.py index b074738ed1..088391571e 100644 --- a/trieste/objectives/utils.py +++ b/trieste/objectives/utils.py @@ -65,7 +65,7 @@ def mk_multi_observer(**kwargs: Callable[[TensorType], TensorType]) -> MultiObse def mk_batch_observer( objective_or_observer: Union[Callable[[TensorType], TensorType], Observer], default_key: Tag = OBJECTIVE, -) -> Observer: +) -> MultiObserver: """ Create an observer that returns the data from ``objective`` or an existing ``observer`` separately for each query point in a batch.