Skip to content

Commit

Permalink
Some refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Dec 12, 2024
1 parent c3f02ba commit ce686fd
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 56 deletions.
14 changes: 8 additions & 6 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import shutil
from collections import defaultdict
from dataclasses import dataclass
from enum import IntEnum
from pathlib import Path
from types import TracebackType
from typing import (
Expand All @@ -19,7 +20,6 @@
DefaultDict,
Dict,
List,
Literal,
Mapping,
Optional,
Protocol,
Expand Down Expand Up @@ -159,6 +159,10 @@ def from_seba_optimal_result(
)


class EverestRunModelExitCode(IntEnum):
MAX_BATCH_NUM_REACHED = 1


class EverestRunModel(BaseRunModel):
def __init__(
self,
Expand Down Expand Up @@ -187,9 +191,7 @@ def __init__(
)
self._display_all_jobs = display_all_jobs
self._result: Optional[OptimalResult] = None
self._exit_code: Optional[
Literal["max_batch_num_reached"] | OptimizerExitCode
] = None
self._exit_code: Optional[EverestRunModelExitCode | OptimizerExitCode] = None
self._max_batch_num_reached = False
self._simulator_cache: Optional[SimulatorCache] = None
if (
Expand Down Expand Up @@ -297,7 +299,7 @@ def run_experiment(
)

self._exit_code = (
"max_batch_num_reached"
EverestRunModelExitCode.MAX_BATCH_NUM_REACHED
if self._max_batch_num_reached
else optimizer_exit_code
)
Expand Down Expand Up @@ -443,7 +445,7 @@ def description(cls) -> str:
@property
def exit_code(
self,
) -> Optional[Literal["max_batch_num_reached"] | OptimizerExitCode]:
) -> Optional[EverestRunModelExitCode | OptimizerExitCode]:
return self._exit_code

@property
Expand Down
75 changes: 35 additions & 40 deletions src/everest/detached/jobs/everest_server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from base64 import b64encode
from datetime import datetime, timedelta
from functools import partial
from typing import Any, Optional
from typing import Optional

import uvicorn
from cryptography import x509
Expand All @@ -20,23 +20,22 @@
from fastapi.encoders import jsonable_encoder
from fastapi.responses import (
JSONResponse,
PlainTextResponse,
Response,
)
from fastapi.security import (
HTTPBasic,
HTTPBasicCredentials,
)
from pydantic import BaseModel
from ropt.enums import OptimizerExitCode

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.everest_run_model import EverestRunModel
from ert.run_models.everest_run_model import EverestRunModel, EverestRunModelExitCode
from ert.shared import get_machine_name as ert_shared_get_machine_name
from everest.config import EverestConfig, ServerConfig
from everest.detached import get_opt_status
from everest.strings import (
EXIT_CODE_ENDPOINT,
OPT_PROGRESS_ENDPOINT,
SHARED_DATA_ENDPOINT,
SIM_PROGRESS_ENDPOINT,
Expand Down Expand Up @@ -166,8 +165,9 @@ def _opt_monitor(shared_data=None):
return "stop_optimization"


class ExitCode(BaseModel):
exit_code: Optional[Any] = None
class ExperimentRunnerStatus(BaseModel):
status: Optional[str] = None
exit_code: Optional[EverestRunModelExitCode | OptimizerExitCode] = None
message: Optional[str] = None


Expand All @@ -177,7 +177,7 @@ def __init__(self, everest_config, state: dict):

self.everest_config = everest_config
self.state = state
self.exit_code = None
self.status: Optional[ExperimentRunnerStatus] = None

def run(self):
run_model = EverestRunModel.create(
Expand All @@ -194,12 +194,16 @@ def run(self):

try:
run_model.run_experiment(evaluator_server_config)
self.exit_code = ExitCode(exit_code=run_model.exit_code)
self.status = ExperimentRunnerStatus(
status="Experiment finished", exit_code=run_model.exit_code
)
except Exception:
self.exit_code = ExitCode(message=traceback.format_exc())
self.status = ExperimentRunnerStatus(
status="Experiment failed", message=traceback.format_exc()
)

def get_exit_code(self) -> Optional[ExitCode]:
return self.exit_code
def get_status(self) -> Optional[ExperimentRunnerStatus]:
return self.status


security = HTTPBasic()
Expand All @@ -226,9 +230,9 @@ def __init__(self, output_dir: str, optimization_output_dir: str):
self.router.add_api_route(
"/" + START_ENDPOINT, self.start_experiment, methods=["POST"]
)
self.router.add_api_route(
"/" + EXIT_CODE_ENDPOINT, self.get_exit_code, methods=["GET"]
)
# self.router.add_api_route(
# "/" + EXIT_CODE_ENDPOINT, self.get_exit_code, methods=["GET"]
# )
self.router.add_api_route(
"/" + SHARED_DATA_ENDPOINT, self.get_state, methods=["GET"]
)
Expand Down Expand Up @@ -284,10 +288,25 @@ def _log(self, request: Request) -> None:

def get_status(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> PlainTextResponse:
) -> JSONResponse:
self._log(request)
self._check_user(credentials)
return PlainTextResponse("Everest is running")

if self.state[STOP_ENDPOINT] == True:
return JSONResponse(
jsonable_encoder(
ExperimentRunnerStatus(status="Everest server stopped")
)
)

if not self.runner:
return JSONResponse(
jsonable_encoder(
ExperimentRunnerStatus(status="Everest server is running")
)
)

return JSONResponse(jsonable_encoder(self.runner.get_status()))

def stop(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
Expand All @@ -305,30 +324,6 @@ def get_sim_progress(
progress = self.state[SIM_PROGRESS_ENDPOINT]
return JSONResponse(jsonable_encoder(progress))

def get_exit_code(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
self._log(request)
self._check_user(credentials)

if self.state[STOP_ENDPOINT] == True:
return JSONResponse(
jsonable_encoder(
ExitCode(
message="Everest server stopped",
)
)
)

if not self.runner:
return JSONResponse(jsonable_encoder({}))

return JSONResponse(
jsonable_encoder(
self.runner.get_exit_code() if self.runner.get_exit_code() else {}
)
)

def get_opt_progress(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
Expand Down
25 changes: 15 additions & 10 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
ServerStatus,
update_everserver_status,
)
from everest.detached.jobs.everest_server_api import EverestServerAPI, ExitCode
from everest.detached.jobs.everest_server_api import (
EverestServerAPI,
ExperimentRunnerStatus,
)
from everest.export import check_for_errors
from everest.simulator import JOB_FAILURE
from everest.strings import (
DEFAULT_LOGGING_FORMAT,
EVEREST,
EXIT_CODE_ENDPOINT,
OPT_FAILURE_REALIZATIONS,
SHARED_DATA_ENDPOINT,
SIM_PROGRESS_ENDPOINT,
Expand Down Expand Up @@ -185,24 +187,27 @@ def main():
is_done = False
while not is_done:
resp: requests.Response = requests.get(
url + "/" + EXIT_CODE_ENDPOINT,
url + "/",
verify=cert,
auth=auth,
proxies=PROXY, # type: ignore
)
exit_code = ExitCode.model_validate_json(
server_status = ExperimentRunnerStatus.model_validate_json(
resp.text if hasattr(resp, "text") else resp.body
)
if exit_code.exit_code or exit_code.message:
is_done = True
else:
if (
server_status.message
and "Everest server is running" in server_status.message
):
time.sleep(1)
else:
is_done = True

if exit_code.message and exit_code.message != "Everest server stopped":
if server_status.message:
update_everserver_status(
status_path,
ServerStatus.failed,
message=exit_code.message,
message=server_status.message,
)
return

Expand All @@ -217,7 +222,7 @@ def main():
):
shared_data = json_body

status, message = _get_optimization_status(exit_code.exit_code, shared_data)
status, message = _get_optimization_status(server_status.exit_code, shared_data)
if status != ServerStatus.completed:
update_everserver_status(status_path, status, message)
return
Expand Down

0 comments on commit ce686fd

Please sign in to comment.