Skip to content

Commit

Permalink
Create a mock zmq server class to be used in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 13, 2024
1 parent 73624e2 commit bd2e161
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 147 deletions.
8 changes: 7 additions & 1 deletion src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
token=None,
ack_timeout=None,
max_retries=None,
finished_event_timeout=None,
):
self._evaluator_url = evaluator_url
self._token = token
Expand All @@ -78,6 +79,7 @@ def __init__(
self._done = threading.Event()
self._ack_timeout = ack_timeout
self._max_retries = max_retries
self._finished_event_timeout = finished_event_timeout or 60

def stop(self):
self._event_queue.put(Event._sentinel)
Expand All @@ -102,7 +104,11 @@ async def publisher():
event = self._event_queue.get()
if event is self._sentinel:
break
if start_time and (time.time() - start_time) > 60:
if (
start_time
and (time.time() - start_time)
> self._finished_event_timeout
):
break
await client._send(event_to_json(event), self._max_retries)
event = None
Expand Down
22 changes: 8 additions & 14 deletions tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio

import pytest

from _ert.forward_model_runner.client import Client, ClientConnectionError
from tests.ert.utils import mock_zmq_task
from tests.ert.utils import MockZMQServer


@pytest.mark.integration_test
Expand All @@ -20,35 +18,31 @@ async def test_invalid_server():
async def test_successful_sending(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
messages = []
messages_c1 = ["test_1", "test_2", "test_3"]
async with mock_zmq_task(unused_tcp_port, messages), Client(url) as c1:
async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1:
for message in messages_c1:
await c1._send(message)

for msg in messages_c1:
assert msg in messages
assert msg in mock_server.messages


async def test_retry(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
messages = []
signal_queue = asyncio.Queue()
await signal_queue.put(2)
client_connection_error_set = False
messages_c1 = ["test_1", "test_2", "test_3"]
async with (
mock_zmq_task(unused_tcp_port, messages, signal_queue),
MockZMQServer(unused_tcp_port, signal=2) as mock_server,
Client(url, ack_timeout=0.5) as c1,
):
for message in messages_c1:
try:
await c1._send(message, retries=1)
except ClientConnectionError:
client_connection_error_set = True
await signal_queue.put(0)
mock_server.signal(0)
assert client_connection_error_set
assert messages.count("test_1") == 2
assert messages.count("test_2") == 1
assert messages.count("test_3") == 1
assert mock_server.messages.count("test_1") == 2
assert mock_server.messages.count("test_2") == 1
assert mock_server.messages.count("test_3") == 1
72 changes: 31 additions & 41 deletions tests/ert/unit_tests/forward_model_runner/test_event_reporter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import queue
import time

import pytest
Expand All @@ -22,7 +21,7 @@
Start,
)
from _ert.forward_model_runner.reporting.statemachine import TransitionError
from tests.ert.utils import mock_zmq_thread
from tests.ert.utils import MockZMQServer


def _wait_until(condition, timeout, fail_msg):
Expand All @@ -39,14 +38,13 @@ def test_report_with_successful_start_message_argument(unused_tcp_port):
fmstep1 = ForwardModelStep(
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)
lines = []
with mock_zmq_thread(unused_tcp_port, lines):
with MockZMQServer(unused_tcp_port) as mock_server:
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))
reporter.report(Start(fmstep1))
reporter.report(Finish())

assert len(lines) == 1
event = event_from_json(lines[0])
assert len(mock_server.messages) == 1
event = event_from_json(mock_server.messages[0])
assert type(event) is ForwardModelStepStart
assert event.ensemble == "ens_id"
assert event.real == "0"
Expand All @@ -64,37 +62,35 @@ def test_report_with_failed_start_message_argument(unused_tcp_port):
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)

lines = []
with mock_zmq_thread(unused_tcp_port, lines):
with MockZMQServer(unused_tcp_port) as mock_server:
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))

msg = Start(fmstep1).with_error("massive_failure")

reporter.report(msg)
reporter.report(Finish())

assert len(lines) == 2
event = event_from_json(lines[1])
assert len(mock_server.messages) == 2
event = event_from_json(mock_server.messages[1])
assert type(event) is ForwardModelStepFailure
assert event.error_msg == "massive_failure"


def test_report_with_successful_exit_message_argument(unused_tcp_port):
async def test_report_with_successful_exit_message_argument(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
reporter = Event(evaluator_url=url)
fmstep1 = ForwardModelStep(
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)

lines = []
with mock_zmq_thread(unused_tcp_port, lines):
with MockZMQServer(unused_tcp_port) as mock_server:
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))
reporter.report(Exited(fmstep1, 0))
reporter.report(Finish().with_error("failed"))

assert len(lines) == 1
event = event_from_json(lines[0])
assert len(mock_server.messages) == 1
event = event_from_json(mock_server.messages[0])
assert type(event) is ForwardModelStepSuccess


Expand All @@ -106,14 +102,13 @@ def test_report_with_failed_exit_message_argument(unused_tcp_port):
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)

