Skip to content

Commit

Permalink
pytest-lsp: Add spec compliance check for workspace/configuration
Browse files Browse the repository at this point in the history
The `LanguageClientProtocol` class is now also able to perform
complicance checks on incoming requests from the server under test.
The first check implemented is making sure that the client has support
for `workspace/configuration` requests.
  • Loading branch information
alcarney committed Oct 23, 2023
1 parent c298616 commit f10277b
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 43 deletions.
123 changes: 96 additions & 27 deletions lib/pytest-lsp/pytest_lsp/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,71 @@
from typing import Set
from typing import Union

from lsprotocol.types import COMPLETION_ITEM_RESOLVE
from lsprotocol.types import TEXT_DOCUMENT_COMPLETION
from lsprotocol.types import TEXT_DOCUMENT_DOCUMENT_LINK
from lsprotocol.types import ClientCapabilities
from lsprotocol.types import CompletionItem
from lsprotocol.types import CompletionList
from lsprotocol.types import DocumentLink
from lsprotocol.types import InsertTextFormat
from lsprotocol.types import MarkupContent
from lsprotocol import types
from pygls.capabilities import get_capability

logger = logging.getLogger(__name__)
ResultChecker = Callable[[ClientCapabilities, Any], None]
ParamsChecker = Callable[[types.ClientCapabilities, Any], None]
ResultChecker = Callable[[types.ClientCapabilities, Any], None]

PARAMS_CHECKS: Dict[str, ParamsChecker] = {}
RESULT_CHECKS: Dict[str, ResultChecker] = {}


class LspSpecificationWarning(UserWarning):
"""Warning raised when encountering results that fall outside the spec."""


def check_result_for(maybe_fn: Optional[ResultChecker] = None, *, method: str):
def check_result_for(*, method: str) -> Callable[[ResultChecker], ResultChecker]:
"""Define a result check."""

def defcheck(fn: ResultChecker):
if (existing := RESULT_CHECKS.get(method, None)) is not None:
raise ValueError(f"{fn!r} conflicts with existing check {existing!r}")

RESULT_CHECKS[method] = fn
return fn

if maybe_fn:
return defcheck(maybe_fn)
return defcheck


def check_params_of(*, method: str) -> Callable[[ParamsChecker], ParamsChecker]:
"""Define a params check."""

def defcheck(fn: ParamsChecker):
if (existing := PARAMS_CHECKS.get(method, None)) is not None:
raise ValueError(f"{fn!r} conflicts with existing check {existing!r}")

PARAMS_CHECKS[method] = fn
return fn

return defcheck


def check_result_against_client_capabilities(
capabilities: Optional[ClientCapabilities], method: str, result: Any
capabilities: Optional[types.ClientCapabilities], method: str, result: Any
):
"""Check that the given result respects the client's declared capabilities."""
"""Check that the given result respects the client's declared capabilities.
This will emit an ``LspSpecificationWarning`` if any issues are detected.
Parameters
----------
capabilities
The client's capabilities
method
The method name to validate the result of
result
The result to validate
"""

if capabilities is None:
raise RuntimeError("Client has not been initialized")

# Only run checks if the user provided some capabilities for the client.
if capabilities == ClientCapabilities():
if capabilities == types.ClientCapabilities():
return

