From 669f210c62fffccd693e015115d78f14b778880d Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 21 Nov 2024 11:41:33 -0500 Subject: [PATCH] fix: lint and then set raw_samples to None to keep current pruning behaviour --- botorch/optim/optimize.py | 2 +- botorch/optim/optimize_homotopy.py | 4 ++++ test/optim/test_optimize.py | 17 +++++++++++++---- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index a4a008b705..f5e5f4e637 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -286,7 +286,7 @@ def _optimize_acqf_sequential_q( def _combine_initial_conditions( provided_initial_conditions: Tensor | None = None, generated_initial_conditions: Tensor | None = None, - dim=0 + dim=0, ) -> Tensor: if ( provided_initial_conditions is not None diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index 84ffbcaf91..5b995135b1 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -191,6 +191,10 @@ def optimize_acqf_homotopy( ) homotopy.step() + # Set raw_samples to None such that pruned restarts are not repopulated + # at each step in the homotopy. + shared_optimize_acqf_kwargs["raw_samples"] = None + # Prune candidates candidates = prune_candidates( candidates=candidates.squeeze(1), diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 91f2930f5d..d913882c1d 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -222,12 +222,12 @@ def test_optimize_acqf_joint( self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) - # test generation with provided initial conditions less than num_restarts + # test generation with batch initial conditions less than num_restarts candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, q=q, - num_restarts=num_restarts+1, + num_restarts=num_restarts + 1, raw_samples=raw_samples, options=options, return_best_only=False, @@ -590,9 +590,18 @@ def test_optimize_acqf_batch_limit(self) -> None: (num_restarts, q, dim), ] - for gen_candidates, (ic_shape, expected_acqf_shape, expected_candidates_shape) in product( + for gen_candidates, ( + ic_shape, + expected_acqf_shape, + expected_candidates_shape, + ) in product( [gen_candidates_scipy, gen_candidates_torch], - zip(initial_conditions, expected_acqf_shapes, expected_candidates_shapes, strict=True), + zip( + initial_conditions, + expected_acqf_shapes, + expected_candidates_shapes, + strict=True, + ), ): ics = torch.ones(ic_shape) if ic_shape is not None else None with self.subTest(gen_candidates=gen_candidates, initial_conditions=ics):