From f84bf873e4c8dbd32f003ef5eb0ebbbf20234503 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Mon, 18 Nov 2024 20:00:49 +0000 Subject: [PATCH] lsp-devtool: upgrade to pygls 2.0a2 --- lib/lsp-devtools/changes/xxx.misc.rst | 1 + lib/lsp-devtools/hatch.toml | 3 + lib/lsp-devtools/lsp_devtools/agent/agent.py | 12 +-- lib/lsp-devtools/lsp_devtools/agent/client.py | 89 +++---------------- lib/lsp-devtools/lsp_devtools/agent/server.py | 25 ++++-- .../lsp_devtools/client/editor/text_editor.py | 2 +- lib/lsp-devtools/pyproject.toml | 2 +- lib/lsp-devtools/tests/servers/simple.py | 9 +- 8 files changed, 47 insertions(+), 96 deletions(-) create mode 100644 lib/lsp-devtools/changes/xxx.misc.rst diff --git a/lib/lsp-devtools/changes/xxx.misc.rst b/lib/lsp-devtools/changes/xxx.misc.rst new file mode 100644 index 0000000..36caf8c --- /dev/null +++ b/lib/lsp-devtools/changes/xxx.misc.rst @@ -0,0 +1 @@ +Migrate to pygls `v2.0a2` diff --git a/lib/lsp-devtools/hatch.toml b/lib/lsp-devtools/hatch.toml index 29ea2b4..bd5ec2b 100644 --- a/lib/lsp-devtools/hatch.toml +++ b/lib/lsp-devtools/hatch.toml @@ -11,6 +11,9 @@ packages = ["lsp_devtools"] [envs.hatch-test] extra-dependencies = ["pytest-asyncio"] +[envs.hatch-test.env-vars] +UV_PRERELEASE="allow" + [envs.hatch-static-analysis] config-path = "ruff_defaults.toml" dependencies = ["ruff==0.5.2"] diff --git a/lib/lsp-devtools/lsp_devtools/agent/agent.py b/lib/lsp-devtools/lsp_devtools/agent/agent.py index 4d0e26f..817eec9 100644 --- a/lib/lsp-devtools/lsp_devtools/agent/agent.py +++ b/lib/lsp-devtools/lsp_devtools/agent/agent.py @@ -21,6 +21,10 @@ from typing import Callable from typing import Union + from pygls.io_ import AsyncReader + from pygls.io_ import AsyncWriter + from pygls.io_ import Writer + MessageHandler = Callable[[bytes], Union[None, Coroutine[Any, Any, None]]] UTC = timezone.utc @@ -74,7 +78,7 @@ def parse_rpc_message(data: bytes) -> RPCMessage: return RPCMessage(headers, body) -async def aio_readline(reader: asyncio.StreamReader, message_handler: MessageHandler): +async def aio_readline(reader: AsyncReader, message_handler: MessageHandler): CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$") # Initialize message buffer @@ -222,12 +226,10 @@ async def stop(self): except TimeoutError: self.server.kill() - args = {} - args["msg"] = "lsp-devtools agent is stopping." - # Cancel the tasks connecting client to server for task in self._tasks: - task.cancel(**args) + logger.debug("cancelling: %s", task) + task.cancel(msg="lsp-devtools agent is stopping.") if self.writer: self.writer.close() diff --git a/lib/lsp-devtools/lsp_devtools/agent/client.py b/lib/lsp-devtools/lsp_devtools/agent/client.py index 87e9d4c..571819f 100644 --- a/lib/lsp-devtools/lsp_devtools/agent/client.py +++ b/lib/lsp-devtools/lsp_devtools/agent/client.py @@ -1,11 +1,11 @@ from __future__ import annotations import asyncio +import inspect import typing import stamina from pygls.client import JsonRPCClient -from pygls.client import aio_readline from pygls.protocol import default_converter from lsp_devtools.agent.protocol import AgentProtocol @@ -13,25 +13,6 @@ if typing.TYPE_CHECKING: from typing import Any -# from websockets.client import WebSocketClientProtocol - - -# class WebSocketClientTransportAdapter: -# """Protocol adapter for the WebSocket client interface.""" - -# def __init__(self, ws: WebSocketClientProtocol, loop: asyncio.AbstractEventLoop): -# self._ws = ws -# self._loop = loop - -# def close(self) -> None: -# """Stop the WebSocket server.""" -# print("-- CLOSING --") -# self._loop.create_task(self._ws.close()) - -# def write(self, data: Any) -> None: -# """Create a task to write specified data into a WebSocket.""" -# asyncio.ensure_future(self._ws.send(data)) - class AgentClient(JsonRPCClient): """Client for connecting to an AgentServer instance.""" @@ -53,7 +34,6 @@ def _report_server_error(self, error, source): def feature(self, feature_name: str, options: Any | None = None): return self.protocol.fm.feature(feature_name, options) - # TODO: Upstream this... or at least something equivalent. async def start_tcp(self, host: str, port: int): # The user might not have started the server app immediately and since the # agent will live as long as the wrapper language server we may as well @@ -67,71 +47,22 @@ async def start_tcp(self, host: str, port: int): ) async for attempt in retries: with attempt: - reader, writer = await asyncio.open_connection(host, port) - - self.protocol.connection_made(writer) # type: ignore[arg-type] - connection = asyncio.create_task( - aio_readline(self._stop_event, reader, self.protocol.data_received) - ) - self.connected = True - self._async_tasks.append(connection) + await super().start_tcp(host, port) + self.connected = True def forward_message(self, message: bytes): """Forward the given message to the server instance.""" - if not self.connected: + if not self.connected or self.protocol.writer is None: self._buffer.append(message) return - if self.protocol.transport is None: - return - # Send any buffered messages while len(self._buffer) > 0: - self.protocol.transport.write(self._buffer.pop(0)) - - self.protocol.transport.write(message) - - # TODO: Upstream this... or at least something equivalent. - # def start_ws(self, host: str, port: int): - # self.protocol._send_only_body = True # Don't send headers within the payload - - # async def client_connection(host: str, port: int): - # """Create and run a client connection.""" - - # self._client = await websockets.connect( # type: ignore - # f"ws://{host}:{port}" - # ) - # loop = asyncio.get_running_loop() - # self.protocol.transport = WebSocketClientTransportAdapter( - # self._client, loop - # ) - # message = None - - # try: - # while not self._stop_event.is_set(): - # try: - # message = await asyncio.wait_for( - # self._client.recv(), timeout=0.5 - # ) - # self.protocol._procedure_handler( - # json.loads( - # message, - # object_hook=self.protocol._deserialize_message - # ) - # ) - # except JSONDecodeError: - # print(message or "-- message not found --") - # raise - # except TimeoutError: - # pass - # except Exception: - # raise - - # finally: - # await self._client.close() + res = self.protocol.writer.write(self._buffer.pop(0)) + if inspect.isawaitable(res): + asyncio.ensure_future(res) - # try: - # asyncio.run(client_connection(host, port)) - # except KeyboardInterrupt: - # pass + res = self.protocol.writer.write(message) + if inspect.isawaitable(res): + asyncio.ensure_future(res) diff --git a/lib/lsp-devtools/lsp_devtools/agent/server.py b/lib/lsp-devtools/lsp_devtools/agent/server.py index f74a47a..df4893a 100644 --- a/lib/lsp-devtools/lsp_devtools/agent/server.py +++ b/lib/lsp-devtools/lsp_devtools/agent/server.py @@ -1,12 +1,13 @@ from __future__ import annotations import asyncio +import json import logging import traceback import typing from pygls.protocol import default_converter -from pygls.server import Server +from pygls.server import JsonRPCServer from lsp_devtools.agent.agent import aio_readline from lsp_devtools.agent.protocol import AgentProtocol @@ -18,7 +19,7 @@ from lsp_devtools.agent.agent import MessageHandler -class AgentServer(Server): +class AgentServer(JsonRPCServer): """A pygls server that accepts connections from agents allowing them to send their collected messages.""" @@ -40,25 +41,33 @@ def __init__( super().__init__(*args, **kwargs) self.logger = logger or logging.getLogger(__name__) - self.handler = handler or self.lsp.data_received + self.handler = handler or self._default_handler self.db: Database | None = None self._client_buffer: list[str] = [] self._server_buffer: list[str] = [] self._tcp_server: asyncio.Task | None = None - def _report_server_error(self, exc: Exception, source): + def _default_handler(self, data: bytes): + message = self.protocol.structure_message(json.loads(data)) + self.protocol.handle_message(message) + + def _report_server_error(self, error: Exception, source): """Report internal server errors.""" - tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) - self.logger.error("%s: %s", type(exc).__name__, exc) + tb = "".join( + traceback.format_exception(type(error), error, error.__traceback__) + ) + self.logger.error("%s: %s", type(error).__name__, error) self.logger.debug("%s", tb) def feature(self, feature_name: str, options: Any | None = None): return self.lsp.fm.feature(feature_name, options) async def start_tcp(self, host: str, port: int) -> None: # type: ignore[override] - async def handle_client(reader, writer): - self.lsp.connection_made(writer) + async def handle_client( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + self.protocol.set_writer(writer) try: await aio_readline(reader, self.handler) diff --git a/lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py b/lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py index ee051c0..38c32ff 100644 --- a/lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py +++ b/lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py @@ -103,7 +103,7 @@ def edit(self, edit): version=self.version, uri=self.uri ), content_changes=[ - types.TextDocumentContentChangeEvent_Type1( + types.TextDocumentContentChangePartial( text=edit.text, range=types.Range( start=types.Position(line=start_line, character=start_col), diff --git a/lib/lsp-devtools/pyproject.toml b/lib/lsp-devtools/pyproject.toml index ed0c787..2452915 100644 --- a/lib/lsp-devtools/pyproject.toml +++ b/lib/lsp-devtools/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ dependencies = [ "aiosqlite", "platformdirs", - "pygls>=1.1.0,<2", + "pygls>=2.0a2", "stamina", "textual>=0.41.0", ] diff --git a/lib/lsp-devtools/tests/servers/simple.py b/lib/lsp-devtools/tests/servers/simple.py index 6bda60d..582d9cc 100644 --- a/lib/lsp-devtools/tests/servers/simple.py +++ b/lib/lsp-devtools/tests/servers/simple.py @@ -1,14 +1,19 @@ """A very simple language server.""" from lsprotocol import types -from pygls.server import LanguageServer +from pygls.lsp.server import LanguageServer server = LanguageServer("simple-server", "v1") @server.feature(types.INITIALIZED) def _(ls: LanguageServer, params: types.InitializedParams): - ls.show_message("Hello, world!") + ls.window_show_message( + types.ShowMessageParams( + message="Hello, world!", + type=types.MessageType.Log, + ) + ) if __name__ == "__main__":