Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the everest run model #9575

Merged
merged 12 commits into from
Dec 20, 2024
839 changes: 448 additions & 391 deletions src/ert/run_models/everest_run_model.py

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions src/everest/config/environment_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal
from typing import Literal, Self

from pydantic import BaseModel, Field, field_validator
from numpy.random import SeedSequence
from pydantic import BaseModel, Field, field_validator, model_validator

from everest.config.validation_utils import check_path_valid

Expand Down Expand Up @@ -43,3 +44,9 @@ class EnvironmentConfig(BaseModel, extra="forbid"): # type: ignore
def validate_output_folder(cls, output_folder): # pylint:disable=E0213
check_path_valid(output_folder)
return output_folder

@model_validator(mode="after")
def validate_random_seed(self) -> Self:
if self.random_seed is None:
self.random_seed = SeedSequence().entropy
return self
43 changes: 24 additions & 19 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
HTTPBasic,
HTTPBasicCredentials,
)
from ropt.enums import OptimizerExitCode

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.everest_run_model import EverestRunModel
from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel
from everest import export_to_csv, export_with_progress
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, get_opt_status, update_everserver_status
Expand Down Expand Up @@ -373,25 +372,31 @@ def main():


def _get_optimization_status(exit_code, shared_data):
if exit_code == "max_batch_num_reached":
return ServerStatus.completed, "Maximum number of batches reached."

if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED:
return ServerStatus.completed, "Maximum number of function evaluations reached."
match exit_code:
case EverestExitCode.MAX_BATCH_NUM_REACHED:
return ServerStatus.completed, "Maximum number of batches reached."

case EverestExitCode.MAX_FUNCTIONS_REACHED:
return (
ServerStatus.completed,
"Maximum number of function evaluations reached.",
)

if exit_code == OptimizerExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."
case EverestExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."

if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS:
status = (
ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)

return ServerStatus.completed, "Optimization completed."
case EverestExitCode.TOO_FEW_REALIZATIONS:
status = (
ServerStatus.stopped
if shared_data[STOP_ENDPOINT]
else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)
case _:
return ServerStatus.completed, "Optimization completed."


def _failed_realizations_messages(shared_data):
Expand Down
3 changes: 0 additions & 3 deletions src/everest/simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from everest.simulator.simulator_cache import SimulatorCache

JOB_SUCCESS = "Finished"
JOB_WAITING = "Waiting"
JOB_RUNNING = "Running"
Expand Down Expand Up @@ -109,5 +107,4 @@
"JOB_RUNNING",
"JOB_SUCCESS",
"JOB_WAITING",
"SimulatorCache",
]
58 changes: 0 additions & 58 deletions src/everest/simulator/simulator_cache.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/everest/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_seed(copy_math_func_test_data_to_tmp):
config.environment.random_seed = random_seed

run_model = EverestRunModel.create(config)
assert random_seed == run_model.everest_config.environment.random_seed
assert random_seed == run_model._everest_config.environment.random_seed

# Res
ert_config = _everest_to_ert_config_dict(config)
Expand All @@ -26,5 +26,5 @@ def test_loglevel(copy_math_func_test_data_to_tmp):
config = EverestConfig.load_file(CONFIG_FILE)
config.environment.log_level = "info"
run_model = EverestRunModel.create(config)
config = run_model.everest_config
config = run_model._everest_config
assert len(EverestConfig.lint_config_dict(config.to_dict())) == 0
8 changes: 1 addition & 7 deletions tests/everest/test_everest_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,7 @@ async def test_everest_output(copy_mocked_test_data_to_tmp):
initial_folders = set(folders)
initial_files = set(files)

# Tests in this class used to fail when a callback was passed in
# Use a callback just to see that everything works fine, even though
# the callback does nothing
def useless_cb(*args, **kwargs):
pass

EverestRunModel.create(config, optimization_callback=useless_cb)
EverestRunModel.create(config)
oyvindeide marked this conversation as resolved.
Show resolved Hide resolved

# Check the output folder is created when stating the optimization
# in everest workflow
Expand Down
8 changes: 4 additions & 4 deletions tests/everest/test_everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from pathlib import Path
from unittest.mock import patch

from ropt.enums import OptimizerExitCode
from seba_sqlite.snapshot import SebaSnapshot

from ert.run_models.everest_run_model import EverestExitCode
from everest.config import EverestConfig, OptimizationConfig, ServerConfig
from everest.detached import ServerStatus, everserver_status
from everest.detached.jobs import everserver
Expand All @@ -33,8 +33,8 @@ def fail_optimization(self, from_ropt=False):
# shared_data (see set_shared_status() below).
self._sim_callback(None)
if from_ropt:
self._exit_code = OptimizerExitCode.TOO_FEW_REALIZATIONS
return OptimizerExitCode.TOO_FEW_REALIZATIONS
self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS
return EverestExitCode.TOO_FEW_REALIZATIONS

raise Exception("Failed optimization")

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_everserver_status_failure(_1, copy_math_func_test_data_to_tmp):
"ert.run_models.everest_run_model.EverestRunModel.run_experiment",
autospec=True,
side_effect=lambda self, evaluator_server_config, restart=False: check_status(
ServerConfig.get_hostfile_path(self.everest_config.output_dir),
ServerConfig.get_hostfile_path(self._everest_config.output_dir),
status=ServerStatus.running,
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/everest/test_simulator_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def new_call(*args):
Path("everest_output/optimization_output/seba.db").unlink()

# The batch_id was used as a stopping criterion, so it must be reset:
run_model.batch_id = 0
run_model._batch_id = 0

run_model.run_experiment(evaluator_server_config)
assert n_evals == 0
Expand Down
8 changes: 6 additions & 2 deletions tests/everest/test_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ def test_random_seed(tmp_path, monkeypatch, random_seed):
if random_seed:
config["environment"] = {"random_seed": random_seed}
ever_config = EverestConfig.with_defaults(**config)
assert ever_config.environment.random_seed == random_seed
ert_config = everest_to_ert_config(ever_config)
assert ert_config.random_seed == random_seed
if random_seed is None:
assert ever_config.environment.random_seed > 0
oyvindeide marked this conversation as resolved.
Show resolved Hide resolved
assert ert_config.random_seed > 0
else:
assert ever_config.environment.random_seed == random_seed
assert ert_config.random_seed == random_seed


def test_read_file(tmp_path, monkeypatch):
Expand Down
Loading