Skip to content

Commit

Permalink
Add gif of trust region optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Aug 24, 2023
1 parent 47f64b6 commit 97b954c
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 6 deletions.
86 changes: 83 additions & 3 deletions docs/notebooks/batch_trust_region.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@
bo = trieste.bayesian_optimizer.BayesianOptimizer(observer, search_space)

num_steps = 5
result = bo.optimize(
num_steps, initial_data, model, acq_rule, track_state=False
)
result = bo.optimize(num_steps, initial_data, model, acq_rule, track_state=True)
dataset = result.try_get_final_dataset()

# %% [markdown]
Expand Down Expand Up @@ -158,6 +156,88 @@
ax.set_ylabel("Regret")
ax.set_xlabel("# evaluations")

# %% [markdown]
# Next we visualize the progress of the optimization by plotting the trust regions at each step. The trust regions are shown as translucent boxes, with the current optimum point in each region shown in matching color.

# %%
import base64
import io

import imageio
import IPython
from matplotlib.colors import rgb2hex
from matplotlib.patches import Rectangle
from matplotlib.pyplot import cm

colors = [
rgb2hex(color) for color in cm.rainbow(np.linspace(0, 1, num_query_points))
]
frames = []

for step, hist in enumerate(result.history + [result.final_result.unwrap()]):
state = hist.acquisition_state
if state is None:
continue

# Plot branin contour.
fig, ax = plot_function_2d(
branin,
search_space.lower,
search_space.upper,
grid_density=40,
contour=True,
)

query_points = hist.dataset.query_points
new_points_mask = np.zeros(query_points.shape[0], dtype=bool)
new_points_mask[-num_query_points:] = True

assert isinstance(state, trieste.acquisition.rule.BatchTrustRegionBox.State)
acquisition_space = state.acquisition_space

# Plot trust regions.
for i, tag in enumerate(acquisition_space.subspace_tags):
lb = acquisition_space.get_subspace(tag).lower
ub = acquisition_space.get_subspace(tag).upper
ax[0, 0].add_patch(
Rectangle(
(lb[0], lb[1]),
ub[0] - lb[0],
ub[1] - lb[1],
facecolor=colors[i],
edgecolor=colors[i],
alpha=0.3,
)
)

# Plot new query points, using failure mask to color them.
plot_bo_points(
query_points,
ax[0, 0],
num_initial_data_points,
mask_fail=new_points_mask,
c_pass="black",
c_fail=colors,
)

fig.suptitle(f"step number {step}")
fig.canvas.draw()
size_pix = fig.get_size_inches() * fig.dpi
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
frames.append(image.reshape(list(size_pix[::-1].astype(int)) + [3]))
plt.close(fig)


# Create and show the GIF.
gif_file = io.BytesIO()
imageio.mimsave(gif_file, frames, format="gif", loop=0, duration=5000) # type: ignore
gif = IPython.display.HTML(
'<img src="data:image/gif;base64,{0}"/>'.format(
base64.b64encode(gif_file.getvalue()).decode()
)
)
IPython.display.display(gif)

# %% [markdown]
# ## LICENSE
#
Expand Down
1 change: 1 addition & 0 deletions docs/notebooks/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ gym==0.26.2
gym-notices==0.0.8
h5py==3.8.0
idna==3.4
imageio==2.31.1
ipykernel==6.23.2
ipython==8.14.0
isoduration==20.11.0
Expand Down
1 change: 1 addition & 0 deletions docs/notebooks/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ jupytext
gym[box2d]
box2d
box2d-kengz
imageio
6 changes: 3 additions & 3 deletions trieste/experimental/plotting/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Callable, Optional, Sequence
from typing import Callable, List, Optional, Sequence, Union

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -234,7 +234,7 @@ def format_point_markers(
m_init: str = "x",
m_add: str = "o",
c_pass: str = "tab:green",
c_fail: str = "tab:red",
c_fail: Union[str, List[str]] = "tab:red",
c_best: str = "tab:purple",
) -> tuple[TensorType, TensorType]:
"""
Expand Down Expand Up @@ -275,7 +275,7 @@ def plot_bo_points(
m_init: str = "x",
m_add: str = "o",
c_pass: str = "tab:green",
c_fail: str = "tab:red",
c_fail: Union[str, List[str]] = "tab:red",
c_best: str = "tab:purple",
) -> None:
"""
Expand Down

0 comments on commit 97b954c

Please sign in to comment.