Skip to content

Commit

Permalink
add restart tests
Browse files Browse the repository at this point in the history
  • Loading branch information
roussel-ryan committed Nov 4, 2024
1 parent 5fb8603 commit 09c5a36
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions xopt/tests/generators/bayesian/test_turbo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import math
import os
from copy import deepcopy
from unittest import TestCase
from unittest.mock import patch
import json

import numpy as np
import pandas as pd
import pytest
import yaml

from xopt import Evaluator, VOCS, Xopt
from xopt.generators.bayesian import UpperConfidenceBoundGenerator
from xopt.generators.bayesian.bax.algorithms import GridOptimize
Expand Down Expand Up @@ -72,6 +76,12 @@ def test_turbo_validation(self):
vocs=test_vocs, turbo_controller=EntropyTurboController(test_vocs)
)

# test validation from serialized turbo controller
gen = BayesianGenerator(vocs=test_vocs, turbo_controller=turbo_controller)
gen.add_data(TEST_VOCS_DATA)
gen_dict = json.loads(gen.to_json())
gen.from_dict(gen_dict | {"vocs": test_vocs})

@patch.multiple(BayesianGenerator, __abstractmethods__=set())
def test_get_trust_region(self):
# test in 1D
Expand Down Expand Up @@ -341,20 +351,31 @@ def test_serialization(self):
evaluator = Evaluator(function=sin_function)
for name in ["optimize", "safety"]:
generator = UpperConfidenceBoundGenerator(vocs=vocs, turbo_controller=name)
X = Xopt(evaluator=evaluator, generator=generator, vocs=vocs)
X = Xopt(evaluator=evaluator, generator=generator, vocs=vocs,
dump_file="dump.yml")

yaml_str = X.yaml()
X2 = Xopt.from_yaml(yaml_str)
assert X2.generator.turbo_controller.name == name

X2.random_evaluate(3)
X2.step()

config = yaml.safe_load(open("dump.yml"))

# test restart
X3 = Xopt.model_validate(config)
X3.generator.train_model()
X3.step()

def test_entropy_turbo(self):
# define variables and function objectives
vocs = VOCS(
variables={"x": [0, 2 * math.pi]},
observables=["y1"],
)

def sin_function(input_dict):
def basic_sin_function(input_dict):
return {"y1": np.sin(input_dict["x"])}

# Prepare BAX algorithm and generator options
Expand All @@ -370,7 +391,7 @@ def sin_function(input_dict):
)

# construct evaluator
evaluator = Evaluator(function=sin_function)
evaluator = Evaluator(function=basic_sin_function)

# construct Xopt optimizer
X = Xopt(evaluator=evaluator, generator=generator, vocs=vocs)
Expand All @@ -387,3 +408,11 @@ def sin_function(input_dict):
algorithm=algorithm,
turbo_controller=OptimizeTurboController(vocs),
)

@pytest.fixture(scope="module", autouse=True)
def clean_up(self):
yield
files = ["dump.yml"]
for f in files:
if os.path.exists(f):
os.remove(f)

0 comments on commit 09c5a36

Please sign in to comment.