diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 03396f6..30c15ac 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -64,8 +64,7 @@ async def handle_call_tool(name: str, args: dict): # Test fixtures -@pytest.fixture -async def server_app()-> Starlette: +def make_server_app()-> Starlette: """Create test Starlette app with SSE transport""" sse = SseServerTransport("/messages/") server = TestServer() @@ -93,43 +92,46 @@ def space_around_test(): yield time.sleep(0.1) -@pytest.fixture() -def server(server_app: Starlette, server_port: int): - proc = multiprocessing.Process(target=uvicorn.run, daemon=True, kwargs={ - "app": server_app, - "host": "127.0.0.1", - "port": server_port, - "log_level": "error" - }) +def run_server(server_port: int): + app = make_server_app() + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f'starting server on {server_port}') - proc.start() + server.run() # Give server time to start while not server.started: print('waiting for server to start') time.sleep(0.5) - try: - yield - finally: - print('killing server') - # Signal the server to stop - server.should_exit = True - - # Force close the server's main socket - if hasattr(server.servers, "servers"): - for s in server.servers: - print(f'closing {s}') - s.close() - - # Wait for thread to finish - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): - print("Warning: Server thread did not exit cleanly") - # Optionally, you could add more aggressive cleanup here - import _thread - _thread.interrupt_main() +@pytest.fixture() +def server(server_port: int): + proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) + print('starting process') + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print('waiting for server to start') + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(('127.0.0.1', server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError("Server failed to start after {} attempts".format(max_attempts)) + + yield + + print('killing server') + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") @pytest.fixture() async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: