From b85c80cd162e7033eb3383e00998b4b75278574a Mon Sep 17 00:00:00 2001 From: Andrei Neagu Date: Thu, 21 Mar 2024 15:20:55 +0100 Subject: [PATCH] refactor tests --- tests/test_activity_monitor.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/test_activity_monitor.py b/tests/test_activity_monitor.py index 7c08359..9b3153d 100644 --- a/tests/test_activity_monitor.py +++ b/tests/test_activity_monitor.py @@ -8,7 +8,7 @@ import threading import pytest_asyncio - +from queue import Queue from typing import Callable, TYPE_CHECKING from pytest_mock import MockFixture from tenacity import AsyncRetrying @@ -132,15 +132,24 @@ async def server_url() -> str: async def tornado_server(server_url: str) -> None: app = await activity_monitor.make_app() - def _start_tornado(): + stop_queue = Queue() + + def _run_server_worker(): http_server = tornado.httpserver.HTTPServer(app) http_server.listen(8899) - tornado.ioloop.IOLoop.current().start() + current_io_loop = tornado.ioloop.IOLoop.current() + + def _queue_stopper() -> None: + stop_queue.get() + current_io_loop.stop() - def _stop_tornado(): - tornado.ioloop.IOLoop.current().stop() + stopping_thread = threading.Thread(target=_queue_stopper, daemon=True) + stopping_thread.start() - thread = threading.Thread(target=lambda: _start_tornado(), daemon=True) + current_io_loop.start() + stopping_thread.join() + + thread = threading.Thread(target=_run_server_worker, daemon=True) thread.start() # ensure server is running @@ -153,7 +162,11 @@ def _stop_tornado(): yield None - _stop_tornado() + stop_queue.put(None) + thread.join(timeout=1) + + with pytest.raises(requests.exceptions.ReadTimeout): + requests.get(f"{server_url}/", timeout=1) @pytest.fixture @@ -162,7 +175,9 @@ def mock_check_interval(mocker: MockFixture) -> None: @pytest.mark.asyncio -async def test_tornado_server_ok(mock_check_interval: None, tornado_server: None, server_url:str): +async def test_tornado_server_ok( + mock_check_interval: None, tornado_server: None, server_url: str +): result = requests.get(f"{server_url}/", timeout=5) assert result.status_code == 200