From 8384f14e2b49529cc92c47909a9e18b4bea8085d Mon Sep 17 00:00:00 2001 From: Florian <33749653+flo-schu@users.noreply.github.com> Date: Tue, 19 Nov 2024 16:07:47 +0100 Subject: [PATCH] Use SingleTaskGP instead of FixedNoiseGP as recommended in the Deprecation Warning Co-authored-by: r.jaepel --- CADETProcess/optimization/axAdapater.py | 8 ++++---- pyproject.toml | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/CADETProcess/optimization/axAdapater.py b/CADETProcess/optimization/axAdapater.py index a3a2c782..3c9abb9e 100644 --- a/CADETProcess/optimization/axAdapater.py +++ b/CADETProcess/optimization/axAdapater.py @@ -20,7 +20,7 @@ from ax.utils.common.result import Err, Ok from ax.service.utils.report_utils import exp_to_df from botorch.utils.sampling import manual_seed -from botorch.models.gp_regression import FixedNoiseGP +from botorch.models.gp_regression import SingleTaskGP from botorch.acquisition.analytic import ( LogExpectedImprovement ) @@ -521,7 +521,7 @@ class BotorchModular(SingleObjectiveAxInterface): surrogate_model: Model class """ acquisition_fn = Typed(ty=type, default=LogExpectedImprovement) - surrogate_model = Typed(ty=type, default=FixedNoiseGP) + surrogate_model = Typed(ty=type, default=SingleTaskGP) _specific_options = [ 'acquisition_fn', 'surrogate_model' @@ -552,7 +552,7 @@ class NEHVI(MultiObjectiveAxInterface): supports_single_objective = False def __repr__(self): - smn = 'FixedNoiseGP' + smn = 'SingleTaskGP' afn = 'NEHVI' return f'{smn}+{afn}' @@ -578,7 +578,7 @@ class qNParEGO(MultiObjectiveAxInterface): supports_single_objective = False def __repr__(self): - smn = 'FixedNoiseGP' + smn = 'SingleTaskGP' afn = 'qNParEGO' return f'{smn}+{afn}' diff --git a/pyproject.toml b/pyproject.toml index d68809e6..8704b608 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ testing = [ "certifi", # tries to prevent certificate problems on windows "pytest", "pre-commit", # system tests run pre-commit - "ax-platform >=0.3.5,<0.4.3" + "ax-platform >=0.3.5" ] docs = [ "myst-nb>=0.17.1", @@ -56,7 +56,7 @@ docs = [ ] ax = [ - "ax-platform >=0.3.5,<0.4.3" + "ax-platform >=0.3.5" ] [project.urls]