Skip to content

Commit

Permalink
Raise DataRequiredError from transforms that need data but did not …
Browse files Browse the repository at this point in the history
…get it

Summary: Unlike a generic `ValueError`, `DataRequiredError` is handled in generation strategy, scheduler, and AxService API. In `Scheduler`, that error type will make the scheduler wait instead of failing.

Reviewed By: bernardbeckerman

Differential Revision: D35853790

fbshipit-source-id: 4008a5beca402f74aadde47ec8d3687494151da1
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Apr 25, 2022
1 parent 776f5a3 commit 9ed7857
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 8 deletions.
3 changes: 2 additions & 1 deletion ax/modelbridge/tests/test_percentile_y_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
from ax.core.observation import ObservationData
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.percentile_y import PercentileY
from ax.utils.common.testutils import TestCase

Expand Down Expand Up @@ -67,7 +68,7 @@ def setUp(self):
)

def testInit(self):
with self.assertRaises(ValueError):
with self.assertRaises(DataRequiredError):
PercentileY(search_space=None, observation_features=[], observation_data=[])

def testTransformObservations(self):
Expand Down
3 changes: 2 additions & 1 deletion ax/modelbridge/tests/test_standardize_y_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.core.optimization_config import OptimizationConfig
from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint
from ax.core.types import ComparisonOp
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.utils.common.testutils import TestCase

Expand Down Expand Up @@ -46,7 +47,7 @@ def setUp(self):
def testInit(self):
self.assertEqual(self.t.Ymean, {"m1": 1.0, "m2": 1.5})
self.assertEqual(self.t.Ystd, {"m1": 1.0, "m2": sqrt(1 / 3)})
with self.assertRaises(ValueError):
with self.assertRaises(DataRequiredError):
StandardizeY(
search_space=None, observation_features=None, observation_data=[]
)
Expand Down
4 changes: 3 additions & 1 deletion ax/modelbridge/tests/test_winsorize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OutcomeConstraint,
ScalarizedOutcomeConstraint,
)
from ax.exceptions.core import DataRequiredError
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.transforms.winsorize import (
_get_auto_winsorization_cutoffs_outcome_constraint,
Expand Down Expand Up @@ -146,7 +147,8 @@ def testInit(self):
self.assertEqual(self.t2.cutoffs["m1"], (0.0, float("inf")))
self.assertEqual(self.t2.cutoffs["m2"], (0.0, float("inf")))
with self.assertRaisesRegex(
ValueError, "Winsorize transform requires non-empty observation data."
DataRequiredError,
"`Winsorize` transform requires non-empty observation data.",
):
Winsorize(search_space=None, observation_features=[], observation_data=[])
obsd = [deepcopy(self.obsd1)]
Expand Down
3 changes: 2 additions & 1 deletion ax/modelbridge/tests/test_winsorize_transform_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
from ax.core.observation import ObservationData
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.winsorize import Winsorize
from ax.utils.common.testutils import TestCase

Expand Down Expand Up @@ -131,7 +132,7 @@ def testInit(self):
self.assertEqual(self.t1.cutoffs["m2"], (-float("inf"), 1.0))
self.assertEqual(self.t2.cutoffs["m1"], (0.0, float("inf")))
self.assertEqual(self.t2.cutoffs["m2"], (0.0, float("inf")))
with self.assertRaises(ValueError):
with self.assertRaises(DataRequiredError):
Winsorize(search_space=None, observation_features=[], observation_data=[])

def testTransformObservations(self):
Expand Down
4 changes: 3 additions & 1 deletion ax/modelbridge/transforms/percentile_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.utils import get_data
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from scipy import stats


if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401 # pragma: no cover
Expand All @@ -37,7 +39,7 @@ def __init__(
config: Optional[TConfig] = None,
) -> None:
if len(observation_data) == 0:
raise ValueError(
raise DataRequiredError(
"Percentile transform requires non-empty observation data."
)
metric_values = get_data(observation_data=observation_data)
Expand Down
5 changes: 3 additions & 2 deletions ax/modelbridge/transforms/standardize_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ax.core.outcome_constraint import ScalarizedOutcomeConstraint
from ax.core.search_space import SearchSpace
from ax.core.types import TParamValue
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.utils import get_data
from ax.models.types import TConfig
Expand Down Expand Up @@ -41,8 +42,8 @@ def __init__(
config: Optional[TConfig] = None,
) -> None:
if len(observation_data) == 0:
raise ValueError(
"StandardizeY transform requires non-empty observation data."
raise DataRequiredError(
"`StandardizeY` transform requires non-empty observation data."
)
Ys = get_data(observation_data=observation_data)
# Compute means and SDs
Expand Down
5 changes: 4 additions & 1 deletion ax/modelbridge/transforms/winsorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ScalarizedOutcomeConstraint,
)
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.utils import get_data
Expand Down Expand Up @@ -114,7 +115,9 @@ def __init__(
config: Optional[TConfig] = None,
) -> None:
if len(observation_data) == 0:
raise ValueError("Winsorize transform requires non-empty observation data.")
raise DataRequiredError(
"`Winsorize` transform requires non-empty observation data."
)
if config is None:
raise ValueError(
"Transform config for `Winsorize` transform must be specified and "
Expand Down

0 comments on commit 9ed7857

Please sign in to comment.