From 6facf2698fd268a4e3f43bcfc5d304ead93a2544 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Wed, 6 Nov 2024 19:20:31 +0000 Subject: [PATCH] lsp-devtools: pyupgrade --py39-plus --- .../lsp_devtools/agent/__init__.py | 7 ++-- lib/lsp-devtools/lsp_devtools/agent/agent.py | 25 ++++++--------- lib/lsp-devtools/lsp_devtools/agent/client.py | 6 ++-- lib/lsp-devtools/lsp_devtools/agent/server.py | 16 ++++------ .../lsp_devtools/client/__init__.py | 9 +++--- .../lsp_devtools/client/editor/text_editor.py | 3 +- lib/lsp-devtools/lsp_devtools/database.py | 18 +++-------- .../lsp_devtools/handlers/__init__.py | 13 ++++---- lib/lsp-devtools/lsp_devtools/handlers/sql.py | 7 +--- .../lsp_devtools/inspector/__init__.py | 21 ++++++------ .../lsp_devtools/record/__init__.py | 18 +++++------ .../lsp_devtools/record/filters.py | 28 ++++++++-------- .../lsp_devtools/record/formatters.py | 32 ++++++++----------- .../lsp_devtools/record/visualize.py | 10 +++--- lib/lsp-devtools/pyproject.toml | 3 +- lib/lsp-devtools/tests/record/test_filters.py | 12 +++---- lib/lsp-devtools/tests/record/test_record.py | 4 +-- 17 files changed, 102 insertions(+), 130 deletions(-) diff --git a/lib/lsp-devtools/lsp_devtools/agent/__init__.py b/lib/lsp-devtools/lsp_devtools/agent/__init__.py index 14556dd..82c0cd7 100644 --- a/lib/lsp-devtools/lsp_devtools/agent/__init__.py +++ b/lib/lsp-devtools/lsp_devtools/agent/__init__.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import argparse import asyncio import subprocess import sys -from typing import List from .agent import Agent from .agent import RPCMessage @@ -31,7 +32,7 @@ async def forward_stderr(server: asyncio.subprocess.Process): sys.stderr.buffer.write(line) -async def main(args, extra: List[str]): +async def main(args, extra: list[str]): if extra is None: print("Missing server start command", file=sys.stderr) return 1 @@ -54,7 +55,7 @@ async def main(args, extra: List[str]): ) -def run_agent(args, extra: List[str]): +def run_agent(args, extra: list[str]): asyncio.run(main(args, extra)) diff --git a/lib/lsp-devtools/lsp_devtools/agent/agent.py b/lib/lsp-devtools/lsp_devtools/agent/agent.py index 2c9ea1d..4d0e26f 100644 --- a/lib/lsp-devtools/lsp_devtools/agent/agent.py +++ b/lib/lsp-devtools/lsp_devtools/agent/agent.py @@ -15,14 +15,10 @@ import attrs if typing.TYPE_CHECKING: + from collections.abc import Coroutine from typing import Any from typing import BinaryIO from typing import Callable - from typing import Coroutine - from typing import Dict - from typing import Optional - from typing import Set - from typing import Tuple from typing import Union MessageHandler = Callable[[bytes], Union[None, Coroutine[Any, Any, None]]] @@ -35,9 +31,9 @@ class RPCMessage: """A Json-RPC message.""" - headers: Dict[str, str] + headers: dict[str, str] - body: Dict[str, Any] + body: dict[str, Any] def __getitem__(self, key: str): return self.headers[key] @@ -46,8 +42,8 @@ def __getitem__(self, key: str): def parse_rpc_message(data: bytes) -> RPCMessage: """Parse a JSON-RPC message from the given set of bytes.""" - headers: Dict[str, str] = {} - body: Optional[Dict[str, Any]] = None + headers: dict[str, str] = {} + body: dict[str, Any] | None = None headers_complete = False for line in data.split(b"\r\n"): @@ -118,7 +114,7 @@ async def aio_readline(reader: asyncio.StreamReader, message_handler: MessageHan async def get_streams( stdin, stdout -) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: +) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Convert blocking stdin/stdout streams into async streams.""" loop = asyncio.get_running_loop() @@ -150,9 +146,9 @@ def __init__( self.handler = handler self.session_id = str(uuid4()) - self._tasks: Set[asyncio.Task] = set() - self.reader: Optional[asyncio.StreamReader] = None - self.writer: Optional[asyncio.StreamWriter] = None + self._tasks: set[asyncio.Task] = set() + self.reader: asyncio.StreamReader | None = None + self.writer: asyncio.StreamWriter | None = None async def start(self): # Get async versions of stdin/stdout @@ -227,8 +223,7 @@ async def stop(self): self.server.kill() args = {} - if sys.version_info >= (3, 9): - args["msg"] = "lsp-devtools agent is stopping." + args["msg"] = "lsp-devtools agent is stopping." # Cancel the tasks connecting client to server for task in self._tasks: diff --git a/lib/lsp-devtools/lsp_devtools/agent/client.py b/lib/lsp-devtools/lsp_devtools/agent/client.py index 1b65cfc..87e9d4c 100644 --- a/lib/lsp-devtools/lsp_devtools/agent/client.py +++ b/lib/lsp-devtools/lsp_devtools/agent/client.py @@ -12,8 +12,6 @@ if typing.TYPE_CHECKING: from typing import Any - from typing import List - from typing import Optional # from websockets.client import WebSocketClientProtocol @@ -45,14 +43,14 @@ def __init__(self): protocol_cls=AgentProtocol, converter_factory=default_converter ) self.connected = False - self._buffer: List[bytes] = [] + self._buffer: list[bytes] = [] def _report_server_error(self, error, source): # Bail on error # TODO: Report the actual error somehow self._stop_event.set() - def feature(self, feature_name: str, options: Optional[Any] = None): + def feature(self, feature_name: str, options: Any | None = None): return self.protocol.fm.feature(feature_name, options) # TODO: Upstream this... or at least something equivalent. diff --git a/lib/lsp-devtools/lsp_devtools/agent/server.py b/lib/lsp-devtools/lsp_devtools/agent/server.py index 17805d7..f74a47a 100644 --- a/lib/lsp-devtools/lsp_devtools/agent/server.py +++ b/lib/lsp-devtools/lsp_devtools/agent/server.py @@ -14,8 +14,6 @@ if typing.TYPE_CHECKING: from typing import Any - from typing import List - from typing import Optional from lsp_devtools.agent.agent import MessageHandler @@ -29,8 +27,8 @@ class AgentServer(Server): def __init__( self, *args, - logger: Optional[logging.Logger] = None, - handler: Optional[MessageHandler] = None, + logger: logging.Logger | None = None, + handler: MessageHandler | None = None, **kwargs, ): if "protocol_cls" not in kwargs: @@ -43,11 +41,11 @@ def __init__( self.logger = logger or logging.getLogger(__name__) self.handler = handler or self.lsp.data_received - self.db: Optional[Database] = None + self.db: Database | None = None - self._client_buffer: List[str] = [] - self._server_buffer: List[str] = [] - self._tcp_server: Optional[asyncio.Task] = None + self._client_buffer: list[str] = [] + self._server_buffer: list[str] = [] + self._tcp_server: asyncio.Task | None = None def _report_server_error(self, exc: Exception, source): """Report internal server errors.""" @@ -55,7 +53,7 @@ def _report_server_error(self, exc: Exception, source): self.logger.error("%s: %s", type(exc).__name__, exc) self.logger.debug("%s", tb) - def feature(self, feature_name: str, options: Optional[Any] = None): + def feature(self, feature_name: str, options: Any | None = None): return self.lsp.fm.feature(feature_name, options) async def start_tcp(self, host: str, port: int) -> None: # type: ignore[override] diff --git a/lib/lsp-devtools/lsp_devtools/client/__init__.py b/lib/lsp-devtools/lsp_devtools/client/__init__.py index 42ba9c0..ff7d951 100644 --- a/lib/lsp-devtools/lsp_devtools/client/__init__.py +++ b/lib/lsp-devtools/lsp_devtools/client/__init__.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import argparse import asyncio import logging import os import pathlib -from typing import List from uuid import uuid4 import platformdirs @@ -54,7 +55,7 @@ class LSPClient(App): ] def __init__( - self, db: Database, server_command: List[str], session: str, *args, **kwargs + self, db: Database, server_command: list[str], session: str, *args, **kwargs ): super().__init__(*args, **kwargs) @@ -65,7 +66,7 @@ def __init__( self.server_command = server_command self.lsp_client = LanguageClient() - self._async_tasks: List[asyncio.Task] = [] + self._async_tasks: list[asyncio.Task] = [] def compose(self) -> ComposeResult: message_viewer = MessageViewer("") @@ -140,7 +141,7 @@ async def action_quit(self): await super().action_quit() -def client(args, extra: List[str]): +def client(args, extra: list[str]): if len(extra) == 0: raise ValueError("Missing server command.") diff --git a/lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py b/lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py index 74b256c..ee051c0 100644 --- a/lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py +++ b/lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py @@ -3,7 +3,6 @@ import contextlib import pathlib import typing -from typing import List from typing import Union from lsprotocol import types @@ -15,7 +14,7 @@ if typing.TYPE_CHECKING: from lsp_devtools.client.lsp import LanguageClient -CompletionResult = Union[List[types.CompletionItem], types.CompletionList, None] +CompletionResult = Union[list[types.CompletionItem], types.CompletionList, None] # TODO: Refactor to diff --git a/lib/lsp-devtools/lsp_devtools/database.py b/lib/lsp-devtools/lsp_devtools/database.py index d1aef89..14e6641 100644 --- a/lib/lsp-devtools/lsp_devtools/database.py +++ b/lib/lsp-devtools/lsp_devtools/database.py @@ -2,13 +2,10 @@ import json import logging import pathlib -import sys from contextlib import asynccontextmanager +from importlib import resources from typing import Any -from typing import Dict -from typing import List from typing import Optional -from typing import Set import aiosqlite from textual.app import App @@ -16,11 +13,6 @@ from lsp_devtools.handlers import LspMessage -if sys.version_info < (3, 9): - import importlib_resources as resources -else: - from importlib import resources # type: ignore[no-redef] - class Database: """Controls access to the backing sqlite database.""" @@ -32,7 +24,7 @@ def __init__(self, dbpath: Optional[pathlib.Path] = None): self.dbpath = dbpath or ":memory:" self.db: Optional[aiosqlite.Connection] = None self.app: Optional[App] = None - self._handlers: Dict[str, set] = {} + self._handlers: dict[str, set] = {} async def close(self): if self.db: @@ -106,8 +98,8 @@ async def get_messages( """ base_query = "SELECT rowid, * FROM protocol" - where: List[str] = [] - parameters: List[Any] = [] + where: list[str] = [] + parameters: list[Any] = [] if session: where.append("session = ?") @@ -151,7 +143,7 @@ class DatabaseLogHandler(logging.Handler): def __init__(self, db: Database, *args, **kwargs): super().__init__(*args, **kwargs) self.db = db - self._tasks: Set[asyncio.Task] = set() + self._tasks: set[asyncio.Task] = set() def emit(self, record: logging.LogRecord): body = json.loads(record.args[0]) # type: ignore diff --git a/lib/lsp-devtools/lsp_devtools/handlers/__init__.py b/lib/lsp-devtools/lsp_devtools/handlers/__init__.py index 3df7596..0380de3 100644 --- a/lib/lsp-devtools/lsp_devtools/handlers/__init__.py +++ b/lib/lsp-devtools/lsp_devtools/handlers/__init__.py @@ -8,10 +8,9 @@ import attrs if typing.TYPE_CHECKING: + from collections.abc import Mapping from typing import Any from typing import Literal - from typing import Mapping - from typing import Optional MessageSource = Literal["client", "server"] @@ -37,19 +36,19 @@ class LspMessage: source: MessageSource """Indicates if the message was sent by the client or the server.""" - id: Optional[str] + id: str | None """The ``id`` field, if it exists.""" - method: Optional[str] + method: str | None """The ``method`` field, if it exists.""" - params: Optional[Any] = attrs.field(converter=maybe_json) + params: Any | None = attrs.field(converter=maybe_json) """The ``params`` field, if it exists.""" - result: Optional[Any] = attrs.field(converter=maybe_json) + result: Any | None = attrs.field(converter=maybe_json) """The ``result`` field, if it exists.""" - error: Optional[Any] = attrs.field(converter=maybe_json) + error: Any | None = attrs.field(converter=maybe_json) """The ``error`` field, if it exists.""" @classmethod diff --git a/lib/lsp-devtools/lsp_devtools/handlers/sql.py b/lib/lsp-devtools/lsp_devtools/handlers/sql.py index b06b26a..5fdad0f 100644 --- a/lib/lsp-devtools/lsp_devtools/handlers/sql.py +++ b/lib/lsp-devtools/lsp_devtools/handlers/sql.py @@ -1,17 +1,12 @@ import json import pathlib import sqlite3 -import sys from contextlib import closing +from importlib import resources from lsp_devtools.handlers import LspHandler from lsp_devtools.handlers import LspMessage -if sys.version_info < (3, 9): - import importlib_resources as resources -else: - from importlib import resources # type: ignore[no-redef] - class SqlHandler(LspHandler): """A logging handler that sends log records to a SQL database""" diff --git a/lib/lsp-devtools/lsp_devtools/inspector/__init__.py b/lib/lsp-devtools/lsp_devtools/inspector/__init__.py index 15ba47e..414f662 100644 --- a/lib/lsp-devtools/lsp_devtools/inspector/__init__.py +++ b/lib/lsp-devtools/lsp_devtools/inspector/__init__.py @@ -1,12 +1,11 @@ +from __future__ import annotations + import argparse import asyncio import logging import pathlib +import typing from functools import partial -from typing import Any -from typing import Dict -from typing import List -from typing import Optional import platformdirs from rich.highlighter import ReprHighlighter @@ -28,6 +27,10 @@ from lsp_devtools.database import Database from lsp_devtools.handlers import LspMessage +if typing.TYPE_CHECKING: + from typing import Any + + logger = logging.getLogger(__name__) @@ -70,9 +73,9 @@ def __init__(self, db: Database, viewer: MessageViewer, session=None): self.db = db - self.rpcdata: Dict[int, LspMessage] = {} + self.rpcdata: dict[int, LspMessage] = {} self.max_row = 0 - self.session: Optional[str] = session + self.session: str | None = session self.viewer = viewer @@ -113,7 +116,7 @@ def show_object(self, event: DataTable.RowHighlighted): def _get_query_params(self): """Return the set of query parameters to use when populating the table.""" - query: Dict[str, Any] = dict(max_row=self.max_row) + query: dict[str, Any] = dict(max_row=self.max_row) if self.session is not None: query["session"] = self.session @@ -156,7 +159,7 @@ def __init__(self, db: Database, server: AgentServer, *args, **kwargs): self.server = server """Server used to manage connections to lsp servers.""" - self._async_tasks: List[asyncio.Task] = [] + self._async_tasks: list[asyncio.Task] = [] def compose(self) -> ComposeResult: yield Header() @@ -215,7 +218,7 @@ async def handle_message(db: Database, data: bytes): await db.add_message(session, timestamp, source, rpc.body) -def inspector(args, extra: List[str]): +def inspector(args, extra: list[str]): db = Database(args.dbpath) server = AgentServer(handler=partial(handle_message, db)) diff --git a/lib/lsp-devtools/lsp_devtools/record/__init__.py b/lib/lsp-devtools/lsp_devtools/record/__init__.py index 7537e8b..79e13ac 100644 --- a/lib/lsp-devtools/lsp_devtools/record/__init__.py +++ b/lib/lsp-devtools/lsp_devtools/record/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import asyncio import json @@ -5,8 +7,6 @@ import pathlib from functools import partial from logging import LogRecord -from typing import List -from typing import Optional from rich.console import Console from rich.console import ConsoleRenderable @@ -41,9 +41,9 @@ def render( self, *, record: logging.LogRecord, - traceback: Optional[Traceback], - message_renderable: "ConsoleRenderable", - ) -> "ConsoleRenderable": + traceback: Traceback | None, + message_renderable: ConsoleRenderable, + ) -> ConsoleRenderable: # Delegate most of the rendering to the base RichHandler class. res = super().render( record=record, traceback=traceback, message_renderable=message_renderable @@ -97,7 +97,7 @@ def setup_stdout_output(args, logger: logging.Logger, console: Console): logger.propagate = False -def setup_file_output(args, logger: logging.Logger, console: Optional[Console] = None): +def setup_file_output(args, logger: logging.Logger, console: Console | None = None): """Log messages to a file.""" handler = logging.FileHandler(filename=str(args.to_file)) handler.setLevel(logging.INFO) @@ -122,9 +122,7 @@ def setup_file_output(args, logger: logging.Logger, console: Optional[Console] = logger.propagate = False -def setup_sqlite_output( - args, logger: logging.Logger, console: Optional[Console] = None -): +def setup_sqlite_output(args, logger: logging.Logger, console: Console | None = None): """Log messages to SQLite.""" handler = SqlHandler(args.to_sqlite) handler.setLevel(logging.INFO) @@ -158,7 +156,7 @@ def log_message(logger: logging.Logger, message: bytes): logger.info("%s", rpc.body, extra=rpc.headers) -def start_recording(args, extra: List[str]): +def start_recording(args, extra: list[str]): logger = logging.getLogger("lsp_devtools") rpc_logger = logging.getLogger(__name__) diff --git a/lib/lsp-devtools/lsp_devtools/record/filters.py b/lib/lsp-devtools/lsp_devtools/record/filters.py index 83a0e2f..6cd7e3a 100644 --- a/lib/lsp-devtools/lsp_devtools/record/filters.py +++ b/lib/lsp-devtools/lsp_devtools/record/filters.py @@ -1,17 +1,19 @@ +from __future__ import annotations + import logging -from typing import Dict -from typing import Literal -from typing import Set -from typing import Union +import typing import attrs from .formatters import FormatString -logger = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + from typing import Literal -MessageSource = Literal["client", "server", "both"] -MessageType = Literal["request", "response", "result", "error", "notification"] + MessageSource = Literal["client", "server", "both"] + MessageType = Literal["request", "response", "result", "error", "notification"] + +logger = logging.getLogger(__name__) @attrs.define @@ -21,16 +23,16 @@ class LSPFilter(logging.Filter): message_source: MessageSource = attrs.field(default="both") """Only include messages from the given source.""" - include_message_types: Set[MessageType] = attrs.field(factory=set, converter=set) + include_message_types: set[MessageType] = attrs.field(factory=set, converter=set) """Only include the given message types.""" - exclude_message_types: Set[MessageType] = attrs.field(factory=set, converter=set) + exclude_message_types: set[MessageType] = attrs.field(factory=set, converter=set) """Exclude the given message types.""" - include_methods: Set[str] = attrs.field(factory=set, converter=set) + include_methods: set[str] = attrs.field(factory=set, converter=set) """Only include messages associated with the given method.""" - exclude_methods: Set[str] = attrs.field(factory=set, converter=set) + exclude_methods: set[str] = attrs.field(factory=set, converter=set) """Exclude messages associated with the given method.""" formatter: FormatString = attrs.field( @@ -39,7 +41,7 @@ class LSPFilter(logging.Filter): ) # type: ignore """Format messages according to the given string""" - _response_method_map: Dict[Union[int, str], str] = attrs.field(factory=dict) + _response_method_map: dict[int | str, str] = attrs.field(factory=dict) """Used to determine the method for response messages""" def filter(self, record: logging.LogRecord) -> bool: @@ -95,7 +97,7 @@ def _get_message_method(self, message_type: str, message: dict) -> str: return self._response_method_map[message["id"]] -def message_matches_type(message_type: str, types: Set[MessageType]) -> bool: +def message_matches_type(message_type: str, types: set[MessageType]) -> bool: """Determine if the type of message is included in the given set of types""" if message_type == "result": diff --git a/lib/lsp-devtools/lsp_devtools/record/formatters.py b/lib/lsp-devtools/lsp_devtools/record/formatters.py index 1a628c4..ce42e8a 100644 --- a/lib/lsp-devtools/lsp_devtools/record/formatters.py +++ b/lib/lsp-devtools/lsp_devtools/record/formatters.py @@ -1,25 +1,19 @@ +from __future__ import annotations + import json import re +import typing +from functools import cache from functools import partial -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union import lsprotocol.types -try: - from functools import cache -except ImportError: - from functools import lru_cache - - cache = lru_cache(None) +if typing.TYPE_CHECKING: + from typing import Any + from typing import Callable -def format_json(obj: dict, *, indent: Union[str, int, None] = 2) -> str: +def format_json(obj: dict, *, indent: str | int | None = 2) -> str: if isinstance(obj, str): return obj @@ -34,7 +28,7 @@ def format_range(range_: dict) -> str: return f"{format_position(range_['start'])}-{format_position(range_['end'])}" -FORMATTERS: Dict[str, Callable[[Any], str]] = { +FORMATTERS: dict[str, Callable[[Any], str]] = { "position": format_position, "range": format_range, "json": format_json, @@ -73,7 +67,7 @@ def __init__(self, accessor: str, formatter: Callable[[Any], str]): def __repr__(self): return f'Value(accessor="{self.accessor}", formatter={self.formatter})' - def format(self, message: dict, accessor: Optional[str] = None) -> str: + def format(self, message: dict, accessor: str | None = None) -> str: """Convert a message to a string according to the current accessor and formatter.""" @@ -110,7 +104,7 @@ def format(self, message: dict, accessor: Optional[str] = None) -> str: @cache -def get_separator_index(separator: str) -> Tuple[str, Union[int, slice, None]]: +def get_separator_index(separator: str) -> tuple[str, int | slice | None]: if not separator: return "\n", None @@ -134,7 +128,7 @@ def get_separator(sep: str) -> str: return sep.replace("\\n", "\n").replace("\\t", "\t") -def get_index(idx: str) -> Union[int, slice, None]: +def get_index(idx: str) -> int | slice | None: try: return int(idx) except ValueError: @@ -165,7 +159,7 @@ def __init__(self, pattern: str): def _parse(self): idx = 0 - parts: List[Union[str, Value]] = [] + parts: list[str | Value] = [] for match in self.VARIABLE.finditer(self.pattern): start, end = match.span() diff --git a/lib/lsp-devtools/lsp_devtools/record/visualize.py b/lib/lsp-devtools/lsp_devtools/record/visualize.py index 4fc4daa..47f228e 100644 --- a/lib/lsp-devtools/lsp_devtools/record/visualize.py +++ b/lib/lsp-devtools/lsp_devtools/record/visualize.py @@ -9,8 +9,6 @@ from rich.style import Style if typing.TYPE_CHECKING: - from typing import List - from typing import Optional from rich.console import Console from rich.console import ConsoleOptions @@ -59,19 +57,19 @@ class PacketPipeColumn(progress.ProgressColumn): """Visualizes messages sent between client and server as "packets".""" def __init__( - self, duration: float = 1.0, table_column: Optional[Column] = None + self, duration: float = 1.0, table_column: Column | None = None ) -> None: self.client_count = 0 self.server_count = 0 - self.server_times: List[float] = [] - self.client_times: List[float] = [] + self.server_times: list[float] = [] + self.client_times: list[float] = [] # How long it should take for a packet to propogate. self.duration = duration super().__init__(table_column) - def _update_packets(self, task: progress.Task, source: str) -> List[float]: + def _update_packets(self, task: progress.Task, source: str) -> list[float]: """Update the packet positions for the given message source. Parameters diff --git a/lib/lsp-devtools/pyproject.toml b/lib/lsp-devtools/pyproject.toml index 531c50f..ed0c787 100644 --- a/lib/lsp-devtools/pyproject.toml +++ b/lib/lsp-devtools/pyproject.toml @@ -46,9 +46,8 @@ skip_covered = true sort = "Cover" [tool.pyright] -venv = ".env" include = ["lsp_devtools"] -pythonVersion = "3.8" +pythonVersion = "3.9" [tool.towncrier] filename = "CHANGES.md" diff --git a/lib/lsp-devtools/tests/record/test_filters.py b/lib/lsp-devtools/tests/record/test_filters.py index 6978f6b..5ec9131 100644 --- a/lib/lsp-devtools/tests/record/test_filters.py +++ b/lib/lsp-devtools/tests/record/test_filters.py @@ -91,7 +91,7 @@ def test_filter_message_source(filter_source: str, message_source: str, expected ), ], ) -def test_filter_included_message_types(message: dict, setup: Tuple[List[str], bool]): +def test_filter_included_message_types(message: dict, setup: tuple[list[str], bool]): """Ensure that we can filter messages by listing the types we DO want to see.""" message_types, expected = setup @@ -164,7 +164,7 @@ def test_filter_included_message_types(message: dict, setup: Tuple[List[str], bo ), ], ) -def test_filter_excluded_message_types(message: dict, setup: Tuple[List[str], bool]): +def test_filter_excluded_message_types(message: dict, setup: tuple[list[str], bool]): """Ensure that we can filter messages by listing the types we DO NOT want to see.""" message_types, expected = setup @@ -202,7 +202,7 @@ def test_filter_excluded_message_types(message: dict, setup: Tuple[List[str], bo ), ], ) -def test_filter_included_method(message: dict, setup: Tuple[List[str], bool]): +def test_filter_included_method(message: dict, setup: tuple[list[str], bool]): """Ensure that we can filter messages by listing the methods we wish to see.""" methods, expected = setup @@ -247,7 +247,7 @@ def test_filter_included_method(message: dict, setup: Tuple[List[str], bool]): ], ) def test_filter_included_method_response_message( - response: dict, setup: Tuple[List[str], str, bool] + response: dict, setup: tuple[list[str], str, bool] ): """Ensure that we can filter response message by listing the methods we wish to see.""" @@ -292,7 +292,7 @@ def test_filter_included_method_response_message( ), ], ) -def test_filter_excluded_method(message: dict, setup: Tuple[List[str], bool]): +def test_filter_excluded_method(message: dict, setup: tuple[list[str], bool]): """Ensure that we can filter messages by listing the methods we don't wish to see.""" @@ -338,7 +338,7 @@ def test_filter_excluded_method(message: dict, setup: Tuple[List[str], bool]): ], ) def test_filter_excluded_method_response_message( - response: dict, setup: Tuple[List[str], str, bool] + response: dict, setup: tuple[list[str], str, bool] ): """Ensure that we can filter response message by listing the methods we dont' wish to see.""" diff --git a/lib/lsp-devtools/tests/record/test_record.py b/lib/lsp-devtools/tests/record/test_record.py index 36078df..53e2ada 100644 --- a/lib/lsp-devtools/tests/record/test_record.py +++ b/lib/lsp-devtools/tests/record/test_record.py @@ -89,8 +89,8 @@ def test_file_output( tmp_path: pathlib.Path, record: argparse.ArgumentParser, logger: logging.Logger, - args: List[str], - messages: List[Dict[str, Any]], + args: list[str], + messages: list[dict[str, Any]], expected: str, ): """Ensure that we can log to files correctly.