diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c906bc0..4c1202c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,10 +41,11 @@ repos: args: [--explicit-package-bases,--check-untyped-defs] additional_dependencies: - importlib-resources + - platformdirs - pygls - pytest - pytest-asyncio - - types-appdirs + - websockets files: 'lib/pytest-lsp/pytest_lsp/.*\.py' - id: mypy diff --git a/lib/pytest-lsp/pytest_lsp/client.py b/lib/pytest-lsp/pytest_lsp/client.py index 6eabb3c..7a60d18 100644 --- a/lib/pytest-lsp/pytest_lsp/client.py +++ b/lib/pytest-lsp/pytest_lsp/client.py @@ -7,23 +7,12 @@ from typing import Dict from typing import List from typing import Optional -from typing import Type +from typing import Union +from lsprotocol import types from lsprotocol.converters import get_converter -from lsprotocol.types import TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS -from lsprotocol.types import WINDOW_LOG_MESSAGE -from lsprotocol.types import WINDOW_SHOW_DOCUMENT -from lsprotocol.types import WINDOW_SHOW_MESSAGE -from lsprotocol.types import ClientCapabilities -from lsprotocol.types import Diagnostic -from lsprotocol.types import InitializedParams -from lsprotocol.types import InitializeParams -from lsprotocol.types import InitializeResult -from lsprotocol.types import LogMessageParams -from lsprotocol.types import PublishDiagnosticsParams -from lsprotocol.types import ShowDocumentParams -from lsprotocol.types import ShowDocumentResult -from lsprotocol.types import ShowMessageParams +from pygls.exceptions import JsonRpcException +from pygls.exceptions import PyglsError from pygls.lsp.client import BaseLanguageClient from pygls.protocol import default_converter @@ -42,30 +31,28 @@ class LanguageClient(BaseLanguageClient): """Used to drive language servers under test.""" - def __init__( - self, - protocol_cls: Type[LanguageClientProtocol] = LanguageClientProtocol, - *args, - **kwargs, - ): - super().__init__( - "pytest-lsp-client", __version__, protocol_cls=protocol_cls, *args, **kwargs - ) + protocol: LanguageClientProtocol + + def __init__(self, *args, **kwargs): + if "protocol_cls" not in kwargs: + kwargs["protocol_cls"] = LanguageClientProtocol - self.capabilities: Optional[ClientCapabilities] = None + super().__init__("pytest-lsp-client", __version__, *args, **kwargs) + + self.capabilities: Optional[types.ClientCapabilities] = None """The client's capabilities.""" - self.shown_documents: List[ShowDocumentParams] = [] + self.shown_documents: List[types.ShowDocumentParams] = [] """Used to keep track of the documents requested to be shown via a ``window/showDocument`` request.""" - self.messages: List[ShowMessageParams] = [] + self.messages: List[types.ShowMessageParams] = [] """Holds any received ``window/showMessage`` requests.""" - self.log_messages: List[LogMessageParams] = [] + self.log_messages: List[types.LogMessageParams] = [] """Holds any received ``window/logMessage`` requests.""" - self.diagnostics: Dict[str, List[Diagnostic]] = {} + self.diagnostics: Dict[str, List[types.Diagnostic]] = {} """Used to hold any recieved diagnostics.""" self.error: Optional[Exception] = None @@ -86,8 +73,8 @@ async def server_exit(self, server: asyncio.subprocess.Process): stderr = "" if server.stderr is not None: - stderr = await server.stderr.read() - stderr = stderr.decode("utf8") + stderr_bytes = await server.stderr.read() + stderr = stderr_bytes.decode("utf8") loop = asyncio.get_running_loop() loop.call_soon( @@ -95,13 +82,15 @@ async def server_exit(self, server: asyncio.subprocess.Process): f"Server process exited with return code: {server.returncode}\n{stderr}", ) - def report_server_error(self, error: Exception, source: Type[Exception]): + def report_server_error( + self, error: Exception, source: Union[PyglsError, JsonRpcException] + ): """Called when the server does something unexpected, e.g. sending malformed JSON.""" self.error = error tb = "".join(traceback.format_exc()) - message = f"{source.__name__}: {error}\n{tb}" + message = f"{source.__name__}: {error}\n{tb}" # type: ignore loop = asyncio.get_running_loop() loop.call_soon(cancel_all_tasks, message) @@ -109,7 +98,9 @@ def report_server_error(self, error: Exception, source: Type[Exception]): if self._stop_event: self._stop_event.set() - async def initialize_session(self, params: InitializeParams) -> InitializeResult: + async def initialize_session( + self, params: types.InitializeParams + ) -> types.InitializeResult: """Make an ``initialize`` request to a lanaguage server. It will also automatically send an ``initialized`` notification once @@ -135,7 +126,7 @@ async def initialize_session(self, params: InitializeParams) -> InitializeResult params.process_id = os.getpid() response = await self.initialize_async(params) - self.initialized(InitializedParams()) + self.initialized(types.InitializedParams()) return response @@ -186,32 +177,34 @@ def make_test_lsp_client() -> LanguageClient: converter_factory=default_converter, ) - @client.feature(TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS) - def publish_diagnostics(client: LanguageClient, params: PublishDiagnosticsParams): + @client.feature(types.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS) + def publish_diagnostics( + client: LanguageClient, params: types.PublishDiagnosticsParams + ): client.diagnostics[params.uri] = params.diagnostics - @client.feature(WINDOW_LOG_MESSAGE) - def log_message(client: LanguageClient, params: LogMessageParams): + @client.feature(types.WINDOW_LOG_MESSAGE) + def log_message(client: LanguageClient, params: types.LogMessageParams): client.log_messages.append(params) levels = [logger.error, logger.warning, logger.info, logger.debug] levels[params.type.value - 1](params.message) - @client.feature(WINDOW_SHOW_MESSAGE) + @client.feature(types.WINDOW_SHOW_MESSAGE) def show_message(client: LanguageClient, params): client.messages.append(params) - @client.feature(WINDOW_SHOW_DOCUMENT) + @client.feature(types.WINDOW_SHOW_DOCUMENT) def show_document( - client: LanguageClient, params: ShowDocumentParams - ) -> ShowDocumentResult: + client: LanguageClient, params: types.ShowDocumentParams + ) -> types.ShowDocumentResult: client.shown_documents.append(params) - return ShowDocumentResult(success=True) + return types.ShowDocumentResult(success=True) return client -def client_capabilities(client_spec: str) -> ClientCapabilities: +def client_capabilities(client_spec: str) -> types.ClientCapabilities: """Find the capabilities that correspond to the given client spec. Parameters @@ -241,4 +234,4 @@ def client_capabilities(client_spec: str) -> ClientCapabilities: converter = get_converter() capabilities = json.loads(filename.read_text()) - return converter.structure(capabilities, ClientCapabilities) + return converter.structure(capabilities, types.ClientCapabilities) diff --git a/lib/pytest-lsp/pytest_lsp/plugin.py b/lib/pytest-lsp/pytest_lsp/plugin.py index 82328cf..af7e94c 100644 --- a/lib/pytest-lsp/pytest_lsp/plugin.py +++ b/lib/pytest-lsp/pytest_lsp/plugin.py @@ -3,6 +3,7 @@ import sys import textwrap import typing +from typing import Any from typing import Callable from typing import Dict from typing import List @@ -104,7 +105,7 @@ def get_fixture_arguments( dict The set of arguments to pass to the user's fixture function """ - kwargs = {} + kwargs: Dict[str, Any] = {} required_parameters = set(inspect.signature(fn).parameters.keys()) # Inject the 'request' fixture if requested diff --git a/lib/pytest-lsp/pytest_lsp/protocol.py b/lib/pytest-lsp/pytest_lsp/protocol.py index a521ec1..24adfb2 100644 --- a/lib/pytest-lsp/pytest_lsp/protocol.py +++ b/lib/pytest-lsp/pytest_lsp/protocol.py @@ -41,7 +41,7 @@ def _handle_notification(self, method_name, params): async def send_request_async(self, method, params=None): result = await super().send_request_async(method, params) check_result_against_client_capabilities( - self._server.capabilities, method, result + self._server.capabilities, method, result # type: ignore ) return result