Skip to content

Commit

Permalink
Split everserver functionality between starting server and submitting…
Browse files Browse the repository at this point in the history
… experiment
  • Loading branch information
frode-aarstad committed Nov 28, 2024
1 parent ca1806e commit bc30caf
Show file tree
Hide file tree
Showing 4 changed files with 484 additions and 377 deletions.
371 changes: 371 additions & 0 deletions src/everest/detached/jobs/everest_server_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
import json
import logging
import os
import socket
import ssl
import threading
import traceback
from base64 import b64encode
from datetime import datetime, timedelta
from functools import partial
from typing import Any, Optional

import uvicorn
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from dns import resolver, reversename
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import (
JSONResponse,
PlainTextResponse,
Response,
)
from fastapi.security import (
HTTPBasic,
HTTPBasicCredentials,
)
from pydantic import BaseModel

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.everest_run_model import EverestRunModel
from everest.config import EverestConfig, ServerConfig
from everest.detached import get_opt_status
from everest.strings import (
EXIT_CODE_ENDPOINT,
OPT_PROGRESS_ENDPOINT,
SHARED_DATA_ENDPOINT,
SIM_PROGRESS_ENDPOINT,
START_ENDPOINT,
STOP_ENDPOINT,
)
from everest.util import makedirs_if_needed


def _get_machine_name() -> str:
"""Returns a name that can be used to identify this machine in a network
A fully qualified domain name is returned if available. Otherwise returns
the string `localhost`
"""
hostname = socket.gethostname()
try:
# We need the ip-address to perform a reverse lookup to deal with
# differences in how the clusters are getting their fqdn's
ip_addr = socket.gethostbyname(hostname)
reverse_name = reversename.from_address(ip_addr)
resolved_hosts = [
str(ptr_record).rstrip(".")
for ptr_record in resolver.resolve(reverse_name, "PTR") # type: ignore
]
resolved_hosts.sort()
return resolved_hosts[0]
except (resolver.NXDOMAIN, resolver.NoResolverConfiguration):
# If local address and reverse lookup not working - fallback
# to socket fqdn which are using /etc/hosts to retrieve this name
return socket.getfqdn()
except socket.gaierror:
logging.debug(traceback.format_exc())
return "localhost"


def _find_open_port(host: str, lower: int, upper: int) -> int:
for port in range(lower, upper):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((host, port))
sock.close()
return port
except socket.error:
logging.getLogger("everserver").info(
"Port {} for host {} is taken".format(port, host)
)
msg = "No open port for host {} in the range {}-{}".format(host, lower, upper)
logging.getLogger("everserver").exception(msg)
raise Exception(msg)


def _write_hostfile(host_file_path, host, port, cert, auth) -> None:
if not os.path.exists(os.path.dirname(host_file_path)):
os.makedirs(os.path.dirname(host_file_path))
data = {
"host": host,
"port": port,
"cert": cert,
"auth": auth,
}
json_string = json.dumps(data)

with open(host_file_path, "w", encoding="utf-8") as f:
f.write(json_string)


def _generate_authentication() -> str:
n_bytes = 128
random_bytes = bytes(os.urandom(n_bytes))
return b64encode(random_bytes).decode("utf-8")


def _generate_certificate(cert_folder: str):
"""Generate a private key and a certificate signed with it
Both the certificate and the key are written to files in the folder given
by `get_certificate_dir(config)`. The key is encrypted before being
stored.
Returns the path to the certificate file, the path to the key file, and
the password used for encrypting the key
"""
# Generate private key
key = rsa.generate_private_key(
public_exponent=65537, key_size=4096, backend=default_backend()
)

# Generate the certificate and sign it with the private key
cert_name = _get_machine_name()
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Bergen"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Sandsli"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Equinor"),
x509.NameAttribute(NameOID.COMMON_NAME, "{}".format(cert_name)),
]
)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.utcnow())
.not_valid_after(datetime.utcnow() + timedelta(days=365)) # 1 year
.add_extension(
x509.SubjectAlternativeName([x509.DNSName("{}".format(cert_name))]),
critical=False,
)
.sign(key, hashes.SHA256(), default_backend())
)

# Write certificate and key to disk
makedirs_if_needed(cert_folder)
cert_path = os.path.join(cert_folder, cert_name + ".crt")
with open(cert_path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
key_path = os.path.join(cert_folder, cert_name + ".key")
pw = bytes(os.urandom(28))
with open(key_path, "wb") as f:
f.write(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.BestAvailableEncryption(pw),
)
)
return cert_path, key_path, pw


def _sim_monitor(context_status, event=None, shared_data=None):
status = context_status["status"]
assert shared_data
shared_data[SIM_PROGRESS_ENDPOINT] = {
"batch_number": context_status["batch_number"],
"status": {
"running": status.running,
"waiting": status.waiting,
"pending": status.pending,
"complete": status.complete,
"failed": status.failed,
},
"progress": context_status["progress"],
"event": event,
}

