From 9a566be5050322096e68ab4c9bc20ea77f74e939 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Sat, 27 Jan 2024 17:47:12 +0000 Subject: [PATCH] lsp-devtools: Ensure the agent stops once the server process exits --- lib/lsp-devtools/changes/132.fix.md | 1 + lib/lsp-devtools/lsp_devtools/agent/agent.py | 88 ++++++++++++++------ lib/lsp-devtools/tests/servers/simple.py | 14 ++++ lib/lsp-devtools/tests/test_agent.py | 81 ++++++++++++++++++ lib/lsp-devtools/tox.ini | 1 + 5 files changed, 159 insertions(+), 26 deletions(-) create mode 100644 lib/lsp-devtools/changes/132.fix.md create mode 100644 lib/lsp-devtools/tests/servers/simple.py create mode 100644 lib/lsp-devtools/tests/test_agent.py diff --git a/lib/lsp-devtools/changes/132.fix.md b/lib/lsp-devtools/changes/132.fix.md new file mode 100644 index 0000000..47ce95c --- /dev/null +++ b/lib/lsp-devtools/changes/132.fix.md @@ -0,0 +1 @@ +The `lsp-devtools agent` now watches for the when the server process exits and closes itself down also. diff --git a/lib/lsp-devtools/lsp_devtools/agent/agent.py b/lib/lsp-devtools/lsp_devtools/agent/agent.py index 8cb1f90..df984fe 100644 --- a/lib/lsp-devtools/lsp_devtools/agent/agent.py +++ b/lib/lsp-devtools/lsp_devtools/agent/agent.py @@ -1,10 +1,18 @@ +from __future__ import annotations + import asyncio import inspect import logging import re -import threading +import sys +import typing from functools import partial -from typing import BinaryIO + +if typing.TYPE_CHECKING: + from typing import BinaryIO + from typing import Optional + from typing import Set + from typing import Tuple logger = logging.getLogger("lsp_devtools.agent") @@ -22,15 +30,14 @@ async def forward_message(source: str, dest: asyncio.StreamWriter, message: byte ) -# TODO: Upstream this? -async def aio_readline(stop_event, reader, message_handler): +async def aio_readline(reader: asyncio.StreamReader, message_handler): CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$") # Initialize message buffer message = [] content_length = 0 - while not stop_event.is_set(): + while True: # Read a header line header = await reader.readline() if not header: @@ -42,7 +49,6 @@ async def aio_readline(stop_event, reader, message_handler): match = CONTENT_LENGTH_PATTERN.fullmatch(header) if match: content_length = int(match.group(1)) - logger.debug("Content length: %s", content_length) # Check if all headers have been read (as indicated by an empty line \r\n) if content_length and not header.strip(): @@ -62,7 +68,9 @@ async def aio_readline(stop_event, reader, message_handler): content_length = 0 -async def get_streams(stdin, stdout): +async def get_streams( + stdin, stdout +) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Convert blocking stdin/stdout streams into async streams.""" loop = asyncio.get_running_loop() @@ -87,38 +95,66 @@ def __init__( self.stdin = stdin self.stdout = stdout self.server = server - self.stop_event = threading.Event() + + self._tasks: Set[asyncio.Task] = set() + self.reader: Optional[asyncio.StreamReader] = None + self.writer: Optional[asyncio.StreamWriter] = None async def start(self): # Get async versions of stdin/stdout - reader, writer = await get_streams(self.stdin, self.stdout) + self.reader, self.writer = await get_streams(self.stdin, self.stdout) + + # Keep mypy happy + assert self.server.stdin + assert self.server.stdout # Connect stdin to the subprocess' stdin - client_to_server = aio_readline( - self.stop_event, - reader, - partial(forward_message, "client", self.server.stdin), + client_to_server = asyncio.create_task( + aio_readline( + self.reader, + partial(forward_message, "client", self.server.stdin), + ), ) + self._tasks.add(client_to_server) # Connect the subprocess' stdout to stdout - server_to_client = aio_readline( - self.stop_event, - self.server.stdout, - partial(forward_message, "server", writer), + server_to_client = asyncio.create_task( + aio_readline( + self.server.stdout, + partial(forward_message, "server", self.writer), + ), ) + self._tasks.add(server_to_client) # Run both connections concurrently. - return await asyncio.gather( + await asyncio.gather( client_to_server, server_to_client, + self._watch_server_process(), ) + async def _watch_server_process(self): + """Once the server process exits, ensure that the agent is also shutdown.""" + ret = await self.server.wait() + print(f"Server process exited with code: {ret}", file=sys.stderr) + await self.stop() + async def stop(self): - self.stop_event.set() - - try: - self.server.terminate() - ret = await self.server.wait() - print(f"Server process exited with code: {ret}") - except TimeoutError: - self.server.kill() + # Kill the server process if necessary. + if self.server.returncode is None: + try: + self.server.terminate() + await asyncio.wait_for(self.server.wait(), timeout=5) # s + except TimeoutError: + self.server.kill() + + args = {} + if sys.version_info.minor > 8: + args["msg"] = "lsp-devtools agent is stopping." + + # Cancel the tasks connecting client to server + for task in self._tasks: + task.cancel(**args) + + if self.writer: + self.writer.close() diff --git a/lib/lsp-devtools/tests/servers/simple.py b/lib/lsp-devtools/tests/servers/simple.py new file mode 100644 index 0000000..2bb2816 --- /dev/null +++ b/lib/lsp-devtools/tests/servers/simple.py @@ -0,0 +1,14 @@ +"""A very simple language server.""" +from lsprotocol import types +from pygls.server import LanguageServer + +server = LanguageServer("simple-server", "v1") + + +@server.feature(types.INITIALIZED) +def _(ls: LanguageServer, params: types.InitializedParams): + ls.show_message("Hello, world!") + + +if __name__ == "__main__": + server.start_io() diff --git a/lib/lsp-devtools/tests/test_agent.py b/lib/lsp-devtools/tests/test_agent.py new file mode 100644 index 0000000..3e88111 --- /dev/null +++ b/lib/lsp-devtools/tests/test_agent.py @@ -0,0 +1,81 @@ +import asyncio +import json +import os +import pathlib +import subprocess +import sys + +import pytest + +from lsp_devtools.agent import Agent + +SERVER_DIR = pathlib.Path(__file__).parent / "servers" + + +def format_message(obj): + content = json.dumps(obj) + message = "".join( + [ + f"Content-Length: {len(content)}\r\n", + "\r\n", + f"{content}", + ] + ) + return message.encode() + + +@pytest.mark.asyncio +async def test_agent_exits(): + """Ensure that when the client closes down the lsp session and the server process + exits, the agent does also.""" + + (stdin_read, stdin_write) = os.pipe() + (stdout_read, stdout_write) = os.pipe() + + server = await asyncio.create_subprocess_exec( + sys.executable, + str(SERVER_DIR / "simple.py"), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + agent = Agent( + server, + os.fdopen(stdin_read, mode="rb"), + os.fdopen(stdout_write, mode="wb"), + ) + + os.write( + stdin_write, + format_message( + dict(jsonrpc="2.0", id=1, method="initialize", params=dict(capabilities={})) + ), + ) + + os.write( + stdin_write, + format_message(dict(jsonrpc="2.0", id=2, method="shutdown", params=None)), + ) + + os.write( + stdin_write, + format_message(dict(jsonrpc="2.0", method="exit", params=None)), + ) + + try: + await asyncio.wait_for( + # asyncio.gather(server.wait(), agent.start()), + agent.start(), + timeout=10, # s + ) + except asyncio.CancelledError: + pass # The agent's tasks should be cancelled + + except TimeoutError as exc: + # Make sure this timed out for the right reason. + if server.returncode is None: + raise RuntimeError("Server process did not exit") + else: + exc.add_note("lsp-devtools agent did not stop") + raise diff --git a/lib/lsp-devtools/tox.ini b/lib/lsp-devtools/tox.ini index 4b0091a..a21ba0d 100644 --- a/lib/lsp-devtools/tox.ini +++ b/lib/lsp-devtools/tox.ini @@ -11,6 +11,7 @@ wheel_build_env = .pkg deps = coverage[toml] pytest + pytest-asyncio commands_pre = coverage erase commands =