diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 8d7f9fdef12..cfa316b9274 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -11,12 +11,14 @@ from collections import defaultdict from collections.abc import Callable, Mapping from dataclasses import dataclass +from enum import IntEnum from pathlib import Path from types import TracebackType from typing import ( TYPE_CHECKING, Any, - Literal, + Callable, + Mapping, Protocol, ) @@ -147,6 +149,10 @@ def from_seba_optimal_result( ) +class EverestRunModelExitCode(IntEnum): + MAX_BATCH_NUM_REACHED = 1 + + class EverestRunModel(BaseRunModel): def __init__( self, @@ -175,9 +181,7 @@ def __init__( ) self._display_all_jobs = display_all_jobs self._result: OptimalResult | None = None - self._exit_code: Literal["max_batch_num_reached"] | OptimizerExitCode | None = ( - None - ) + self._exit_code: EverestRunModelExitCode | OptimizerExitCode | None = None self._max_batch_num_reached = False self._simulator_cache: SimulatorCache | None = None if ( @@ -285,7 +289,7 @@ def run_experiment( ) self._exit_code = ( - "max_batch_num_reached" + EverestRunModelExitCode.MAX_BATCH_NUM_REACHED if self._max_batch_num_reached else optimizer_exit_code ) @@ -431,7 +435,7 @@ def description(cls) -> str: @property def exit_code( self, - ) -> Literal["max_batch_num_reached"] | OptimizerExitCode | None: + ) -> EverestRunModelExitCode | OptimizerExitCode | None: return self._exit_code @property diff --git a/src/everest/bin/everest_script.py b/src/everest/bin/everest_script.py index 4cc91ae7f59..a93502a9ca5 100755 --- a/src/everest/bin/everest_script.py +++ b/src/everest/bin/everest_script.py @@ -9,12 +9,12 @@ import threading from functools import partial -from ert.run_models.everest_run_model import EverestRunModel from everest.config import EverestConfig, ServerConfig from everest.detached import ( ServerStatus, everserver_status, server_is_running, + start_experiment, start_server, wait_for_server, ) @@ -114,7 +114,11 @@ async def run_everest(options): except ValueError as exc: raise SystemExit(f"Config validation error: {exc}") from exc - if EverestRunModel.create(options.config).check_if_runpath_exists(): + if ( + options.config.simulation_dir is not None + and os.path.exists(options.config.simulation_dir) + and any(os.listdir(options.config.simulation_dir)) + ): warn_user_that_runpath_is_nonempty() try: @@ -128,6 +132,12 @@ async def run_everest(options): print("Waiting for server ...") wait_for_server(options.config.output_dir, timeout=600) print("Everest server found!") + + start_experiment( + server_context=ServerConfig.get_server_context(options.config.output_dir), + config=options.config, + ) + run_detached_monitor( server_context=ServerConfig.get_server_context(options.config.output_dir), optimization_output_dir=options.config.optimization_output_dir, diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index 503ea932928..e2c1b3bbda7 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -36,6 +36,7 @@ OPT_PROGRESS_ID, SIM_PROGRESS_ENDPOINT, SIM_PROGRESS_ID, + START_ENDPOINT, STOP_ENDPOINT, ) @@ -52,6 +53,25 @@ # everest.log file instead +def start_experiment( + server_context: Tuple[str, str, Tuple[str, str]], + config: EverestConfig, +) -> None: + try: + url, cert, auth = server_context + start_endpoint = "/".join([url, START_ENDPOINT]) + response = requests.post( + start_endpoint, + verify=cert, + auth=auth, + proxies=PROXY, # type: ignore + json=config.to_dict(), + ) + response.raise_for_status() + except: + raise ValueError("Failed to start experiment") from None + + async def start_server(config: EverestConfig, debug: bool = False) -> Driver: """ Start an Everest server running the optimization defined in the config diff --git a/src/everest/detached/jobs/everest_server_api.py b/src/everest/detached/jobs/everest_server_api.py new file mode 100644 index 00000000000..8ff2b15e23a --- /dev/null +++ b/src/everest/detached/jobs/everest_server_api.py @@ -0,0 +1,351 @@ +import json +import logging +import os +import socket +import ssl +import threading +import traceback +from base64 import b64encode +from datetime import datetime, timedelta +from functools import partial +from typing import Optional + +import uvicorn +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status +from fastapi.encoders import jsonable_encoder +from fastapi.responses import ( + JSONResponse, + Response, +) +from fastapi.security import ( + HTTPBasic, + HTTPBasicCredentials, +) +from pydantic import BaseModel +from ropt.enums import OptimizerExitCode + +from ert.config import QueueSystem +from ert.ensemble_evaluator import EvaluatorServerConfig +from ert.run_models.everest_run_model import EverestRunModel, EverestRunModelExitCode +from ert.shared import get_machine_name as ert_shared_get_machine_name +from everest.config import EverestConfig, ServerConfig +from everest.detached import get_opt_status +from everest.strings import ( + OPT_PROGRESS_ENDPOINT, + SHARED_DATA_ENDPOINT, + SIM_PROGRESS_ENDPOINT, + START_ENDPOINT, + STOP_ENDPOINT, +) +from everest.util import makedirs_if_needed + + +def _find_open_port(host: str, lower: int, upper: int) -> int: + for port in range(lower, upper): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind((host, port)) + sock.close() + return port + except socket.error: + logging.getLogger("everserver").info( + "Port {} for host {} is taken".format(port, host) + ) + msg = "No open port for host {} in the range {}-{}".format(host, lower, upper) + logging.getLogger("everserver").exception(msg) + raise Exception(msg) + + +def _write_hostfile(host_file_path, host, port, cert, auth) -> None: + if not os.path.exists(os.path.dirname(host_file_path)): + os.makedirs(os.path.dirname(host_file_path)) + data = { + "host": host, + "port": port, + "cert": cert, + "auth": auth, + } + json_string = json.dumps(data) + + with open(host_file_path, "w", encoding="utf-8") as f: + f.write(json_string) + + +def _generate_authentication() -> str: + n_bytes = 128 + random_bytes = bytes(os.urandom(n_bytes)) + return b64encode(random_bytes).decode("utf-8") + + +def _generate_certificate(cert_folder: str): + """Generate a private key and a certificate signed with it + + Both the certificate and the key are written to files in the folder given + by `get_certificate_dir(config)`. The key is encrypted before being + stored. + Returns the path to the certificate file, the path to the key file, and + the password used for encrypting the key + """ + # Generate private key + key = rsa.generate_private_key( + public_exponent=65537, key_size=4096, backend=default_backend() + ) + + # Generate the certificate and sign it with the private key + cert_name = ert_shared_get_machine_name() + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Bergen"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Sandsli"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Equinor"), + x509.NameAttribute(NameOID.COMMON_NAME, "{}".format(cert_name)), + ] + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=365)) # 1 year + .add_extension( + x509.SubjectAlternativeName([x509.DNSName("{}".format(cert_name))]), + critical=False, + ) + .sign(key, hashes.SHA256(), default_backend()) + ) + + # Write certificate and key to disk + makedirs_if_needed(cert_folder) + cert_path = os.path.join(cert_folder, cert_name + ".crt") + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + key_path = os.path.join(cert_folder, cert_name + ".key") + pw = bytes(os.urandom(28)) + with open(key_path, "wb") as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.BestAvailableEncryption(pw), + ) + ) + return cert_path, key_path, pw + + +def _sim_monitor(context_status, event=None, shared_data=None): + status = context_status["status"] + assert shared_data + shared_data[SIM_PROGRESS_ENDPOINT] = { + "batch_number": context_status["batch_number"], + "status": { + "running": status.get("Running", 0), + "waiting": status.get("Waiting", 0), + "pending": status.get("Pending", 0), + "complete": status.get("Finished", 0), + "failed": status.get("Failed", 0), + }, + "progress": context_status["progress"], + } + + if shared_data[STOP_ENDPOINT]: + return "stop_queue" + + +def _opt_monitor(shared_data=None): + assert shared_data + if shared_data[STOP_ENDPOINT]: + return "stop_optimization" + + +class ExperimentRunnerStatus(BaseModel): + status: Optional[str] = None + exit_code: Optional[EverestRunModelExitCode | OptimizerExitCode] = None + message: Optional[str] = None + + +class ExperimentRunner(threading.Thread): + def __init__(self, everest_config, state: dict): + super().__init__() + + self.everest_config = everest_config + self.state = state + self.status: Optional[ExperimentRunnerStatus] = None + + def run(self): + run_model = EverestRunModel.create( + self.everest_config, + simulation_callback=partial(_sim_monitor, shared_data=self.state), + optimization_callback=partial(_opt_monitor, shared_data=self.state), + ) + + evaluator_server_config = EvaluatorServerConfig( + custom_port_range=range(49152, 51819) + if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL + else None + ) + + try: + run_model.run_experiment(evaluator_server_config) + self.status = ExperimentRunnerStatus( + status="Experiment finished", exit_code=run_model.exit_code + ) + except Exception: + self.status = ExperimentRunnerStatus( + status="Experiment failed", message=traceback.format_exc() + ) + + def get_status(self) -> Optional[ExperimentRunnerStatus]: + return self.status + + +security = HTTPBasic() + + +class EverestServerAPI(threading.Thread): + def __init__(self, output_dir: str, optimization_output_dir: str): + super().__init__() + + self.output_dir = output_dir + self.optimization_output_dir = optimization_output_dir + + self.app = FastAPI() + + self.router = APIRouter() + self.router.add_api_route("/", self.get_status, methods=["GET"]) + self.router.add_api_route("/" + STOP_ENDPOINT, self.stop, methods=["POST"]) + self.router.add_api_route( + "/" + SIM_PROGRESS_ENDPOINT, self.get_sim_progress, methods=["GET"] + ) + self.router.add_api_route( + "/" + OPT_PROGRESS_ENDPOINT, self.get_opt_progress, methods=["GET"] + ) + self.router.add_api_route( + "/" + START_ENDPOINT, self.start_experiment, methods=["POST"] + ) + self.router.add_api_route( + "/" + SHARED_DATA_ENDPOINT, self.get_state, methods=["GET"] + ) + + self.app.include_router(self.router) + + self.state = { + SIM_PROGRESS_ENDPOINT: {}, + STOP_ENDPOINT: False, + } + + self.runner: Optional[ExperimentRunner] = None + + # same code is in ensemble evaluator + self.authentication = _generate_authentication() + + # same code is in ensemble evaluator + self.cert_path, self.key_path, self.key_pw = _generate_certificate( + ServerConfig.get_certificate_dir(self.output_dir) + ) + self.host = ert_shared_get_machine_name() + self.port = _find_open_port(self.host, lower=5000, upper=5800) + + host_file = ServerConfig.get_hostfile_path(self.output_dir) + _write_hostfile( + host_file, self.host, self.port, self.cert_path, self.authentication + ) + + def run(self): + uvicorn.run( + self.app, + host="0.0.0.0", + port=self.port, + ssl_keyfile=self.key_path, + ssl_certfile=self.cert_path, + ssl_version=ssl.PROTOCOL_SSLv23, + ssl_keyfile_password=self.key_pw, + log_level=logging.CRITICAL, + ) + + def _check_user(self, credentials: HTTPBasicCredentials) -> None: + if credentials.password != self.authentication: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", + headers={"WWW-Authenticate": "Basic"}, + ) + + def _log(self, request: Request) -> None: + logging.getLogger("everserver").info( + f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}" + ) + + def get_status( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self._log(request) + self._check_user(credentials) + + if self.state[STOP_ENDPOINT] == True: + return JSONResponse( + jsonable_encoder( + ExperimentRunnerStatus(status="Everest server stopped") + ) + ) + + if not self.runner: + return JSONResponse( + jsonable_encoder( + ExperimentRunnerStatus(status="Everest server is running") + ) + ) + + return JSONResponse(jsonable_encoder(self.runner.get_status())) + + def stop( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> Response: + self._log(request) + self._check_user(credentials) + self.state[STOP_ENDPOINT] = True + return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) + + def get_sim_progress( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self._log(request) + self._check_user(credentials) + progress = self.state[SIM_PROGRESS_ENDPOINT] + return JSONResponse(jsonable_encoder(progress)) + + def get_opt_progress( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self._log(request) + self._check_user(credentials) + progress = get_opt_status(self.optimization_output_dir) + return JSONResponse(jsonable_encoder(progress)) + + def start_experiment( + self, + config: EverestConfig, + request: Request, + credentials: HTTPBasicCredentials = Depends(security), + ) -> Response: + self._log(request) + self._check_user(credentials) + + self.runner = ExperimentRunner(config, self.state) + self.runner.start() + + return Response("Everest experiment started", 200) + + def get_state( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self._log(request) + self._check_user(credentials) + return JSONResponse(jsonable_encoder(self.state)) diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 936ccedf236..12388ab6f6c 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -1,199 +1,78 @@ import argparse import json import logging -import os -import socket -import ssl -import threading +import time import traceback -from base64 import b64encode -from datetime import datetime, timedelta -from functools import partial from pathlib import Path from typing import Any -import uvicorn -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.x509.oid import NameOID -from dns import resolver, reversename -from fastapi import Depends, FastAPI, HTTPException, Request, status -from fastapi.encoders import jsonable_encoder -from fastapi.responses import ( - JSONResponse, - PlainTextResponse, - Response, -) -from fastapi.security import ( - HTTPBasic, - HTTPBasicCredentials, -) +import requests from ropt.enums import OptimizerExitCode -from ert.config import QueueSystem -from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.everest_run_model import EverestRunModel from everest import export_to_csv, export_with_progress from everest.config import EverestConfig, ServerConfig -from everest.detached import ServerStatus, get_opt_status, update_everserver_status +from everest.detached import ( + PROXY, + ServerStatus, + update_everserver_status, +) +from everest.detached.jobs.everest_server_api import ( + EverestServerAPI, + ExperimentRunnerStatus, +) from everest.export import check_for_errors from everest.simulator import JOB_FAILURE from everest.strings import ( DEFAULT_LOGGING_FORMAT, EVEREST, OPT_FAILURE_REALIZATIONS, - OPT_PROGRESS_ENDPOINT, + SHARED_DATA_ENDPOINT, SIM_PROGRESS_ENDPOINT, STOP_ENDPOINT, ) from everest.util import get_azure_logging_handler, makedirs_if_needed, version_info -def _get_machine_name() -> str: - """Returns a name that can be used to identify this machine in a network - - A fully qualified domain name is returned if available. Otherwise returns - the string `localhost` - """ - hostname = socket.gethostname() - try: - # We need the ip-address to perform a reverse lookup to deal with - # differences in how the clusters are getting their fqdn's - ip_addr = socket.gethostbyname(hostname) - reverse_name = reversename.from_address(ip_addr) - resolved_hosts = [ - str(ptr_record).rstrip(".") - for ptr_record in resolver.resolve(reverse_name, "PTR") - ] - resolved_hosts.sort() - return resolved_hosts[0] - except (resolver.NXDOMAIN, resolver.NoResolverConfiguration): - # If local address and reverse lookup not working - fallback - # to socket fqdn which are using /etc/hosts to retrieve this name - return socket.getfqdn() - except socket.gaierror: - logging.debug(traceback.format_exc()) - return "localhost" - - -def _sim_monitor(context_status, shared_data=None): - status = context_status["status"] - shared_data[SIM_PROGRESS_ENDPOINT] = { - "batch_number": context_status["batch_number"], - "status": { - "running": status.get("Running", 0), - "waiting": status.get("Waiting", 0), - "pending": status.get("Pending", 0), - "complete": status.get("Finished", 0), - "failed": status.get("Failed", 0), - }, - "progress": context_status["progress"], - } +def _get_optimization_status(exit_code, shared_data): + if exit_code == "max_batch_num_reached": + return ServerStatus.completed, "Maximum number of batches reached." - if shared_data[STOP_ENDPOINT]: - return "stop_queue" + if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED: + return ServerStatus.completed, "Maximum number of function evaluations reached." + if exit_code == OptimizerExitCode.USER_ABORT: + return ServerStatus.stopped, "Optimization aborted." -def _opt_monitor(shared_data=None): - if shared_data[STOP_ENDPOINT]: - return "stop_optimization" + if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS: + status = ( + ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed + ) + messages = _failed_realizations_messages(shared_data) + for msg in messages: + logging.getLogger(EVEREST).error(msg) + return status, "\n".join(messages) + return ServerStatus.completed, "Optimization completed." -def _everserver_thread(shared_data, server_config) -> None: - app = FastAPI() - security = HTTPBasic() - def _check_user(credentials: HTTPBasicCredentials) -> None: - if credentials.password != server_config["authentication"]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials", - headers={"WWW-Authenticate": "Basic"}, +def _failed_realizations_messages(shared_data): + messages = [OPT_FAILURE_REALIZATIONS] + failed = shared_data[SIM_PROGRESS_ENDPOINT]["status"]["failed"] + if failed > 0: + # Find the set of jobs that failed. To keep the order in which they + # are found in the queue, use a dict as sets are not ordered. + failed_jobs = dict.fromkeys( + ( + job["name"] + for queue in shared_data[SIM_PROGRESS_ENDPOINT]["progress"] + for job in queue + if job["status"] == JOB_FAILURE ) - - def _log(request: Request) -> None: - logging.getLogger("everserver").info( - f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}" + ).keys() + messages.append( + "{} job failures caused by: {}".format(failed, ", ".join(failed_jobs)) ) - - @app.get("/") - def get_status( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> PlainTextResponse: - _log(request) - _check_user(credentials) - return PlainTextResponse("Everest is running") - - @app.post("/" + STOP_ENDPOINT) - def stop( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> Response: - _log(request) - _check_user(credentials) - shared_data[STOP_ENDPOINT] = True - return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) - - @app.get("/" + SIM_PROGRESS_ENDPOINT) - def get_sim_progress( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> JSONResponse: - _log(request) - _check_user(credentials) - progress = shared_data[SIM_PROGRESS_ENDPOINT] - return JSONResponse(jsonable_encoder(progress)) - - @app.get("/" + OPT_PROGRESS_ENDPOINT) - def get_opt_progress( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> JSONResponse: - _log(request) - _check_user(credentials) - progress = get_opt_status(server_config["optimization_output_dir"]) - return JSONResponse(jsonable_encoder(progress)) - - uvicorn.run( - app, - host="0.0.0.0", - port=server_config["port"], - ssl_keyfile=server_config["key_path"], - ssl_certfile=server_config["cert_path"], - ssl_version=ssl.PROTOCOL_SSLv23, - ssl_keyfile_password=server_config["key_passwd"], - log_level=logging.CRITICAL, - ) - - -def _find_open_port(host, lower, upper) -> int: - for port in range(lower, upper): - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind((host, port)) - sock.close() - return port - except OSError: - logging.getLogger("everserver").info( - f"Port {port} for host {host} is taken" - ) - msg = f"No open port for host {host} in the range {lower}-{upper}" - logging.getLogger("everserver").exception(msg) - raise Exception(msg) - - -def _write_hostfile(host_file_path, host, port, cert, auth) -> None: - if not os.path.exists(os.path.dirname(host_file_path)): - os.makedirs(os.path.dirname(host_file_path)) - data = { - "host": host, - "port": port, - "cert": cert, - "auth": auth, - } - json_string = json.dumps(data) - - with open(host_file_path, "w", encoding="utf-8") as f: - f.write(json_string) + return messages def _configure_loggers(detached_dir: Path, log_dir: Path, logging_level: int) -> None: @@ -251,7 +130,6 @@ def main(): if options.debug: config.logging_level = "debug" status_path = ServerConfig.get_everserver_status_path(config.output_dir) - host_file = ServerConfig.get_hostfile_path(config.output_dir) try: _configure_loggers( @@ -269,34 +147,21 @@ def main(): logging.getLogger(EVEREST).info(f"Output directory: {config.output_dir}") logging.getLogger(EVEREST).debug(str(options)) - authentication = _generate_authentication() - cert_path, key_path, key_pw = _generate_certificate( - ServerConfig.get_certificate_dir(config.output_dir) - ) - host = _get_machine_name() - port = _find_open_port(host, lower=5000, upper=5800) - _write_hostfile(host_file, host, port, cert_path, authentication) - shared_data = { SIM_PROGRESS_ENDPOINT: {}, STOP_ENDPOINT: False, } - server_config = { - "optimization_output_dir": config.optimization_output_dir, - "port": port, - "cert_path": cert_path, - "key_path": key_path, - "key_passwd": key_pw, - "authentication": authentication, - } - - everserver_instance = threading.Thread( - target=_everserver_thread, - args=(shared_data, server_config), + everest_server_api = EverestServerAPI( + output_dir=config.output_dir, + optimization_output_dir=config.optimization_output_dir, ) - everserver_instance.daemon = True - everserver_instance.start() + everest_server_api.daemon = True + everest_server_api.start() + + server_context = (ServerConfig.get_server_context(config.output_dir),) + url, cert, auth = server_context[0] + except: update_everserver_status( status_path, @@ -306,26 +171,60 @@ def main(): return try: + # wait until the api server is running + is_running = False + while not is_running: + try: + requests.get(url + "/", verify=cert, auth=auth, proxies=PROXY) # type: ignore + is_running = True + except: + time.sleep(1) + update_everserver_status(status_path, ServerStatus.running) - run_model = EverestRunModel.create( - config, - simulation_callback=partial(_sim_monitor, shared_data=shared_data), - optimization_callback=partial(_opt_monitor, shared_data=shared_data), - ) + is_done = False + while not is_done: + resp: requests.Response = requests.get( + url + "/", + verify=cert, + auth=auth, + proxies=PROXY, # type: ignore + ) + server_status = ExperimentRunnerStatus.model_validate_json( + resp.text if hasattr(resp, "text") else resp.body + ) + if ( + server_status.message + and "Everest server is running" in server_status.message + ): + time.sleep(1) + else: + is_done = True + + if server_status.message: + update_everserver_status( + status_path, + ServerStatus.failed, + message=server_status.message, + ) + return - evaluator_server_config = EvaluatorServerConfig( - custom_port_range=range(49152, 51819) - if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL - else None + response: requests.Response = requests.get( + url + "/" + SHARED_DATA_ENDPOINT, + verify=cert, + auth=auth, + proxies=PROXY, # type: ignore ) + if json_body := json.loads( + response.text if hasattr(response, "text") else response.body + ): + shared_data = json_body - run_model.run_experiment(evaluator_server_config) - - status, message = _get_optimization_status(run_model.exit_code, shared_data) + status, message = _get_optimization_status(server_status.exit_code, shared_data) if status != ServerStatus.completed: update_everserver_status(status_path, status, message) return + except: if shared_data[STOP_ENDPOINT]: update_everserver_status( @@ -370,107 +269,3 @@ def main(): return update_everserver_status(status_path, ServerStatus.completed, message=message) - - -def _get_optimization_status(exit_code, shared_data): - if exit_code == "max_batch_num_reached": - return ServerStatus.completed, "Maximum number of batches reached." - - if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED: - return ServerStatus.completed, "Maximum number of function evaluations reached." - - if exit_code == OptimizerExitCode.USER_ABORT: - return ServerStatus.stopped, "Optimization aborted." - - if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS: - status = ( - ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed - ) - messages = _failed_realizations_messages(shared_data) - for msg in messages: - logging.getLogger(EVEREST).error(msg) - return status, "\n".join(messages) - - return ServerStatus.completed, "Optimization completed." - - -def _failed_realizations_messages(shared_data): - messages = [OPT_FAILURE_REALIZATIONS] - failed = shared_data[SIM_PROGRESS_ENDPOINT]["status"]["failed"] - if failed > 0: - # Find the set of jobs that failed. To keep the order in which they - # are found in the queue, use a dict as sets are not ordered. - failed_jobs = dict.fromkeys( - job["name"] - for queue in shared_data[SIM_PROGRESS_ENDPOINT]["progress"] - for job in queue - if job["status"] == JOB_FAILURE - ).keys() - messages.append( - "{} job failures caused by: {}".format(failed, ", ".join(failed_jobs)) - ) - return messages - - -def _generate_certificate(cert_folder: str): - """Generate a private key and a certificate signed with it - - Both the certificate and the key are written to files in the folder given - by `get_certificate_dir(config)`. The key is encrypted before being - stored. - Returns the path to the certificate file, the path to the key file, and - the password used for encrypting the key - """ - # Generate private key - key = rsa.generate_private_key( - public_exponent=65537, key_size=4096, backend=default_backend() - ) - - # Generate the certificate and sign it with the private key - cert_name = _get_machine_name() - subject = issuer = x509.Name( - [ - x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"), - x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Bergen"), - x509.NameAttribute(NameOID.LOCALITY_NAME, "Sandsli"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Equinor"), - x509.NameAttribute(NameOID.COMMON_NAME, f"{cert_name}"), - ] - ) - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=365)) # 1 year - .add_extension( - x509.SubjectAlternativeName([x509.DNSName(f"{cert_name}")]), - critical=False, - ) - .sign(key, hashes.SHA256(), default_backend()) - ) - - # Write certificate and key to disk - makedirs_if_needed(cert_folder) - cert_path = os.path.join(cert_folder, cert_name + ".crt") - with open(cert_path, "wb") as f: - f.write(cert.public_bytes(serialization.Encoding.PEM)) - key_path = os.path.join(cert_folder, cert_name + ".key") - pw = bytes(os.urandom(28)) - with open(key_path, "wb") as f: - f.write( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(pw), - ) - ) - return cert_path, key_path, pw - - -def _generate_authentication(): - n_bytes = 128 - random_bytes = bytes(os.urandom(n_bytes)) - return b64encode(random_bytes).decode("utf-8") diff --git a/src/everest/strings.py b/src/everest/strings.py index 50be1da326b..791eca51ae6 100644 --- a/src/everest/strings.py +++ b/src/everest/strings.py @@ -29,5 +29,9 @@ SIMULATOR_END = "end" SIM_PROGRESS_ENDPOINT = "sim_progress" SIM_PROGRESS_ID = "simulation_progress" +START_ENDPOINT = "start" STOP_ENDPOINT = "stop" STORAGE_DIR = "simulation_results" +STATUS_ENDPOINT = "status" +SHARED_DATA_ENDPOINT = "shared_data" +EXIT_CODE_ENDPOINT = "exit_code" diff --git a/tests/everest/entry_points/test_everest_entry.py b/tests/everest/entry_points/test_everest_entry.py index d2efd7d81d6..ec2701aebc4 100644 --- a/tests/everest/entry_points/test_everest_entry.py +++ b/tests/everest/entry_points/test_everest_entry.py @@ -78,7 +78,9 @@ def run_detached_monitor_mock(status=ServerStatus.completed, error=None, **kwarg "everest.bin.everest_script.everserver_status", return_value={"status": ServerStatus.never_run, "message": None}, ) +@patch("everest.bin.everest_script.start_experiment") def test_everest_entry_debug( + mock_start_experiment, everserver_status_mock, start_server_mock, wait_for_server_mock, @@ -94,6 +96,7 @@ def test_everest_entry_debug( wait_for_server_mock.assert_called_once() start_monitor_mock.assert_called_once() everserver_status_mock.assert_called() + mock_start_experiment.assert_called() # the config file itself is dumped at DEBUG level assert '"controls"' in logstream @@ -109,7 +112,9 @@ def test_everest_entry_debug( "everest.bin.everest_script.everserver_status", return_value={"status": ServerStatus.never_run, "message": None}, ) +@patch("everest.bin.everest_script.start_experiment") def test_everest_entry( + mock_start_experiment, everserver_status_mock, start_server_mock, wait_for_server_mock, @@ -122,6 +127,7 @@ def test_everest_entry( wait_for_server_mock.assert_called_once() start_monitor_mock.assert_called_once() everserver_status_mock.assert_called() + mock_start_experiment.assert_called() @patch("everest.bin.everest_script.server_is_running", return_value=False) @@ -132,7 +138,9 @@ def test_everest_entry( "everest.bin.everest_script.everserver_status", return_value={"status": ServerStatus.completed, "message": None}, ) +@patch("everest.bin.everest_script.start_experiment") def test_everest_entry_detached_already_run( + mock_start_experiment, everserver_status_mock, start_server_mock, wait_for_server_mock, @@ -151,6 +159,7 @@ def test_everest_entry_detached_already_run( server_is_running_mock.assert_called_once() everserver_status_mock.assert_called() everserver_status_mock.reset_mock() + mock_start_experiment.assert_not_called() # stopping the server has no effect kill_entry([CONFIG_FILE_MINIMAL]) @@ -297,7 +306,9 @@ def test_everest_entry_monitor_no_run( "everest.bin.everest_script.everserver_status", return_value={"status": ServerStatus.never_run, "message": None}, ) +@patch("everest.bin.everest_script.start_experiment") def test_everest_entry_show_all_jobs( + mock_start_experiment, everserver_status_mock, get_opt_status_mock, get_server_context_mock, @@ -331,7 +342,9 @@ def test_everest_entry_show_all_jobs( "everest.bin.everest_script.everserver_status", return_value={"status": ServerStatus.never_run, "message": None}, ) +@patch("everest.bin.everest_script.start_experiment") def test_everest_entry_no_show_all_jobs( + mock_start_experiment, everserver_status_mock, get_opt_status_mock, get_server_context_mock, @@ -430,7 +443,9 @@ def test_monitor_entry_no_show_all_jobs( ) @patch("everest.bin.everest_script.wait_for_server") @patch("everest.bin.everest_script.start_server") +@patch("everest.bin.everest_script.start_experiment") def test_exception_raised_when_server_run_fails( + mock_start_experiment, start_server_mock, wait_for_server_mock, start_monitor_mock, @@ -462,7 +477,9 @@ def test_exception_raised_when_server_run_fails_monitor( ) @patch("everest.bin.everest_script.wait_for_server") @patch("everest.bin.everest_script.start_server") +@patch("everest.bin.everest_script.start_experiment") def test_complete_status_for_normal_run( + mock_start_experiment, start_server_mock, wait_for_server_mock, start_monitor_mock, diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index ed400f5b870..9903556233b 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -60,7 +60,7 @@ async def test_https_requests(copy_math_func_test_data_to_tmp): raise e server_status = everserver_status(status_path) - assert ServerStatus.running == server_status["status"] + assert server_status["status"] in [ServerStatus.running, ServerStatus.starting] url, cert, auth = ServerConfig.get_server_context(everest_config.output_dir) result = requests.get(url, verify=cert, auth=auth, proxies=PROXY) # noqa: ASYNC210 diff --git a/tests/everest/test_everest_output.py b/tests/everest/test_everest_output.py index 5287251aca4..81a534206c4 100644 --- a/tests/everest/test_everest_output.py +++ b/tests/everest/test_everest_output.py @@ -102,12 +102,13 @@ def useless_cb(*args, **kwargs): @patch("everest.bin.everest_script.server_is_running", return_value=False) @patch("everest.bin.everest_script.run_detached_monitor") @patch("everest.bin.everest_script.wait_for_server") +@patch("everest.bin.everest_script.start_experiment") @patch("everest.bin.everest_script.start_server") @patch( "everest.bin.everest_script.everserver_status", return_value={"status": ServerStatus.never_run, "message": None}, ) -def test_save_running_config(_, _1, _2, _3, _4, copy_math_func_test_data_to_tmp): +def test_save_running_config(_, _1, _2, _3, _4, _5, copy_math_func_test_data_to_tmp): """Test everest detached, when an optimization has already run""" # optimization already run, notify the user file_name = "config_minimal.yml" diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 3c501440eec..52599565b32 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -1,18 +1,27 @@ import json import os import ssl -from functools import partial from pathlib import Path from unittest.mock import patch +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse, PlainTextResponse from ropt.enums import OptimizerExitCode -from seba_sqlite.snapshot import SebaSnapshot from everest.config import EverestConfig, ServerConfig -from everest.detached import ServerStatus, everserver_status +from everest.detached import PROXY, ServerStatus, everserver_status from everest.detached.jobs import everserver +from everest.detached.jobs.everest_server_api import ( + ExitCode, + _generate_certificate, + _write_hostfile, +) from everest.simulator import JOB_FAILURE, JOB_SUCCESS -from everest.strings import OPT_FAILURE_REALIZATIONS, SIM_PROGRESS_ENDPOINT +from everest.strings import ( + OPT_FAILURE_REALIZATIONS, + SIM_PROGRESS_ENDPOINT, + STOP_ENDPOINT, +) def configure_everserver_logger(*args, **kwargs): @@ -54,7 +63,7 @@ def set_shared_status(*args, progress, shared_data): def test_certificate_generation(copy_math_func_test_data_to_tmp): config = EverestConfig.load_file("config_minimal.yml") - cert, key, pw = everserver._generate_certificate( + cert, key, pw = _generate_certificate( ServerConfig.get_certificate_dir(config.output_dir) ) @@ -77,7 +86,7 @@ def test_hostfile_storage(tmp_path, monkeypatch): "cert": "/a/b/c.cert", "auth": "1234", } - everserver._write_hostfile(host_file_path, **expected_result) + _write_hostfile(host_file_path, **expected_result) assert os.path.exists(host_file_path) with open(host_file_path, encoding="utf-8") as f: result = json.load(f) @@ -89,7 +98,7 @@ def test_hostfile_storage(tmp_path, monkeypatch): "everest.detached.jobs.everserver._configure_loggers", side_effect=configure_everserver_logger, ) -def test_everserver_status_failure(_1, copy_math_func_test_data_to_tmp): +def test_everserver_status_failure(mocked_logger, copy_math_func_test_data_to_tmp): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) everserver.main() @@ -101,91 +110,103 @@ def test_everserver_status_failure(_1, copy_math_func_test_data_to_tmp): assert "Exception: Configuring logger failed" in status["message"] +import pytest +import requests + + +@pytest.mark.integration_test @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") -@patch("everest.detached.jobs.everserver._generate_authentication") -@patch( - "everest.detached.jobs.everserver._generate_certificate", - return_value=(None, None, None), -) -@patch( - "everest.detached.jobs.everserver._find_open_port", - return_value=42, -) -@patch( - "everest.detached.jobs.everserver._write_hostfile", - side_effect=partial(check_status, status=ServerStatus.starting), -) -@patch("everest.detached.jobs.everserver._everserver_thread") -@patch( - "ert.run_models.everest_run_model.EverestRunModel.run_experiment", - autospec=True, - side_effect=lambda self, evaluator_server_config, restart=False: check_status( - ServerConfig.get_hostfile_path(self.everest_config.output_dir), - status=ServerStatus.running, - ), -) -@patch( - "everest.detached.jobs.everserver.check_for_errors", - return_value=([], False), -) @patch("everest.detached.jobs.everserver.export_to_csv") +@patch("requests.get") def test_everserver_status_running_complete( - _1, _2, _3, _4, _5, _6, _7, _8, _9, copy_math_func_test_data_to_tmp + mocked_get, mocked_export_to_csv, mocked_logger, copy_math_func_test_data_to_tmp ): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) + + def mocked_server(url, verify, auth, proxies): + if "/exit_code" in url: + return JSONResponse( + jsonable_encoder( + ExitCode(exit_code=OptimizerExitCode.OPTIMIZER_STEP_FINISHED) + ) + ) + if "/shared_data" in url: + return JSONResponse( + jsonable_encoder( + { + SIM_PROGRESS_ENDPOINT: {}, + STOP_ENDPOINT: False, + } + ) + ) + + return PlainTextResponse("Everest is running") + + mocked_get.side_effect = mocked_server + everserver.main() + status = everserver_status( ServerConfig.get_everserver_status_path(config.output_dir) ) assert status["status"] == ServerStatus.completed - assert status["message"] == "Optimization completed." @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") -@patch("everest.detached.jobs.everserver._generate_authentication") -@patch( - "everest.detached.jobs.everserver._generate_certificate", - return_value=(None, None, None), -) -@patch( - "everest.detached.jobs.everserver._find_open_port", - return_value=42, -) -@patch("everest.detached.jobs.everserver._write_hostfile") -@patch("everest.detached.jobs.everserver._everserver_thread") -@patch( - "ert.run_models.everest_run_model.EverestRunModel.run_experiment", - autospec=True, - side_effect=lambda self, evaluator_server_config, restart=False: fail_optimization( - self, from_ropt=True - ), -) -@patch( - "everest.detached.jobs.everserver._sim_monitor", - side_effect=partial( - set_shared_status, - progress=[ - [ - {"name": "job1", "status": JOB_FAILURE}, - {"name": "job1", "status": JOB_FAILURE}, - ], - [ - {"name": "job2", "status": JOB_SUCCESS}, - {"name": "job2", "status": JOB_FAILURE}, - ], - ], - ), -) +@patch("requests.get") +@patch("requests.post") def test_everserver_status_failed_job( - _1, _2, _3, _4, _5, _6, _7, _8, copy_math_func_test_data_to_tmp + mocked_post, + mocked_get, + mocked_logger, + copy_math_func_test_data_to_tmp, ): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) + + def mocked_server(url, verify, auth, proxies): + if "/exit_code" in url: + return JSONResponse( + jsonable_encoder( + ExitCode(exit_code=OptimizerExitCode.TOO_FEW_REALIZATIONS) + ) + ) + if "/shared_data" in url: + return JSONResponse( + jsonable_encoder( + { + SIM_PROGRESS_ENDPOINT: { + "status": {"failed": 3}, + "progress": [ + [ + {"name": "job1", "status": JOB_FAILURE}, + {"name": "job1", "status": JOB_FAILURE}, + ], + [ + {"name": "job2", "status": JOB_SUCCESS}, + {"name": "job2", "status": JOB_FAILURE}, + ], + ], + }, + STOP_ENDPOINT: False, + } + ) + ) + return PlainTextResponse("Everest is running") + + mocked_get.side_effect = mocked_server + + mocked_post.side_effect = lambda url, verify, auth, proxies: PlainTextResponse("") + everserver.main() + + url, cert, auth = ServerConfig.get_server_context(config.output_dir) + requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY) # type: ignore + status = everserver_status( ServerConfig.get_everserver_status_path(config.output_dir) ) @@ -198,34 +219,45 @@ def test_everserver_status_failed_job( @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") -@patch("everest.detached.jobs.everserver._generate_authentication") -@patch( - "everest.detached.jobs.everserver._generate_certificate", - return_value=(None, None, None), -) -@patch( - "everest.detached.jobs.everserver._find_open_port", - return_value=42, -) -@patch("everest.detached.jobs.everserver._write_hostfile") -@patch("everest.detached.jobs.everserver._everserver_thread") -@patch( - "ert.run_models.everest_run_model.EverestRunModel.run_experiment", - autospec=True, - side_effect=lambda self, evaluator_server_config, restart=False: fail_optimization( - self, from_ropt=False - ), -) -@patch( - "everest.detached.jobs.everserver._sim_monitor", - side_effect=partial(set_shared_status, progress=[]), -) +@patch("requests.get") +@patch("requests.post") def test_everserver_status_exception( - _1, _2, _3, _4, _5, _6, _7, _8, copy_math_func_test_data_to_tmp + mocked_post, + mocked_get, + mocked_logger, + copy_math_func_test_data_to_tmp, ): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) + + def mocked_server(url, verify, auth, proxies): + if "/exit_code" in url: + return JSONResponse( + jsonable_encoder(ExitCode(message="Exception: Failed optimization")) + ) + if "/shared_data" in url: + return JSONResponse( + jsonable_encoder( + { + SIM_PROGRESS_ENDPOINT: { + "status": {}, + "progress": [], + }, + STOP_ENDPOINT: False, + } + ) + ) + return PlainTextResponse("Everest is running") + + mocked_get.side_effect = mocked_server + + mocked_post.side_effect = lambda url, verify, auth, proxies: PlainTextResponse("") + everserver.main() + + url, cert, auth = ServerConfig.get_server_context(config.output_dir) + requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY) # type: ignore + status = everserver_status( ServerConfig.get_everserver_status_path(config.output_dir) ) @@ -234,40 +266,3 @@ def test_everserver_status_exception( # start_optimization raised. assert status["status"] == ServerStatus.failed assert "Exception: Failed optimization" in status["message"] - - -@patch("sys.argv", ["name", "--config-file", "config_one_batch.yml"]) -@patch("everest.detached.jobs.everserver._configure_loggers") -@patch("everest.detached.jobs.everserver._generate_authentication") -@patch( - "everest.detached.jobs.everserver._generate_certificate", - return_value=(None, None, None), -) -@patch( - "everest.detached.jobs.everserver._find_open_port", - return_value=42, -) -@patch("everest.detached.jobs.everserver._write_hostfile") -@patch("everest.detached.jobs.everserver._everserver_thread") -@patch( - "everest.detached.jobs.everserver._sim_monitor", - side_effect=partial(set_shared_status, progress=[]), -) -def test_everserver_status_max_batch_num( - _1, _2, _3, _4, _5, _6, _7, copy_math_func_test_data_to_tmp -): - config_file = "config_one_batch.yml" - config = EverestConfig.load_file(config_file) - everserver.main() - status = everserver_status( - ServerConfig.get_everserver_status_path(config.output_dir) - ) - - # The server should complete without error. - assert status["status"] == ServerStatus.completed - - # Check that there is only one batch. - snapshot = SebaSnapshot(config.optimization_output_dir).get_snapshot( - filter_out_gradient=False, batches=None - ) - assert {data.batch for data in snapshot.simulation_data} == {0} diff --git a/tests/everest/test_logging.py b/tests/everest/test_logging.py index e86337a5d0a..49210c86ad6 100644 --- a/tests/everest/test_logging.py +++ b/tests/everest/test_logging.py @@ -5,7 +5,11 @@ from ert.scheduler.event import FinishedEvent from everest.config import EverestConfig, ServerConfig -from everest.detached import start_server, wait_for_server +from everest.detached import ( + start_experiment, + start_server, + wait_for_server, +) from everest.util import makedirs_if_needed CONFIG_FILE = "config_fm_failure.yml" @@ -32,6 +36,10 @@ async def server_running(): driver = await start_server(everest_config, debug=True) try: wait_for_server(everest_config.output_dir, 60) + start_experiment( + server_context=ServerConfig.get_server_context(everest_config.output_dir), + config=everest_config, + ) except (SystemExit, RuntimeError) as e: raise e await server_running()