From bd2e161f37a1bf2eb6c9ea7296364abf56ff20d7 Mon Sep 17 00:00:00 2001 From: xjules Date: Fri, 13 Dec 2024 16:41:32 +0100 Subject: [PATCH] Create a mock zmq server class to be used in tests --- .../forward_model_runner/reporting/event.py | 8 +- .../test_ensemble_client.py | 22 +-- .../test_event_reporter.py | 72 ++++----- .../forward_model_runner/test_job_dispatch.py | 6 +- tests/ert/utils.py | 142 +++++++----------- 5 files changed, 103 insertions(+), 147 deletions(-) diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 43b2104ddfb..8ca50302131 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -61,6 +61,7 @@ def __init__( token=None, ack_timeout=None, max_retries=None, + finished_event_timeout=None, ): self._evaluator_url = evaluator_url self._token = token @@ -78,6 +79,7 @@ def __init__( self._done = threading.Event() self._ack_timeout = ack_timeout self._max_retries = max_retries + self._finished_event_timeout = finished_event_timeout or 60 def stop(self): self._event_queue.put(Event._sentinel) @@ -102,7 +104,11 @@ async def publisher(): event = self._event_queue.get() if event is self._sentinel: break - if start_time and (time.time() - start_time) > 60: + if ( + start_time + and (time.time() - start_time) + > self._finished_event_timeout + ): break await client._send(event_to_json(event), self._max_retries) event = None diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py index 90d6555f574..3a87cdac364 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py @@ -1,9 +1,7 @@ -import asyncio - import pytest from _ert.forward_model_runner.client import Client, ClientConnectionError -from tests.ert.utils import mock_zmq_task +from tests.ert.utils import MockZMQServer @pytest.mark.integration_test @@ -20,26 +18,22 @@ async def test_invalid_server(): async def test_successful_sending(unused_tcp_port): host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" - messages = [] messages_c1 = ["test_1", "test_2", "test_3"] - async with mock_zmq_task(unused_tcp_port, messages), Client(url) as c1: + async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1: for message in messages_c1: await c1._send(message) for msg in messages_c1: - assert msg in messages + assert msg in mock_server.messages async def test_retry(unused_tcp_port): host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" - messages = [] - signal_queue = asyncio.Queue() - await signal_queue.put(2) client_connection_error_set = False messages_c1 = ["test_1", "test_2", "test_3"] async with ( - mock_zmq_task(unused_tcp_port, messages, signal_queue), + MockZMQServer(unused_tcp_port, signal=2) as mock_server, Client(url, ack_timeout=0.5) as c1, ): for message in messages_c1: @@ -47,8 +41,8 @@ async def test_retry(unused_tcp_port): await c1._send(message, retries=1) except ClientConnectionError: client_connection_error_set = True - await signal_queue.put(0) + mock_server.signal(0) assert client_connection_error_set - assert messages.count("test_1") == 2 - assert messages.count("test_2") == 1 - assert messages.count("test_3") == 1 + assert mock_server.messages.count("test_1") == 2 + assert mock_server.messages.count("test_2") == 1 + assert mock_server.messages.count("test_3") == 1 diff --git a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py index 15c63ccb7a7..32c62f574b0 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py +++ b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py @@ -1,5 +1,4 @@ import os -import queue import time import pytest @@ -22,7 +21,7 @@ Start, ) from _ert.forward_model_runner.reporting.statemachine import TransitionError -from tests.ert.utils import mock_zmq_thread +from tests.ert.utils import MockZMQServer def _wait_until(condition, timeout, fail_msg): @@ -39,14 +38,13 @@ def test_report_with_successful_start_message_argument(unused_tcp_port): fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with mock_zmq_thread(unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Start(fmstep1)) reporter.report(Finish()) - assert len(lines) == 1 - event = event_from_json(lines[0]) + assert len(mock_server.messages) == 1 + event = event_from_json(mock_server.messages[0]) assert type(event) is ForwardModelStepStart assert event.ensemble == "ens_id" assert event.real == "0" @@ -64,8 +62,7 @@ def test_report_with_failed_start_message_argument(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with mock_zmq_thread(unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) msg = Start(fmstep1).with_error("massive_failure") @@ -73,13 +70,13 @@ def test_report_with_failed_start_message_argument(unused_tcp_port): reporter.report(msg) reporter.report(Finish()) - assert len(lines) == 2 - event = event_from_json(lines[1]) + assert len(mock_server.messages) == 2 + event = event_from_json(mock_server.messages[1]) assert type(event) is ForwardModelStepFailure assert event.error_msg == "massive_failure" -def test_report_with_successful_exit_message_argument(unused_tcp_port): +async def test_report_with_successful_exit_message_argument(unused_tcp_port): host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" reporter = Event(evaluator_url=url) @@ -87,14 +84,13 @@ def test_report_with_successful_exit_message_argument(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with mock_zmq_thread(unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Exited(fmstep1, 0)) reporter.report(Finish().with_error("failed")) - assert len(lines) == 1 - event = event_from_json(lines[0]) + assert len(mock_server.messages) == 1 + event = event_from_json(mock_server.messages[0]) assert type(event) is ForwardModelStepSuccess @@ -106,14 +102,13 @@ def test_report_with_failed_exit_message_argument(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with mock_zmq_thread(unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Exited(fmstep1, 1).with_error("massive_failure")) reporter.report(Finish()) - assert len(lines) == 1 - event = event_from_json(lines[0]) + assert len(mock_server.messages) == 1 + event = event_from_json(mock_server.messages[0]) assert type(event) is ForwardModelStepFailure assert event.error_msg == "massive_failure" @@ -126,14 +121,13 @@ def test_report_with_running_message_argument(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with mock_zmq_thread(unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish()) - assert len(lines) == 1 - event = event_from_json(lines[0]) + assert len(mock_server.messages) == 1 + event = event_from_json(mock_server.messages[0]) assert type(event) is ForwardModelStepRunning assert event.max_memory_usage == 100 assert event.current_memory_usage == 10 @@ -147,13 +141,12 @@ def test_report_only_job_running_for_successful_run(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with mock_zmq_thread(unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish()) - assert len(lines) == 1 + assert len(mock_server.messages) == 1 def test_report_with_failed_finish_message_argument(unused_tcp_port): @@ -164,13 +157,12 @@ def test_report_with_failed_finish_message_argument(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - lines = [] - with mock_zmq_thread(unused_tcp_port, lines): + with MockZMQServer(unused_tcp_port) as mock_server: reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish().with_error("massive_failure")) - assert len(lines) == 1 + assert len(mock_server.messages) == 1 def test_report_inconsistent_events(unused_tcp_port): @@ -197,15 +189,15 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" - lines = [] - signal_queue = queue.Queue() - with mock_zmq_thread(unused_tcp_port, lines, signal_queue): - reporter = Event(evaluator_url=url, ack_timeout=1, max_retries=1) + with MockZMQServer(unused_tcp_port) as mock_server: + reporter = Event( + evaluator_url=url, ack_timeout=2, max_retries=1, finished_event_timeout=2 + ) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - signal_queue.put(1) # prevent router to receive messages + mock_server.signal(1) # prevent router to receive messages reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) @@ -214,7 +206,7 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): if reporter._event_publisher_thread.is_alive(): reporter._event_publisher_thread.join() assert reporter._done.is_set() - assert len(lines) == 0, "expected 0 Job running messages" + assert len(mock_server.messages) == 0, "expected 0 Job running messages" def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): @@ -226,25 +218,23 @@ def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" - lines = [] - signal_queue = queue.Queue() - with mock_zmq_thread(unused_tcp_port, lines, signal_queue): + with MockZMQServer(unused_tcp_port) as mock_server: reporter = Event(evaluator_url=url, ack_timeout=1, max_retries=1) fmstep1 = ForwardModelStep( {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) - signal_queue.put(1) # prevent router to receive messages + mock_server.signal(1) # prevent router to receive messages reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) - signal_queue.put(0) # enable router to receive messages + mock_server.signal(0) # enable router to receive messages reporter.report(Finish()) if reporter._event_publisher_thread.is_alive(): reporter._event_publisher_thread.join() assert reporter._done.is_set() - assert len(lines) == 3, "expected 3 Job running messages" + assert len(mock_server.messages) == 3, "expected 3 Job running messages" # REFACTOR maybe we don't this anymore diff --git a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py index a39632d0800..d1f16819930 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py +++ b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py @@ -23,7 +23,7 @@ from _ert.forward_model_runner.reporting import Event, Interactive from _ert.forward_model_runner.reporting.message import Finish, Init from _ert.threading import ErtThread -from tests.ert.utils import mock_zmq_thread, wait_until +from tests.ert.utils import MockZMQServer, wait_until from .test_event_reporter import _wait_until @@ -316,7 +316,7 @@ def create_jobs_file_after_lock(): (tmp_path / JOBS_FILE).write_text(jobs_json) lock.release() - with mock_zmq_thread(unused_tcp_port, []): + with MockZMQServer(unused_tcp_port): thread = ErtThread(target=create_jobs_file_after_lock) thread.start() main(args=["script.py", str(tmp_path)]) @@ -362,7 +362,7 @@ def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port): ] mock_getpgid.return_value = 17 - with mock_zmq_thread(port, []): + with MockZMQServer(port): main(["script.py"]) mock_killpg.assert_called_with(17, signal.SIGKILL) diff --git a/tests/ert/utils.py b/tests/ert/utils.py index 19a9aac74a6..1665ac4e0a5 100644 --- a/tests/ert/utils.py +++ b/tests/ert/utils.py @@ -2,7 +2,6 @@ import asyncio import contextlib -import queue import time from pathlib import Path from typing import TYPE_CHECKING @@ -62,101 +61,68 @@ def wait_until(func, interval=0.5, timeout=30): ) -@contextlib.asynccontextmanager -async def mock_zmq_task(port, messages, signal_queue=None): - async def mock_zmq_server(messages, port, signal_queue=None): - async def _handler(router_socket): - signal_value = 0 - while True: - dealer, __, frame = await router_socket.recv_multipart() - if signal_queue: - with contextlib.suppress(TimeoutError): - signal_value = await asyncio.wait_for( - signal_queue.get(), timeout=0.1 - ) - - print(f"{dealer=} {frame=} {signal_value=}") - frame = frame.decode("utf-8") - if frame in [CONNECT_MSG, DISCONNECT_MSG] or signal_value == 0: - await router_socket.send_multipart([dealer, b"", ACK_MSG]) - if frame not in [CONNECT_MSG, DISCONNECT_MSG] and signal_value != 1: - messages.append(frame) +class MockZMQServer: + def __init__(self, port, signal=0): + self.port = port + self.messages = [] + self.value = signal + self.loop = None + self.server_task = None + self.handler_task = None + + def start_event_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_until_complete(self.mock_zmq_server()) + + def __enter__(self): + self.loop = asyncio.new_event_loop() + self.thread = ErtThread(target=self.start_event_loop) + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.handler_task and not self.handler_task.done(): + self.loop.call_soon_threadsafe(self.handler_task.cancel) + self.thread.join() + self.loop.close() + + async def __aenter__(self): + self.server_task = asyncio.create_task(self.mock_zmq_server()) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if not self.server_task.done(): + self.server_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self.server_task + async def mock_zmq_server(self): zmq_context = zmq.asyncio.Context() - router_socket = zmq_context.socket(zmq.ROUTER) - router_socket.bind(f"tcp://*:{port}") + self.router_socket = zmq_context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://*:{self.port}") + self.handler_task = asyncio.create_task(self._handler()) try: - await _handler(router_socket) + await self.handler_task finally: - router_socket.close() + self.router_socket.close() zmq_context.term() - # Create the server task - server_task = asyncio.create_task(mock_zmq_server(messages, port, signal_queue)) - - try: - yield - finally: - print(f"these are the final {messages=}") - if not server_task.done(): - server_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await server_task - - -@contextlib.contextmanager -def mock_zmq_thread(port, messages, signal_queue=None): - loop = None - handler_task = None - - def mock_zmq_server(messages, port, signal_queue=None): - nonlocal loop, handler_task - loop = asyncio.new_event_loop() - - async def _handler(router_socket): - nonlocal messages, signal_queue - signal_value = 0 - while True: - try: - dealer, __, frame = await router_socket.recv_multipart() - if signal_queue: - with contextlib.suppress(queue.Empty): - signal_value = signal_queue.get(timeout=0.1) - - print(f"{dealer=} {frame=} {signal_value=}") - frame = frame.decode("utf-8") - if frame in [CONNECT_MSG, DISCONNECT_MSG] or signal_value == 0: - await router_socket.send_multipart([dealer, b"", ACK_MSG]) - if frame not in [CONNECT_MSG, DISCONNECT_MSG] and signal_value != 1: - messages.append(frame) - - except asyncio.CancelledError: - break - - async def _run_server(): - nonlocal handler_task - zmq_context = zmq.asyncio.Context() # type: ignore - router_socket = zmq_context.socket(zmq.ROUTER) - router_socket.bind(f"tcp://*:{port}") - handler_task = asyncio.create_task(_handler(router_socket)) - await handler_task - router_socket.close() - - loop.run_until_complete(_run_server()) - loop.close() + def signal(self, value): + self.value = value - mock_zmq_thread = ErtThread( - target=lambda: mock_zmq_server(messages, port, signal_queue), - ) - mock_zmq_thread.start() - try: - yield - finally: - print(f"these are the final {messages=}") - if handler_task and not handler_task.done(): - loop.call_soon_threadsafe(handler_task.cancel) - mock_zmq_thread.join() + async def _handler(self): + while True: + try: + dealer, __, frame = await self.router_socket.recv_multipart() + print(f"{dealer=} {frame=} {self.value=}") + frame = frame.decode("utf-8") + if frame in [CONNECT_MSG, DISCONNECT_MSG] or self.value == 0: + await self.router_socket.send_multipart([dealer, b"", ACK_MSG]) + if frame not in [CONNECT_MSG, DISCONNECT_MSG] and self.value != 1: + self.messages.append(frame) + except asyncio.CancelledError: + break async def poll(driver: Driver, expected: set[int], *, started=None, finished=None):