diff --git a/docs/notebooks/batch_trust_region.pct.py b/docs/notebooks/batch_trust_region.pct.py index 78f4f3c050..ab46471fe6 100644 --- a/docs/notebooks/batch_trust_region.pct.py +++ b/docs/notebooks/batch_trust_region.pct.py @@ -86,9 +86,7 @@ bo = trieste.bayesian_optimizer.BayesianOptimizer(observer, search_space) num_steps = 5 -result = bo.optimize( - num_steps, initial_data, model, acq_rule, track_state=False -) +result = bo.optimize(num_steps, initial_data, model, acq_rule, track_state=True) dataset = result.try_get_final_dataset() # %% [markdown] @@ -158,6 +156,88 @@ ax.set_ylabel("Regret") ax.set_xlabel("# evaluations") +# %% [markdown] +# Next we visualize the progress of the optimization by plotting the trust regions at each step. The trust regions are shown as translucent boxes, with the current optimum point in each region shown in matching color. + +# %% +import base64 +import io + +import imageio +import IPython +from matplotlib.colors import rgb2hex +from matplotlib.patches import Rectangle +from matplotlib.pyplot import cm + +colors = [ + rgb2hex(color) for color in cm.rainbow(np.linspace(0, 1, num_query_points)) +] +frames = [] + +for step, hist in enumerate(result.history + [result.final_result.unwrap()]): + state = hist.acquisition_state + if state is None: + continue + + # Plot branin contour. + fig, ax = plot_function_2d( + branin, + search_space.lower, + search_space.upper, + grid_density=40, + contour=True, + ) + + query_points = hist.dataset.query_points + new_points_mask = np.zeros(query_points.shape[0], dtype=bool) + new_points_mask[-num_query_points:] = True + + assert isinstance(state, trieste.acquisition.rule.BatchTrustRegionBox.State) + acquisition_space = state.acquisition_space + + # Plot trust regions. + for i, tag in enumerate(acquisition_space.subspace_tags): + lb = acquisition_space.get_subspace(tag).lower + ub = acquisition_space.get_subspace(tag).upper + ax[0, 0].add_patch( + Rectangle( + (lb[0], lb[1]), + ub[0] - lb[0], + ub[1] - lb[1], + facecolor=colors[i], + edgecolor=colors[i], + alpha=0.3, + ) + ) + + # Plot new query points, using failure mask to color them. + plot_bo_points( + query_points, + ax[0, 0], + num_initial_data_points, + mask_fail=new_points_mask, + c_pass="black", + c_fail=colors, + ) + + fig.suptitle(f"step number {step}") + fig.canvas.draw() + size_pix = fig.get_size_inches() * fig.dpi + image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8") + frames.append(image.reshape(list(size_pix[::-1].astype(int)) + [3])) + plt.close(fig) + + +# Create and show the GIF. +gif_file = io.BytesIO() +imageio.mimsave(gif_file, frames, format="gif", loop=0, duration=5000) # type: ignore +gif = IPython.display.HTML( + ''.format( + base64.b64encode(gif_file.getvalue()).decode() + ) +) +IPython.display.display(gif) + # %% [markdown] # ## LICENSE # diff --git a/docs/notebooks/constraints.txt b/docs/notebooks/constraints.txt index 47b2c39cfd..724356c348 100644 --- a/docs/notebooks/constraints.txt +++ b/docs/notebooks/constraints.txt @@ -68,6 +68,7 @@ gym==0.26.2 gym-notices==0.0.8 h5py==3.8.0 idna==3.4 +imageio==2.31.1 ipykernel==6.23.2 ipython==8.14.0 isoduration==20.11.0 diff --git a/docs/notebooks/requirements.txt b/docs/notebooks/requirements.txt index f3ffb86896..cc0cdaed07 100644 --- a/docs/notebooks/requirements.txt +++ b/docs/notebooks/requirements.txt @@ -21,3 +21,4 @@ jupytext gym[box2d] box2d box2d-kengz +imageio diff --git a/trieste/experimental/plotting/plotting.py b/trieste/experimental/plotting/plotting.py index f5e26b1aa8..f4a8d8722a 100644 --- a/trieste/experimental/plotting/plotting.py +++ b/trieste/experimental/plotting/plotting.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Callable, Optional, Sequence +from typing import Callable, List, Optional, Sequence, Union import matplotlib.pyplot as plt import numpy as np @@ -234,7 +234,7 @@ def format_point_markers( m_init: str = "x", m_add: str = "o", c_pass: str = "tab:green", - c_fail: str = "tab:red", + c_fail: Union[str, List[str]] = "tab:red", c_best: str = "tab:purple", ) -> tuple[TensorType, TensorType]: """ @@ -275,7 +275,7 @@ def plot_bo_points( m_init: str = "x", m_add: str = "o", c_pass: str = "tab:green", - c_fail: str = "tab:red", + c_fail: Union[str, List[str]] = "tab:red", c_best: str = "tab:purple", ) -> None: """