Skip to content

Commit

Permalink
Implement pick_fitted_models_to_gen_from for ModelSelectionGeneration…
Browse files Browse the repository at this point in the history
…Node

Summary:
- Fill in implementation of `pick_fitted_models_to_gen_from` to get cross-validation diagnostics from each model, and select best model with the specified `BestModelSelector`
- Add `BestModelSelector` base class
- Implement `SingleDiagnosticBestModelSelector` which aggregates values of a specific cross-validation diagnostic and chooses the best value.
- Add unit tests for the above

Reviewed By: lena-kashtelyan

Differential Revision: D32330645

fbshipit-source-id: 59f2074b65f75049cc9a9e81f4e6d8ec1c216f12
  • Loading branch information
adamobeng authored and facebook-github-bot committed Dec 9, 2021
1 parent 61017fd commit 4d7e716
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 27 deletions.
80 changes: 79 additions & 1 deletion ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple
from enum import Enum
from functools import partial
from numbers import Number
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Any

import numpy as np
from ax.core.observation import Observation, ObservationData
Expand Down Expand Up @@ -379,3 +382,78 @@ def _fisher_exact_test_p(
# Compute the test statistic
_, p = fisher_exact(table, alternative="greater")
return float(p)


class BestModelSelector:
def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
"""
Return the index of the best diagnostic.
"""
...


class CallableEnum(Enum):
def __call__(self, *args: Optional[Any], **kwargs: Optional[Any]) -> Any:
return self.value(*args, **kwargs)


class MetricAggregation(CallableEnum):
MEAN: Callable[[Iterable[Number]], Number] = partial(np.mean)


class DiagnosticCriterion(CallableEnum):
MIN: Callable[[Iterable[Number]], Number] = partial(np.amin)


class SingleDiagnosticBestModelSelector(BestModelSelector):
"""Choose the best model using a single cross-validation diagnostic.
The input is a list of CVDiagnostics, each corresponding to one model.
The specified diagnostic is extracted from each of the CVDiagnostics,
its values (each of which corresponds to a separate metric) are
aggregated with the aggregation function, the best one is determined
with the criterion, and the index of the best diagnostic result is returned.
Example:
::
s = SingleDiagnosticBestModelSelector(
diagnostic = 'Fisher exact test p',
criterion = DiagnosticCriterion.MIN,
metric_aggregation = MetricAggregation.MEAN,
)
best_diagnostic_index = s.best_diagnostic(diagnostics)
Args:
diagnostic (str): The name of the diagnostic to use, which should be
a key in CVDiagnostic.
metric_aggregation (MetricAggregation): Callable
applied to the values of the diagnostic for a single model to
produce a single number.
criterion (DiagnosticCriterion): Callable used
to determine which of the (aggregated) diagnostics is the best.
Returns:
int: index of the selected best diagnostic.
"""

def __init__(
self,
diagnostic: str,
metric_aggregation: MetricAggregation,
criterion: DiagnosticCriterion,
) -> None:
self.diagnostic = diagnostic
self.metric_aggregation = metric_aggregation
self.criterion = criterion

def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
aggregated_diagnostic_values = [
self.metric_aggregation(list(d[self.diagnostic].values()))
for d in diagnostics
]
best_diagnostic = self.criterion(aggregated_diagnostic_values)
return [d == best_diagnostic for d in aggregated_diagnostic_values].index(True)
37 changes: 31 additions & 6 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.cross_validation import (
cross_validate,
compute_diagnostics,
BestModelSelector,
)
from ax.modelbridge.model_spec import ModelSpec, FactoryFunctionModelSpec
from ax.modelbridge.registry import (
ModelRegistryBase,
Expand Down Expand Up @@ -245,10 +250,30 @@ def _unique_id(self) -> str:
return str(self.index)


class ModelSelectionGenerationNode(GenerationNode):
class ModelSelectionNode(GenerationNode):
model_specs: List[ModelSpec]

def __init__(
self,
model_specs: List[ModelSpec],
best_model_selector: BestModelSelector,
cvkwargs: Optional[Dict[str, Any]] = None,
) -> None:
# While `GenerationNode` only handles a single `ModelSpec` in the `gen`
# and `_pick_fitted_model_to_gen_from` methods, we validate the
# length of `model_specs` in `_pick_fitted_model_to_gen_from` in order
# to not require all `GenerationNode` subclasses to override an `__init__`
# method to bypass that validation.
self.model_specs = model_specs
self.best_model_selector = best_model_selector
self.cvkwargs = cvkwargs

def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
"""Select one model to generate from among the fitted models on this
generation node.
"""
# TODO[adamobeng]: Add actual model-selection logic here
return self.model_specs[0]
cvkwargs = self.cvkwargs or {}
cv_diagnostics = []
for model_spec in self.model_specs:
cv_result = cross_validate(model_spec.fitted_model, **cvkwargs)
cv_diagnostics.append(compute_diagnostics(cv_result))

best_model_index = self.best_model_selector.best_diagnostic(cv_diagnostics)
return self.model_specs[best_model_index]
32 changes: 32 additions & 0 deletions ax/modelbridge/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List
from unittest import mock

import numpy as np
Expand All @@ -23,6 +24,8 @@
cross_validate,
cross_validate_by_trial,
has_good_opt_config_model_fit,
SingleDiagnosticBestModelSelector,
CVDiagnostics,
)
from ax.utils.common.testutils import TestCase

Expand Down Expand Up @@ -76,6 +79,11 @@ def setUp(self):
metric_names=["a", "b"],
)
] * 4
self.diagnostics: List[CVDiagnostics] = [
{"Fisher exact test p": {"y_a": 0.0, "y_b": 0.4}},
{"Fisher exact test p": {"y_a": 0.1, "y_b": 0.1}},
{"Fisher exact test p": {"y_a": 0.5, "y_b": 0.6}},
]

