diff --git a/src/everest/detached/jobs/__init__.py b/src/everest/detached/jobs/__init__.py index d7657dc1afa..e69de29bb2d 100644 --- a/src/everest/detached/jobs/__init__.py +++ b/src/everest/detached/jobs/__init__.py @@ -1,177 +0,0 @@ - -import argparse -import json -import logging -import os -import socket -import ssl -import threading -import traceback -from base64 import b64encode -from datetime import datetime, timedelta -from functools import partial - -import uvicorn -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.x509.oid import NameOID -from dns import resolver, reversename -from fastapi import Depends, FastAPI, HTTPException, Request, status -from fastapi.encoders import jsonable_encoder -from fastapi.responses import ( - JSONResponse, - PlainTextResponse, - Response, -) -from fastapi.security import ( - HTTPBasic, - HTTPBasicCredentials, -) -from ropt.enums import OptimizerExitCode - -from ert.config import QueueSystem -from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.everest_run_model import EverestRunModel -from everest import export_to_csv, export_with_progress -from everest.config import EverestConfig, ServerConfig -from everest.detached import ServerStatus, get_opt_status, update_everserver_status -from everest.export import check_for_errors -from everest.simulator import JOB_FAILURE -from everest.strings import ( - EVEREST, - OPT_FAILURE_REALIZATIONS, - OPT_PROGRESS_ENDPOINT, - SIM_PROGRESS_ENDPOINT, - STOP_ENDPOINT, -) -from everest.util import configure_logger, makedirs_if_needed, version_info - - -def _get_machine_name() -> str: - """Returns a name that can be used to identify this machine in a network - - A fully qualified domain name is returned if available. Otherwise returns - the string `localhost` - """ - hostname = socket.gethostname() - try: - # We need the ip-address to perform a reverse lookup to deal with - # differences in how the clusters are getting their fqdn's - ip_addr = socket.gethostbyname(hostname) - reverse_name = reversename.from_address(ip_addr) - resolved_hosts = [ - str(ptr_record).rstrip(".") - for ptr_record in resolver.resolve(reverse_name, "PTR") - ] - resolved_hosts.sort() - return resolved_hosts[0] - except (resolver.NXDOMAIN, resolver.NoResolverConfiguration): - # If local address and reverse lookup not working - fallback - # to socket fqdn which are using /etc/hosts to retrieve this name - return socket.getfqdn() - except socket.gaierror: - logging.debug(traceback.format_exc()) - return "localhost" - - - - -def _find_open_port(host, lower, upper) -> int: - for port in range(lower, upper): - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind((host, port)) - sock.close() - return port - except socket.error: - logging.getLogger("everserver").info( - "Port {} for host {} is taken".format(port, host) - ) - msg = "No open port for host {} in the range {}-{}".format(host, lower, upper) - logging.getLogger("everserver").exception(msg) - raise Exception(msg) - - -def _write_hostfile(host_file_path, host, port, cert, auth) -> None: - if not os.path.exists(os.path.dirname(host_file_path)): - os.makedirs(os.path.dirname(host_file_path)) - data = { - "host": host, - "port": port, - "cert": cert, - "auth": auth, - } - json_string = json.dumps(data) - - with open(host_file_path, "w", encoding="utf-8") as f: - f.write(json_string) - - - - - -def _generate_authentication(): - n_bytes = 128 - random_bytes = bytes(os.urandom(n_bytes)) - return b64encode(random_bytes).decode("utf-8") - - - -def _generate_certificate(cert_folder: str): - """Generate a private key and a certificate signed with it - - Both the certificate and the key are written to files in the folder given - by `get_certificate_dir(config)`. The key is encrypted before being - stored. - Returns the path to the certificate file, the path to the key file, and - the password used for encrypting the key - """ - # Generate private key - key = rsa.generate_private_key( - public_exponent=65537, key_size=4096, backend=default_backend() - ) - - # Generate the certificate and sign it with the private key - cert_name = _get_machine_name() - subject = issuer = x509.Name( - [ - x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"), - x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Bergen"), - x509.NameAttribute(NameOID.LOCALITY_NAME, "Sandsli"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Equinor"), - x509.NameAttribute(NameOID.COMMON_NAME, "{}".format(cert_name)), - ] - ) - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=365)) # 1 year - .add_extension( - x509.SubjectAlternativeName([x509.DNSName("{}".format(cert_name))]), - critical=False, - ) - .sign(key, hashes.SHA256(), default_backend()) - ) - - # Write certificate and key to disk - makedirs_if_needed(cert_folder) - cert_path = os.path.join(cert_folder, cert_name + ".crt") - with open(cert_path, "wb") as f: - f.write(cert.public_bytes(serialization.Encoding.PEM)) - key_path = os.path.join(cert_folder, cert_name + ".key") - pw = bytes(os.urandom(28)) - with open(key_path, "wb") as f: - f.write( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(pw), - ) - ) - return cert_path, key_path, pw \ No newline at end of file diff --git a/src/everest/detached/jobs/everest_server_api.py b/src/everest/detached/jobs/everest_server_api.py index 4cb175f73dc..b1e67482d42 100644 --- a/src/everest/detached/jobs/everest_server_api.py +++ b/src/everest/detached/jobs/everest_server_api.py @@ -1,4 +1,3 @@ -import argparse import json import logging import os @@ -17,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID from dns import resolver, reversename -from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status from fastapi.encoders import jsonable_encoder from fastapi.responses import ( JSONResponse, @@ -28,125 +27,306 @@ HTTPBasic, HTTPBasicCredentials, ) -from ropt.enums import OptimizerExitCode from ert.config import QueueSystem from ert.ensemble_evaluator import EvaluatorServerConfig from ert.run_models.everest_run_model import EverestRunModel -from everest import export_to_csv, export_with_progress from everest.config import EverestConfig, ServerConfig -from everest.detached import ServerStatus, get_opt_status, update_everserver_status -from everest.detached.jobs import _find_open_port, _generate_authentication, _generate_certificate, _get_machine_name, _write_hostfile -from everest.export import check_for_errors -from everest.simulator import JOB_FAILURE +from everest.detached import get_opt_status from everest.strings import ( - EVEREST, - OPT_FAILURE_REALIZATIONS, - OPT_PROGRESS_ENDPOINT, SIM_PROGRESS_ENDPOINT, STOP_ENDPOINT, ) -from everest.util import configure_logger, makedirs_if_needed, version_info +from everest.util import makedirs_if_needed +def _get_machine_name() -> str: + """Returns a name that can be used to identify this machine in a network -def everest_server_api(): + A fully qualified domain name is returned if available. Otherwise returns + the string `localhost` + """ + hostname = socket.gethostname() + try: + # We need the ip-address to perform a reverse lookup to deal with + # differences in how the clusters are getting their fqdn's + ip_addr = socket.gethostbyname(hostname) + reverse_name = reversename.from_address(ip_addr) + resolved_hosts = [ + str(ptr_record).rstrip(".") + for ptr_record in resolver.resolve(reverse_name, "PTR") + ] + resolved_hosts.sort() + return resolved_hosts[0] + except (resolver.NXDOMAIN, resolver.NoResolverConfiguration): + # If local address and reverse lookup not working - fallback + # to socket fqdn which are using /etc/hosts to retrieve this name + return socket.getfqdn() + except socket.gaierror: + logging.debug(traceback.format_exc()) + return "localhost" - def __init__(output_dir:str, optimization_output_dir:str): - # same code is in ensemble evaluator - authentication = _generate_authentication() - - # same code is in ensemble evaluator - cert_path, key_path, key_pw = _generate_certificate( - ServerConfig.get_certificate_dir(output_dir) + +def _find_open_port(host, lower, upper) -> int: + for port in range(lower, upper): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind((host, port)) + sock.close() + return port + except socket.error: + logging.getLogger("everserver").info( + "Port {} for host {} is taken".format(port, host) + ) + msg = "No open port for host {} in the range {}-{}".format(host, lower, upper) + logging.getLogger("everserver").exception(msg) + raise Exception(msg) + + +def _write_hostfile(host_file_path, host, port, cert, auth) -> None: + if not os.path.exists(os.path.dirname(host_file_path)): + os.makedirs(os.path.dirname(host_file_path)) + data = { + "host": host, + "port": port, + "cert": cert, + "auth": auth, + } + json_string = json.dumps(data) + + with open(host_file_path, "w", encoding="utf-8") as f: + f.write(json_string) + + +def _generate_authentication(): + n_bytes = 128 + random_bytes = bytes(os.urandom(n_bytes)) + return b64encode(random_bytes).decode("utf-8") + + +def _generate_certificate(cert_folder: str): + """Generate a private key and a certificate signed with it + + Both the certificate and the key are written to files in the folder given + by `get_certificate_dir(config)`. The key is encrypted before being + stored. + Returns the path to the certificate file, the path to the key file, and + the password used for encrypting the key + """ + # Generate private key + key = rsa.generate_private_key( + public_exponent=65537, key_size=4096, backend=default_backend() + ) + + # Generate the certificate and sign it with the private key + cert_name = _get_machine_name() + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Bergen"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Sandsli"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Equinor"), + x509.NameAttribute(NameOID.COMMON_NAME, "{}".format(cert_name)), + ] + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=365)) # 1 year + .add_extension( + x509.SubjectAlternativeName([x509.DNSName("{}".format(cert_name))]), + critical=False, ) - host = _get_machine_name() - port = _find_open_port(host, lower=5000, upper=5800) + .sign(key, hashes.SHA256(), default_backend()) + ) + + # Write certificate and key to disk + makedirs_if_needed(cert_folder) + cert_path = os.path.join(cert_folder, cert_name + ".crt") + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + key_path = os.path.join(cert_folder, cert_name + ".key") + pw = bytes(os.urandom(28)) + with open(key_path, "wb") as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.BestAvailableEncryption(pw), + ) + ) + return cert_path, key_path, pw + + +def _sim_monitor(context_status, event=None, shared_data=None): + status = context_status["status"] + assert shared_data + shared_data[SIM_PROGRESS_ENDPOINT] = { + "batch_number": context_status["batch_number"], + "status": { + "running": status.running, + "waiting": status.waiting, + "pending": status.pending, + "complete": status.complete, + "failed": status.failed, + }, + "progress": context_status["progress"], + "event": event, + } + + if shared_data[STOP_ENDPOINT]: + return "stop_queue" - host_file = ServerConfig.get_hostfile_path(output_dir) - _write_hostfile(host_file, host, port, cert_path, authentication) +def _opt_monitor(shared_data=None): + assert shared_data + if shared_data[STOP_ENDPOINT]: + return "stop_optimization" + + +class ExperimentRunner(threading.Thread): + def __init__(self, everest_config, shared_data): + super().__init__() + + self.everest_config = everest_config + self.shared_data = shared_data + self.exit_code = None + + def run(self): run_model = EverestRunModel.create( - config, - simulation_callback=partial(_sim_monitor, shared_data=shared_data), - optimization_callback=partial(_opt_monitor, shared_data=shared_data), + self.everest_config, + simulation_callback=partial(_sim_monitor, shared_data=self.shared_data), + optimization_callback=partial(_opt_monitor, shared_data=self.shared_data), ) + evaluator_server_config = EvaluatorServerConfig( + custom_port_range=range(49152, 51819) + if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL + else None + ) + run_model.run_experiment(evaluator_server_config) + self.exit_code = run_model.exit_code + print("RUN DONE") - def start(self): - uvicorn.run( - app, - host="0.0.0.0", - port=server_config["port"], - ssl_keyfile=server_config["key_path"], - ssl_certfile=server_config["cert_path"], - ssl_version=ssl.PROTOCOL_SSLv23, - ssl_keyfile_password=server_config["key_passwd"], + def exit_code(self): + return self.exit_code + + +security = HTTPBasic() + + +class EverestServerAPI(threading.Thread): + def __init__(self, everest_config: EverestConfig, shared_data: dict): + super().__init__() + + self.app = FastAPI() + + self.router = APIRouter() + self.router.add_api_route("/", self.get_status, methods=["GET"]) + self.router.add_api_route("/stop", self.stop, methods=["POST"]) + self.router.add_api_route( + "/sim_progress", self.get_sim_progress, methods=["GET"] + ) + self.router.add_api_route( + "/opt_progress", self.get_opt_progress, methods=["GET"] ) + self.router.add_api_route("/start", self.start_experiment, methods=["POST"]) + + self.router.add_api_route("/exit_code", self.get_exit_code, methods=["GET"]) - + self.app.include_router(self.router) + self.shared_data = shared_data + self.everest_config = everest_config + self.output_dir = everest_config.output_dir + self.optimization_output_dir = everest_config.optimization_output_dir + # same code is in ensemble evaluator + self.authentication = _generate_authentication() + # same code is in ensemble evaluator + self.cert_path, self.key_path, self.key_pw = _generate_certificate( + ServerConfig.get_certificate_dir(self.output_dir) + ) + self.host = _get_machine_name() + self.port = _find_open_port(self.host, lower=5000, upper=5800) + host_file = ServerConfig.get_hostfile_path(self.output_dir) + _write_hostfile( + host_file, self.host, self.port, self.cert_path, self.authentication + ) + def run(self): + uvicorn.run( + self.app, + host="0.0.0.0", + port=self.port, + ssl_keyfile=self.key_path, + ssl_certfile=self.cert_path, + ssl_version=ssl.PROTOCOL_SSLv23, + ssl_keyfile_password=self.key_pw, + ) - def _check_user(credentials: HTTPBasicCredentials) -> None: - if credentials.password != server_config["authentication"]: + def _check_user(self, credentials: HTTPBasicCredentials) -> None: + if credentials.password != self.authentication: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials", headers={"WWW-Authenticate": "Basic"}, ) - def _log(request: Request) -> None: + def _log(self, request: Request) -> None: logging.getLogger("everserver").info( f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}" ) - - - - app = FastAPI() - security = HTTPBasic() - - - @app.get("/") def get_status( - request: Request, credentials: HTTPBasicCredentials = Depends(security) + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) ) -> PlainTextResponse: - _log(request) - _check_user(credentials) + self._log(request) + self._check_user(credentials) return PlainTextResponse("Everest is running") - @app.post("/" + STOP_ENDPOINT) def stop( - request: Request, credentials: HTTPBasicCredentials = Depends(security) + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) ) -> Response: - _log(request) - _check_user(credentials) - shared_data[STOP_ENDPOINT] = True + self._log(request) + self._check_user(credentials) + self.shared_data[STOP_ENDPOINT] = True return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) - @app.get("/" + SIM_PROGRESS_ENDPOINT) def get_sim_progress( - request: Request, credentials: HTTPBasicCredentials = Depends(security) + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) ) -> JSONResponse: - _log(request) - _check_user(credentials) - progress = shared_data[SIM_PROGRESS_ENDPOINT] + self._log(request) + self._check_user(credentials) + progress = self.shared_data[SIM_PROGRESS_ENDPOINT] + print(self.runner.exit_code) return JSONResponse(jsonable_encoder(progress)) - @app.get("/" + OPT_PROGRESS_ENDPOINT) + def get_exit_code( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + return JSONResponse({"exit_code": self.runner.exit_code}) + def get_opt_progress( - request: Request, credentials: HTTPBasicCredentials = Depends(security) + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) ) -> JSONResponse: - _log(request) - _check_user(credentials) - progress = get_opt_status(server_config["optimization_output_dir"]) + self._log(request) + self._check_user(credentials) + progress = get_opt_status(self.optimization_output_dir) return JSONResponse(jsonable_encoder(progress)) - + def start_experiment( + self, request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + self.runner = ExperimentRunner(self.everest_config, self.shared_data) + self.runner.start() + return JSONResponse("ok") diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 46693da5247..785dfb7ba3f 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -2,78 +2,29 @@ import json import logging import os -import socket -import ssl -import threading +import time import traceback -from base64 import b64encode -from datetime import datetime, timedelta -from functools import partial - -import uvicorn -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.x509.oid import NameOID -from dns import resolver, reversename -from fastapi import Depends, FastAPI, HTTPException, Request, status -from fastapi.encoders import jsonable_encoder -from fastapi.responses import ( - JSONResponse, - PlainTextResponse, - Response, -) -from fastapi.security import ( - HTTPBasic, - HTTPBasicCredentials, -) + +import requests from ropt.enums import OptimizerExitCode -from ert.config import QueueSystem -from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.everest_run_model import EverestRunModel from everest import export_to_csv, export_with_progress from everest.config import EverestConfig, ServerConfig -from everest.detached import ServerStatus, get_opt_status, update_everserver_status +from everest.detached import ( + PROXY, + ServerStatus, + update_everserver_status, +) +from everest.detached.jobs.everest_server_api import EverestServerAPI from everest.export import check_for_errors from everest.simulator import JOB_FAILURE from everest.strings import ( EVEREST, OPT_FAILURE_REALIZATIONS, - OPT_PROGRESS_ENDPOINT, SIM_PROGRESS_ENDPOINT, STOP_ENDPOINT, ) -from everest.util import configure_logger, makedirs_if_needed, version_info - - - - -def _sim_monitor(context_status, event=None, shared_data=None): - status = context_status["status"] - shared_data[SIM_PROGRESS_ENDPOINT] = { - "batch_number": context_status["batch_number"], - "status": { - "running": status.running, - "waiting": status.waiting, - "pending": status.pending, - "complete": status.complete, - "failed": status.failed, - }, - "progress": context_status["progress"], - "event": event, - } - - if shared_data[STOP_ENDPOINT]: - return "stop_queue" - - -def _opt_monitor(shared_data=None): - if shared_data[STOP_ENDPOINT]: - return "stop_optimization" - - +from everest.util import configure_logger, version_info def _get_optimization_status(exit_code, shared_data): @@ -155,9 +106,6 @@ def _configure_loggers( ) - - - def main(): arg_parser = argparse.ArgumentParser() arg_parser.add_argument("--config-file", type=str) @@ -168,7 +116,6 @@ def main(): config.logging_level = "debug" detached_dir = ServerConfig.get_detached_node_dir(config.output_dir) status_path = ServerConfig.get_everserver_status_path(config.output_dir) - try: _configure_loggers( @@ -183,35 +130,18 @@ def main(): ) logging.getLogger(EVEREST).debug(str(options)) - - shared_data = { SIM_PROGRESS_ENDPOINT: {}, STOP_ENDPOINT: False, } - server_config = { - "optimization_output_dir": config.optimization_output_dir, - "port": port, - "cert_path": cert_path, - "key_path": key_path, - "key_passwd": key_pw, - "authentication": authentication, - } + everest_server_api = EverestServerAPI(config, shared_data) + everest_server_api.daemon = True + everest_server_api.start() - everest_server_api = everest_server_api(config.output_dir, config.optimization_output_dir) + server_context = (ServerConfig.get_server_context(config.output_dir),) + url, cert, auth = server_context[0] - - everserver_instance = threading.Thread( - target=_everserver_thread, - args=(shared_data, server_config), - ) - everserver_instance.daemon = True - everserver_instance.start() - - - - except: update_everserver_status( status_path, @@ -221,25 +151,33 @@ def main(): return try: - update_everserver_status(status_path, ServerStatus.running) - - - evaluator_server_config = EvaluatorServerConfig( - custom_port_range=range(49152, 51819) - if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL - else None - ) - - run_model.run_experiment(evaluator_server_config) - - ## yield - - + # wait until the api server is running + is_running = False + while not is_running: + try: + requests.get(url + "/", verify=cert, auth=auth, proxies=PROXY) + # check return value + is_running = True + except: + time.sleep(1) + update_everserver_status(status_path, ServerStatus.running) + # start + response = requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY) + is_running = True + while is_running: + response = requests.get( + url + "/exit_code", verify=cert, auth=auth, proxies=PROXY + ) + if json_body := json.loads(response.text): + if exit_code := json_body["exit_code"]: + is_running = False + else: + time.sleep(1) - status, message = _get_optimization_status(run_model.exit_code, shared_data) + status, message = _get_optimization_status(exit_code, shared_data) if status != ServerStatus.completed: update_everserver_status(status_path, status, message) return @@ -287,6 +225,3 @@ def main(): return update_everserver_status(status_path, ServerStatus.completed, message=message) - - -