Skip to content

Commit

Permalink
Add test that monitor can't connect to a secured server
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 30, 2024
1 parent 855c517 commit 913bc39
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from contextlib import suppress

import pytest
import zmq
Expand All @@ -16,9 +17,13 @@
from ert.ensemble_evaluator.config import EvaluatorConnectionInfo


async def async_zmq_server(port, handler):
async def async_zmq_server(port, handler, secret_key: bytes | None = None):
zmq_context = zmq.asyncio.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
if secret_key is not None:
router_socket.curve_secretkey = secret_key
router_socket.curve_publickey = zmq.curve_public(secret_key)
router_socket.curve_server = True
router_socket.setsockopt(zmq.LINGER, 0)
router_socket.bind(f"tcp://*:{port}")
await handler(router_socket)
Expand Down Expand Up @@ -121,6 +126,55 @@ async def mock_event_handler(router_socket):
await websocket_server_task


@pytest.mark.parametrize(
"correct_server_key",
[
pytest.param(True),
pytest.param(False),
],
)
async def test_that_monitor_cannot_connect_with_wrong_server_key(
correct_server_key, monkeypatch, unused_tcp_port
):
public_key, secret_key = zmq.curve_keypair()
ee_con_info = EvaluatorConnectionInfo(
f"tcp://127.0.0.1:{unused_tcp_port}",
public_key.decode("utf-8") if correct_server_key else None,
)

monkeypatch.setattr(Monitor, "DEFAULT_MAX_RETRIES", 0)
monkeypatch.setattr(Monitor, "DEFAULT_ACK_TIMEOUT", 1)

connected = False

async def mock_event_handler(router_socket):
nonlocal connected
while True:
dealer, _, frame = await router_socket.recv_multipart()
await router_socket.send_multipart([dealer, b"", ACK_MSG])
if frame == CONNECT_MSG:
connected = True
elif frame == DISCONNECT_MSG:
connected = False
return

websocket_server_task = asyncio.create_task(
async_zmq_server(unused_tcp_port, mock_event_handler, secret_key=secret_key)
)
if correct_server_key:
async with Monitor(ee_con_info):
assert connected
assert connected is False
else:
with pytest.raises(ClientConnectionError):
async with Monitor(ee_con_info):
pass
assert connected is False
websocket_server_task.cancel()
with suppress(asyncio.CancelledError):
await websocket_server_task


async def test_that_monitor_track_can_exit_without_terminated_event_from_evaluator(
unused_tcp_port, caplog
):
Expand Down

0 comments on commit 913bc39

Please sign in to comment.