def testCrossValidate(self):
# Prepare input and output data
Expand Down Expand Up @@ -278,3 +286,27 @@ def testHasGoodOptConfigModelFit(self):
assess_model_fit_result=assess_model_fit_result,
)
self.assertFalse(has_good_fit)

def testSingleDiagnosticBestModelSelector_min_mean(self):
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=min,
metric_aggregation=np.mean,
)
self.assertEqual(s.best_diagnostic(self.diagnostics), 1)

def testSingleDiagnosticBestModelSelector_min_min(self):
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=min,
metric_aggregation=min,
)
self.assertEqual(s.best_diagnostic(self.diagnostics), 0)

def testSingleDiagnosticBestModelSelector_max_mean(self):
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=max,
metric_aggregation=np.mean,
)
self.assertEqual(s.best_diagnostic(self.diagnostics), 2)
56 changes: 36 additions & 20 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@

from unittest.mock import patch

from ax.modelbridge.cross_validation import (
SingleDiagnosticBestModelSelector,
MetricAggregation,
DiagnosticCriterion,
)
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import (
GenerationNode,
ModelSelectionGenerationNode,
ModelSelectionNode,
GenerationStep,
)
from ax.modelbridge.model_spec import ModelSpec, FactoryFunctionModelSpec
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.core_stubs import get_branin_experiment_with_multi_objective


class TestGenerationNode(TestCase):
Expand Down Expand Up @@ -118,31 +124,41 @@ def test_properties(self):
self.assertEqual(self.sobol_generation_step._unique_id, "-1")


class TestModelSelectionGenerationNode(TestCase):
class TestModelSelectionNode(TestCase):
def setUp(self):
self.sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={"init_position": 3},
model_gen_kwargs={"some_gen_kwarg": "some_value"},
self.branin_experiment = get_branin_experiment_with_multi_objective()
sobol = Models.SOBOL(search_space=self.branin_experiment.search_space)
sobol_run = sobol.gen(n=20)
self.branin_experiment.new_batch_trial().add_generator_run(
sobol_run
).run().mark_completed()
data = self.branin_experiment.fetch_data()

ms_gpei = ModelSpec(model_enum=Models.GPEI)
ms_gpei.fit(experiment=self.branin_experiment, data=data)

ms_gpkg = ModelSpec(model_enum=Models.GPKG)
ms_gpkg.fit(experiment=self.branin_experiment, data=data)

self.fitted_model_specs = [ms_gpei, ms_gpkg]

self.model_selection_node = ModelSelectionNode(
model_specs=self.fitted_model_specs,
best_model_selector=SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=MetricAggregation.MEAN,
metric_aggregation=DiagnosticCriterion.MIN,
),
)
self.uniform_model_spec = ModelSpec(
model_enum=Models.UNIFORM,
)
self.model_selection_generation_node = ModelSelectionGenerationNode(
model_specs=[self.uniform_model_spec, self.sobol_model_spec]
)
self.branin_experiment = get_branin_experiment(with_completed_trial=True)

def test_gen(self):
self.model_selection_generation_node.fit(
self.model_selection_node.fit(
experiment=self.branin_experiment, data=self.branin_experiment.lookup_data()
)
# Check that with `ModelSelectionGenerationNode` generation from a node with
# Check that with `ModelSelectionNode` generation from a node with
# multiple model specs does not fail.
gr = self.model_selection_generation_node.gen(
n=1, pending_observations={"branin": []}
)
# Currently, `ModelSelectionGenerationNode` should just pick the first model
gr = self.model_selection_node.gen(n=1, pending_observations={"branin": []})
# Currently, `ModelSelectionNode` should just pick the first model
# spec as the one to generate from.
# TODO[adamobeng]: Test correct behavior here when implemented.
self.assertEqual(gr._model_key, "Uniform")
self.assertEqual(gr._model_key, "GPEI")

0 comments on commit 4d7e716

Please sign in to comment.