Skip to content

Commit

Permalink
Use SingleTaskGP instead of FixedNoiseGP as recommended in the Deprec…
Browse files Browse the repository at this point in the history
…ation Warning

Co-authored-by: r.jaepel <[email protected]>
  • Loading branch information
flo-schu and ronald-jaepel authored Nov 19, 2024
1 parent ccdf509 commit 8384f14
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions CADETProcess/optimization/axAdapater.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ax.utils.common.result import Err, Ok
from ax.service.utils.report_utils import exp_to_df
from botorch.utils.sampling import manual_seed
from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.acquisition.analytic import (
LogExpectedImprovement
)
Expand Down Expand Up @@ -521,7 +521,7 @@ class BotorchModular(SingleObjectiveAxInterface):
surrogate_model: Model class
"""
acquisition_fn = Typed(ty=type, default=LogExpectedImprovement)
surrogate_model = Typed(ty=type, default=FixedNoiseGP)
surrogate_model = Typed(ty=type, default=SingleTaskGP)

_specific_options = [
'acquisition_fn', 'surrogate_model'
Expand Down Expand Up @@ -552,7 +552,7 @@ class NEHVI(MultiObjectiveAxInterface):
supports_single_objective = False

def __repr__(self):
smn = 'FixedNoiseGP'
smn = 'SingleTaskGP'
afn = 'NEHVI'

return f'{smn}+{afn}'
Expand All @@ -578,7 +578,7 @@ class qNParEGO(MultiObjectiveAxInterface):
supports_single_objective = False

def __repr__(self):
smn = 'FixedNoiseGP'
smn = 'SingleTaskGP'
afn = 'qNParEGO'

return f'{smn}+{afn}'
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ testing = [
"certifi", # tries to prevent certificate problems on windows
"pytest",
"pre-commit", # system tests run pre-commit
"ax-platform >=0.3.5,<0.4.3"
"ax-platform >=0.3.5"
]
docs = [
"myst-nb>=0.17.1",
Expand All @@ -56,7 +56,7 @@ docs = [
]

ax = [
"ax-platform >=0.3.5,<0.4.3"
"ax-platform >=0.3.5"
]

[project.urls]
Expand Down

0 comments on commit 8384f14

Please sign in to comment.