Skip to content

Commit

Permalink
Implementing router-dealer pattern with custom acknowledgments with zmq
Browse files Browse the repository at this point in the history
 - dealers always wait for acknowledgment from the evaluator
 - removing websockets, no more wait_for_evaluator
 - Settup encryption with curve
 - each dealer (client, dispatcher) will get a unique name
 - Make sure to check cancellation error when sending event from client
 - Monitor is an advanced version Client
 - Make sure to wait to ensemble start before initiating monitor
 - Use TCP protocol only when using LSF, SLURM or TORQUE queues
 - Remove certificate
 - Use ipc_protocol when using LOCAL driver
  • Loading branch information
xjules committed Dec 16, 2024
1 parent ebe548e commit 7762f6f
Show file tree
Hide file tree
Showing 33 changed files with 780 additions and 1,082 deletions.
1 change: 0 additions & 1 deletion docs/ert/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
("py:class", "pydantic.types.PositiveInt"),
("py:class", "LibresFacade"),
("py:class", "pandas.core.frame.DataFrame"),
("py:class", "websockets.server.WebSocketServerProtocol"),
("py:class", "EnsembleReader"),
]
nitpick_ignore_regex = [
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dependencies = [
"python-dateutil",
"python-multipart", # extra dependency for fastapi
"pyyaml",
"pyzmq",
"qtpy",
"requests",
"resfo",
Expand All @@ -68,7 +69,6 @@ dependencies = [
"tqdm>=4.62.0",
"typing_extensions>=4.5",
"uvicorn >= 0.17.0",
"websockets",
"xarray",
"xtgeo >= 3.3.0",
]
Expand Down
9 changes: 1 addition & 8 deletions src/_ert/forward_model_runner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@ def _setup_reporters(
ens_id,
dispatch_url,
ee_token=None,
ee_cert_path=None,
experiment_id=None,
) -> list[reporting.Reporter]:
reporters: list[reporting.Reporter] = []
if is_interactive_run:
reporters.append(reporting.Interactive())
elif ens_id and experiment_id is None:
reporters.append(reporting.File())
reporters.append(
reporting.Event(
evaluator_url=dispatch_url, token=ee_token, cert_path=ee_cert_path
)
)
reporters.append(reporting.Event(evaluator_url=dispatch_url, token=ee_token))
else:
reporters.append(reporting.File())
return reporters
Expand Down Expand Up @@ -123,7 +118,6 @@ def main(args):
experiment_id = jobs_data.get("experiment_id")
ens_id = jobs_data.get("ens_id")
ee_token = jobs_data.get("ee_token")
ee_cert_path = jobs_data.get("ee_cert_path")
dispatch_url = jobs_data.get("dispatch_url")

is_interactive_run = len(parsed_args.job) > 0
Expand All @@ -132,7 +126,6 @@ def main(args):
ens_id,
dispatch_url,
ee_token,
ee_cert_path,
experiment_id,
)

Expand Down
209 changes: 121 additions & 88 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from __future__ import annotations

import asyncio
import logging
import ssl
from typing import Any, AnyStr, Self

from websockets.asyncio.client import ClientConnection, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
ConnectionClosedError,
ConnectionClosedOK,
InvalidHandshake,
InvalidURI,
)
import uuid
from typing import Any, Self

import zmq
import zmq.asyncio

from _ert.async_utils import new_event_loop

Expand All @@ -25,108 +21,145 @@ class ClientConnectionClosedOK(Exception):
pass


CONNECT_MSG = "CONNECT"
DISCONNECT_MSG = "DISCONNECT"
ACK_MSG = b"ACK"


class Client:
DEFAULT_MAX_RETRIES = 10
DEFAULT_TIMEOUT_MULTIPLIER = 5
CONNECTION_TIMEOUT = 60
DEFAULT_MAX_RETRIES = 5
DEFAULT_ACK_TIMEOUT = 5
_receiver_task: asyncio.Task[None] | None

def __enter__(self) -> Self:
self.loop.run_until_complete(self.__aenter__())
return self

def term(self) -> None:
self.socket.close()
self.context.term()

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
if self.websocket is not None:
self.loop.run_until_complete(self.websocket.close())
self.loop.run_until_complete(self.__aexit__(exc_type, exc_value, exc_traceback))
self.loop.close()

async def __aenter__(self) -> "Client":
async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
if self.websocket is not None:
await self.websocket.close()
try:
await self._send(DISCONNECT_MSG)
except ClientConnectionError:
logger.error("No ack for dealer disconnection. Connection is down!")
finally:
self.socket.disconnect(self.url)
await self._term_receiver_task()
self.term()

async def _term_receiver_task(self) -> None:
if self._receiver_task and not self._receiver_task.done():
self._receiver_task.cancel()
await asyncio.gather(self._receiver_task, return_exceptions=True)
self._receiver_task = None

def __init__(
self,
url: str,
token: str | None = None,
cert: str | bytes | None = None,
max_retries: int | None = None,
timeout_multiplier: int | None = None,
dealer_name: str | None = None,
ack_timeout: float | None = None,
) -> None:
if max_retries is None:
max_retries = self.DEFAULT_MAX_RETRIES
if timeout_multiplier is None:
timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER
if url is None:
raise ValueError("url was None")
self._ack_timeout = ack_timeout or self.DEFAULT_ACK_TIMEOUT
self.url = url
self.token = token
self._additional_headers = Headers()

