Skip to content

Commit

Permalink
fix: lint and then set raw_samples to None to keep current pruning be…
Browse files Browse the repository at this point in the history
…haviour
  • Loading branch information
CompRhys committed Nov 21, 2024
1 parent d75734f commit 669f210
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _optimize_acqf_sequential_q(
def _combine_initial_conditions(
provided_initial_conditions: Tensor | None = None,
generated_initial_conditions: Tensor | None = None,
dim=0
dim=0,
) -> Tensor:
if (
provided_initial_conditions is not None
Expand Down
4 changes: 4 additions & 0 deletions botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def optimize_acqf_homotopy(
)
homotopy.step()

# Set raw_samples to None such that pruned restarts are not repopulated
# at each step in the homotopy.
shared_optimize_acqf_kwargs["raw_samples"] = None

# Prune candidates
candidates = prune_candidates(
candidates=candidates.squeeze(1),
Expand Down
17 changes: 13 additions & 4 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,12 @@ def test_optimize_acqf_joint(
self.assertTrue(torch.equal(acq_vals, mock_acq_values))
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test generation with provided initial conditions less than num_restarts
# test generation with batch initial conditions less than num_restarts
candidates, acq_vals = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
q=q,
num_restarts=num_restarts+1,
num_restarts=num_restarts + 1,
raw_samples=raw_samples,
options=options,
return_best_only=False,
Expand Down Expand Up @@ -590,9 +590,18 @@ def test_optimize_acqf_batch_limit(self) -> None:
(num_restarts, q, dim),
]

for gen_candidates, (ic_shape, expected_acqf_shape, expected_candidates_shape) in product(
for gen_candidates, (
ic_shape,
expected_acqf_shape,
expected_candidates_shape,
) in product(
[gen_candidates_scipy, gen_candidates_torch],
zip(initial_conditions, expected_acqf_shapes, expected_candidates_shapes, strict=True),
zip(
initial_conditions,
expected_acqf_shapes,
expected_candidates_shapes,
strict=True,
),
):
ics = torch.ones(ic_shape) if ic_shape is not None else None
with self.subTest(gen_candidates=gen_candidates, initial_conditions=ics):
Expand Down

0 comments on commit 669f210

Please sign in to comment.