Skip to content

Commit

Permalink
pytest-lsp: Typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alcarney committed Oct 6, 2023
1 parent a89cef7 commit aa5feff
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 49 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ repos:
args: [--explicit-package-bases,--check-untyped-defs]
additional_dependencies:
- importlib-resources
- platformdirs
- pygls
- pytest
- pytest-asyncio
- types-appdirs
- websockets
files: 'lib/pytest-lsp/pytest_lsp/.*\.py'

- id: mypy
Expand Down
85 changes: 39 additions & 46 deletions lib/pytest-lsp/pytest_lsp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,12 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import Union

from lsprotocol import types
from lsprotocol.converters import get_converter
from lsprotocol.types import TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS
from lsprotocol.types import WINDOW_LOG_MESSAGE
from lsprotocol.types import WINDOW_SHOW_DOCUMENT
from lsprotocol.types import WINDOW_SHOW_MESSAGE
from lsprotocol.types import ClientCapabilities
from lsprotocol.types import Diagnostic
from lsprotocol.types import InitializedParams
from lsprotocol.types import InitializeParams
from lsprotocol.types import InitializeResult
from lsprotocol.types import LogMessageParams
from lsprotocol.types import PublishDiagnosticsParams
from lsprotocol.types import ShowDocumentParams
from lsprotocol.types import ShowDocumentResult
from lsprotocol.types import ShowMessageParams
from pygls.exceptions import JsonRpcException
from pygls.exceptions import PyglsError
from pygls.lsp.client import BaseLanguageClient
from pygls.protocol import default_converter

Expand All @@ -42,30 +31,28 @@
class LanguageClient(BaseLanguageClient):
"""Used to drive language servers under test."""

def __init__(
self,
protocol_cls: Type[LanguageClientProtocol] = LanguageClientProtocol,
*args,
**kwargs,
):
super().__init__(
"pytest-lsp-client", __version__, protocol_cls=protocol_cls, *args, **kwargs
)
protocol: LanguageClientProtocol

def __init__(self, *args, **kwargs):
if "protocol_cls" not in kwargs:
kwargs["protocol_cls"] = LanguageClientProtocol

self.capabilities: Optional[ClientCapabilities] = None
super().__init__("pytest-lsp-client", __version__, *args, **kwargs)

self.capabilities: Optional[types.ClientCapabilities] = None
"""The client's capabilities."""

self.shown_documents: List[ShowDocumentParams] = []
self.shown_documents: List[types.ShowDocumentParams] = []
"""Used to keep track of the documents requested to be shown via a
``window/showDocument`` request."""

self.messages: List[ShowMessageParams] = []
self.messages: List[types.ShowMessageParams] = []
"""Holds any received ``window/showMessage`` requests."""

self.log_messages: List[LogMessageParams] = []
self.log_messages: List[types.LogMessageParams] = []
"""Holds any received ``window/logMessage`` requests."""

self.diagnostics: Dict[str, List[Diagnostic]] = {}
self.diagnostics: Dict[str, List[types.Diagnostic]] = {}
"""Used to hold any recieved diagnostics."""

self.error: Optional[Exception] = None
Expand All @@ -86,30 +73,34 @@ async def server_exit(self, server: asyncio.subprocess.Process):

stderr = ""
if server.stderr is not None:
stderr = await server.stderr.read()
stderr = stderr.decode("utf8")
stderr_bytes = await server.stderr.read()
stderr = stderr_bytes.decode("utf8")

loop = asyncio.get_running_loop()
loop.call_soon(
cancel_all_tasks,
f"Server process exited with return code: {server.returncode}\n{stderr}",
)

def report_server_error(self, error: Exception, source: Type[Exception]):
def report_server_error(
self, error: Exception, source: Union[PyglsError, JsonRpcException]
):
"""Called when the server does something unexpected, e.g. sending malformed
JSON."""
self.error = error
tb = "".join(traceback.format_exc())

message = f"{source.__name__}: {error}\n{tb}"
message = f"{source.__name__}: {error}\n{tb}" # type: ignore

loop = asyncio.get_running_loop()
loop.call_soon(cancel_all_tasks, message)

if self._stop_event:
self._stop_event.set()

async def initialize_session(self, params: InitializeParams) -> InitializeResult:
async def initialize_session(
self, params: types.InitializeParams
) -> types.InitializeResult:
"""Make an ``initialize`` request to a lanaguage server.
It will also automatically send an ``initialized`` notification once
Expand All @@ -135,7 +126,7 @@ async def initialize_session(self, params: InitializeParams) -> InitializeResult
params.process_id = os.getpid()

response = await self.initialize_async(params)
self.initialized(InitializedParams())
self.initialized(types.InitializedParams())

return response

Expand Down Expand Up @@ -186,32 +177,34 @@ def make_test_lsp_client() -> LanguageClient:
converter_factory=default_converter,
)

@client.feature(TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS)
def publish_diagnostics(client: LanguageClient, params: PublishDiagnosticsParams):
@client.feature(types.TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS)
def publish_diagnostics(
client: LanguageClient, params: types.PublishDiagnosticsParams
):
client.diagnostics[params.uri] = params.diagnostics

@client.feature(WINDOW_LOG_MESSAGE)
def log_message(client: LanguageClient, params: LogMessageParams):
@client.feature(types.WINDOW_LOG_MESSAGE)
def log_message(client: LanguageClient, params: types.LogMessageParams):
client.log_messages.append(params)

levels = [logger.error, logger.warning, logger.info, logger.debug]
levels[params.type.value - 1](params.message)

@client.feature(WINDOW_SHOW_MESSAGE)
@client.feature(types.WINDOW_SHOW_MESSAGE)
def show_message(client: LanguageClient, params):
client.messages.append(params)

@client.feature(WINDOW_SHOW_DOCUMENT)
@client.feature(types.WINDOW_SHOW_DOCUMENT)
def show_document(
client: LanguageClient, params: ShowDocumentParams
) -> ShowDocumentResult:
client: LanguageClient, params: types.ShowDocumentParams
) -> types.ShowDocumentResult:
client.shown_documents.append(params)
return ShowDocumentResult(success=True)
return types.ShowDocumentResult(success=True)

return client


def client_capabilities(client_spec: str) -> ClientCapabilities:
def client_capabilities(client_spec: str) -> types.ClientCapabilities:
"""Find the capabilities that correspond to the given client spec.
Parameters
Expand Down Expand Up @@ -241,4 +234,4 @@ def client_capabilities(client_spec: str) -> ClientCapabilities:

converter = get_converter()
capabilities = json.loads(filename.read_text())
return converter.structure(capabilities, ClientCapabilities)
return converter.structure(capabilities, types.ClientCapabilities)
3 changes: 2 additions & 1 deletion lib/pytest-lsp/pytest_lsp/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import textwrap
import typing
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
Expand Down Expand Up @@ -104,7 +105,7 @@ def get_fixture_arguments(
dict
The set of arguments to pass to the user's fixture function
"""
kwargs = {}
kwargs: Dict[str, Any] = {}
required_parameters = set(inspect.signature(fn).parameters.keys())

# Inject the 'request' fixture if requested
Expand Down
2 changes: 1 addition & 1 deletion lib/pytest-lsp/pytest_lsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _handle_notification(self, method_name, params):
async def send_request_async(self, method, params=None):
result = await super().send_request_async(method, params)
check_result_against_client_capabilities(
self._server.capabilities, method, result
self._server.capabilities, method, result # type: ignore
)

return result
Expand Down

0 comments on commit aa5feff

Please sign in to comment.