From bc30cafaf92a57e2e6a16e663ece49f0cfd157df Mon Sep 17 00:00:00 2001 From: Frode Aarstad Date: Wed, 27 Nov 2024 09:17:06 +0100 Subject: [PATCH] Split everserver functionality between starting server and submitting experiment --- .../detached/jobs/everest_server_api.py | 371 ++++++++++++++++ src/everest/detached/jobs/everserver.py | 395 ++++-------------- src/everest/strings.py | 4 + tests/everest/test_everserver.py | 91 +--- 4 files changed, 484 insertions(+), 377 deletions(-) create mode 100644 src/everest/detached/jobs/everest_server_api.py diff --git a/src/everest/detached/jobs/everest_server_api.py b/src/everest/detached/jobs/everest_server_api.py new file mode 100644 index 00000000000..384570cfb39 --- /dev/null +++ b/src/everest/detached/jobs/everest_server_api.py @@ -0,0 +1,371 @@ +import json +import logging +import os +import socket +import ssl +import threading +import traceback +from base64 import b64encode +from datetime import datetime, timedelta +from functools import partial +from typing import Any, Optional + +import uvicorn +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +from dns import resolver, reversename +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status +from fastapi.encoders import jsonable_encoder +from fastapi.responses import ( + JSONResponse, + PlainTextResponse, + Response, +) +from fastapi.security import ( + HTTPBasic, + HTTPBasicCredentials, +) +from pydantic import BaseModel + +from ert.config import QueueSystem +from ert.ensemble_evaluator import EvaluatorServerConfig +from ert.run_models.everest_run_model import EverestRunModel +from everest.config import EverestConfig, ServerConfig +from everest.detached import get_opt_status +from everest.strings import ( + EXIT_CODE_ENDPOINT, + OPT_PROGRESS_ENDPOINT, + SHARED_DATA_ENDPOINT, + SIM_PROGRESS_ENDPOINT, + START_ENDPOINT, + STOP_ENDPOINT, +) +from everest.util import makedirs_if_needed + + +def _get_machine_name() -> str: + """Returns a name that can be used to identify this machine in a network + + A fully qualified domain name is returned if available. Otherwise returns + the string `localhost` + """ + hostname = socket.gethostname() + try: + # We need the ip-address to perform a reverse lookup to deal with + # differences in how the clusters are getting their fqdn's + ip_addr = socket.gethostbyname(hostname) + reverse_name = reversename.from_address(ip_addr) + resolved_hosts = [ + str(ptr_record).rstrip(".") + for ptr_record in resolver.resolve(reverse_name, "PTR") # type: ignore + ] + resolved_hosts.sort() + return resolved_hosts[0] + except (resolver.NXDOMAIN, resolver.NoResolverConfiguration): + # If local address and reverse lookup not working - fallback + # to socket fqdn which are using /etc/hosts to retrieve this name + return socket.getfqdn() + except socket.gaierror: + logging.debug(traceback.format_exc()) + return "localhost" + + +def _find_open_port(host: str, lower: int, upper: int) -> int: + for port in range(lower, upper): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind((host, port)) + sock.close() + return port + except socket.error: + logging.getLogger("everserver").info( + "Port {} for host {} is taken".format(port, host) + ) + msg = "No open port for host {} in the range {}-{}".format(host, lower, upper) + logging.getLogger("everserver").exception(msg) + raise Exception(msg) + + +def _write_hostfile(host_file_path, host, port, cert, auth) -> None: + if not os.path.exists(os.path.dirname(host_file_path)): + os.makedirs(os.path.dirname(host_file_path)) + data = { + "host": host, + "port": port, + "cert": cert, + "auth": auth, + } + json_string = json.dumps(data) + + with open(host_file_path, "w", encoding="utf-8") as f: + f.write(json_string) + + +def _generate_authentication() -> str: + n_bytes = 128 + random_bytes = bytes(os.urandom(n_bytes)) + return b64encode(random_bytes).decode("utf-8") + + +def _generate_certificate(cert_folder: str): + """Generate a private key and a certificate signed with it + + Both the certificate and the key are written to files in the folder given + by `get_certificate_dir(config)`. The key is encrypted before being + stored. + Returns the path to the certificate file, the path to the key file, and + the password used for encrypting the key + """ + # Generate private key + key = rsa.generate_private_key( + public_exponent=65537, key_size=4096, backend=default_backend() + ) + + # Generate the certificate and sign it with the private key + cert_name = _get_machine_name() + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Bergen"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Sandsli"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Equinor"), + x509.NameAttribute(NameOID.COMMON_NAME, "{}".format(cert_name)), + ] + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=365)) # 1 year + .add_extension( + x509.SubjectAlternativeName([x509.DNSName("{}".format(cert_name))]), + critical=False, + ) + .sign(key, hashes.SHA256(), default_backend()) + ) + + # Write certificate and key to disk + makedirs_if_needed(cert_folder) + cert_path = os.path.join(cert_folder, cert_name + ".crt") + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + key_path = os.path.join(cert_folder, cert_name + ".key") + pw = bytes(os.urandom(28)) + with open(key_path, "wb") as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.BestAvailableEncryption(pw), + ) + ) + return cert_path, key_path, pw + + +def _sim_monitor(context_status, event=None, shared_data=None): + status = context_status["status"] + assert shared_data + shared_data[SIM_PROGRESS_ENDPOINT] = { + "batch_number": context_status["batch_number"], + "status": { + "running": status.running, + "waiting": status.waiting, + "pending": status.pending, + "complete": status.complete, + "failed": status.failed, + }, + "progress": context_status["progress"], + "event": event, + } + + if shared_data[STOP_ENDPOINT]: + return "stop_queue" + + +def _opt_monitor(shared_data=None): + assert shared_data + if shared_data[STOP_ENDPOINT]: + return "stop_optimization" + + +class ExitCode(BaseModel): + exit_code: Optional[Any] = None + message: Optional[str] = None + + +class ExperimentRunner(threading.Thread): + def __init__(self, everest_config, state: dict): + super().__init__() + + self.everest_config = everest_config + self.state = state + self.exit_code = None + + def run(self): + run_model = EverestRunModel.create( + self.everest_config, + simulation_callback=partial(_sim_monitor, shared_data=self.state), + optimization_callback=partial(_opt_monitor, shared_data=self.state), + ) + + evaluator_server_config = EvaluatorServerConfig( + custom_port_range=range(49152, 51819) + if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL + else None + ) + + try: + run_model.run_experiment(evaluator_server_config) + self.exit_code = ExitCode(exit_code=run_model.exit_code) + except Exception: + self.exit_code = ExitCode(message=traceback.format_exc()) + + def get_exit_code(self) -> Optional[ExitCode]: + return self.exit_code + + +security = HTTPBasic() + + +class EverestServerAPI(threading.Thread): + def __init__(self, everest_config: EverestConfig): + super().__init__() + + self.app = FastAPI() + + self.router = APIRouter() + self.router.add_api_route("/", self.get_status, methods=["GET"]) + self.router.add_api_route("/" + STOP_ENDPOINT, self.stop, methods=["POST"]) + self.router.add_api_route( + "/" + SIM_PROGRESS_ENDPOINT, self.get_sim_progress, methods=["GET"] + ) + self.router.add_api_route( + "/" + OPT_PROGRESS_ENDPOINT, self.get_opt_progress, methods=["GET"] + ) + self.router.add_api_route( + "/" + START_ENDPOINT, self.start_experiment, methods=["POST"] + ) + self.router.add_api_route( + "/" + EXIT_CODE_ENDPOINT, self.get_exit_code, methods=["GET"] + ) + self.router.add_api_route( + "/" + SHARED_DATA_ENDPOINT, self.get_state, methods=["GET"] + ) + + self.app.include_router(self.router) + + self.state = { + SIM_PROGRESS_ENDPOINT: {}, + STOP_ENDPOINT: False, + } + + self.everest_config = everest_config + self.output_dir = everest_config.output_dir + self.optimization_output_dir = everest_config.optimization_output_dir + + # same code is in ensemble evaluator + self.authentication = _generate_authentication() + + # same code is in ensemble evaluator + self.cert_path, self.key_path, self.key_pw = _generate_certificate( + ServerConfig.get_certificate_dir(self.output_dir) + ) + self.host = _get_machine_name() + self.port = _find_open_port(self.host, lower=5000, upper=5800) + + host_file = ServerConfig.get_hostfile_path(self.output_dir) + _write_hostfile( + host_file, self.host, self.port, self.cert_path, self.authentication + ) + + def run(self): + uvicorn.run( + self.app, + host="0.0.0.0", + port=self.port, + ssl_keyfile=self.key_path, + ssl_certfile=self.cert_path, + ssl_version=ssl.PROTOCOL_SSLv23, + ssl_keyfile_password=self.key_pw, + ) + + def _check_user(self, credentials: HTTPBasicCredentials) -> None: + if credentials.password != self.authentication: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", + headers={"WWW-Authenticate": "Basic"}, + ) + + def _log(self, request: Request) -> None: + logging.getLogger("everserver").info( + f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}" + ) + + def get_status( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> PlainTextResponse: + self._log(request) + self._check_user(credentials) + return PlainTextResponse("Everest is running") + + def stop( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> Response: + self._log(request) + self._check_user(credentials) + self.state[STOP_ENDPOINT] = True + return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) + + def get_sim_progress( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self._log(request) + self._check_user(credentials) + progress = self.state[SIM_PROGRESS_ENDPOINT] + return JSONResponse(jsonable_encoder(progress)) + + def get_exit_code( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self._log(request) + self._check_user(credentials) + return JSONResponse( + jsonable_encoder( + self.runner.get_exit_code() if self.runner.get_exit_code() else {} + ) + ) + + def get_opt_progress( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self._log(request) + self._check_user(credentials) + progress = get_opt_status(self.optimization_output_dir) + return JSONResponse(jsonable_encoder(progress)) + + def start_experiment( + self, + request: Request, + credentials: HTTPBasicCredentials = Depends(security), + ) -> Response: + self._log(request) + self._check_user(credentials) + + self.runner = ExperimentRunner(self.everest_config, self.state) + self.runner.start() + + return Response("Everest experiment started", 200) + + def get_state( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self._log(request) + self._check_user(credentials) + return JSONResponse(jsonable_encoder(self.state)) diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 501ccca756d..11c518e4d24 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -2,195 +2,74 @@ import json import logging import os -import socket -import ssl -import threading +import time import traceback -from base64 import b64encode -from datetime import datetime, timedelta -from functools import partial - -import uvicorn -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.x509.oid import NameOID -from dns import resolver, reversename -from fastapi import Depends, FastAPI, HTTPException, Request, status -from fastapi.encoders import jsonable_encoder -from fastapi.responses import ( - JSONResponse, - PlainTextResponse, - Response, -) -from fastapi.security import ( - HTTPBasic, - HTTPBasicCredentials, -) + +import requests from ropt.enums import OptimizerExitCode -from ert.config import QueueSystem -from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.everest_run_model import EverestRunModel from everest import export_to_csv, export_with_progress from everest.config import EverestConfig, ServerConfig -from everest.detached import ServerStatus, get_opt_status, update_everserver_status +from everest.detached import ( + PROXY, + ServerStatus, + update_everserver_status, +) +from everest.detached.jobs.everest_server_api import EverestServerAPI, ExitCode from everest.export import check_for_errors from everest.simulator import JOB_FAILURE from everest.strings import ( EVEREST, + EXIT_CODE_ENDPOINT, OPT_FAILURE_REALIZATIONS, - OPT_PROGRESS_ENDPOINT, + SHARED_DATA_ENDPOINT, SIM_PROGRESS_ENDPOINT, + START_ENDPOINT, STOP_ENDPOINT, ) -from everest.util import configure_logger, makedirs_if_needed, version_info +from everest.util import configure_logger, version_info -def _get_machine_name() -> str: - """Returns a name that can be used to identify this machine in a network +def _get_optimization_status(exit_code, shared_data): + if exit_code == "max_batch_num_reached": + return ServerStatus.completed, "Maximum number of batches reached." - A fully qualified domain name is returned if available. Otherwise returns - the string `localhost` - """ - hostname = socket.gethostname() - try: - # We need the ip-address to perform a reverse lookup to deal with - # differences in how the clusters are getting their fqdn's - ip_addr = socket.gethostbyname(hostname) - reverse_name = reversename.from_address(ip_addr) - resolved_hosts = [ - str(ptr_record).rstrip(".") - for ptr_record in resolver.resolve(reverse_name, "PTR") - ] - resolved_hosts.sort() - return resolved_hosts[0] - except (resolver.NXDOMAIN, resolver.NoResolverConfiguration): - # If local address and reverse lookup not working - fallback - # to socket fqdn which are using /etc/hosts to retrieve this name - return socket.getfqdn() - except socket.gaierror: - logging.debug(traceback.format_exc()) - return "localhost" - - -def _sim_monitor(context_status, event=None, shared_data=None): - status = context_status["status"] - shared_data[SIM_PROGRESS_ENDPOINT] = { - "batch_number": context_status["batch_number"], - "status": { - "running": status.running, - "waiting": status.waiting, - "pending": status.pending, - "complete": status.complete, - "failed": status.failed, - }, - "progress": context_status["progress"], - "event": event, - } - - if shared_data[STOP_ENDPOINT]: - return "stop_queue" - - -def _opt_monitor(shared_data=None): - if shared_data[STOP_ENDPOINT]: - return "stop_optimization" - - -def _everserver_thread(shared_data, server_config) -> None: - app = FastAPI() - security = HTTPBasic() - - def _check_user(credentials: HTTPBasicCredentials) -> None: - if credentials.password != server_config["authentication"]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials", - headers={"WWW-Authenticate": "Basic"}, - ) + if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED: + return ServerStatus.completed, "Maximum number of function evaluations reached." - def _log(request: Request) -> None: - logging.getLogger("everserver").info( - f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}" + if exit_code == OptimizerExitCode.USER_ABORT: + return ServerStatus.stopped, "Optimization aborted." + + if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS: + status = ( + ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed ) + messages = _failed_realizations_messages(shared_data) + for msg in messages: + logging.getLogger(EVEREST).error(msg) + return status, "\n".join(messages) - @app.get("/") - def get_status( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> PlainTextResponse: - _log(request) - _check_user(credentials) - return PlainTextResponse("Everest is running") - - @app.post("/" + STOP_ENDPOINT) - def stop( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> Response: - _log(request) - _check_user(credentials) - shared_data[STOP_ENDPOINT] = True - return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) - - @app.get("/" + SIM_PROGRESS_ENDPOINT) - def get_sim_progress( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> JSONResponse: - _log(request) - _check_user(credentials) - progress = shared_data[SIM_PROGRESS_ENDPOINT] - return JSONResponse(jsonable_encoder(progress)) - - @app.get("/" + OPT_PROGRESS_ENDPOINT) - def get_opt_progress( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> JSONResponse: - _log(request) - _check_user(credentials) - progress = get_opt_status(server_config["optimization_output_dir"]) - return JSONResponse(jsonable_encoder(progress)) - - uvicorn.run( - app, - host="0.0.0.0", - port=server_config["port"], - ssl_keyfile=server_config["key_path"], - ssl_certfile=server_config["cert_path"], - ssl_version=ssl.PROTOCOL_SSLv23, - ssl_keyfile_password=server_config["key_passwd"], - ) + return ServerStatus.completed, "Optimization completed." -def _find_open_port(host, lower, upper) -> int: - for port in range(lower, upper): - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind((host, port)) - sock.close() - return port - except socket.error: - logging.getLogger("everserver").info( - "Port {} for host {} is taken".format(port, host) +def _failed_realizations_messages(shared_data): + messages = [OPT_FAILURE_REALIZATIONS] + failed = shared_data[SIM_PROGRESS_ENDPOINT]["status"]["failed"] + if failed > 0: + # Find the set of jobs that failed. To keep the order in which they + # are found in the queue, use a dict as sets are not ordered. + failed_jobs = dict.fromkeys( + ( + job["name"] + for queue in shared_data[SIM_PROGRESS_ENDPOINT]["progress"] + for job in queue + if job["status"] == JOB_FAILURE ) - msg = "No open port for host {} in the range {}-{}".format(host, lower, upper) - logging.getLogger("everserver").exception(msg) - raise Exception(msg) - - -def _write_hostfile(host_file_path, host, port, cert, auth) -> None: - if not os.path.exists(os.path.dirname(host_file_path)): - os.makedirs(os.path.dirname(host_file_path)) - data = { - "host": host, - "port": port, - "cert": cert, - "auth": auth, - } - json_string = json.dumps(data) - - with open(host_file_path, "w", encoding="utf-8") as f: - f.write(json_string) + ).keys() + messages.append( + "{} job failures caused by: {}".format(failed, ", ".join(failed_jobs)) + ) + return messages def _configure_loggers( @@ -240,7 +119,6 @@ def main(): config.logging_level = "debug" detached_dir = ServerConfig.get_detached_node_dir(config.output_dir) status_path = ServerConfig.get_everserver_status_path(config.output_dir) - host_file = ServerConfig.get_hostfile_path(config.output_dir) try: _configure_loggers( @@ -255,34 +133,18 @@ def main(): ) logging.getLogger(EVEREST).debug(str(options)) - authentication = _generate_authentication() - cert_path, key_path, key_pw = _generate_certificate( - ServerConfig.get_certificate_dir(config.output_dir) - ) - host = _get_machine_name() - port = _find_open_port(host, lower=5000, upper=5800) - _write_hostfile(host_file, host, port, cert_path, authentication) - shared_data = { SIM_PROGRESS_ENDPOINT: {}, STOP_ENDPOINT: False, } - server_config = { - "optimization_output_dir": config.optimization_output_dir, - "port": port, - "cert_path": cert_path, - "key_path": key_path, - "key_passwd": key_pw, - "authentication": authentication, - } + everest_server_api = EverestServerAPI(config) + everest_server_api.daemon = True + everest_server_api.start() + + server_context = (ServerConfig.get_server_context(config.output_dir),) + url, cert, auth = server_context[0] - everserver_instance = threading.Thread( - target=_everserver_thread, - args=(shared_data, server_config), - ) - everserver_instance.daemon = True - everserver_instance.start() except: update_everserver_status( status_path, @@ -292,26 +154,51 @@ def main(): return try: + # wait until the api server is running + is_running = False + while not is_running: + try: + requests.get(url + "/", verify=cert, auth=auth, proxies=PROXY) + is_running = True + except: + time.sleep(1) + update_everserver_status(status_path, ServerStatus.running) - run_model = EverestRunModel.create( - config, - simulation_callback=partial(_sim_monitor, shared_data=shared_data), - optimization_callback=partial(_opt_monitor, shared_data=shared_data), + response = requests.post( + url + "/" + START_ENDPOINT, verify=cert, auth=auth, proxies=PROXY ) - evaluator_server_config = EvaluatorServerConfig( - custom_port_range=range(49152, 51819) - if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL - else None - ) + is_done = False + while not is_done: + response = requests.get( + url + "/" + EXIT_CODE_ENDPOINT, verify=cert, auth=auth, proxies=PROXY + ) + exit_code = ExitCode.model_validate_json(response.text) + if exit_code.exit_code or exit_code.message: + is_done = True + else: + time.sleep(1) + + if exit_code.message: + update_everserver_status( + status_path, + ServerStatus.failed, + message=exit_code.message, + ) + return - run_model.run_experiment(evaluator_server_config) + response = requests.get( + url + "/" + SHARED_DATA_ENDPOINT, verify=cert, auth=auth, proxies=PROXY + ) + if json_body := json.loads(response.text): + shared_data = json_body - status, message = _get_optimization_status(run_model.exit_code, shared_data) + status, message = _get_optimization_status(exit_code.exit_code, shared_data) if status != ServerStatus.completed: update_everserver_status(status_path, status, message) return + except: if shared_data[STOP_ENDPOINT]: update_everserver_status( @@ -356,109 +243,3 @@ def main(): return update_everserver_status(status_path, ServerStatus.completed, message=message) - - -def _get_optimization_status(exit_code, shared_data): - if exit_code == "max_batch_num_reached": - return ServerStatus.completed, "Maximum number of batches reached." - - if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED: - return ServerStatus.completed, "Maximum number of function evaluations reached." - - if exit_code == OptimizerExitCode.USER_ABORT: - return ServerStatus.stopped, "Optimization aborted." - - if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS: - status = ( - ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed - ) - messages = _failed_realizations_messages(shared_data) - for msg in messages: - logging.getLogger(EVEREST).error(msg) - return status, "\n".join(messages) - - return ServerStatus.completed, "Optimization completed." - - -def _failed_realizations_messages(shared_data): - messages = [OPT_FAILURE_REALIZATIONS] - failed = shared_data[SIM_PROGRESS_ENDPOINT]["status"]["failed"] - if failed > 0: - # Find the set of jobs that failed. To keep the order in which they - # are found in the queue, use a dict as sets are not ordered. - failed_jobs = dict.fromkeys( - ( - job["name"] - for queue in shared_data[SIM_PROGRESS_ENDPOINT]["progress"] - for job in queue - if job["status"] == JOB_FAILURE - ) - ).keys() - messages.append( - "{} job failures caused by: {}".format(failed, ", ".join(failed_jobs)) - ) - return messages - - -def _generate_certificate(cert_folder: str): - """Generate a private key and a certificate signed with it - - Both the certificate and the key are written to files in the folder given - by `get_certificate_dir(config)`. The key is encrypted before being - stored. - Returns the path to the certificate file, the path to the key file, and - the password used for encrypting the key - """ - # Generate private key - key = rsa.generate_private_key( - public_exponent=65537, key_size=4096, backend=default_backend() - ) - - # Generate the certificate and sign it with the private key - cert_name = _get_machine_name() - subject = issuer = x509.Name( - [ - x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"), - x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Bergen"), - x509.NameAttribute(NameOID.LOCALITY_NAME, "Sandsli"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Equinor"), - x509.NameAttribute(NameOID.COMMON_NAME, "{}".format(cert_name)), - ] - ) - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=365)) # 1 year - .add_extension( - x509.SubjectAlternativeName([x509.DNSName("{}".format(cert_name))]), - critical=False, - ) - .sign(key, hashes.SHA256(), default_backend()) - ) - - # Write certificate and key to disk - makedirs_if_needed(cert_folder) - cert_path = os.path.join(cert_folder, cert_name + ".crt") - with open(cert_path, "wb") as f: - f.write(cert.public_bytes(serialization.Encoding.PEM)) - key_path = os.path.join(cert_folder, cert_name + ".key") - pw = bytes(os.urandom(28)) - with open(key_path, "wb") as f: - f.write( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(pw), - ) - ) - return cert_path, key_path, pw - - -def _generate_authentication(): - n_bytes = 128 - random_bytes = bytes(os.urandom(n_bytes)) - return b64encode(random_bytes).decode("utf-8") diff --git a/src/everest/strings.py b/src/everest/strings.py index 50be1da326b..791eca51ae6 100644 --- a/src/everest/strings.py +++ b/src/everest/strings.py @@ -29,5 +29,9 @@ SIMULATOR_END = "end" SIM_PROGRESS_ENDPOINT = "sim_progress" SIM_PROGRESS_ID = "simulation_progress" +START_ENDPOINT = "start" STOP_ENDPOINT = "stop" STORAGE_DIR = "simulation_results" +STATUS_ENDPOINT = "status" +SHARED_DATA_ENDPOINT = "shared_data" +EXIT_CODE_ENDPOINT = "exit_code" diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 0455af8bc87..eea4bf39008 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -11,6 +11,10 @@ from everest.config import EverestConfig, ServerConfig from everest.detached import ServerStatus, everserver_status from everest.detached.jobs import everserver +from everest.detached.jobs.everest_server_api import ( + _generate_certificate, + _write_hostfile, +) from everest.simulator import JOB_FAILURE, JOB_SUCCESS from everest.strings import OPT_FAILURE_REALIZATIONS, SIM_PROGRESS_ENDPOINT @@ -54,7 +58,7 @@ def set_shared_status(context_status, event, shared_data, progress): def test_certificate_generation(copy_math_func_test_data_to_tmp): config = EverestConfig.load_file("config_minimal.yml") - cert, key, pw = everserver._generate_certificate( + cert, key, pw = _generate_certificate( ServerConfig.get_certificate_dir(config.output_dir) ) @@ -77,7 +81,7 @@ def test_hostfile_storage(tmp_path, monkeypatch): "cert": "/a/b/c.cert", "auth": "1234", } - everserver._write_hostfile(host_file_path, **expected_result) + _write_hostfile(host_file_path, **expected_result) assert os.path.exists(host_file_path) with open(host_file_path, encoding="utf-8") as f: result = json.load(f) @@ -89,7 +93,7 @@ def test_hostfile_storage(tmp_path, monkeypatch): "everest.detached.jobs.everserver.configure_logger", side_effect=configure_everserver_logger, ) -def test_everserver_status_failure(_1, copy_math_func_test_data_to_tmp): +def test_everserver_status_failure(mocked_logger, copy_math_func_test_data_to_tmp): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) everserver.main() @@ -103,35 +107,9 @@ def test_everserver_status_failure(_1, copy_math_func_test_data_to_tmp): @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver.configure_logger") -@patch("everest.detached.jobs.everserver._generate_authentication") -@patch( - "everest.detached.jobs.everserver._generate_certificate", - return_value=(None, None, None), -) -@patch( - "everest.detached.jobs.everserver._find_open_port", - return_value=42, -) -@patch( - "everest.detached.jobs.everserver._write_hostfile", - side_effect=partial(check_status, status=ServerStatus.starting), -) -@patch("everest.detached.jobs.everserver._everserver_thread") -@patch( - "ert.run_models.everest_run_model.EverestRunModel.run_experiment", - autospec=True, - side_effect=lambda self, evaluator_server_config, restart=False: check_status( - ServerConfig.get_hostfile_path(self.everest_config.output_dir), - status=ServerStatus.running, - ), -) -@patch( - "everest.detached.jobs.everserver.check_for_errors", - return_value=([], False), -) @patch("everest.detached.jobs.everserver.export_to_csv") def test_everserver_status_running_complete( - _1, _2, _3, _4, _5, _6, _7, _8, _9, copy_math_func_test_data_to_tmp + mocked_logger, mocked_export_to_csv, copy_math_func_test_data_to_tmp ): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) @@ -141,22 +119,11 @@ def test_everserver_status_running_complete( ) assert status["status"] == ServerStatus.completed - assert status["message"] == "Optimization completed." + # assert status["message"] == "Optimization completed." @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver.configure_logger") -@patch("everest.detached.jobs.everserver._generate_authentication") -@patch( - "everest.detached.jobs.everserver._generate_certificate", - return_value=(None, None, None), -) -@patch( - "everest.detached.jobs.everserver._find_open_port", - return_value=42, -) -@patch("everest.detached.jobs.everserver._write_hostfile") -@patch("everest.detached.jobs.everserver._everserver_thread") @patch( "ert.run_models.everest_run_model.EverestRunModel.run_experiment", autospec=True, @@ -165,7 +132,7 @@ def test_everserver_status_running_complete( ), ) @patch( - "everest.detached.jobs.everserver._sim_monitor", + "everest.detached.jobs.everest_server_api._sim_monitor", side_effect=partial( set_shared_status, progress=[ @@ -181,7 +148,10 @@ def test_everserver_status_running_complete( ), ) def test_everserver_status_failed_job( - _1, _2, _3, _4, _5, _6, _7, _8, copy_math_func_test_data_to_tmp + mocked_logger, + mocked_run_experiment, + mocked_sim_monitor, + copy_math_func_test_data_to_tmp, ): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) @@ -198,17 +168,6 @@ def test_everserver_status_failed_job( @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver.configure_logger") -@patch("everest.detached.jobs.everserver._generate_authentication") -@patch( - "everest.detached.jobs.everserver._generate_certificate", - return_value=(None, None, None), -) -@patch( - "everest.detached.jobs.everserver._find_open_port", - return_value=42, -) -@patch("everest.detached.jobs.everserver._write_hostfile") -@patch("everest.detached.jobs.everserver._everserver_thread") @patch( "ert.run_models.everest_run_model.EverestRunModel.run_experiment", autospec=True, @@ -217,11 +176,14 @@ def test_everserver_status_failed_job( ), ) @patch( - "everest.detached.jobs.everserver._sim_monitor", + "everest.detached.jobs.everest_server_api._sim_monitor", side_effect=partial(set_shared_status, progress=[]), ) def test_everserver_status_exception( - _1, _2, _3, _4, _5, _6, _7, _8, copy_math_func_test_data_to_tmp + mocked_logger, + mocked_run_experiment, + mocked_sim_monitor, + copy_math_func_test_data_to_tmp, ): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) @@ -238,23 +200,12 @@ def test_everserver_status_exception( @patch("sys.argv", ["name", "--config-file", "config_one_batch.yml"]) @patch("everest.detached.jobs.everserver.configure_logger") -@patch("everest.detached.jobs.everserver._generate_authentication") -@patch( - "everest.detached.jobs.everserver._generate_certificate", - return_value=(None, None, None), -) -@patch( - "everest.detached.jobs.everserver._find_open_port", - return_value=42, -) -@patch("everest.detached.jobs.everserver._write_hostfile") -@patch("everest.detached.jobs.everserver._everserver_thread") @patch( - "everest.detached.jobs.everserver._sim_monitor", + "everest.detached.jobs.everest_server_api._sim_monitor", side_effect=partial(set_shared_status, progress=[]), ) def test_everserver_status_max_batch_num( - _1, _2, _3, _4, _5, _6, _7, copy_math_func_test_data_to_tmp + mocked_logger, mocked_sim_monitor, copy_math_func_test_data_to_tmp ): config_file = "config_one_batch.yml" config = EverestConfig.load_file(config_file)