# Set up ZeroMQ context and socke
self._ack_event: asyncio.Event = asyncio.Event()
self.context = zmq.asyncio.Context()
self.socket = self.context.socket(zmq.DEALER)
self.socket.setsockopt(zmq.LINGER, 0)
if dealer_name is None:
self.dealer_id = f"dispatch-{uuid.uuid4().hex[:8]}"
else:
self.dealer_id = dealer_name
self.socket.setsockopt_string(zmq.IDENTITY, self.dealer_id)
print(f"Created: {self.dealer_id=} {token=} {self._ack_timeout=}")
if token is not None:
self._additional_headers["token"] = token

# Mimics the behavior of the ssl argument when connection to
# websockets. If none is specified it will deduce based on the url,
# if True it will enforce TLS, and if you want to use self signed
# certificates you need to pass an ssl_context with the certificate
# loaded.
self._ssl_context: bool | ssl.SSLContext | None = None
if cert is not None:
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ssl_context.load_verify_locations(cadata=cert)
elif url.startswith("wss"):
self._ssl_context = True

self._max_retries = max_retries
self._timeout_multiplier = timeout_multiplier
self.websocket: ClientConnection | None = None
client_public, client_secret = zmq.curve_keypair()
self.socket.curve_secretkey = client_secret
self.socket.curve_publickey = client_public
self.socket.curve_serverkey = token.encode("utf-8")

self.loop = new_event_loop()
self._receiver_task = None

async def connect(self) -> None:
self.socket.connect(self.url)
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self._send(CONNECT_MSG, retries=1)
except ClientConnectionError:
await self._term_receiver_task()
self.term()
raise

def send(self, message: str, retries: int | None = None) -> None:
self.loop.run_until_complete(self._send(message, retries))

async def process_message(self, msg: str) -> None:
pass

async def _receiver(self) -> None:
while True:
try:
_, raw_msg = await self.socket.recv_multipart()
if raw_msg == ACK_MSG:
self._ack_event.set()
else:
await self.process_message(raw_msg.decode("utf-8"))
except zmq.ZMQError as exc:
logger.debug(
f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}"
)
await asyncio.sleep(1)
self.socket.connect(self.url)

async def get_websocket(self) -> ClientConnection:
return await connect(
self.url,
ssl=self._ssl_context,
additional_headers=self._additional_headers,
open_timeout=self.CONNECTION_TIMEOUT,
ping_timeout=self.CONNECTION_TIMEOUT,
ping_interval=self.CONNECTION_TIMEOUT,
close_timeout=self.CONNECTION_TIMEOUT,
)
async def _send(self, message: str, retries: int | None = None) -> None:
self._ack_event.clear()

async def _send(self, msg: AnyStr) -> None:
for retry in range(self._max_retries + 1):
backoff = 1
retries = retries or self.DEFAULT_MAX_RETRIES
while retries >= 0:
try:
if self.websocket is None:
self.websocket = await self.get_websocket()
await self.websocket.send(msg)
return
except ConnectionClosedOK as exception:
error_msg = (
f"Connection closed received from the server {self.url}! "
f" Exception from {type(exception)}: {exception!s}"
)
raise ClientConnectionClosedOK(error_msg) from exception
except (TimeoutError, InvalidHandshake, InvalidURI, OSError) as exception:
if retry == self._max_retries:
error_msg = (
f"Not able to establish the "
f"websocket connection {self.url}! Max retries reached!"
" Check for firewall issues."
f" Exception from {type(exception)}: {exception!s}"
await self.socket.send_multipart([b"", message.encode("utf-8")])
try:
await asyncio.wait_for(
self._ack_event.wait(), timeout=self._ack_timeout
)
raise ClientConnectionError(error_msg) from exception
except ConnectionClosedError as exception:
if retry == self._max_retries:
error_msg = (
f"Not been able to send the event"
f" to {self.url}! Max retries reached!"
f" Exception from {type(exception)}: {exception!s}"
return
except TimeoutError:
logger.warning(
f"{self.dealer_id} failed to get acknowledgment on the {message}. Resending."
)
raise ClientConnectionError(error_msg) from exception
await asyncio.sleep(0.2 + self._timeout_multiplier * retry)
self.websocket = None

def send(self, msg: AnyStr) -> None:
self.loop.run_until_complete(self._send(msg))
except zmq.ZMQError as exc:
logger.debug(
f"{self.dealer_id} connection to evaluator went down, reconnecting: {exc}"
)
await asyncio.sleep(1)
self.socket.connect(self.url)
except asyncio.CancelledError:
self.term()
raise

retries -= 1
if retries > 0:
logger.info(f"Retrying... ({retries} attempts left)")
await asyncio.sleep(backoff)
backoff = min(backoff * 2, 10) # Exponential backoff
raise ClientConnectionError(
f"{self.dealer_id} Failed to send {message=} after {retries=}"
)
Loading

0 comments on commit 7762f6f

Please sign in to comment.