From 128f7e798a815916b311898e3df8a59dbe2b22e8 Mon Sep 17 00:00:00 2001 From: Frode Aarstad Date: Mon, 18 Nov 2024 12:49:57 +0100 Subject: [PATCH] Replace Flask with FastAPI --- pyproject.toml | 7 +- src/everest/detached/jobs/everserver.py | 134 +++++++++++++----------- 2 files changed, 79 insertions(+), 62 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e3f12f4996..2b29885f005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,7 +146,7 @@ types = [ everest = [ "progressbar2", "ruamel.yaml", - "flask", + "fastapi", "decorator", "resdata", "colorama", @@ -235,6 +235,11 @@ allowed-confusables = ["–"] [tool.ruff.lint.pylint] max-args = 20 +[tool.ruff.lint.flake8-bugbear] +extend-immutable-calls = [ + "fastapi.Depends", +] + [tool.pyright] include = ["src"] exclude = ["tests"] diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index f75c06a1508..78fa537dd1c 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -8,15 +8,27 @@ import traceback from base64 import b64encode from datetime import datetime, timedelta -from functools import partial, wraps +from functools import partial +# from flask import Flask, Response, jsonify, request +import uvicorn from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID from dns import resolver, reversename -from flask import Flask, Response, jsonify, request +from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi.encoders import jsonable_encoder +from fastapi.responses import ( + JSONResponse, + PlainTextResponse, + Response, +) +from fastapi.security import ( + HTTPBasic, + HTTPBasicCredentials, +) from ropt.enums import OptimizerExitCode from ert.config import QueueSystem @@ -37,7 +49,7 @@ from everest.util import configure_logger, makedirs_if_needed, version_info -def get_machine_name(): +def _get_machine_name() -> str: """Returns a name that can be used to identify this machine in a network A fully qualified domain name is returned if available. Otherwise returns @@ -88,71 +100,70 @@ def _opt_monitor(shared_data=None): return "stop_optimization" -def _everserver_thread(shared_data, server_config): - app = Flask(__name__) - - def check_user(password): - return password == server_config["authentication"] +def _everserver_thread(shared_data, server_config) -> None: + app = FastAPI() + security = HTTPBasic() - def requires_authenticated(f): - @wraps(f) - def decorated(*args, **kwargs): - auth = request.authorization - if not auth or not check_user(auth.password): - return "unauthorized", 401 - return f(*args, **kwargs) - - return decorated - - def log(f): - @wraps(f) - def decorated(*args, **kwargs): - url = request.path - method = request.method - ip = request.environ.get("HTTP_X_REAL_IP", request.remote_addr) - logging.getLogger("everserver").info( - "{} entered from {} with HTTP {}".format(url, ip, method) + def _check_user(credentials: HTTPBasicCredentials) -> None: + if credentials.password != server_config["authentication"]: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", + headers={"WWW-Authenticate": "Basic"}, ) - return f(*args, **kwargs) - return decorated - - @app.route("/") - @requires_authenticated - @log - def get_home(): - return "Everest is running" + def _log(request: Request) -> None: + logging.getLogger("everserver").info( + f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}" + ) - @app.route("/" + STOP_ENDPOINT, methods=["POST"]) - @requires_authenticated - @log - def stop(): + @app.get("/") + def get_status( + request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> PlainTextResponse: + _log(request) + _check_user(credentials) + return PlainTextResponse("Everest is running") + + @app.post("/" + STOP_ENDPOINT) + def stop( + request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> Response: + _log(request) + _check_user(credentials) shared_data[STOP_ENDPOINT] = True return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) - @app.route("/" + SIM_PROGRESS_ENDPOINT) - @requires_authenticated - @log - def get_sim_progress(): - return jsonify(shared_data[SIM_PROGRESS_ENDPOINT]) - - @app.route("/" + OPT_PROGRESS_ENDPOINT) - @requires_authenticated - @log - def get_opt_progress(): + @app.get("/" + SIM_PROGRESS_ENDPOINT) + def get_sim_progress( + request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + _log(request) + _check_user(credentials) + progress = shared_data[SIM_PROGRESS_ENDPOINT] + return JSONResponse(jsonable_encoder(progress)) + + @app.get("/" + OPT_PROGRESS_ENDPOINT) + def get_opt_progress( + request: Request, credentials: HTTPBasicCredentials = Depends(security) + ) -> JSONResponse: + _log(request) + _check_user(credentials) progress = get_opt_status(server_config["optimization_output_dir"]) - return jsonify(progress) - - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx.load_cert_chain( - server_config["cert_path"], - server_config["key_path"], - server_config["key_passwd"], + return JSONResponse(jsonable_encoder(progress)) + + uvicorn.run( + app, + host="0.0.0.0", + port=server_config["port"], + ssl_keyfile=server_config["key_path"], + ssl_certfile=server_config["cert_path"], + ssl_version=ssl.PROTOCOL_SSLv23, + ssl_keyfile_password=server_config["key_passwd"], ) - app.run(host="0.0.0.0", port=server_config["port"], ssl_context=ctx) -def _find_open_port(host, lower, upper): +def _find_open_port(host, lower, upper) -> int: for port in range(lower, upper): try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -168,7 +179,7 @@ def _find_open_port(host, lower, upper): raise Exception(msg) -def _write_hostfile(host_file_path, host, port, cert, auth): +def _write_hostfile(host_file_path, host, port, cert, auth) -> None: if not os.path.exists(os.path.dirname(host_file_path)): os.makedirs(os.path.dirname(host_file_path)) data = { @@ -187,7 +198,7 @@ def _configure_loggers( detached_node_dir: str, everest_logs_dir: str, logging_level: int, -): +) -> None: configure_logger( name="res", file_path=os.path.join(detached_node_dir, "simulations.log"), @@ -249,7 +260,7 @@ def main(): cert_path, key_path, key_pw = _generate_certificate( ServerConfig.get_certificate_dir(config.output_dir) ) - host = get_machine_name() + host = _get_machine_name() port = _find_open_port(host, lower=5000, upper=5800) _write_hostfile(host_file, host, port, cert_path, authentication) @@ -344,6 +355,7 @@ def main(): message=traceback.format_exc(), ) return + update_everserver_status(status_path, ServerStatus.completed, message=message) @@ -404,7 +416,7 @@ def _generate_certificate(cert_folder: str): ) # Generate the certificate and sign it with the private key - cert_name = get_machine_name() + cert_name = _get_machine_name() subject = issuer = x509.Name( [ x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"),