if shared_data[STOP_ENDPOINT]:
return "stop_queue"


def _opt_monitor(shared_data=None):
assert shared_data
if shared_data[STOP_ENDPOINT]:
return "stop_optimization"


class ExitCode(BaseModel):
exit_code: Optional[Any] = None
message: Optional[str] = None


class ExperimentRunner(threading.Thread):
def __init__(self, everest_config, state: dict):
super().__init__()

self.everest_config = everest_config
self.state = state
self.exit_code = None

def run(self):
run_model = EverestRunModel.create(
self.everest_config,
simulation_callback=partial(_sim_monitor, shared_data=self.state),
optimization_callback=partial(_opt_monitor, shared_data=self.state),
)

evaluator_server_config = EvaluatorServerConfig(
custom_port_range=range(49152, 51819)
if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL
else None
)

try:
run_model.run_experiment(evaluator_server_config)
self.exit_code = ExitCode(exit_code=run_model.exit_code)
except Exception:
self.exit_code = ExitCode(message=traceback.format_exc())

def get_exit_code(self) -> Optional[ExitCode]:
return self.exit_code


security = HTTPBasic()


class EverestServerAPI(threading.Thread):
def __init__(self, everest_config: EverestConfig):
super().__init__()

self.app = FastAPI()

self.router = APIRouter()
self.router.add_api_route("/", self.get_status, methods=["GET"])
self.router.add_api_route("/" + STOP_ENDPOINT, self.stop, methods=["POST"])
self.router.add_api_route(
"/" + SIM_PROGRESS_ENDPOINT, self.get_sim_progress, methods=["GET"]
)
self.router.add_api_route(
"/" + OPT_PROGRESS_ENDPOINT, self.get_opt_progress, methods=["GET"]
)
self.router.add_api_route(
"/" + START_ENDPOINT, self.start_experiment, methods=["POST"]
)
self.router.add_api_route(
"/" + EXIT_CODE_ENDPOINT, self.get_exit_code, methods=["GET"]
)
self.router.add_api_route(
"/" + SHARED_DATA_ENDPOINT, self.get_state, methods=["GET"]
)

self.app.include_router(self.router)

self.state = {
SIM_PROGRESS_ENDPOINT: {},
STOP_ENDPOINT: False,
}

self.everest_config = everest_config
self.output_dir = everest_config.output_dir
self.optimization_output_dir = everest_config.optimization_output_dir

# same code is in ensemble evaluator
self.authentication = _generate_authentication()

# same code is in ensemble evaluator
self.cert_path, self.key_path, self.key_pw = _generate_certificate(
ServerConfig.get_certificate_dir(self.output_dir)
)
self.host = _get_machine_name()
self.port = _find_open_port(self.host, lower=5000, upper=5800)

host_file = ServerConfig.get_hostfile_path(self.output_dir)
_write_hostfile(
host_file, self.host, self.port, self.cert_path, self.authentication
)

def run(self):
uvicorn.run(
self.app,
host="0.0.0.0",
port=self.port,
ssl_keyfile=self.key_path,
ssl_certfile=self.cert_path,
ssl_version=ssl.PROTOCOL_SSLv23,
ssl_keyfile_password=self.key_pw,
)

def _check_user(self, credentials: HTTPBasicCredentials) -> None:
if credentials.password != self.authentication:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)

def _log(self, request: Request) -> None:
logging.getLogger("everserver").info(
f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}"
)

def get_status(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> PlainTextResponse:
self._log(request)
self._check_user(credentials)
return PlainTextResponse("Everest is running")

def stop(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> Response:
self._log(request)
self._check_user(credentials)
self.state[STOP_ENDPOINT] = True
return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200)

def get_sim_progress(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
self._log(request)
self._check_user(credentials)
progress = self.state[SIM_PROGRESS_ENDPOINT]
return JSONResponse(jsonable_encoder(progress))

def get_exit_code(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
self._log(request)
self._check_user(credentials)
return JSONResponse(
jsonable_encoder(
self.runner.get_exit_code() if self.runner.get_exit_code() else {}
)
)

def get_opt_progress(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
self._log(request)
self._check_user(credentials)
progress = get_opt_status(self.optimization_output_dir)
return JSONResponse(jsonable_encoder(progress))

def start_experiment(
self,
request: Request,
credentials: HTTPBasicCredentials = Depends(security),
) -> Response:
self._log(request)
self._check_user(credentials)

self.runner = ExperimentRunner(self.everest_config, self.state)
self.runner.start()

return Response("Everest experiment started", 200)

def get_state(
self, request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
self._log(request)
self._check_user(credentials)
return JSONResponse(jsonable_encoder(self.state))
Loading

0 comments on commit bc30caf

Please sign in to comment.