diff --git a/docs/ert/conf.py b/docs/ert/conf.py index 71a069bd046..74b0a485078 100644 --- a/docs/ert/conf.py +++ b/docs/ert/conf.py @@ -67,7 +67,6 @@ ("py:class", "pydantic.types.PositiveInt"), ("py:class", "LibresFacade"), ("py:class", "pandas.core.frame.DataFrame"), - ("py:class", "websockets.server.WebSocketServerProtocol"), ("py:class", "EnsembleReader"), ] nitpick_ignore_regex = [ diff --git a/pyproject.toml b/pyproject.toml index 7a226b0c46a..6776ed9a154 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dependencies = [ "python-dateutil", "python-multipart", # extra dependency for fastapi "pyyaml", + "pyzmq", "qtpy", "requests", "resfo", @@ -68,7 +69,6 @@ dependencies = [ "tqdm>=4.62.0", "typing_extensions>=4.5", "uvicorn >= 0.17.0", - "websockets", "xarray", "xtgeo >= 3.3.0", ] diff --git a/src/_ert/forward_model_runner/cli.py b/src/_ert/forward_model_runner/cli.py index e176d066f4d..c3ab16bdc97 100644 --- a/src/_ert/forward_model_runner/cli.py +++ b/src/_ert/forward_model_runner/cli.py @@ -22,7 +22,6 @@ def _setup_reporters( ens_id, dispatch_url, ee_token=None, - ee_cert_path=None, experiment_id=None, ) -> list[reporting.Reporter]: reporters: list[reporting.Reporter] = [] @@ -30,11 +29,7 @@ def _setup_reporters( reporters.append(reporting.Interactive()) elif ens_id and experiment_id is None: reporters.append(reporting.File()) - reporters.append( - reporting.Event( - evaluator_url=dispatch_url, token=ee_token, cert_path=ee_cert_path - ) - ) + reporters.append(reporting.Event(evaluator_url=dispatch_url, token=ee_token)) else: reporters.append(reporting.File()) return reporters @@ -123,7 +118,6 @@ def main(args): experiment_id = jobs_data.get("experiment_id") ens_id = jobs_data.get("ens_id") ee_token = jobs_data.get("ee_token") - ee_cert_path = jobs_data.get("ee_cert_path") dispatch_url = jobs_data.get("dispatch_url") is_interactive_run = len(parsed_args.job) > 0 @@ -132,7 +126,6 @@ def main(args): ens_id, dispatch_url, ee_token, - ee_cert_path, experiment_id, ) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index ea798522b86..2c75dd1518f 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -1,18 +1,13 @@ +from __future__ import annotations + import asyncio import logging -import ssl -from typing import Any, AnyStr, Self - -from websockets.asyncio.client import ClientConnection, connect -from websockets.datastructures import Headers -from websockets.exceptions import ( - ConnectionClosedError, - ConnectionClosedOK, - InvalidHandshake, - InvalidURI, -) +import uuid +from abc import abstractmethod +from typing import Any, Self -from _ert.async_utils import new_event_loop +import zmq +import zmq.asyncio logger = logging.getLogger(__name__) @@ -21,112 +16,140 @@ class ClientConnectionError(Exception): pass -class ClientConnectionClosedOK(Exception): - pass +CONNECT_MSG = b"CONNECT" +DISCONNECT_MSG = b"DISCONNECT" +ACK_MSG = b"ACK" class Client: DEFAULT_MAX_RETRIES = 10 - DEFAULT_TIMEOUT_MULTIPLIER = 5 - CONNECTION_TIMEOUT = 60 - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None: - if self.websocket is not None: - self.loop.run_until_complete(self.websocket.close()) - self.loop.close() - - async def __aenter__(self) -> "Client": - return self - - async def __aexit__( - self, exc_type: Any, exc_value: Any, exc_traceback: Any - ) -> None: - if self.websocket is not None: - await self.websocket.close() + DEFAULT_ACK_TIMEOUT = 5 def __init__( self, url: str, token: str | None = None, - cert: str | bytes | None = None, - max_retries: int | None = None, - timeout_multiplier: int | None = None, + dealer_name: str | None = None, + ack_timeout: float | None = None, ) -> None: - if max_retries is None: - max_retries = self.DEFAULT_MAX_RETRIES - if timeout_multiplier is None: - timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER - if url is None: - raise ValueError("url was None") + self._ack_timeout = ack_timeout or self.DEFAULT_ACK_TIMEOUT self.url = url self.token = token - self._additional_headers = Headers() + + self._ack_event: asyncio.Event = asyncio.Event() + self.context = zmq.asyncio.Context() + self.socket = self.context.socket(zmq.DEALER) + # this is to avoid blocking the event loop when closing the socket + # wherein the linger is set to 0 to discard all messages in the queue + self.socket.setsockopt(zmq.LINGER, 0) + self.dealer_id = dealer_name or f"dispatch-{uuid.uuid4().hex[:8]}" + self.socket.setsockopt_string(zmq.IDENTITY, self.dealer_id) + if token is not None: - self._additional_headers["token"] = token - - # Mimics the behavior of the ssl argument when connection to - # websockets. If none is specified it will deduce based on the url, - # if True it will enforce TLS, and if you want to use self signed - # certificates you need to pass an ssl_context with the certificate - # loaded. - self._ssl_context: bool | ssl.SSLContext | None = None - if cert is not None: - self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - self._ssl_context.load_verify_locations(cadata=cert) - elif url.startswith("wss"): - self._ssl_context = True - - self._max_retries = max_retries - self._timeout_multiplier = timeout_multiplier - self.websocket: ClientConnection | None = None - self.loop = new_event_loop() - - async def get_websocket(self) -> ClientConnection: - return await connect( - self.url, - ssl=self._ssl_context, - additional_headers=self._additional_headers, - open_timeout=self.CONNECTION_TIMEOUT, - ping_timeout=self.CONNECTION_TIMEOUT, - ping_interval=self.CONNECTION_TIMEOUT, - close_timeout=self.CONNECTION_TIMEOUT, - ) + client_public, client_secret = zmq.curve_keypair() + self.socket.curve_secretkey = client_secret + self.socket.curve_publickey = client_public + self.socket.curve_serverkey = token.encode("utf-8") + + self._receiver_task: asyncio.Task[None] | None = None + + async def __aenter__(self) -> Self: + await self.connect() + return self - async def _send(self, msg: AnyStr) -> None: - for retry in range(self._max_retries + 1): + async def __aexit__( + self, exc_type: Any, exc_value: Any, exc_traceback: Any + ) -> None: + try: + await self.send(DISCONNECT_MSG) + except ClientConnectionError: + logger.error("No ack for dealer disconnection. Connection is down!") + finally: + self.socket.disconnect(self.url) + await self._term_receiver_task() + self.term() + + def term(self) -> None: + self.socket.close() + self.context.term() + + async def _term_receiver_task(self) -> None: + if self._receiver_task and not self._receiver_task.done(): + self._receiver_task.cancel() + await asyncio.gather(self._receiver_task, return_exceptions=True) + self._receiver_task = None + + async def connect(self) -> None: + self.socket.connect(self.url) + await self._term_receiver_task() + self._receiver_task = asyncio.create_task(self._receiver()) + try: + await self.send(CONNECT_MSG, retries=1) + except ClientConnectionError: + await self._term_receiver_task() + self.term() + raise + + @abstractmethod + async def process_message(self, msg: str) -> None: + """ + This method is implemented in the Monitor, which stores the messages in a queue. + + Args: + msg (str): Message (event) to be processed + """ + + async def _receiver(self) -> None: + while True: try: - if self.websocket is None: - self.websocket = await self.get_websocket() - await self.websocket.send(msg) - return - except ConnectionClosedOK as exception: - error_msg = ( - f"Connection closed received from the server {self.url}! " - f" Exception from {type(exception)}: {exception!s}" + _, raw_msg = await self.socket.recv_multipart() + if raw_msg == ACK_MSG: + self._ack_event.set() + else: + await self.process_message(raw_msg.decode("utf-8")) + except zmq.ZMQError as exc: + logger.debug( + f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}" ) - raise ClientConnectionClosedOK(error_msg) from exception - except (TimeoutError, InvalidHandshake, InvalidURI, OSError) as exception: - if retry == self._max_retries: - error_msg = ( - f"Not able to establish the " - f"websocket connection {self.url}! Max retries reached!" - " Check for firewall issues." - f" Exception from {type(exception)}: {exception!s}" + await asyncio.sleep(0) + self.socket.connect(self.url) + + async def send(self, message: str | bytes, retries: int | None = None) -> None: + self._ack_event.clear() + + if isinstance(message, str): + message = message.encode("utf-8") + + backoff = 1 + if retries is None: + retries = self.DEFAULT_MAX_RETRIES + while retries >= 0: + try: + await self.socket.send_multipart([b"", message]) + try: + await asyncio.wait_for( + self._ack_event.wait(), timeout=self._ack_timeout ) - raise ClientConnectionError(error_msg) from exception - except ConnectionClosedError as exception: - if retry == self._max_retries: - error_msg = ( - f"Not been able to send the event" - f" to {self.url}! Max retries reached!" - f" Exception from {type(exception)}: {exception!s}" + return + except TimeoutError: + logger.warning( + f"{self.dealer_id} failed to get acknowledgment on the {message!r}. Resending." ) - raise ClientConnectionError(error_msg) from exception - await asyncio.sleep(0.2 + self._timeout_multiplier * retry) - self.websocket = None - - def send(self, msg: AnyStr) -> None: - self.loop.run_until_complete(self._send(msg)) + except zmq.ZMQError as exc: + logger.debug( + f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}" + ) + except asyncio.CancelledError: + self.term() + raise + + retries -= 1 + if retries > 0: + logger.info(f"Retrying... ({retries} attempts left)") + await asyncio.sleep(backoff) + # this call is idempotent + self.socket.connect(self.url) + backoff = min(backoff * 2, 10) # Exponential backoff + raise ClientConnectionError( + f"{self.dealer_id} Failed to send {message!r} after retries!" + ) diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 81cbb43e682..28c38d5c0a6 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -1,9 +1,9 @@ from __future__ import annotations +import asyncio import logging import queue import threading -from datetime import datetime, timedelta from pathlib import Path from typing import Final @@ -16,11 +16,7 @@ ForwardModelStepSuccess, event_to_json, ) -from _ert.forward_model_runner.client import ( - Client, - ClientConnectionClosedOK, - ClientConnectionError, -) +from _ert.forward_model_runner.client import Client, ClientConnectionError from _ert.forward_model_runner.reporting.base import Reporter from _ert.forward_model_runner.reporting.message import ( _JOB_EXIT_FAILED_STRING, @@ -59,14 +55,16 @@ class Event(Reporter): _sentinel: Final = EventSentinel() - def __init__(self, evaluator_url, token=None, cert_path=None): + def __init__( + self, + evaluator_url: str, + token: str | None = None, + ack_timeout: float | None = None, + max_retries: int | None = None, + finished_event_timeout: float | None = None, + ): self._evaluator_url = evaluator_url self._token = token - if cert_path is not None: - with open(cert_path, encoding="utf-8") as f: - self._cert = f.read() - else: - self._cert = None self._statemachine = StateMachine() self._statemachine.add_handler((Init,), self._init_handler) @@ -78,53 +76,54 @@ def __init__(self, evaluator_url, token=None, cert_path=None): self._real_id = None self._event_queue: queue.Queue[events.Event | EventSentinel] = queue.Queue() self._event_publisher_thread = ErtThread(target=self._event_publisher) - self._timeout_timestamp = None - self._timestamp_lock = threading.Lock() - # seconds to timeout the reporter the thread after Finish() was received - self._reporter_timeout = 60 + self._done = threading.Event() + self._ack_timeout = ack_timeout + self._max_retries = max_retries + if finished_event_timeout is not None: + self._finished_event_timeout = finished_event_timeout + else: + self._finished_event_timeout = 60 - def stop(self) -> None: + def stop(self): self._event_queue.put(Event._sentinel) - with self._timestamp_lock: - self._timeout_timestamp = datetime.now() + timedelta( - seconds=self._reporter_timeout - ) + self._done.set() if self._event_publisher_thread.is_alive(): self._event_publisher_thread.join() def _event_publisher(self): - logger.debug("Publishing event.") - with Client( - url=self._evaluator_url, - token=self._token, - cert=self._cert, - ) as client: - event = None - while True: - with self._timestamp_lock: - if ( - self._timeout_timestamp is not None - and datetime.now() > self._timeout_timestamp - ): - self._timeout_timestamp = None - break - if event is None: - # if we successfully sent the event we can proceed - # to next one - event = self._event_queue.get() - if event is self._sentinel: - break - try: - client.send(event_to_json(event)) - event = None - except ClientConnectionError as exception: - # Possible intermittent failure, we retry sending the event - logger.error(str(exception)) - except ClientConnectionClosedOK as exception: - # The receiving end has closed the connection, we stop - # sending events - logger.debug(str(exception)) - break + async def publisher(): + async with Client( + url=self._evaluator_url, + token=self._token, + ack_timeout=self._ack_timeout, + ) as client: + event = None + start_time = None + while True: + try: + if self._done.is_set() and start_time is None: + start_time = asyncio.get_event_loop().time() + if event is None: + event = self._event_queue.get() + if event is self._sentinel: + break + if ( + start_time + and (asyncio.get_event_loop().time() - start_time) + > self._finished_event_timeout + ): + break + await client.send(event_to_json(event), self._max_retries) + event = None + except asyncio.CancelledError: + return + except ClientConnectionError as exc: + logger.error(f"Failed to send event: {exc}") + + try: + asyncio.run(publisher()) + except ClientConnectionError as exc: + raise ClientConnectionError("Couldn't connect to evaluator") from exc def report(self, msg): self._statemachine.transition(msg) diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index 6131dbe3b9c..547ec323cd0 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -104,7 +104,10 @@ def run_cli(args: Namespace, plugin_manager: ErtPluginManager | None = None) -> # most unix flavors https://en.wikipedia.org/wiki/Ephemeral_port args.port_range = range(49152, 51819) - evaluator_server_config = EvaluatorServerConfig(custom_port_range=args.port_range) + use_ipc_protocol = model.queue_system == QueueSystem.LOCAL + evaluator_server_config = EvaluatorServerConfig( + custom_port_range=args.port_range, use_ipc_protocol=use_ipc_protocol + ) if model.check_if_runpath_exists(): print( diff --git a/src/ert/ensemble_evaluator/__init__.py b/src/ert/ensemble_evaluator/__init__.py index bcef41a53d8..642199a0926 100644 --- a/src/ert/ensemble_evaluator/__init__.py +++ b/src/ert/ensemble_evaluator/__init__.py @@ -1,6 +1,5 @@ from ._ensemble import LegacyEnsemble as Ensemble from ._ensemble import Realization -from ._wait_for_evaluator import wait_for_evaluator from .config import EvaluatorServerConfig from .evaluator import EnsembleEvaluator from .event import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent @@ -19,5 +18,4 @@ "Realization", "RealizationSnapshot", "SnapshotUpdateEvent", - "wait_for_evaluator", ] diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index b09a0a81536..05dcb38ecb1 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -6,10 +6,7 @@ from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from functools import partialmethod -from typing import ( - Any, - Protocol, -) +from typing import Any, Protocol from _ert.events import ( Event, @@ -25,13 +22,8 @@ from ert.run_arg import RunArg from ert.scheduler import Scheduler, create_driver -from ._wait_for_evaluator import wait_for_evaluator from .config import EvaluatorServerConfig -from .snapshot import ( - EnsembleSnapshot, - FMStepSnapshot, - RealizationSnapshot, -) +from .snapshot import EnsembleSnapshot, FMStepSnapshot, RealizationSnapshot from .state import ( ENSEMBLE_STATE_CANCELLED, ENSEMBLE_STATE_FAILED, @@ -198,11 +190,10 @@ async def send_event( url: str, event: Event, token: str | None = None, - cert: str | bytes | None = None, retries: int = 10, ) -> None: - async with Client(url, token, cert, max_retries=retries) as client: - await client._send(event_to_json(event)) + async with Client(url, token) as client: + await client.send(event_to_json(event), retries) def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]: def event_builder(status: str) -> Event: @@ -227,21 +218,18 @@ async def evaluate( ce_unary_send_method_name, partialmethod( self.__class__.send_event, - self._config.dispatch_uri, + self._config.get_connection_info().router_uri, token=self._config.token, - cert=self._config.cert, ), ) - await wait_for_evaluator( - base_url=self._config.url, - token=self._config.token, - cert=self._config.cert, - ) - await self._evaluate_inner( - event_unary_send=getattr(self, ce_unary_send_method_name), - scheduler_queue=scheduler_queue, - manifest_queue=manifest_queue, - ) + try: + await self._evaluate_inner( + event_unary_send=getattr(self, ce_unary_send_method_name), + scheduler_queue=scheduler_queue, + manifest_queue=manifest_queue, + ) + except asyncio.CancelledError: + print("Cancelling evaluator task!") async def _evaluate_inner( # pylint: disable=too-many-branches self, @@ -279,8 +267,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches max_running=self._queue_config.max_running, submit_sleep=self._queue_config.submit_sleep, ens_id=self.id_, - ee_uri=self._config.dispatch_uri, - ee_cert=self._config.cert, + ee_uri=self._config.get_connection_info().router_uri, ee_token=self._config.token, ) logger.info( diff --git a/src/ert/ensemble_evaluator/_wait_for_evaluator.py b/src/ert/ensemble_evaluator/_wait_for_evaluator.py deleted file mode 100644 index f97fb758a6b..00000000000 --- a/src/ert/ensemble_evaluator/_wait_for_evaluator.py +++ /dev/null @@ -1,77 +0,0 @@ -import asyncio -import logging -import ssl -import time - -import aiohttp - -logger = logging.getLogger(__name__) - -WAIT_FOR_EVALUATOR_TIMEOUT = 60 - - -def get_ssl_context(cert: str | bytes | None) -> ssl.SSLContext | bool: - if cert is None: - return False - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(cadata=cert) - return ssl_context - - -async def attempt_connection( - url: str, - token: str | None = None, - cert: str | bytes | None = None, - connection_timeout: float = 2, -) -> None: - timeout = aiohttp.ClientTimeout(connect=connection_timeout) - headers = {} if token is None else {"token": token} - async with ( - aiohttp.ClientSession() as session, - session.request( - method="get", - url=url, - ssl=get_ssl_context(cert), - headers=headers, - timeout=timeout, - ) as resp, - ): - resp.raise_for_status() - - -async def wait_for_evaluator( - base_url: str, - token: str | None = None, - cert: str | bytes | None = None, - healthcheck_endpoint: str = "/healthcheck", - timeout: float | None = None, # noqa: ASYNC109 - connection_timeout: float = 2, -) -> None: - if timeout is None: - timeout = WAIT_FOR_EVALUATOR_TIMEOUT - healthcheck_url = base_url + healthcheck_endpoint - start = time.time() - sleep_time = 0.2 - sleep_time_max = 5.0 - while time.time() - start < timeout: - try: - await attempt_connection( - url=healthcheck_url, - token=token, - cert=cert, - connection_timeout=connection_timeout, - ) - return - except aiohttp.ClientError: - sleep_time = min(sleep_time_max, sleep_time * 2) - remaining_time = max(0, timeout - (time.time() - start) + 0.1) - await asyncio.sleep(min(sleep_time, remaining_time)) - - # We have timed out, but we make one last attempt to ensure that - # we have tried to connect at both ends of the time window - await attempt_connection( - url=healthcheck_url, - token=token, - cert=cert, - connection_timeout=connection_timeout, - ) diff --git a/src/ert/ensemble_evaluator/config.py b/src/ert/ensemble_evaluator/config.py index 51b059a6ce1..882824889f1 100644 --- a/src/ert/ensemble_evaluator/config.py +++ b/src/ert/ensemble_evaluator/config.py @@ -1,19 +1,9 @@ -import ipaddress import logging -import os -import pathlib import socket -import ssl -import tempfile +import uuid import warnings -from base64 import b64encode -from datetime import UTC, datetime, timedelta -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.x509.oid import NameOID +import zmq from ert.shared import find_available_socket from ert.shared import get_machine_name as ert_shared_get_machine_name @@ -32,143 +22,42 @@ def get_machine_name() -> str: return ert_shared_get_machine_name() -def _generate_authentication() -> str: - n_bytes = 128 - random_bytes = bytes(os.urandom(n_bytes)) - token = b64encode(random_bytes).decode("utf-8") - return token - - -def _generate_certificate( - ip_address: str, -) -> tuple[str, bytes, bytes]: - """Generate a private key and a certificate signed with it - The key is encrypted before being stored. - Returns the certificate as a string, the key as bytes (encrypted), and - the password used for encrypting the key - """ - # Generate private key - key = rsa.generate_private_key( - public_exponent=65537, key_size=4096, backend=default_backend() - ) - - # Generate the certificate and sign it with the private key - cert_name = ert_shared_get_machine_name() - subject = issuer = x509.Name( - [ - x509.NameAttribute(NameOID.COUNTRY_NAME, "NO"), - x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Bergen"), - x509.NameAttribute(NameOID.LOCALITY_NAME, "Sandsli"), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Equinor"), - x509.NameAttribute(NameOID.COMMON_NAME, f"{cert_name}"), - ] - ) - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(UTC)) - .not_valid_after(datetime.now(UTC) + timedelta(days=365)) # 1 year - .add_extension( - x509.SubjectAlternativeName( - [ - x509.DNSName(f"{cert_name}"), - x509.DNSName(ip_address), - x509.IPAddress(ipaddress.ip_address(ip_address)), - ] - ), - critical=False, - ) - .sign(key, hashes.SHA256(), default_backend()) - ) - - cert_str = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") - pw = bytes(os.urandom(28)) - key_bytes = key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(pw), - ) - return cert_str, key_bytes, pw - - class EvaluatorServerConfig: - """ - This class is responsible for identifying a host:port-combo and then provide - low-level sockets bound to said combo. The problem is that these sockets may - be closed by underlying code, while the EvaluatorServerConfig-instance is - still alive and expected to provide a bound low-level socket. Thus we risk - that the host:port is hijacked by another process in the meantime. - - To prevent this, we keep a handle to the bound socket and every time - a socket is requested we return a duplicate of this. The duplicate will be - bound similarly to the handle, but when closed the handle stays open and - holds the port. - - In particular, the websocket-server closes the websocket when exiting a - context: - - https://github.com/aaugustin/websockets/blob/c439f1d52aafc05064cc11702d1c3014046799b0/src/websockets/legacy/server.py#L890 - - and digging into the cpython-implementation of asyncio, we see that causes - the asyncio code to also close the underlying socket: - - https://github.com/python/cpython/blob/b34dd58fee707b8044beaf878962a6fa12b304dc/Lib/asyncio/selector_events.py#L607-L611 - - """ - def __init__( self, custom_port_range: range | None = None, use_token: bool = True, - generate_cert: bool = True, custom_host: str | None = None, + use_ipc_protocol: bool = True, ) -> None: - self._socket_handle = find_available_socket( - custom_range=custom_port_range, custom_host=custom_host - ) - host, port = self._socket_handle.getsockname() - self.protocol = "wss" if generate_cert else "ws" - self.url = f"{self.protocol}://{host}:{port}" - self.client_uri = f"{self.url}/client" - self.dispatch_uri = f"{self.url}/dispatch" - if generate_cert: - cert, key, pw = _generate_certificate(host) - else: - cert, key, pw = None, None, None - self.cert = cert - self._key: bytes | None = key - self._key_pw = pw - - self.token = _generate_authentication() if use_token else None - - def get_socket(self) -> socket.socket: - return self._socket_handle.dup() + self.host: str | None = None + self.router_port: int | None = None + self.url = f"ipc:///tmp/socket-{uuid.uuid4().hex[:8]}" + self.token: str | None = None + self._socket_handle: socket.socket | None = None + + self.server_public_key: bytes | None = None + self.server_secret_key: bytes | None = None + if not use_ipc_protocol: + self._socket_handle = find_available_socket( + custom_range=custom_port_range, + custom_host=custom_host, + will_close_then_reopen_socket=True, + ) + self.host, self.router_port = self._socket_handle.getsockname() + self.url = f"tcp://{self.host}:{self.router_port}" + + if use_token: + self.server_public_key, self.server_secret_key = zmq.curve_keypair() + self.token = self.server_public_key.decode("utf-8") + + def get_socket(self) -> socket.socket | None: + if self._socket_handle: + return self._socket_handle.dup() + return None def get_connection_info(self) -> EvaluatorConnectionInfo: return EvaluatorConnectionInfo( self.url, - self.cert, self.token, ) - - def get_server_ssl_context( - self, protocol: int = ssl.PROTOCOL_TLS_SERVER - ) -> ssl.SSLContext | None: - if self.cert is None: - return None - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = pathlib.Path(tmp_dir) - cert_path = tmp_path / "ee.crt" - with open(cert_path, "w", encoding="utf-8") as filehandle_1: - filehandle_1.write(self.cert) - - key_path = tmp_path / "ee.key" - if self._key is not None: - with open(key_path, "wb") as filehandle_2: - filehandle_2.write(self._key) - context = ssl.SSLContext(protocol=protocol) - context.load_cert_chain(cert_path, key_path, self._key_pw) - return context diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 7078ce412e4..3696660524f 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -1,26 +1,13 @@ +from __future__ import annotations + import asyncio import datetime import logging import traceback -from collections.abc import ( - AsyncIterator, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, -) -from contextlib import asynccontextmanager, contextmanager -from http import HTTPStatus -from typing import ( - Any, - get_args, -) +from collections.abc import Awaitable, Callable, Iterable, Sequence +from typing import Any, get_args -from pydantic_core._pydantic_core import ValidationError -from websockets.asyncio.server import ServerConnection, serve -from websockets.exceptions import ConnectionClosedError -from websockets.http11 import Request, Response +import zmq.asyncio from _ert.events import ( EESnapshot, @@ -40,6 +27,7 @@ event_from_json, event_to_json, ) +from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG from ert.ensemble_evaluator import identifiers as ids from ._ensemble import FMStepSnapshot @@ -64,15 +52,11 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._loop: asyncio.AbstractEventLoop | None = None - self._clients: set[ServerConnection] = set() - self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue() - self._events: asyncio.Queue[Event] = asyncio.Queue() self._events_to_send: asyncio.Queue[Event] = asyncio.Queue() self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue() self._ee_tasks: list[asyncio.Task[None]] = [] - self._server_started: asyncio.Event = asyncio.Event() self._server_done: asyncio.Event = asyncio.Event() # batching section @@ -82,14 +66,22 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._max_batch_size: int = 500 self._batching_interval: float = 2.0 self._complete_batch: asyncio.Event = asyncio.Event() + self._server_started: asyncio.Event = asyncio.Event() + self._clients_connected: set[bytes] = set() + self._clients_empty: asyncio.Event = asyncio.Event() + self._clients_empty.set() + self._dispatchers_connected: set[bytes] = set() + self._dispatchers_empty: asyncio.Event = asyncio.Event() + self._dispatchers_empty.set() async def _publisher(self) -> None: + await self._server_started.wait() while True: event = await self._events_to_send.get() - await asyncio.gather( - *[client.send(event_to_json(event)) for client in self._clients], - return_exceptions=True, - ) + for identity in self._clients_connected: + await self._router_socket.send_multipart( + [identity, b"", event_to_json(event).encode("utf-8")] + ) self._events_to_send.task_done() async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None: @@ -204,140 +196,128 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: def ensemble(self) -> Ensemble: return self._ensemble - @contextmanager - def store_client(self, websocket: ServerConnection) -> Generator[None, None, None]: - self._clients.add(websocket) - yield - self._clients.remove(websocket) - - async def handle_client(self, websocket: ServerConnection) -> None: - with self.store_client(websocket): + async def handle_client(self, dealer: bytes, frame: bytes) -> None: + if frame == CONNECT_MSG: + self._clients_connected.add(dealer) + self._clients_empty.clear() current_snapshot_dict = self._ensemble.snapshot.to_dict() event: Event = EESnapshot( - snapshot=current_snapshot_dict, ensemble=self.ensemble.id_ + snapshot=current_snapshot_dict, + ensemble=self.ensemble.id_, ) - await websocket.send(event_to_json(event)) - - async for raw_msg in websocket: - event = event_from_json(raw_msg) - logger.debug(f"got message from client: {event}") - if type(event) is EEUserCancel: - logger.debug(f"Client {websocket.remote_address} asked to cancel.") - self._signal_cancel() - - elif type(event) is EEUserDone: - logger.debug(f"Client {websocket.remote_address} signalled done.") - self.stop() - - @asynccontextmanager - async def count_dispatcher(self) -> AsyncIterator[None]: - await self._dispatchers_connected.put(None) - yield - await self._dispatchers_connected.get() - self._dispatchers_connected.task_done() - - async def handle_dispatch(self, websocket: ServerConnection) -> None: - async with self.count_dispatcher(): - try: - async for raw_msg in websocket: - try: - event = dispatch_event_from_json(raw_msg) - if event.ensemble != self.ensemble.id_: - logger.info( - "Got event from evaluator " - f"{event.ensemble}. " - f"Ignoring since I am {self.ensemble.id_}" - ) - continue - if type(event) is ForwardModelStepChecksum: - await self.forward_checksum(event) - else: - await self._events.put(event) - except ValidationError as ex: - logger.warning( - "cannot handle event - " - f"closing connection to dispatcher: {ex}" - ) - await websocket.close( - code=1011, reason=f"failed handling message {raw_msg!r}" - ) - return - - if type(event) in {EnsembleSucceeded, EnsembleFailed}: - return - except ConnectionClosedError as connection_error: - # Dispatchers may close the connection abruptly in the case of - # * flaky network (then the dispatcher will try to reconnect) - # * job being killed due to MAX_RUNTIME - # * job being killed by user - logger.error( - f"a dispatcher abruptly closed a websocket: {connection_error!s}" + await self._router_socket.send_multipart( + [dealer, b"", event_to_json(event).encode("utf-8")] + ) + elif frame == DISCONNECT_MSG: + self._clients_connected.discard(dealer) + if not self._clients_connected: + self._clients_empty.set() + else: + event = event_from_json(frame.decode("utf-8")) + if type(event) is EEUserCancel: + logger.debug("Client asked to cancel.") + self._signal_cancel() + elif type(event) is EEUserDone: + logger.debug("Client signalled done.") + self.stop() + + async def handle_dispatch(self, dealer: bytes, frame: bytes) -> None: + if frame == CONNECT_MSG: + self._dispatchers_connected.add(dealer) + self._dispatchers_empty.clear() + elif frame == DISCONNECT_MSG: + self._dispatchers_connected.discard(dealer) + if not self._dispatchers_connected: + self._dispatchers_empty.set() + else: + event = dispatch_event_from_json(frame.decode("utf-8")) + if event.ensemble != self.ensemble.id_: + logger.info( + "Got event from evaluator " + f"{event.ensemble}. " + f"Ignoring since I am {self.ensemble.id_}" ) + return + if type(event) is ForwardModelStepChecksum: + await self.forward_checksum(event) + else: + await self._events.put(event) + + async def listen_for_messages(self) -> None: + await self._server_started.wait() + while True: + try: + dealer, _, frame = await self._router_socket.recv_multipart() + await self._router_socket.send_multipart([dealer, b"", ACK_MSG]) + sender = dealer.decode("utf-8") + if sender.startswith("client"): + await self.handle_client(dealer, frame) + elif sender.startswith("dispatch"): + await self.handle_dispatch(dealer, frame) + else: + logger.info(f"Connection attempt to unknown sender: {sender}.") + except zmq.error.ZMQError as e: + if e.errno == zmq.ENOTSOCK: + logger.warning( + "Evaluator receiver closed, no new messages are received" + ) + else: + logger.error(f"Unexpected error when listening to messages: {e}") + except asyncio.CancelledError: + self._router_socket.close() + return async def forward_checksum(self, event: Event) -> None: # clients still need to receive events via ws await self._events_to_send.put(event) await self._manifest_queue.put(event) - async def connection_handler(self, websocket: ServerConnection) -> None: - if websocket.request is not None: - path = websocket.request.path - elements = path.split("/") - if elements[1] == "client": - await self.handle_client(websocket) - elif elements[1] == "dispatch": - await self.handle_dispatch(websocket) - else: - logger.info(f"Connection attempt to unknown path: {path}.") - else: - logger.info("No request to handle.") - - async def process_request( - self, connection: ServerConnection, request: Request - ) -> Response | None: - if request.headers.get("token") != self._config.token: - return connection.respond(HTTPStatus.UNAUTHORIZED, "") - if request.path == "/healthcheck": - return connection.respond(HTTPStatus.OK, "") - return None - async def _server(self) -> None: - async with serve( - self.connection_handler, - sock=self._config.get_socket(), - ssl=self._config.get_server_ssl_context(), - process_request=self.process_request, - max_size=2**26, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ) as server: + zmq_context = zmq.asyncio.Context() + try: + self._router_socket: zmq.asyncio.Socket = zmq_context.socket(zmq.ROUTER) + self._router_socket.setsockopt(zmq.LINGER, 0) + if self._config.server_public_key and self._config.server_secret_key: + self._router_socket.curve_secretkey = self._config.server_secret_key + self._router_socket.curve_publickey = self._config.server_public_key + self._router_socket.curve_server = True + + if self._config.router_port: + self._router_socket.bind(f"tcp://*:{self._config.router_port}") + else: + self._router_socket.bind(self._config.url) self._server_started.set() + except zmq.error.ZMQError as e: + logger.error(f"ZMQ error encountered {e} during evaluator initialization") + raise + try: await self._server_done.wait() - server.close(close_connections=False) - if self._dispatchers_connected is not None: - logger.debug( - f"Got done signal. {self._dispatchers_connected.qsize()} " - "dispatchers to disconnect..." + try: + await asyncio.wait_for(self._dispatchers_empty.wait(), timeout=5) + except TimeoutError: + logger.warning( + "Not all dispatchers were disconnected when closing zmq server!" ) - try: # Wait for dispatchers to disconnect - await asyncio.wait_for( - self._dispatchers_connected.join(), timeout=20 - ) - except TimeoutError: - logger.debug("Timed out waiting for dispatchers to disconnect") - else: - logger.debug("Got done signal. No dispatchers connected") - - logger.debug("Sending termination-message to clients...") - await self._events.join() await self._complete_batch.wait() await self._batch_processing_queue.join() event = EETerminated(ensemble=self._ensemble.id_) await self._events_to_send.put(event) await self._events_to_send.join() - logger.debug("Async server exiting.") + try: + await asyncio.wait_for(self._clients_empty.wait(), timeout=5) + except TimeoutError: + logger.warning( + "Not all clients were disconnected when closing zmq server!" + ) + logger.debug("Async server exiting.") + finally: + try: + self._router_socket.close() + zmq_context.destroy() + except Exception as exc: + logger.warning(f"Failed to clean up zmq context {exc}") + logger.info("ZMQ cleanup done!") def stop(self) -> None: self._server_done.set() @@ -370,10 +350,10 @@ async def _start_running(self) -> None: ), asyncio.create_task(self._process_event_buffer(), name="processing_task"), asyncio.create_task(self._publisher(), name="publisher_task"), + asyncio.create_task(self.listen_for_messages(), name="listener_task"), ] - # now we wait for the server to actually start - await self._server_started.wait() + await self._server_started.wait() self._ee_tasks.append( asyncio.create_task( self._ensemble.evaluate( @@ -405,9 +385,11 @@ async def _monitor_and_handle_tasks(self) -> None: raise task_exception elif task.get_name() == "server_task": return - elif task.get_name() == "ensemble_task": + elif task.get_name() == "ensemble_task" or task.get_name() in { + "ensemble_task", + "listener_task", + }: timeout = self.CLOSE_SERVER_TIMEOUT - continue else: msg = ( f"Something went wrong, {task.get_name()} is done prematurely!" @@ -433,6 +415,9 @@ async def run_and_get_successful_realizations(self) -> list[int]: try: await self._monitor_and_handle_tasks() finally: + self._server_done.set() + self._clients_empty.set() + self._dispatchers_empty.set() for task in self._ee_tasks: if not task.done(): task.cancel() @@ -442,7 +427,7 @@ async def run_and_get_successful_realizations(self) -> list[int]: result, Exception ): logger.error(str(result)) - raise result + raise RuntimeError(result) from result logger.debug("Evaluator is done") return self._ensemble.get_successful_realizations() diff --git a/src/ert/ensemble_evaluator/evaluator_connection_info.py b/src/ert/ensemble_evaluator/evaluator_connection_info.py index e01326c5c99..ac8ec35ef0c 100644 --- a/src/ert/ensemble_evaluator/evaluator_connection_info.py +++ b/src/ert/ensemble_evaluator/evaluator_connection_info.py @@ -5,18 +5,5 @@ class EvaluatorConnectionInfo: """Read only server-info""" - url: str - cert: str | bytes | None = None + router_uri: str token: str | None = None - - @property - def dispatch_uri(self) -> str: - return f"{self.url}/dispatch" - - @property - def client_uri(self) -> str: - return f"{self.url}/client" - - @property - def result_uri(self) -> str: - return f"{self.url}/result" diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index d3f549377c6..d55a50b9661 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -1,13 +1,10 @@ +from __future__ import annotations + import asyncio import logging -import ssl import uuid from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any, Final - -from aiohttp import ClientError -from websockets import ConnectionClosed, Headers -from websockets.asyncio.client import ClientConnection, connect +from typing import TYPE_CHECKING, Final from _ert.events import ( EETerminated, @@ -17,7 +14,7 @@ event_from_json, event_to_json, ) -from ert.ensemble_evaluator._wait_for_evaluator import wait_for_evaluator +from _ert.forward_model_runner.client import Client if TYPE_CHECKING: from ert.ensemble_evaluator.evaluator_connection_info import EvaluatorConnectionInfo @@ -30,60 +27,37 @@ class EventSentinel: pass -class Monitor: +class Monitor(Client): _sentinel: Final = EventSentinel() - def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: - self._ee_con_info = ee_con_info + def __init__(self, ee_con_info: EvaluatorConnectionInfo) -> None: self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0] self._event_queue: asyncio.Queue[Event | EventSentinel] = asyncio.Queue() - self._connection: ClientConnection | None = None - self._receiver_task: asyncio.Task[None] | None = None - self._connected: asyncio.Future[None] = asyncio.Future() - self._connection_timeout: float = 120.0 self._receiver_timeout: float = 60.0 + super().__init__( + ee_con_info.router_uri, + ee_con_info.token, + dealer_name=f"client-{self._id}", + ) - async def __aenter__(self) -> "Monitor": - self._receiver_task = asyncio.create_task(self._receiver()) - try: - await asyncio.wait_for(self._connected, timeout=self._connection_timeout) - except TimeoutError as exc: - msg = "Couldn't establish connection with the ensemble evaluator!" - logger.error(msg) - self._receiver_task.cancel() - raise RuntimeError(msg) from exc - return self - - async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - if self._receiver_task: - if not self._receiver_task.done(): - self._receiver_task.cancel() - # we are done and not interested in errors when cancelling - await asyncio.gather( - self._receiver_task, - return_exceptions=True, - ) - if self._connection: - await self._connection.close() + async def process_message(self, msg: str) -> None: + event = event_from_json(msg) + await self._event_queue.put(event) async def signal_cancel(self) -> None: - if not self._connection: - return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} asking server to cancel...") cancel_event = EEUserCancel(monitor=self._id) - await self._connection.send(event_to_json(cancel_event)) + await self.send(event_to_json(cancel_event)) logger.debug(f"monitor-{self._id} asked server to cancel") async def signal_done(self) -> None: - if not self._connection: - return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} informing server monitor is done...") done_event = EEUserDone(monitor=self._id) - await self._connection.send(event_to_json(done_event)) + await self.send(event_to_json(done_event)) logger.debug(f"monitor-{self._id} informed server monitor is done") async def track( @@ -116,45 +90,3 @@ async def track( break if event is not None: self._event_queue.task_done() - - async def _receiver(self) -> None: - tls: ssl.SSLContext | None = None - if self._ee_con_info.cert: - tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - tls.load_verify_locations(cadata=self._ee_con_info.cert) - headers = Headers() - if self._ee_con_info.token: - headers["token"] = self._ee_con_info.token - try: - await wait_for_evaluator( - base_url=self._ee_con_info.url, - token=self._ee_con_info.token, - cert=self._ee_con_info.cert, - timeout=5, - ) - except Exception as e: - self._connected.set_exception(e) - return - async for conn in connect( - self._ee_con_info.client_uri, - ssl=tls, - additional_headers=headers, - max_size=2**26, - max_queue=500, - open_timeout=5, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ): - try: - self._connection = conn - self._connected.set_result(None) - async for raw_msg in self._connection: - event = event_from_json(raw_msg) - await self._event_queue.put(event) - except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc: - self._connection = None - self._connected = asyncio.Future() - logger.debug( - f"Monitor connection to EnsembleEvaluator went down, reconnecting: {exc}" - ) diff --git a/src/ert/gui/simulation/run_dialog.py b/src/ert/gui/simulation/run_dialog.py index 93908ac2c4d..dee4e941a0b 100644 --- a/src/ert/gui/simulation/run_dialog.py +++ b/src/ert/gui/simulation/run_dialog.py @@ -349,9 +349,13 @@ def run_experiment(self, restart: bool = False) -> None: self._tab_widget.clear() port_range = None + use_ipc_protocol = False if self._run_model.queue_system == QueueSystem.LOCAL: port_range = range(49152, 51819) - evaluator_server_config = EvaluatorServerConfig(custom_port_range=port_range) + use_ipc_protocol = True + evaluator_server_config = EvaluatorServerConfig( + custom_port_range=port_range, use_ipc_protocol=use_ipc_protocol + ) def run() -> None: self._run_model.start_simulations_thread( diff --git a/src/ert/logging/logger.conf b/src/ert/logging/logger.conf index 012c8366cff..d9959f765f7 100644 --- a/src/ert/logging/logger.conf +++ b/src/ert/logging/logger.conf @@ -33,8 +33,8 @@ loggers: level: INFO subscript: level: INFO - websockets: - level: WARNING + zmq: + level: INFO root: diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index ce7e2d6ba76..9a2de1b3720 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -18,12 +18,7 @@ import numpy as np -from _ert.events import ( - EESnapshot, - EESnapshotUpdate, - EETerminated, - Event, -) +from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event from ert.analysis import ( AnalysisEvent, AnalysisStatusEvent, @@ -514,7 +509,6 @@ async def run_monitor( event, iteration, ) - if event.snapshot.get(STATUS) in { ENSEMBLE_STATE_STOPPED, ENSEMBLE_STATE_FAILED, @@ -567,6 +561,7 @@ async def run_ensemble_evaluator_async( evaluator_task = asyncio.create_task( evaluator.run_and_get_successful_realizations() ) + await evaluator._server_started.wait() if not (await self.run_monitor(ee_config, ensemble.iteration)): return [] diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index a1610930b26..6495ec28052 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -9,7 +9,6 @@ from collections.abc import Iterable, MutableMapping, Sequence from contextlib import suppress from dataclasses import asdict -from pathlib import Path from typing import TYPE_CHECKING, Any import orjson @@ -17,7 +16,6 @@ from _ert.async_utils import get_running_loop from _ert.events import Event, ForwardModelStepChecksum, Id, event_from_dict -from ert.constant_filenames import CERT_FILE from .driver import Driver from .event import FinishedEvent, StartedEvent @@ -35,7 +33,6 @@ class _JobsJson: real_id: int dispatch_url: str | None ee_token: str | None - ee_cert_path: str | None experiment_id: str | None @@ -69,7 +66,6 @@ def __init__( submit_sleep: float = 0.0, ens_id: str | None = None, ee_uri: str | None = None, - ee_cert: str | None = None, ee_token: str | None = None, ) -> None: self.driver = driver @@ -103,7 +99,6 @@ def __init__( self._max_running = max_running self._ee_uri = ee_uri self._ens_id = ens_id - self._ee_cert = ee_cert self._ee_token = ee_token self.checksum: dict[str, dict[str, Any]] = {} @@ -330,22 +325,12 @@ async def _process_event_queue(self) -> None: job.returncode.set_result(event.returncode) def _update_jobs_json(self, iens: int, runpath: str) -> None: - cert_path = f"{runpath}/{CERT_FILE}" - try: - if self._ee_cert is not None: - Path(cert_path).write_text(self._ee_cert, encoding="utf-8") - except OSError as err: - error_msg = f"Could not write ensemble certificate: {err}" - self._jobs[iens].unschedule(error_msg) - logger.error(error_msg) - return jobs = _JobsJson( experiment_id=None, ens_id=self._ens_id, real_id=iens, dispatch_url=self._ee_uri, ee_token=self._ee_token, - ee_cert_path=cert_path if self._ee_cert is not None else None, ) jobs_path = os.path.join(runpath, "jobs.json") try: diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 3c36a71fd71..49686acaa3b 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -313,12 +313,12 @@ def main(): simulation_callback=partial(_sim_monitor, shared_data=shared_data), optimization_callback=partial(_opt_monitor, shared_data=shared_data), ) - - evaluator_server_config = EvaluatorServerConfig( - custom_port_range=range(49152, 51819) - if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL - else None - ) + if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL: + evaluator_server_config = EvaluatorServerConfig() + else: + evaluator_server_config = EvaluatorServerConfig( + custom_port_range=range(49152, 51819), use_ipc_protocol=False + ) run_model.run_experiment(evaluator_server_config) diff --git a/tests/ert/conftest.py b/tests/ert/conftest.py index 07b847a52de..39502e495a3 100644 --- a/tests/ert/conftest.py +++ b/tests/ert/conftest.py @@ -413,8 +413,8 @@ class MockESConfig(EvaluatorServerConfig): def __init__(self, *args, **kwargs): if "use_token" not in kwargs: kwargs["use_token"] = False - if "generate_cert" not in kwargs: - kwargs["generate_cert"] = False + if sys.platform != "linux": + kwargs["use_ipc_protocol"] = True super().__init__(*args, **kwargs) monkeypatch.setattr("ert.cli.main.EvaluatorServerConfig", MockESConfig) diff --git a/tests/ert/ui_tests/cli/test_cli.py b/tests/ert/ui_tests/cli/test_cli.py index 9daca798a79..d56595802b4 100644 --- a/tests/ert/ui_tests/cli/test_cli.py +++ b/tests/ert/ui_tests/cli/test_cli.py @@ -12,21 +12,17 @@ import numpy as np import pandas as pd import pytest -import websockets.exceptions import xtgeo +import zmq from psutil import NoSuchProcess, Popen, Process, ZombieProcess from resdata.summary import Summary import _ert.threading import ert.shared from _ert.forward_model_runner.client import Client -from ert import LibresFacade, ensemble_evaluator +from ert import LibresFacade from ert.cli.main import ErtCliError -from ert.config import ( - ConfigValidationError, - ConfigWarning, - ErtConfig, -) +from ert.config import ConfigValidationError, ConfigWarning, ErtConfig from ert.enkf_main import sample_prior from ert.ensemble_evaluator import EnsembleEvaluator from ert.mode_definitions import ( @@ -106,9 +102,6 @@ def test_that_the_cli_raises_exceptions_when_no_weight_provided_for_es_mda(): @pytest.mark.usefixtures("copy_snake_oil_field") def test_field_init_file_not_readable(monkeypatch): - monkeypatch.setattr( - ensemble_evaluator._wait_for_evaluator, "WAIT_FOR_EVALUATOR_TIMEOUT", 5 - ) config_file_name = "snake_oil_field.ert" field_file_rel_path = "fields/permx0.grdecl" os.chmod(field_file_rel_path, 0x0) @@ -197,10 +190,12 @@ def test_that_the_model_raises_exception_if_successful_realizations_less_than_mi else: fout.write(line) fout.write( - dedent(""" + dedent( + """ INSTALL_JOB failing_fm FAILING_FM FORWARD_MODEL failing_fm - """) + """ + ) ) Path("FAILING_FM").write_text("EXECUTABLE failing_fm.py", encoding="utf-8") Path("failing_fm.py").write_text( @@ -957,14 +952,13 @@ def test_tracking_missing_ecl(monkeypatch, tmp_path, caplog): def test_that_connection_errors_do_not_effect_final_result( monkeypatch: pytest.MonkeyPatch, ): - monkeypatch.setattr(Client, "DEFAULT_MAX_RETRIES", 0) - monkeypatch.setattr(Client, "DEFAULT_TIMEOUT_MULTIPLIER", 0) - monkeypatch.setattr(Client, "CONNECTION_TIMEOUT", 1) + monkeypatch.setattr(Client, "DEFAULT_MAX_RETRIES", 1) + monkeypatch.setattr(Client, "DEFAULT_ACK_TIMEOUT", 1) monkeypatch.setattr(EnsembleEvaluator, "CLOSE_SERVER_TIMEOUT", 0.01) monkeypatch.setattr(Job, "DEFAULT_CHECKSUM_TIMEOUT", 0) def raise_connection_error(*args, **kwargs): - raise websockets.exceptions.ConnectionClosedError(None, None) + raise zmq.error.ZMQError(None, None) with patch( "ert.ensemble_evaluator.evaluator.dispatch_event_from_json", diff --git a/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py b/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py index 3088ed0a131..8e64fcdf12a 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py +++ b/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py @@ -1,36 +1,8 @@ -import asyncio - -import websockets - -from _ert.async_utils import new_event_loop from ert.config import QueueConfig from ert.ensemble_evaluator import Ensemble from ert.ensemble_evaluator._ensemble import ForwardModelStep, Realization -def _mock_ws(host, port, messages, delay_startup=0): - loop = new_event_loop() - done = loop.create_future() - - async def _handler(websocket): - while True: - msg = await websocket.recv() - messages.append(msg) - if msg == "stop": - done.set_result(None) - break - - async def _run_server(): - await asyncio.sleep(delay_startup) - async with websockets.server.serve( - _handler, host, port, ping_timeout=1, ping_interval=1 - ): - await done - - loop.run_until_complete(_run_server()) - loop.close() - - class TestEnsemble(Ensemble): __test__ = False diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py index 6b6fc294530..75501453140 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py @@ -1,68 +1,48 @@ -from functools import partial - import pytest from _ert.forward_model_runner.client import Client, ClientConnectionError -from _ert.threading import ErtThread - -from .ensemble_evaluator_utils import _mock_ws +from tests.ert.utils import MockZMQServer @pytest.mark.integration_test -def test_invalid_server(): +async def test_invalid_server(): port = 7777 host = "localhost" - url = f"ws://{host}:{port}" + url = f"tcp://{host}:{port}" - with ( - Client(url, max_retries=2, timeout_multiplier=2) as c1, - pytest.raises(ClientConnectionError), - ): - c1.send("hei") + with pytest.raises(ClientConnectionError): + async with Client(url, ack_timeout=1.0): + pass -def test_successful_sending(unused_tcp_port): +async def test_successful_sending(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" - messages = [] - mock_ws_thread = ErtThread( - target=partial(_mock_ws, messages=messages), args=(host, unused_tcp_port) - ) - - mock_ws_thread.start() - messages_c1 = ["test_1", "test_2", "test_3", "stop"] - - with Client(url) as c1: - for msg in messages_c1: - c1.send(msg) - - mock_ws_thread.join() + url = f"tcp://{host}:{unused_tcp_port}" + messages_c1 = ["test_1", "test_2", "test_3"] + async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1: + for message in messages_c1: + await c1.send(message) for msg in messages_c1: - assert msg in messages + assert msg in mock_server.messages -@pytest.mark.integration_test -def test_retry(unused_tcp_port): +async def test_retry(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" - messages = [] - mock_ws_thread = ErtThread( - target=partial(_mock_ws, messages=messages, delay_startup=2), - args=( - host, - unused_tcp_port, - ), - ) - - mock_ws_thread.start() - messages_c1 = ["test_1", "test_2", "test_3", "stop"] - - with Client(url, max_retries=2, timeout_multiplier=2) as c1: - for msg in messages_c1: - c1.send(msg) - - mock_ws_thread.join() - - for msg in messages_c1: - assert msg in messages + url = f"tcp://{host}:{unused_tcp_port}" + client_connection_error_set = False + messages_c1 = ["test_1", "test_2", "test_3"] + async with ( + MockZMQServer(unused_tcp_port, signal=2) as mock_server, + Client(url, ack_timeout=0.5) as c1, + ): + for message in messages_c1: + try: + await c1.send(message, retries=1) + except ClientConnectionError: + client_connection_error_set = True + mock_server.signal(0) + assert client_connection_error_set + assert mock_server.messages.count("test_1") == 2 + assert mock_server.messages.count("test_2") == 1 + assert mock_server.messages.count("test_3") == 1 diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py index 2d2cad74f3b..3959beaa273 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py @@ -2,12 +2,11 @@ import datetime from functools import partial from typing import cast -from unittest.mock import MagicMock import pytest from hypothesis import given from hypothesis import strategies as st -from websockets.server import WebSocketServerProtocol +from pydantic import ValidationError from _ert.events import ( EESnapshot, @@ -21,7 +20,7 @@ RealizationSuccess, event_to_json, ) -from _ert.forward_model_runner.client import Client +from _ert.forward_model_runner.client import CONNECT_MSG, DISCONNECT_MSG, Client from ert.ensemble_evaluator import ( EnsembleEvaluator, EnsembleSnapshot, @@ -55,7 +54,10 @@ async def test_when_task_fails_evaluator_raises_exception( async def mock_failure(message, *args, **kwargs): raise RuntimeError(message) - evaluator = EnsembleEvaluator(TestEnsemble(0, 2, 2, id_="0"), make_ee_config()) + evaluator = EnsembleEvaluator( + TestEnsemble(0, 2, 2, id_="0"), make_ee_config(use_token=False) + ) + monkeypatch.setattr( EnsembleEvaluator, task, @@ -65,17 +67,27 @@ async def mock_failure(message, *args, **kwargs): await evaluator.run_and_get_successful_realizations() -async def test_when_dispatch_is_given_invalid_event_the_socket_is_closed( +async def test_evaluator_raises_on_invalid_dispatch_event( make_ee_config, ): evaluator = EnsembleEvaluator(TestEnsemble(0, 2, 2, id_="0"), make_ee_config()) - socket = MagicMock(spec=WebSocketServerProtocol) - socket.__aiter__.return_value = ["invalid_json"] - await evaluator.handle_dispatch(socket) - socket.close.assert_called_once_with( - code=1011, reason="failed handling message 'invalid_json'" - ) + with pytest.raises(ValidationError): + await evaluator.handle_dispatch(b"dispatcher-1", b"This is not an event!!") + + +async def test_evaluator_handles_dispatchers_connected( + make_ee_config, +): + evaluator = EnsembleEvaluator(TestEnsemble(0, 2, 2, id_="0"), make_ee_config()) + + await evaluator.handle_dispatch(b"dispatcher-1", CONNECT_MSG) + await evaluator.handle_dispatch(b"dispatcher-2", CONNECT_MSG) + assert not evaluator._dispatchers_empty.is_set() + assert evaluator._dispatchers_connected == {b"dispatcher-1", b"dispatcher-2"} + await evaluator.handle_dispatch(b"dispatcher-1", DISCONNECT_MSG) + await evaluator.handle_dispatch(b"dispatcher-2", DISCONNECT_MSG) + assert evaluator._dispatchers_empty.is_set() async def test_no_config_raises_valueerror_when_running(): @@ -110,32 +122,26 @@ async def mock_done_prematurely(message, *args, **kwargs): await evaluator.run_and_get_successful_realizations() -async def test_new_connections_are_denied_when_evaluator_is_closing_down( +async def test_new_connections_are_no_problem_when_evaluator_is_closing_down( evaluator_to_use, ): evaluator = evaluator_to_use - class TestMonitor(Monitor): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._connection_timeout = 1 - async def new_connection(): await evaluator._server_done.wait() - async with TestMonitor(evaluator._config.get_connection_info()): + async with Monitor(evaluator._config.get_connection_info()): pass new_connection_task = asyncio.create_task(new_connection()) evaluator.stop() - with pytest.raises(RuntimeError): - await new_connection_task + await new_connection_task @pytest.fixture(name="evaluator_to_use") async def evaluator_to_use_fixture(make_ee_config): ensemble = TestEnsemble(0, 2, 2, id_="0") - evaluator = EnsembleEvaluator(ensemble, make_ee_config()) + evaluator = EnsembleEvaluator(ensemble, make_ee_config(use_token=False)) evaluator._batching_interval = 0.5 # batching can be faster for tests run_task = asyncio.create_task(evaluator.run_and_get_successful_realizations()) await evaluator._server_started.wait() @@ -149,8 +155,7 @@ async def evaluator_to_use_fixture(make_ee_config): async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use): evaluator = evaluator_to_use token = evaluator._config.token - cert = evaluator._config.cert - url = evaluator._config.url + url = evaluator._config.get_connection_info().router_uri config_info = evaluator._config.get_connection_info() async with Monitor(config_info) as monitor: @@ -161,11 +166,9 @@ async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use): assert snapshot.status == ENSEMBLE_STATE_UNKNOWN # two dispatch endpoint clients connect async with Client( - url + "/dispatch", - cert=cert, + url, token=token, - max_retries=1, - timeout_multiplier=1, + dealer_name="dispatch_from_test_1", ) as dispatch: event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -173,7 +176,7 @@ async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, @@ -181,7 +184,7 @@ async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use): fm_step="0", error_msg="error", ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool: try: @@ -201,11 +204,8 @@ def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool: break async with Client( - url + "/dispatch", - cert=cert, + url, token=token, - max_retries=1, - timeout_multiplier=1, ) as dispatch: event = ForwardModelStepSuccess( ensemble=evaluator.ensemble.id_, @@ -213,7 +213,7 @@ def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool: fm_step="0", current_memory_usage=1000, ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) # reconnect new monitor async with Monitor(config_info) as new_monitor: @@ -243,25 +243,18 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): evaluator = evaluator_to_use token = evaluator._config.token - cert = evaluator._config.cert - url = evaluator._config.url + url = evaluator._config.get_connection_info().router_uri config_info = evaluator._config.get_connection_info() async with Monitor(config_info) as monitor: async with ( Client( - url + "/dispatch", - cert=cert, + url, token=token, - max_retries=1, - timeout_multiplier=1, ) as dispatch1, Client( - url + "/dispatch", - cert=cert, + url, token=token, - max_retries=1, - timeout_multiplier=1, ) as dispatch2, ): # first dispatch endpoint client informs that forward model 0 is running @@ -271,7 +264,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch1._send(event_to_json(event)) + await dispatch1.send(event_to_json(event)) # second dispatch endpoint client informs that forward model 0 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -279,7 +272,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that forward model 1 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -287,7 +280,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="1", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) final_snapshot = EnsembleSnapshot() @@ -318,11 +311,8 @@ def check_if_all_fm_running(snapshot: EnsembleSnapshot) -> bool: # take down first monitor by leaving context async with Client( - url + "/dispatch", - cert=cert, + url, token=token, - max_retries=1, - timeout_multiplier=1, ) as dispatch2: # second dispatch endpoint client informs that job 0 is done event = ForwardModelStepSuccess( @@ -331,12 +321,12 @@ def check_if_all_fm_running(snapshot: EnsembleSnapshot) -> bool: fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that job 1 is failed event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, real="1", fm_step="1", error_msg="error" ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) def check_if_final_snapshot_is_complete(final_snapshot: EnsembleSnapshot) -> bool: try: @@ -378,9 +368,8 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e async with Monitor(conn_info) as monitor: events = monitor.track() token = evaluator._config.token - cert = evaluator._config.cert - url = evaluator._config.url + url = conn_info.router_uri # first snapshot before any event occurs snapshot_event = await anext(events) assert type(snapshot_event) is EESnapshot @@ -389,18 +378,12 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e # two dispatch endpoint clients connect async with ( Client( - url + "/dispatch", - cert=cert, + url, token=token, - max_retries=1, - timeout_multiplier=1, ) as dispatch1, Client( - url + "/dispatch", - cert=cert, + url, token=token, - max_retries=1, - timeout_multiplier=1, ) as dispatch2, ): # first dispatch endpoint client informs that real 0 fm 0 is running @@ -410,7 +393,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch1._send(event_to_json(event)) + await dispatch1.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 0 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -418,7 +401,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 0 is done event = ForwardModelStepSuccess( ensemble=evaluator.ensemble.id_, @@ -426,7 +409,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 1 is failed event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, @@ -434,7 +417,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="1", error_msg="error", ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) event = await anext(events) snapshot = EnsembleSnapshot.from_nested_dict(event.snapshot) @@ -491,24 +474,23 @@ async def test_ensure_multi_level_events_in_order(evaluator_to_use): events = monitor.track() token = evaluator._config.token - cert = evaluator._config.cert - url = evaluator._config.url + url = config_info.router_uri snapshot_event = await anext(events) assert type(snapshot_event) is EESnapshot - async with Client(url + "/dispatch", cert=cert, token=token) as dispatch: + async with Client(url, token=token) as dispatch: event = EnsembleStarted(ensemble=evaluator.ensemble.id_) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = RealizationSuccess( ensemble=evaluator.ensemble.id_, real="0", queue_event_type="" ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = RealizationSuccess( ensemble=evaluator.ensemble.id_, real="1", queue_event_type="" ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = EnsembleSucceeded(ensemble=evaluator.ensemble.id_) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) await monitor.signal_done() diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator_config.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator_config.py index 0049d6e656c..2d8f2189fe7 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator_config.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator_config.py @@ -3,38 +3,35 @@ from ert.ensemble_evaluator.config import EvaluatorServerConfig -def test_load_config(unused_tcp_port): +def test_ensemble_evaluator_config_tcp_protocol(unused_tcp_port): fixed_port = range(unused_tcp_port, unused_tcp_port) serv_config = EvaluatorServerConfig( custom_port_range=fixed_port, custom_host="127.0.0.1", + use_ipc_protocol=False, ) expected_host = "127.0.0.1" expected_port = unused_tcp_port - expected_url = f"wss://{expected_host}:{expected_port}" - expected_client_uri = f"{expected_url}/client" - expected_dispatch_uri = f"{expected_url}/dispatch" + expected_url = f"tcp://{expected_host}:{expected_port}" url = urlparse(serv_config.url) assert url.hostname == expected_host assert url.port == expected_port assert serv_config.url == expected_url - assert serv_config.client_uri == expected_client_uri - assert serv_config.dispatch_uri == expected_dispatch_uri assert serv_config.token is not None - assert serv_config.cert is not None + assert serv_config.server_public_key is not None + assert serv_config.server_secret_key is not None sock = serv_config.get_socket() assert sock is not None assert not sock._closed sock.close() - ee_config = EvaluatorServerConfig( - custom_port_range=range(1024, 65535), - custom_host="127.0.0.1", - use_token=False, - generate_cert=False, - ) - sock = ee_config.get_socket() - assert sock is not None - assert not sock._closed - sock.close() + +def test_ensemble_evaluator_config_ipc_protocol(): + serv_config = EvaluatorServerConfig(use_ipc_protocol=True, use_token=False) + + assert serv_config.url.startswith("ipc:///tmp/socket-") + assert serv_config.token is None + assert serv_config.server_public_key is None + assert serv_config.server_secret_key is None + assert serv_config.get_socket() is None diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index a657a872571..f11d693b15a 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -1,11 +1,9 @@ import asyncio -import contextlib import os from contextlib import asynccontextmanager from unittest.mock import MagicMock import pytest -from websockets.exceptions import ConnectionClosed from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated from ert.config import QueueConfig @@ -44,11 +42,10 @@ async def test_run_legacy_ensemble( custom_port_range=custom_port_range, custom_host="127.0.0.1", use_token=False, - generate_cert=False, ) async with ( evaluator_to_use(ensemble, config) as evaluator, - Monitor(config) as monitor, + Monitor(config.get_connection_info()) as monitor, ): async for event in monitor.track(): if type(event) in { @@ -80,29 +77,25 @@ async def test_run_and_cancel_legacy_ensemble( custom_port_range=custom_port_range, custom_host="127.0.0.1", use_token=False, - generate_cert=False, ) terminated_event = False async with ( evaluator_to_use(ensemble, config) as evaluator, - Monitor(config) as monitor, + Monitor(config.get_connection_info()) as monitor, ): # on lesser hardware the realizations might be killed by max_runtime # and the ensemble is set to STOPPED monitor._receiver_timeout = 10.0 cancel = True - with contextlib.suppress( - ConnectionClosed - ): # monitor throws some variant of CC if dispatcher dies - async for event in monitor.track(heartbeat_interval=0.1): - # Cancel the ensemble upon the arrival of the first event - if cancel: - await monitor.signal_cancel() - cancel = False - if type(event) is EETerminated: - terminated_event = True + async for event in monitor.track(heartbeat_interval=0.1): + # Cancel the ensemble upon the arrival of the first event + if cancel: + await monitor.signal_cancel() + cancel = False + if type(event) is EETerminated: + terminated_event = True if terminated_event: assert evaluator._ensemble.status == state.ENSEMBLE_STATE_CANCELLED diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py index e4615649c72..a8d762ce59c 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py @@ -1,88 +1,123 @@ import asyncio import logging -from http import HTTPStatus -from typing import NoReturn -from urllib.parse import urlparse import pytest -from websockets.asyncio import server -from websockets.exceptions import ConnectionClosedOK +import zmq +import zmq.asyncio -import ert -import ert.ensemble_evaluator from _ert.events import EEUserCancel, EEUserDone, event_from_json +from _ert.forward_model_runner.client import ( + ACK_MSG, + CONNECT_MSG, + DISCONNECT_MSG, + ClientConnectionError, +) from ert.ensemble_evaluator import Monitor from ert.ensemble_evaluator.config import EvaluatorConnectionInfo -async def _mock_ws( - set_when_done: asyncio.Event, handler, ee_config: EvaluatorConnectionInfo -): - async def process_request(connection, request): - if request.path == "/healthcheck": - return connection.respond(HTTPStatus.OK, "") +async def async_zmq_server(port, handler): + zmq_context = zmq.asyncio.Context() + router_socket = zmq_context.socket(zmq.ROUTER) + router_socket.setsockopt(zmq.LINGER, 0) + router_socket.bind(f"tcp://*:{port}") + await handler(router_socket) + router_socket.close() + zmq_context.destroy() + + +async def test_monitor_connects_and_disconnects_successfully(unused_tcp_port): + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") + monitor = Monitor(ee_con_info) + + messages = [] + + async def mock_event_handler(router_socket): + nonlocal messages + while True: + dealer, _, frame = await router_socket.recv_multipart() + await router_socket.send_multipart([dealer, b"", ACK_MSG]) + messages.append((dealer.decode("utf-8"), frame)) + if frame == DISCONNECT_MSG: + break - url = urlparse(ee_config.url) - async with server.serve( - handler, url.hostname, url.port, process_request=process_request - ): - await set_when_done.wait() + websocket_server_task = asyncio.create_task( + async_zmq_server(unused_tcp_port, mock_event_handler) + ) + async with monitor: + pass + await websocket_server_task + dealer, msg = messages[0] + assert dealer.startswith("client-") + assert msg == CONNECT_MSG + dealer, msg = messages[1] + assert dealer.startswith("client-") + assert msg == DISCONNECT_MSG async def test_no_connection_established(make_ee_config): ee_config = make_ee_config() monitor = Monitor(ee_config.get_connection_info()) - monitor._connection_timeout = 0.1 - with pytest.raises( - RuntimeError, match="Couldn't establish connection with the ensemble evaluator!" - ): + monitor._ack_timeout = 0.1 + with pytest.raises(ClientConnectionError): async with monitor: pass async def test_immediate_stop(unused_tcp_port): - ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}") - - set_when_done = asyncio.Event() - - async def mock_ws_event_handler(websocket): - async for raw_msg in websocket: - event = event_from_json(raw_msg) - assert type(event) is EEUserDone - break - await websocket.close() + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") + + connected = False + + async def mock_event_handler(router_socket): + nonlocal connected + while True: + dealer, _, frame = await router_socket.recv_multipart() + await router_socket.send_multipart([dealer, b"", ACK_MSG]) + dealer = dealer.decode("utf-8") + if frame == CONNECT_MSG: + connected = True + elif frame == DISCONNECT_MSG: + connected = False + return + else: + event = event_from_json(frame.decode("utf-8")) + assert connected + assert type(event) is EEUserDone websocket_server_task = asyncio.create_task( - _mock_ws(set_when_done, mock_ws_event_handler, ee_con_info) + async_zmq_server(unused_tcp_port, mock_event_handler) ) async with Monitor(ee_con_info) as monitor: + assert connected is True await monitor.signal_done() - set_when_done.set() await websocket_server_task + assert connected is False -async def test_unexpected_close(unused_tcp_port): - ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}") +async def test_unexpected_close_after_connection_successful( + monkeypatch, unused_tcp_port +): + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") - set_when_done = asyncio.Event() - socket_closed = asyncio.Event() + monkeypatch.setattr(Monitor, "DEFAULT_MAX_RETRIES", 0) + monkeypatch.setattr(Monitor, "DEFAULT_ACK_TIMEOUT", 1) - async def mock_ws_event_handler(websocket): - await websocket.close() - socket_closed.set() + async def mock_event_handler(router_socket): + dealer, _, frame = await router_socket.recv_multipart() + await router_socket.send_multipart([dealer, b"", ACK_MSG]) + dealer = dealer.decode("utf-8") + assert dealer.startswith("client-") + assert frame == CONNECT_MSG + router_socket.close() websocket_server_task = asyncio.create_task( - _mock_ws(set_when_done, mock_ws_event_handler, ee_con_info) + async_zmq_server(unused_tcp_port, mock_event_handler) ) async with Monitor(ee_con_info) as monitor: - # this expects Event send to fail - # but no attempt on resubmitting - # since connection closed via websocket.close - with pytest.raises(ConnectionClosedOK): - await socket_closed.wait() + with pytest.raises(ClientConnectionError): await monitor.signal_done() - set_when_done.set() await websocket_server_task @@ -90,20 +125,29 @@ async def test_that_monitor_track_can_exit_without_terminated_event_from_evaluat unused_tcp_port, caplog ): caplog.set_level(logging.ERROR) - ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}") - - set_when_done = asyncio.Event() - - async def mock_ws_event_handler(websocket): - async for raw_msg in websocket: - event = event_from_json(raw_msg) - assert type(event) is EEUserCancel - break - await websocket.close() + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") + + connected = False + + async def mock_event_handler(router_socket): + nonlocal connected + while True: + dealer, _, frame = await router_socket.recv_multipart() + await router_socket.send_multipart([dealer, b"", ACK_MSG]) + if frame == CONNECT_MSG: + connected = True + elif frame == DISCONNECT_MSG: + connected = False + return + else: + event = event_from_json(frame.decode("utf-8")) + assert connected + assert type(event) is EEUserCancel websocket_server_task = asyncio.create_task( - _mock_ws(set_when_done, mock_ws_event_handler, ee_con_info) + async_zmq_server(unused_tcp_port, mock_event_handler) ) + async with Monitor(ee_con_info) as monitor: monitor._receiver_timeout = 0.1 await monitor.signal_cancel() @@ -115,7 +159,6 @@ async def mock_ws_event_handler(websocket): "Evaluator did not send the TERMINATED event!" ) in caplog.messages, "Monitor receiver did not stop!" - set_when_done.set() await websocket_server_task @@ -124,11 +167,18 @@ async def test_that_monitor_can_emit_heartbeats(unused_tcp_port): exit anytime. A heartbeat is a None event. If the heartbeat is never sent, this test function will hang and then timeout.""" - ee_con_info = EvaluatorConnectionInfo(f"ws://127.0.0.1:{unused_tcp_port}") + ee_con_info = EvaluatorConnectionInfo(f"tcp://127.0.0.1:{unused_tcp_port}") + + async def mock_event_handler(router_socket): + while True: + try: + dealer, _, __ = await router_socket.recv_multipart() + await router_socket.send_multipart([dealer, b"", ACK_MSG]) + except asyncio.CancelledError: + break - set_when_done = asyncio.Event() websocket_server_task = asyncio.create_task( - _mock_ws(set_when_done, None, ee_con_info) + async_zmq_server(unused_tcp_port, mock_event_handler) ) async with Monitor(ee_con_info) as monitor: @@ -136,24 +186,6 @@ async def test_that_monitor_can_emit_heartbeats(unused_tcp_port): if event is None: break - set_when_done.set() # shuts down websocket server - await websocket_server_task - - -@pytest.mark.timeout(10) -async def test_that_monitor_will_raise_exception_if_wait_for_evaluator_fails( - monkeypatch, -): - async def mock_failing_wait_for_evaluator(*args, **kwargs) -> NoReturn: - raise ValueError() - - monkeypatch.setattr( - ert.ensemble_evaluator.monitor, - "wait_for_evaluator", - mock_failing_wait_for_evaluator, - ) - ee_con_info = EvaluatorConnectionInfo("") - - with pytest.raises(ValueError): - async with Monitor(ee_con_info): - pass + if not websocket_server_task.done(): + websocket_server_task.cancel() + asyncio.gather(websocket_server_task, return_exceptions=True) diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py b/tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py index 0160f2053d9..43abdad5d3a 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py @@ -25,7 +25,7 @@ async def rename_and_wait(): Path("real_0/test").rename("real_0/job_test_file") async def _run_monitor(): - async with Monitor(config) as monitor: + async with Monitor(config.get_connection_info()) as monitor: async for event in monitor.track(): if type(event) is ForwardModelStepChecksum: # Monitor got the checksum message renaming the file @@ -60,7 +60,6 @@ def create_manifest_file(): custom_port_range=custom_port_range, custom_host="127.0.0.1", use_token=False, - generate_cert=False, ) evaluator = EnsembleEvaluator(ensemble, config) with caplog.at_level(logging.DEBUG): diff --git a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py index 0575e78b954..237c80a66d7 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py +++ b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py @@ -1,7 +1,5 @@ import os -import sys import time -from unittest.mock import patch import pytest @@ -12,10 +10,6 @@ ForwardModelStepSuccess, event_from_json, ) -from _ert.forward_model_runner.client import ( - ClientConnectionClosedOK, - ClientConnectionError, -) from _ert.forward_model_runner.forward_model_step import ForwardModelStep from _ert.forward_model_runner.reporting import Event from _ert.forward_model_runner.reporting.message import ( @@ -27,7 +21,7 @@ Start, ) from _ert.forward_model_runner.reporting.statemachine import TransitionError -from tests.ert.utils import _mock_ws_thread +from tests.ert.utils import MockZMQServer def _wait_until(condition, timeout, fail_msg): @@ -39,19 +33,18 @@ def _wait_until(condition, timeout, fail_msg): def test_report_with_successful_start_message_argument(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" + url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Start(fmstep1)) reporter.report(Finish()) - assert len(lines) == 1 - event = event_from_json(lines[0]) + assert len(mock_server.messages) == 1 + event = event_from_json(mock_server.messages[0]) assert type(event) is ForwardModelStepStart assert event.ensemble == "ens_id" assert event.real == "0" @@ -62,15 +55,14 @@ def test_report_with_successful_start_message_argument(unused_tcp_port): def test_report_with_failed_start_message_argument(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" + url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) msg = Start(fmstep1).with_error("massive_failure") @@ -78,67 +70,64 @@ def test_report_with_failed_start_message_argument(unused_tcp_port): reporter.report(msg) reporter.report(Finish()) - assert len(lines) == 2 - event = event_from_json(lines[1]) + assert len(mock_server.messages) == 2 + event = event_from_json(mock_server.messages[1]) assert type(event) is ForwardModelStepFailure assert event.error_msg == "massive_failure" -def test_report_with_successful_exit_message_argument(unused_tcp_port): +async def test_report_with_successful_exit_message_argument(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" + url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Exited(fmstep1, 0)) reporter.report(Finish().with_error("failed")) - assert len(lines) == 1 - event = event_from_json(lines[0]) + assert len(mock_server.messages) == 1 + event = event_from_json(mock_server.messages[0]) assert type(event) is ForwardModelStepSuccess def test_report_with_failed_exit_message_argument(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" + url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Exited(fmstep1, 1).with_error("massive_failure")) reporter.report(Finish()) - assert len(lines) == 1 - event = event_from_json(lines[0]) + assert len(mock_server.messages) == 1 + event = event_from_json(mock_server.messages[0]) assert type(event) is ForwardModelStepFailure assert event.error_msg == "massive_failure" def test_report_with_running_message_argument(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" + url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish()) - assert len(lines) == 1 - event = event_from_json(lines[0]) + assert len(mock_server.messages) == 1 + event = event_from_json(mock_server.messages[0]) assert type(event) is ForwardModelStepRunning assert event.max_memory_usage == 100 assert event.current_memory_usage == 10 @@ -146,46 +135,42 @@ def test_report_with_running_message_argument(unused_tcp_port): def test_report_only_job_running_for_successful_run(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" + url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish()) - assert len(lines) == 1 + assert len(mock_server.messages) == 1 def test_report_with_failed_finish_message_argument(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" + url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish().with_error("massive_failure")) - assert len(lines) == 1 + assert len(mock_server.messages) == 1 def test_report_inconsistent_events(unused_tcp_port): host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" + url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) - lines = [] with ( - _mock_ws_thread(host, unused_tcp_port, lines), pytest.raises( TransitionError, match=r"Illegal transition None -> \(MessageType,\)", @@ -194,7 +179,6 @@ def test_report_inconsistent_events(unused_tcp_port): reporter.report(Finish()) -@pytest.mark.integration_test def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): # this is to show when the reporter fails ert won't crash nor # staying hanging but instead finishes up the job; @@ -202,134 +186,52 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): # also assert reporter._timeout_timestamp is None # meaning Finish event initiated _timeout and timeout was reached # which then sets _timeout_timestamp=None - mock_send_retry_time = 2 - - def mock_send(msg): - time.sleep(mock_send_retry_time) - raise ClientConnectionError("Sending failed!") host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" - reporter = Event(evaluator_url=url) - reporter._reporter_timeout = 4 - fmstep1 = ForwardModelStep( - {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 - ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - with patch( - "_ert.forward_model_runner.client.Client.send", lambda x, y: mock_send(y) - ): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) - # set _stop_timestamp - reporter.report(Finish()) + url = f"tcp://{host}:{unused_tcp_port}" + with MockZMQServer(unused_tcp_port) as mock_server: + reporter = Event( + evaluator_url=url, ack_timeout=2, max_retries=0, finished_event_timeout=2 + ) + fmstep1 = ForwardModelStep( + {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 + ) + + mock_server.signal(1) # prevent router to receive messages + reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + reporter.report(Finish()) if reporter._event_publisher_thread.is_alive(): reporter._event_publisher_thread.join() - # set _stop_timestamp to None only when timer stopped - assert reporter._timeout_timestamp is None - assert len(lines) == 0, "expected 0 Job running messages" + assert reporter._done.is_set() + assert len(mock_server.messages) == 0, "expected 0 Job running messages" -@pytest.mark.integration_test -@pytest.mark.flaky(reruns=5) -@pytest.mark.skipif( - sys.platform.startswith("darwin"), reason="Performance can be flaky" -) def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): # this is to show when the reporter fails but reconnects # reporter still manages to send events and completes fine # see assert reporter._timeout_timestamp is not None # meaning Finish event initiated _timeout but timeout wasn't reached since # it finished succesfully - mock_send_retry_time = 0.1 - - def send_func(msg): - time.sleep(mock_send_retry_time) - raise ClientConnectionError("Sending failed!") host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" - reporter = Event(evaluator_url=url) - fmstep1 = ForwardModelStep( - {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 - ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): - with patch("_ert.forward_model_runner.client.Client.send") as patched_send: - patched_send.side_effect = send_func - - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10))) - - _wait_until( - condition=lambda: patched_send.call_count == 3, - timeout=10, - fail_msg="10 seconds should be sufficient to send three events", - ) - - # reconnect and continue sending events - # set _stop_timestamp - reporter.report(Finish()) - if reporter._event_publisher_thread.is_alive(): - reporter._event_publisher_thread.join() - # set _stop_timestamp was not set to None since the reporter finished on time - assert reporter._timeout_timestamp is not None - assert len(lines) == 3, "expected 3 Job running messages" - - -@pytest.mark.integration_test -def test_report_with_closed_received_exiting_gracefully(unused_tcp_port): - # Whenever the receiver end closes the connection, a ConnectionClosedOK is raised - # The reporter should exit the publisher thread gracefully and not send any - # more events - mock_send_retry_time = 3 - - def mock_send(msg): - time.sleep(mock_send_retry_time) - raise ClientConnectionClosedOK("Connection Closed") + url = f"tcp://{host}:{unused_tcp_port}" + with MockZMQServer(unused_tcp_port) as mock_server: + reporter = Event(evaluator_url=url, ack_timeout=1, max_retries=1) + fmstep1 = ForwardModelStep( + {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 + ) - host = "localhost" - url = f"ws://{host}:{unused_tcp_port}" - reporter = Event(evaluator_url=url) - fmstep1 = ForwardModelStep( - {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 - ) - lines = [] - with _mock_ws_thread(host, unused_tcp_port, lines): + mock_server.signal(1) # prevent router to receive messages reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) - - # sleep until both Running events have been received - _wait_until( - condition=lambda: len(lines) == 2, - timeout=10, - fail_msg="Should not take 10 seconds to send two events", - ) - - with patch( - "_ert.forward_model_runner.client.Client.send", lambda x, y: mock_send(y) - ): - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10))) - # Make sure the publisher thread exits because it got - # ClientConnectionClosedOK. If it hangs it could indicate that the - # exception is not caught/handled correctly - if reporter._event_publisher_thread.is_alive(): - reporter._event_publisher_thread.join() - - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=400, rss=10))) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + mock_server.signal(0) # enable router to receive messages reporter.report(Finish()) - - # set _stop_timestamp was not set to None since the reporter finished on time - assert reporter._timeout_timestamp is not None - - # The Running(fmstep1, 300, 10) is popped from the queue, but never sent. - # The following Running is added to queue along with the sentinel - assert reporter._event_queue.qsize() == 2 - # None of the messages after ClientConnectionClosedOK was raised, has been sent - assert len(lines) == 2, "expected 2 Job running messages" + if reporter._event_publisher_thread.is_alive(): + reporter._event_publisher_thread.join() + assert reporter._done.is_set() + assert len(mock_server.messages) == 3, "expected 3 Job running messages" diff --git a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py index baa02bb61d6..b7e9102ca50 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py +++ b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py @@ -23,7 +23,7 @@ from _ert.forward_model_runner.reporting import Event, Interactive from _ert.forward_model_runner.reporting.message import Finish, Init from _ert.threading import ErtThread -from tests.ert.utils import _mock_ws_thread, wait_until +from tests.ert.utils import MockZMQServer, wait_until from .test_event_reporter import _wait_until @@ -302,7 +302,7 @@ def test_retry_of_jobs_json_file_read(unused_tcp_port, tmp_path, monkeypatch, ca jobs_json = json.dumps( { "ens_id": "_id_", - "dispatch_url": f"ws://localhost:{unused_tcp_port}", + "dispatch_url": f"tcp://localhost:{unused_tcp_port}", "jobList": [], } ) @@ -316,7 +316,7 @@ def create_jobs_file_after_lock(): (tmp_path / JOBS_FILE).write_text(jobs_json) lock.release() - with _mock_ws_thread("localhost", unused_tcp_port, []): + with MockZMQServer(unused_tcp_port): thread = ErtThread(target=create_jobs_file_after_lock) thread.start() main(args=["script.py", str(tmp_path)]) @@ -344,10 +344,11 @@ def test_setup_reporters(is_interactive_run, ens_id): @pytest.mark.usefixtures("use_tmpdir") -def test_fm_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port): - host = "localhost" +def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port): port = unused_tcp_port - jobs_json = json.dumps({"ens_id": "_id_", "dispatch_url": f"ws://localhost:{port}"}) + jobs_json = json.dumps( + {"ens_id": "_id_", "dispatch_url": f"tcp://localhost:{port}"} + ) with ( patch("_ert.forward_model_runner.cli.os.killpg") as mock_killpg, @@ -361,7 +362,7 @@ def test_fm_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port): ] mock_getpgid.return_value = 17 - with _mock_ws_thread(host, port, []): + with MockZMQServer(port): main(["script.py"]) mock_killpg.assert_called_with(17, signal.SIGKILL) diff --git a/tests/ert/unit_tests/scheduler/test_scheduler.py b/tests/ert/unit_tests/scheduler/test_scheduler.py index fb33ae3de89..ba8df1e97f8 100644 --- a/tests/ert/unit_tests/scheduler/test_scheduler.py +++ b/tests/ert/unit_tests/scheduler/test_scheduler.py @@ -11,7 +11,6 @@ from _ert.events import Id, RealizationFailed, RealizationTimeout from ert.config import QueueConfig -from ert.constant_filenames import CERT_FILE from ert.ensemble_evaluator import Realization from ert.load_status import LoadResult, LoadStatus from ert.run_arg import RunArg @@ -124,10 +123,9 @@ async def kill(): async def test_add_dispatch_information_to_jobs_file( storage, tmp_path: Path, mock_driver ): - test_ee_uri = "ws://test_ee_uri.com/121/" + test_ee_uri = "tcp://test_ee_uri.com/121/" test_ens_id = "test_ens_id121" test_ee_token = "test_ee_token_t0k€n121" - test_ee_cert = "test_ee_cert121.pem" ensemble_size = 10 @@ -144,7 +142,6 @@ async def test_add_dispatch_information_to_jobs_file( realizations=realizations, ens_id=test_ens_id, ee_uri=test_ee_uri, - ee_cert=test_ee_cert, ee_token=test_ee_token, ) @@ -155,15 +152,12 @@ async def test_add_dispatch_information_to_jobs_file( for realization in realizations: job_file_path = Path(realization.run_arg.runpath) / "jobs.json" - cert_file_path = Path(realization.run_arg.runpath) / CERT_FILE content: dict = json.loads(job_file_path.read_text(encoding="utf-8")) assert content["ens_id"] == test_ens_id assert content["real_id"] == realization.iens assert content["dispatch_url"] == test_ee_uri assert content["ee_token"] == test_ee_token - assert content["ee_cert_path"] == str(cert_file_path) assert type(content["jobList"]) == list and len(content["jobList"]) == 0 - assert cert_file_path.read_text(encoding="utf-8") == test_ee_cert @pytest.mark.parametrize("max_submit", [1, 2, 3]) diff --git a/tests/ert/unit_tests/test_tracking.py b/tests/ert/unit_tests/test_tracking.py index ab9dd76d41b..5f42a9fb1e6 100644 --- a/tests/ert/unit_tests/test_tracking.py +++ b/tests/ert/unit_tests/test_tracking.py @@ -188,7 +188,6 @@ def test_tracking( custom_port_range=range(1024, 65535), custom_host="127.0.0.1", use_token=False, - generate_cert=False, ) thread = ErtThread( @@ -279,7 +278,6 @@ def test_setting_env_context_during_run( custom_port_range=range(1024, 65535), custom_host="127.0.0.1", use_token=False, - generate_cert=False, ) queue = Events() model = create_model( @@ -356,7 +354,6 @@ def test_run_information_present_as_env_var_in_fm_context( custom_port_range=range(1024, 65535), custom_host="127.0.0.1", use_token=False, - generate_cert=False, ) queue = Events() model = create_model(ert_config, storage, parsed, queue) diff --git a/tests/ert/utils.py b/tests/ert/utils.py index 732f816f8cd..4f66f93e157 100644 --- a/tests/ert/utils.py +++ b/tests/ert/utils.py @@ -3,13 +3,13 @@ import asyncio import contextlib import time -from functools import partial from pathlib import Path from typing import TYPE_CHECKING -import websockets.server +import zmq +import zmq.asyncio -from _ert.forward_model_runner.client import Client +from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG from _ert.threading import ErtThread from ert.scheduler.event import FinishedEvent, StartedEvent @@ -61,47 +61,72 @@ def wait_until(func, interval=0.5, timeout=30): ) -def _mock_ws(host, port, messages, delay_startup=0): - loop = asyncio.new_event_loop() - done = loop.create_future() - - async def _handler(websocket, path): +class MockZMQServer: + def __init__(self, port, signal=0): + """Mock ZMQ server for testing + signal = 0: normal operation + signal = 1: don't send ACK and don't receive messages + signal = 2: don't send ACK, but receive messages + """ + self.port = port + self.messages = [] + self.value = signal + self.loop = None + self.server_task = None + self.handler_task = None + + def start_event_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_until_complete(self.mock_zmq_server()) + + def __enter__(self): + self.loop = asyncio.new_event_loop() + self.thread = ErtThread(target=self.start_event_loop) + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.handler_task and not self.handler_task.done(): + self.loop.call_soon_threadsafe(self.handler_task.cancel) + self.thread.join() + self.loop.close() + + async def __aenter__(self): + self.server_task = asyncio.create_task(self.mock_zmq_server()) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if not self.server_task.done(): + self.server_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self.server_task + + async def mock_zmq_server(self): + zmq_context = zmq.asyncio.Context() + self.router_socket = zmq_context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://*:{self.port}") + + self.handler_task = asyncio.create_task(self._handler()) + try: + await self.handler_task + finally: + self.router_socket.close() + zmq_context.term() + + def signal(self, value): + self.value = value + + async def _handler(self): while True: - msg = await websocket.recv() - messages.append(msg) - if msg == "stop": - done.set_result(None) + try: + dealer, __, frame = await self.router_socket.recv_multipart() + if frame in {CONNECT_MSG, DISCONNECT_MSG} or self.value == 0: + await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) + if frame not in {CONNECT_MSG, DISCONNECT_MSG} and self.value != 1: + self.messages.append(frame.decode("utf-8")) + except asyncio.CancelledError: break - async def _run_server(): - await asyncio.sleep(delay_startup) - async with websockets.server.serve(_handler, host, port): - await done - - loop.run_until_complete(_run_server()) - loop.close() - - -@contextlib.contextmanager -def _mock_ws_thread(host, port, messages): - mock_ws_thread = ErtThread( - target=partial(_mock_ws, messages=messages), - args=( - host, - port, - ), - ) - mock_ws_thread.start() - try: - yield - # Make sure to join the thread even if an exception occurs - finally: - url = f"ws://{host}:{port}" - with Client(url) as client: - client.send("stop") - mock_ws_thread.join() - messages.pop() - async def poll(driver: Driver, expected: set[int], *, started=None, finished=None): """Poll driver until expected realisations finish