result_checker = RESULT_CHECKS.get(method, None)
Expand All @@ -73,8 +95,43 @@ def check_result_against_client_capabilities(
warnings.warn(str(e), LspSpecificationWarning, stacklevel=4)


def check_params_against_client_capabilities(
capabilities: Optional[types.ClientCapabilities], method: str, params: Any
):
"""Check that the given params respect the client's declared capabilities.
This will emit an ``LspSpecificationWarning`` if any issues are detected.
Parameters
----------
capabilities
The client's capabilities
method
The method name to validate the result of
params
The params to validate
"""
if capabilities is None:
raise RuntimeError("Client has not been initialized")

# Only run checks if the user provided some capabilities for the client.
if capabilities == types.ClientCapabilities():
return

params_checker = PARAMS_CHECKS.get(method, None)
if params_checker is None:
return

try:
params_checker(capabilities, params)
except AssertionError as e:
warnings.warn(str(e), LspSpecificationWarning, stacklevel=2)


def check_completion_item(
item: CompletionItem,
item: types.CompletionItem,
commit_characters_support: bool,
documentation_formats: Set[str],
snippet_support: bool,
Expand All @@ -84,19 +141,19 @@ def check_completion_item(
if item.commit_characters:
assert commit_characters_support, "Client does not support commit characters"

if isinstance(item.documentation, MarkupContent):
if isinstance(item.documentation, types.MarkupContent):
kind = item.documentation.kind
message = f"Client does not support documentation format '{kind}'"
assert kind in documentation_formats, message

if item.insert_text_format == InsertTextFormat.Snippet:
if item.insert_text_format == types.InsertTextFormat.Snippet:
assert snippet_support, "Client does not support snippets."


@check_result_for(method=TEXT_DOCUMENT_COMPLETION)
@check_result_for(method=types.TEXT_DOCUMENT_COMPLETION)
def completion_items(
capabilities: ClientCapabilities,
result: Union[CompletionList, List[CompletionItem], None],
capabilities: types.ClientCapabilities,
result: Union[types.CompletionList, List[types.CompletionItem], None],
):
"""Ensure that the completion items returned from the server are compliant with the
spec and the client's declared capabilities."""
Expand All @@ -122,7 +179,7 @@ def completion_items(
False,
)

if isinstance(result, CompletionList):
if isinstance(result, types.CompletionList):
items = result.items
else:
items = result
Expand All @@ -136,8 +193,10 @@ def completion_items(
)


@check_result_for(method=COMPLETION_ITEM_RESOLVE)
def completion_item_resolve(capabilities: ClientCapabilities, item: CompletionItem):
@check_result_for(method=types.COMPLETION_ITEM_RESOLVE)
def completion_item_resolve(
capabilities: types.ClientCapabilities, item: types.CompletionItem
):
"""Ensure that the completion item returned from the server is compliant with the
spec and the client's declared capbabilities."""

Expand Down Expand Up @@ -167,9 +226,9 @@ def completion_item_resolve(capabilities: ClientCapabilities, item: CompletionIt
)


@check_result_for(method=TEXT_DOCUMENT_DOCUMENT_LINK)
@check_result_for(method=types.TEXT_DOCUMENT_DOCUMENT_LINK)
def document_links(
capabilities: ClientCapabilities, result: Optional[List[DocumentLink]]
capabilities: types.ClientCapabilities, result: Optional[List[types.DocumentLink]]
):
"""Ensure that the document links returned from the server are compliant with the
Spec and the client's declared capabilities."""
Expand All @@ -184,3 +243,13 @@ def document_links(
for item in result:
if item.tooltip:
assert tooltip_support, "Client does not support tooltips."


@check_params_of(method=types.WORKSPACE_CONFIGURATION)
def workspace_configuration(
capabilities: types.ClientCapabilities,
params: types.WorkspaceConfigurationParams,
):
"""Ensure that the client has support for ``workspace/configuration`` requests."""
is_supported = get_capability(capabilities, "workspace.configuration", False)
assert is_supported, "Client does not support 'workspace/configuration'"
95 changes: 79 additions & 16 deletions lib/pytest-lsp/pytest_lsp/protocol.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,103 @@
from __future__ import annotations

import asyncio
import logging
import typing
from concurrent.futures import Future

from lsprotocol.types import CANCEL_REQUEST
from pygls.exceptions import JsonRpcMethodNotFound
from pygls.protocol import LanguageServerProtocol

from .checks import check_params_against_client_capabilities
from .checks import check_result_against_client_capabilities

if typing.TYPE_CHECKING:
from .client import LanguageClient


logger = logging.getLogger(__name__)


class LanguageClientProtocol(LanguageServerProtocol):
"""An extended protocol class with extra methods that are useful for testing."""
"""An extended protocol class adding functionality useful for testing."""

_server: LanguageClient # type: ignore[assignment]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._notification_futures = {}

def _handle_request(self, msg_id, method_name, params):
"""Wrap pygls' handle_request implementation. This will
- Check if the request from the server is compatible with the client's stated
capabilities.
"""
check_params_against_client_capabilities(
self._server.capabilities, method_name, params
)
return super()._handle_request(msg_id, method_name, params)

def _handle_notification(self, method_name, params):
if method_name == CANCEL_REQUEST:
self._handle_cancel_notification(params.id)
return
"""Wrap pygls' handle_notification implementation. This will
- Notify a future waiting on the notification, if applicable.
- Check the params to see if they are compatible with the client's stated
capabilities.
"""
future = self._notification_futures.pop(method_name, None)
if future:
future.set_result(params)

try:
handler = self._get_handler(method_name)
self._execute_notification(handler, params)
except (KeyError, JsonRpcMethodNotFound):
logger.warning("Ignoring notification for unknown method '%s'", method_name)
except Exception:
logger.exception(
"Failed to handle notification '%s': %s", method_name, params
)
super()._handle_notification(method_name, params)

async def send_request_async(self, method, params=None):
"""Wrap pygls' ``send_request_async`` implementation. This will
- Check the result to see if it's compatible with the client's stated
capabilities
Parameters
----------
method
The method name of the request to send
params
The associated parameters to go with the request
Returns
-------
Any
The response's result
"""
result = await super().send_request_async(method, params)
check_result_against_client_capabilities(
self._server.capabilities, method, result # type: ignore
)

return result

def wait_for_notification(self, method: str, callback=None):
def wait_for_notification(self, method: str, callback=None) -> Future:
"""Wait for a notification message with the given ``method``.
Parameters
----------
method
The method name to wait for
callback
If given, ``callback`` will be called with the notification message's
``params`` when recevied
Returns
-------
Future
A future that will resolve when the requested notification message is
recevied.
"""
future: Future = Future()
if callback:

Expand All @@ -60,5 +111,17 @@ def wrapper(future: Future):
return future

def wait_for_notification_async(self, method: str):
"""Wait for a notification message with the given ``method``.
Parameters
----------
method
The method name to wait for
Returns
-------
Any
The notification message's ``params``
"""
future = self.wait_for_notification(method)
return asyncio.wrap_future(future)
43 changes: 43 additions & 0 deletions lib/pytest-lsp/tests/test_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any

import pytest
from lsprotocol import types

from pytest_lsp import checks


@pytest.mark.parametrize(
"capabilities,method,params,expected",
[
(
types.ClientCapabilities(
workspace=types.WorkspaceClientCapabilities(configuration=False)
),
types.WORKSPACE_CONFIGURATION,
types.WorkspaceConfigurationParams(items=[]),
"does not support 'workspace/configuration'",
),
],
)
def test_params_check_warning(
capabilities: types.ClientCapabilities, method: str, params: Any, expected: str
):
"""Ensure that parameter checks work as expected.
Parameters
----------
capabilities
The client's capabilities
method
The method name to check
params
The params to check
expected
The expected warning message
"""

with pytest.warns(checks.LspSpecificationWarning, match=expected):
checks.check_params_against_client_capabilities(capabilities, method, params)

0 comments on commit f10277b

Please sign in to comment.