lines = []
with mock_zmq_thread(unused_tcp_port, lines):
with MockZMQServer(unused_tcp_port) as mock_server:
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))
reporter.report(Exited(fmstep1, 1).with_error("massive_failure"))
reporter.report(Finish())

assert len(lines) == 1
event = event_from_json(lines[0])
assert len(mock_server.messages) == 1
event = event_from_json(mock_server.messages[0])
assert type(event) is ForwardModelStepFailure
assert event.error_msg == "massive_failure"

Expand All @@ -126,14 +121,13 @@ def test_report_with_running_message_argument(unused_tcp_port):
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)

lines = []
with mock_zmq_thread(unused_tcp_port, lines):
with MockZMQServer(unused_tcp_port) as mock_server:
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))
reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)))
reporter.report(Finish())

assert len(lines) == 1
event = event_from_json(lines[0])
assert len(mock_server.messages) == 1
event = event_from_json(mock_server.messages[0])
assert type(event) is ForwardModelStepRunning
assert event.max_memory_usage == 100
assert event.current_memory_usage == 10
Expand All @@ -147,13 +141,12 @@ def test_report_only_job_running_for_successful_run(unused_tcp_port):
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)

lines = []
with mock_zmq_thread(unused_tcp_port, lines):
with MockZMQServer(unused_tcp_port) as mock_server:
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))
reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)))
reporter.report(Finish())

assert len(lines) == 1
assert len(mock_server.messages) == 1


def test_report_with_failed_finish_message_argument(unused_tcp_port):
Expand All @@ -164,13 +157,12 @@ def test_report_with_failed_finish_message_argument(unused_tcp_port):
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)

lines = []
with mock_zmq_thread(unused_tcp_port, lines):
with MockZMQServer(unused_tcp_port) as mock_server:
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))
reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)))
reporter.report(Finish().with_error("massive_failure"))

assert len(lines) == 1
assert len(mock_server.messages) == 1


def test_report_inconsistent_events(unused_tcp_port):
Expand All @@ -197,15 +189,15 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port):

host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
lines = []
signal_queue = queue.Queue()
with mock_zmq_thread(unused_tcp_port, lines, signal_queue):
reporter = Event(evaluator_url=url, ack_timeout=1, max_retries=1)
with MockZMQServer(unused_tcp_port) as mock_server:
reporter = Event(
evaluator_url=url, ack_timeout=2, max_retries=1, finished_event_timeout=2
)
fmstep1 = ForwardModelStep(
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)

signal_queue.put(1) # prevent router to receive messages
mock_server.signal(1) # prevent router to receive messages
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))
reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)))
reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10)))
Expand All @@ -214,7 +206,7 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port):
if reporter._event_publisher_thread.is_alive():
reporter._event_publisher_thread.join()
assert reporter._done.is_set()
assert len(lines) == 0, "expected 0 Job running messages"
assert len(mock_server.messages) == 0, "expected 0 Job running messages"


def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port):
Expand All @@ -226,25 +218,23 @@ def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port):

host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
lines = []
signal_queue = queue.Queue()
with mock_zmq_thread(unused_tcp_port, lines, signal_queue):
with MockZMQServer(unused_tcp_port) as mock_server:
reporter = Event(evaluator_url=url, ack_timeout=1, max_retries=1)
fmstep1 = ForwardModelStep(
{"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0
)

signal_queue.put(1) # prevent router to receive messages
mock_server.signal(1) # prevent router to receive messages
reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0))
reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10)))
reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10)))
reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10)))
signal_queue.put(0) # enable router to receive messages
mock_server.signal(0) # enable router to receive messages
reporter.report(Finish())
if reporter._event_publisher_thread.is_alive():
reporter._event_publisher_thread.join()
assert reporter._done.is_set()
assert len(lines) == 3, "expected 3 Job running messages"
assert len(mock_server.messages) == 3, "expected 3 Job running messages"


# REFACTOR maybe we don't this anymore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from _ert.forward_model_runner.reporting import Event, Interactive
from _ert.forward_model_runner.reporting.message import Finish, Init
from _ert.threading import ErtThread
from tests.ert.utils import mock_zmq_thread, wait_until
from tests.ert.utils import MockZMQServer, wait_until

from .test_event_reporter import _wait_until

Expand Down Expand Up @@ -316,7 +316,7 @@ def create_jobs_file_after_lock():
(tmp_path / JOBS_FILE).write_text(jobs_json)
lock.release()

with mock_zmq_thread(unused_tcp_port, []):
with MockZMQServer(unused_tcp_port):
thread = ErtThread(target=create_jobs_file_after_lock)
thread.start()
main(args=["script.py", str(tmp_path)])
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port):
]
mock_getpgid.return_value = 17

with mock_zmq_thread(port, []):
with MockZMQServer(port):
main(["script.py"])

mock_killpg.assert_called_with(17, signal.SIGKILL)
Expand Down
Loading

0 comments on commit bd2e161

Please sign in to comment.