diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8e6055be..71934b6d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -9,7 +9,7 @@ on: pull_request: branches: - master - - feature-2.0 + - feat/3.0-release jobs: simple-checks: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee74939a..f328a052 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: pyupgrade args: - - --py38-plus + - --py311-plus files: "tavern/.*" - repo: https://github.com/rhysd/actionlint rev: v1.6.26 diff --git a/docs/source/basics.md b/docs/source/basics.md index e0c37934..fb8200f4 100644 --- a/docs/source/basics.md +++ b/docs/source/basics.md @@ -1520,7 +1520,7 @@ third block must start with 4 and the third block must start with 8, 9, "A", or ``` This is using the `!re_fullmatch` variant of the tag - this calls -[`re.fullmatch`](https://docs.python.org/3.8/library/re.html#re.fullmatch) under +[`re.fullmatch`](https://docs.python.org/3.11/library/re.html#re.fullmatch) under the hood, which means that the regex given needs to match the _entire_ part of the response that is being checked for it to pass. There is also `!re_search` which will pass if it matches _part_ of the thing being checked, or `!re_match` diff --git a/docs/source/index.md b/docs/source/index.md index 105814a8..85203a80 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -4,7 +4,7 @@ Tavern is an advanced pytest based API testing framework for HTTP, MQTT or other protocols. Note that Tavern **only** supports Python 3.4 and up. At the time of writing we -test against Python 3.8-3.10. Python 2 is now **unsupported**. +test against Python 3.11. Python 2 is now **unsupported**. ## Why Tavern diff --git a/pyproject.toml b/pyproject.toml index ef10537e..b27705c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,15 +9,13 @@ classifiers = [ "Intended Audience :: Developers", "Framework :: Pytest", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Utilities", "Topic :: Software Development :: Testing", "License :: OSI Approved :: MIT License", ] +requires-python = ">=3.11" keywords = ["testing", "pytest"] @@ -37,8 +35,6 @@ dependencies = [ "stevedore>=4,<5", ] -requires-python = ">=3.10" - [[project.authors]] name = "Michael Boulton" @@ -122,7 +118,7 @@ paho-mqtt = "tavern._plugins.mqtt.tavernhook" grpc = "tavern._plugins.grpc.tavernhook" [tool.mypy] -python_version = 3.8 +python_version = 3.11 ignore_missing_imports = true [tool.coverage.run] @@ -172,7 +168,7 @@ ignore = [ ] select = ["E", "F", "B", "W", "I", "S", "C4", "ICN", "T20", "PLE", "RUF", "SIM105", "PL"] # Look at: UP -target-version = "py38" +target-version = "py311" extend-exclude = [ "tests/unit/tavern_grpc/test_services_pb2.py", "tests/unit/tavern_grpc/test_services_pb2.pyi", diff --git a/scripts/smoke.bash b/scripts/smoke.bash index aa9762f6..ab25f8a5 100755 --- a/scripts/smoke.bash +++ b/scripts/smoke.bash @@ -5,8 +5,7 @@ set -ex pre-commit run ruff --all-files || true pre-commit run ruff-format --all-files || true -tox --parallel -c tox.ini \ - -e py3check +tox --parallel -c tox.ini -e py3check || true tox --parallel -c tox.ini \ -e py3mypy diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index a47f5939..4c653598 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -4,7 +4,8 @@ import os import re import string -from typing import Any, Dict, List, Mapping, Union +from collections.abc import Collection, Iterable, Mapping, Sequence +from typing import Any, TypeVar import box import jmespath @@ -22,10 +23,10 @@ from .formatted_str import FormattedString from .strict_util import StrictSetting, StrictSettingKinds, extract_strict_setting -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str: +def _check_and_format_values(to_format: str, box_vars: Mapping[str, Any]) -> str: formatter = string.Formatter() would_format = formatter.parse(to_format) @@ -93,8 +94,8 @@ def _attempt_find_include(to_format: str, box_vars: box.Box): def format_keys( - val, - variables: Mapping, + val: TypeConvertToken | str | dict | list | tuple | Mapping | set, + variables: Mapping | Box, *, no_double_format: bool = True, dangerously_ignore_string_format_errors: bool = False, @@ -102,7 +103,7 @@ def format_keys( """recursively format a dictionary with the given values Args: - val: Input dictionary to format + val: Input to format variables: Dictionary of keys to format it with no_double_format: Whether to use the 'inner formatted string' class to avoid double formatting This is required if passing something via pytest-xdist, such as markers: @@ -130,7 +131,7 @@ def format_keys( # formatted = {key: format_keys(val[key], box_vars) for key in val} for key in val: formatted[key] = format_keys_(val[key], box_vars) - elif isinstance(val, (list, tuple)): + elif isinstance(val, (list, tuple, set)): formatted = [format_keys_(item, box_vars) for item in val] # type: ignore elif isinstance(formatted, FormattedString): logger.debug("Already formatted %s, not double-formatting", formatted) @@ -156,7 +157,7 @@ def format_keys( return formatted -def recurse_access_key(data, query: str): +def recurse_access_key(data: dict | list[str] | Mapping, query: str): """ Search for something in the given data using the given query. @@ -168,8 +169,8 @@ def recurse_access_key(data, query: str): 'c' Args: - data (dict, list): Data to search in - query (str): Query to run + data: Data to search in + query: Query to run Returns: object: Whatever was found by the search @@ -181,7 +182,7 @@ def recurse_access_key(data, query: str): logger.error("Error parsing JMES query") try: - _deprecated_recurse_access_key(data, query.split(".")) + _deprecated_recurse_access_key(data, query.split(".")) # type:ignore except (IndexError, KeyError): logger.debug("Nothing found searching using old method") else: @@ -195,7 +196,7 @@ def recurse_access_key(data, query: str): return from_jmespath -def _deprecated_recurse_access_key(current_val, keys): +def _deprecated_recurse_access_key(current_val: dict, keys: list[str]): """Given a list of keys and a dictionary, recursively access the dicionary using the keys until we find the key its looking for @@ -209,8 +210,8 @@ def _deprecated_recurse_access_key(current_val, keys): 'c' Args: - current_val (dict): current dictionary we have recursed into - keys (list): list of str/int of subkeys + current_val: current dictionary we have recursed into + keys: list of str/int of subkeys Raises: IndexError: list index not found in data @@ -224,7 +225,7 @@ def _deprecated_recurse_access_key(current_val, keys): if not keys: return current_val else: - current_key = keys.pop(0) + current_key: str | int = keys.pop(0) with contextlib.suppress(ValueError): current_key = int(current_key) @@ -241,7 +242,7 @@ def _deprecated_recurse_access_key(current_val, keys): raise -def deep_dict_merge(initial_dct: Dict, merge_dct: Mapping) -> dict: +def deep_dict_merge(initial_dct: dict, merge_dct: Mapping) -> dict: """Recursive dict merge. Instead of updating only top-level keys, dict_merge recurses down into dicts nested to an arbitrary depth and returns the merged dict. Keys values present in merge_dct take @@ -266,12 +267,15 @@ def deep_dict_merge(initial_dct: Dict, merge_dct: Mapping) -> dict: return dct -def check_expected_keys(expected, actual) -> None: +_CanCheck = Sequence | Mapping | set | Collection + + +def check_expected_keys(expected: _CanCheck, actual: _CanCheck) -> None: """Check that a set of expected keys is a superset of the actual keys Args: - expected (list, set, dict): keys we expect - actual (list, set, dict): keys we have got from the input + expected: keys we expect + actual: keys we have got from the input Raises: UnexpectedKeysError: If not actual <= expected @@ -289,7 +293,7 @@ def check_expected_keys(expected, actual) -> None: raise exceptions.UnexpectedKeysError(msg) -def yield_keyvals(block): +def yield_keyvals(block: _CanCheck) -> Iterable[tuple[list[str], str, str]]: """Return indexes, keys and expected values for matching recursive keys Given a list or dict, return a 3-tuple of the 'split' key (key split on @@ -321,10 +325,10 @@ def yield_keyvals(block): (['2'], '2', 'c') Args: - block (dict, list): input matches + block: input matches Yields: - (list, str, str): key split on dots, key, expected value + key split on dots, key, expected value """ if isinstance(block, dict): for joined_key, expected_val in block.items(): @@ -336,10 +340,13 @@ def yield_keyvals(block): yield [sidx], sidx, val +T = TypeVar("T", Mapping, set, Sequence, Collection) + + def check_keys_match_recursive( - expected_val: Any, - actual_val: Any, - keys: List[Union[str, int]], + expected_val: T, + actual_val: T, + keys: list[str | int], strict: StrictSettingKinds = True, ) -> None: """Utility to recursively check response values @@ -443,7 +450,7 @@ def _format_err(which): raise exceptions.KeyMismatchError(msg) from e if isinstance(expected_val, dict): - akeys = set(actual_val.keys()) + akeys = set(actual_val.keys()) # type:ignore ekeys = set(expected_val.keys()) if akeys != ekeys: @@ -481,7 +488,10 @@ def _format_err(which): for key in to_recurse: try: check_keys_match_recursive( - expected_val[key], actual_val[key], keys + [key], strict + expected_val[key], + actual_val[key], # type:ignore + keys + [key], + strict, ) except KeyError: logger.debug( diff --git a/tavern/_core/exceptions.py b/tavern/_core/exceptions.py index 9c825007..f05f1c60 100644 --- a/tavern/_core/exceptions.py +++ b/tavern/_core/exceptions.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from tavern._core.pytest.config import TestConfig @@ -15,7 +15,7 @@ class TavernException(Exception): test_block_config: config for stage """ - stage: Optional[Dict] + stage: dict | None test_block_config: Optional["TestConfig"] is_final: bool = False diff --git a/tavern/_core/extfunctions.py b/tavern/_core/extfunctions.py index ce2b972b..691238d6 100644 --- a/tavern/_core/extfunctions.py +++ b/tavern/_core/extfunctions.py @@ -1,7 +1,8 @@ import functools import importlib import logging -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any from tavern._core import exceptions @@ -21,7 +22,7 @@ def is_ext_function(block: Any) -> bool: return isinstance(block, dict) and block.get("$ext", None) is not None -def get_pykwalify_logger(module: Optional[str]) -> logging.Logger: +def get_pykwalify_logger(module: str | None) -> logging.Logger: """Get logger for this module Have to do it like this because the way that pykwalify load extension @@ -140,7 +141,7 @@ def _get_ext_values(ext: Mapping): return func, args, kwargs -def update_from_ext(request_args: dict, keys_to_check: List[str]) -> None: +def update_from_ext(request_args: dict, keys_to_check: list[str]) -> None: """ Updates the request_args dict with any values from external functions diff --git a/tavern/_core/general.py b/tavern/_core/general.py index 51984c90..ca2d2057 100644 --- a/tavern/_core/general.py +++ b/tavern/_core/general.py @@ -1,15 +1,14 @@ import logging import os -from typing import List from tavern._core.loader import load_single_document_yaml from .dict_util import deep_dict_merge -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def load_global_config(global_cfg_paths: List[os.PathLike]) -> dict: +def load_global_config(global_cfg_paths: list[os.PathLike]) -> dict: """Given a list of file paths to global config files, load each of them and return the joined dictionary. diff --git a/tavern/_core/jmesutils.py b/tavern/_core/jmesutils.py index 7940b8b1..74bebacf 100644 --- a/tavern/_core/jmesutils.py +++ b/tavern/_core/jmesutils.py @@ -1,6 +1,7 @@ import operator import re -from typing import Any, Dict, List, Sized +from collections.abc import Sized +from typing import Any from tavern._core import exceptions @@ -37,7 +38,7 @@ def test_type(val, mytype) -> bool: "regex": lambda x, y: regex_compare(str(x), str(y)), "type": test_type, } -TYPES: Dict[str, List[Any]] = { +TYPES: dict[str, list[Any]] = { "none": [type(None)], "number": [int, float], "int": [int], diff --git a/tavern/_core/loader.py b/tavern/_core/loader.py index a4ceccc7..d42866e3 100644 --- a/tavern/_core/loader.py +++ b/tavern/_core/loader.py @@ -6,12 +6,13 @@ import uuid from abc import abstractmethod from itertools import chain -from typing import List, Optional import pytest import yaml +from _pytest.python_api import ApproxBase from yaml.composer import Composer from yaml.constructor import SafeConstructor +from yaml.nodes import Node, ScalarNode from yaml.parser import Parser from yaml.reader import Reader from yaml.resolver import Resolver @@ -21,22 +22,25 @@ from tavern._core.exceptions import BadSchemaError from tavern._core.strtobool import strtobool -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def makeuuid(loader, node): +def makeuuid(loader, node) -> str: return str(uuid.uuid4()) class RememberComposer(Composer): """A composer that doesn't forget anchors across documents""" - def compose_document(self): + def get_event(self) -> None: + ... + + def compose_document(self) -> Node | None: # Drop the DOCUMENT-START event. self.get_event() # Compose the root node. - node = self.compose_node(None, None) + node = self.compose_node(None, None) # type:ignore # Drop the DOCUMENT-END event. self.get_event() @@ -121,7 +125,7 @@ def __init__(self, stream): Resolver.__init__(self) SourceMappingConstructor.__init__(self) - env_path_list: Optional[List] = None + env_path_list: list | None = None env_var_name = "TAVERN_INCLUDE" @@ -140,7 +144,7 @@ def _get_include_dirs(loader): return chain(loader_list, IncludeLoader.env_path_list) -def find_include(loader, node): +def find_include(loader, node) -> str: """Locate an include file and return the abs path.""" for directory in _get_include_dirs(loader): filename = os.path.abspath( @@ -190,14 +194,14 @@ def constructor(_): raise NotImplementedError @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> "TypeSentinel": return cls() - def __str__(self): + def __str__(self) -> str: return f"" @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: node = yaml.nodes.ScalarNode(cls.yaml_tag, "", style=cls.yaml_flow_style) return node @@ -242,7 +246,7 @@ class RegexSentinel(TypeSentinel): constructor = str compiled: re.Pattern - def __str__(self): + def __str__(self) -> str: return f"" @property @@ -254,28 +258,28 @@ def passes(self, string): raise NotImplementedError @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> "RegexSentinel": return cls(re.compile(node.value)) class _RegexMatchSentinel(RegexSentinel): yaml_tag = "!re_match" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.match(string) is not None class _RegexFullMatchSentinel(RegexSentinel): yaml_tag = "!re_fullmatch" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.fullmatch(string) is not None class _RegexSearchSentinel(RegexSentinel): yaml_tag = "!re_search" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.search(string) is not None @@ -321,7 +325,7 @@ class TypeConvertToken(yaml.YAMLObject): def constructor(_): raise NotImplementedError - def __init__(self, value): + def __init__(self, value) -> None: self.value = value @classmethod @@ -338,7 +342,7 @@ def from_yaml(cls, loader, node): return converted @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: return yaml.nodes.ScalarNode( cls.yaml_tag, data.value, style=cls.yaml_flow_style ) @@ -357,7 +361,7 @@ class FloatToken(TypeConvertToken): class StrToBoolConstructor: """Using `bool` as a constructor directly will evaluate all strings to `True`.""" - def __new__(cls, s): + def __new__(cls, s: str) -> bool: # type:ignore return strtobool(s) @@ -369,7 +373,7 @@ class BoolToken(TypeConvertToken): class StrToRawConstructor: """Used when we want to ignore brace formatting syntax""" - def __new__(cls, s): + def __new__(cls, s) -> str: # type:ignore return str(s.replace("{", "{{").replace("}", "}}")) @@ -407,7 +411,7 @@ class ApproxSentinel(yaml.YAMLObject, ApproxScalar): # type:ignore yaml_loader = IncludeLoader @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> ApproxBase: try: val = float(node.value) except (ValueError, TypeError) as e: @@ -420,7 +424,7 @@ def from_yaml(cls, loader, node): return pytest.approx(val) @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: return yaml.nodes.ScalarNode( "!approx", str(data.expected), style=cls.yaml_flow_style ) @@ -430,7 +434,7 @@ def to_yaml(cls, dumper, data): yaml.dumper.Dumper.add_representer(ApproxScalar, ApproxSentinel.to_yaml) -def load_single_document_yaml(filename: os.PathLike) -> dict: +def load_single_document_yaml(filename: str | os.PathLike) -> dict: """ Load a yaml file and expect only one document diff --git a/tavern/_core/plugins.py b/tavern/_core/plugins.py index b296be78..6b2cf2ae 100644 --- a/tavern/_core/plugins.py +++ b/tavern/_core/plugins.py @@ -5,8 +5,9 @@ """ import dataclasses import logging +from collections.abc import Callable, Mapping from functools import partial -from typing import Any, Callable, Dict, List, Mapping, Optional, Protocol, Type +from typing import Any, Protocol import stevedore @@ -16,7 +17,7 @@ from tavern.request import BaseRequest from tavern.response import BaseResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class PluginHelperBase: @@ -32,9 +33,9 @@ def plugin_load_error(mgr, entry_point, err): class _TavernPlugin(Protocol): """A tavern plugin""" - session_type: Type[Any] - request_type: Type[BaseRequest] - verifier_type: Type[BaseResponse] + session_type: type[Any] + request_type: type[BaseRequest] + verifier_type: type[BaseResponse] response_block_name: str request_block_name: str schema: Mapping @@ -90,9 +91,9 @@ class _Plugin: @dataclasses.dataclass class _PluginCache: - plugins: List[_Plugin] = dataclasses.field(default_factory=list) + plugins: list[_Plugin] = dataclasses.field(default_factory=list) - def __call__(self, config: Optional[TestConfig] = None) -> List[_Plugin]: + def __call__(self, config: TestConfig | None = None) -> list[_Plugin]: if self.plugins: return self.plugins @@ -103,7 +104,7 @@ def __call__(self, config: Optional[TestConfig] = None) -> List[_Plugin]: raise exceptions.PluginLoadError("No config to load plugins from") - def _load_plugins(self, test_block_config: TestConfig) -> List[_Plugin]: + def _load_plugins(self, test_block_config: TestConfig) -> list[_Plugin]: """Load plugins from the 'tavern' entrypoint namespace This can be a module or a class as long as it defines the right things @@ -256,7 +257,7 @@ def _foreach_response( stage: Mapping, test_block_config: TestConfig, action: Callable[[_Plugin, str], dict], -) -> Dict[str, dict]: +) -> dict[str, dict]: """Do something for each response Args: diff --git a/tavern/_core/pytest/config.py b/tavern/_core/pytest/config.py index 2e33620e..6f70a2f8 100644 --- a/tavern/_core/pytest/config.py +++ b/tavern/_core/pytest/config.py @@ -2,11 +2,11 @@ import dataclasses import logging from importlib.util import find_spec -from typing import Any, List +from typing import Any from tavern._core.strict_util import StrictLevel -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass(frozen=True) @@ -52,7 +52,7 @@ def with_strictness(self, new_strict: StrictLevel) -> "TestConfig": return dataclasses.replace(self, strict=new_strict) @staticmethod - def backends() -> List[str]: + def backends() -> list[str]: available_backends = ["http"] if has_module("paho.mqtt"): diff --git a/tavern/_core/pytest/error.py b/tavern/_core/pytest/error.py index 0c7b65be..d36f9143 100644 --- a/tavern/_core/pytest/error.py +++ b/tavern/_core/pytest/error.py @@ -1,11 +1,11 @@ import json import logging import re +from collections.abc import Mapping from io import StringIO -from typing import List, Mapping, Optional import yaml -from _pytest._code.code import FormattedExcinfo +from _pytest._code.code import FormattedExcinfo, TerminalRepr from _pytest._io import TerminalWriter from tavern._core import exceptions @@ -18,10 +18,10 @@ start_mark, ) -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -class ReprdError: +class ReprdError(TerminalRepr): def __init__(self, exce, item) -> None: self.exce = exce self.item = item @@ -44,8 +44,8 @@ def _get_available_format_keys(self): return keys def _print_format_variables( - self, tw: TerminalWriter, code_lines: List[str] - ) -> List[str]: + self, tw: TerminalWriter, code_lines: list[str] + ) -> list[str]: """Print a list of the format variables and their value at this stage If the format variable is not defined, print it in red as '???' @@ -108,9 +108,9 @@ def read_formatted_vars(lines): def _print_test_stage( self, tw: TerminalWriter, - code_lines: List[str], - missing_format_vars: List[str], - line_start: Optional[int], + code_lines: list[str], + missing_format_vars: list[str], + line_start: int | None, ) -> None: """Print the direct source lines from this test stage diff --git a/tavern/_core/pytest/file.py b/tavern/_core/pytest/file.py index 89114dcf..24084253 100644 --- a/tavern/_core/pytest/file.py +++ b/tavern/_core/pytest/file.py @@ -2,7 +2,8 @@ import functools import itertools import logging -from typing import Any, Dict, Iterator, List, Mapping, Tuple +from collections.abc import Iterator, Mapping +from typing import Any import pytest import yaml @@ -18,14 +19,14 @@ from .item import YamlItem from .util import load_global_cfg -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) _format_without_inner = functools.partial(format_keys, no_double_format=False) def _format_test_marks( - original_marks: List[Any], fmt_vars: Mapping, test_name: str -) -> Tuple[List[MarkDecorator], Any]: + original_marks: list[Any], fmt_vars: Mapping, test_name: str +) -> tuple[list[MarkDecorator], Any]: """Given the 'raw' marks from the test and any available format variables, generate new marks for this test @@ -89,7 +90,7 @@ def _format_test_marks( return pytest_marks, formatted_marks -def _generate_parametrized_test_items(keys: List, vals_combination): +def _generate_parametrized_test_items(keys: list, vals_combination): """Generate test name from given key(s)/value(s) combination Args: @@ -168,9 +169,9 @@ def maybe_load_ext(v): def _get_parametrized_items( parent: pytest.File, - test_spec: Dict, - parametrize_marks: List[Dict], - pytest_marks: List[MarkDecorator], + test_spec: dict, + parametrize_marks: list[dict], + pytest_marks: list[MarkDecorator], ) -> Iterator[YamlItem]: """Return new items with new format values available based on the mark @@ -272,7 +273,7 @@ def _get_test_fmt_vars(self, test_spec: Mapping) -> dict: # skipif: {my_integer} > 2 # skipif: 'https' in '{hostname}' # skipif: '{hostname}'.contains('ignoreme') - fmt_vars: Dict = {} + fmt_vars: dict = {} global_cfg = load_global_cfg(self.config) fmt_vars.update(**global_cfg.variables) @@ -297,7 +298,7 @@ def _get_test_fmt_vars(self, test_spec: Mapping) -> dict: tavern_box.merge_update(**fmt_vars) return tavern_box - def _generate_items(self, test_spec: Dict) -> Iterator[YamlItem]: + def _generate_items(self, test_spec: dict) -> Iterator[YamlItem]: """Modify or generate tests based on test spec If there are any 'parametrize' marks, this will generate extra tests diff --git a/tavern/_core/pytest/item.py b/tavern/_core/pytest/item.py index 57b80425..023951a6 100644 --- a/tavern/_core/pytest/item.py +++ b/tavern/_core/pytest/item.py @@ -1,10 +1,11 @@ import dataclasses import logging import pathlib -from typing import Iterable, MutableMapping, Optional, Tuple +from collections.abc import Iterable, MutableMapping import pytest import yaml +from _pytest._code.code import TerminalRepr from _pytest.nodes import Node from pytest import Mark, MarkDecorator @@ -20,7 +21,7 @@ from .config import TestConfig from .util import load_global_cfg -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class YamlItem(pytest.Item): @@ -48,7 +49,7 @@ def __init__( self.path = path self.spec = spec - self.global_cfg: Optional[TestConfig] = None + self.global_cfg: TestConfig | None = None if not YamlItem._patched_yaml: yaml.parser.Parser.process_empty_scalar = ( # type:ignore @@ -260,8 +261,8 @@ def runtest(self) -> None: ) def repr_failure( - self, excinfo: pytest.ExceptionInfo[BaseException], style: Optional[str] = None - ): + self, excinfo: pytest.ExceptionInfo[BaseException], style: str | None = None + ) -> TerminalRepr | str: """called when self.runtest() raises an exception. By default, will raise a custom formatted traceback if it's a tavern error. if not, will use the default @@ -283,7 +284,7 @@ def repr_failure( attach_text(str(error), name="error_output") return error - def reportinfo(self) -> Tuple[pathlib.Path, int, str]: + def reportinfo(self) -> tuple[pathlib.Path, int, str]: return ( self.path, 0, diff --git a/tavern/_core/pytest/newhooks.py b/tavern/_core/pytest/newhooks.py index 8fd494b2..25d579c2 100644 --- a/tavern/_core/pytest/newhooks.py +++ b/tavern/_core/pytest/newhooks.py @@ -1,6 +1,6 @@ import logging -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def pytest_tavern_beta_before_every_test_run(test_dict, variables) -> None: diff --git a/tavern/_core/pytest/util.py b/tavern/_core/pytest/util.py index 30343114..0ad8a285 100644 --- a/tavern/_core/pytest/util.py +++ b/tavern/_core/pytest/util.py @@ -1,6 +1,6 @@ import logging from functools import lru_cache -from typing import Any, Dict +from typing import Any import pytest @@ -9,7 +9,7 @@ from tavern._core.pytest.config import TavernInternalConfig, TestConfig from tavern._core.strict_util import StrictLevel -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def add_parser_options(parser_addoption, with_defaults: bool = True) -> None: @@ -174,7 +174,7 @@ def _load_global_cfg(pytest_config: pytest.Config) -> TestConfig: return global_cfg -def _load_global_backends(pytest_config: pytest.Config) -> Dict[str, Any]: +def _load_global_backends(pytest_config: pytest.Config) -> dict[str, Any]: """Load which backend should be used""" backend_settings = {} diff --git a/tavern/_core/report.py b/tavern/_core/report.py index 432d9abe..f72a391f 100644 --- a/tavern/_core/report.py +++ b/tavern/_core/report.py @@ -24,7 +24,7 @@ def call(step_func): from tavern._core.formatted_str import FormattedString from tavern._core.stage_lines import get_stage_lines, read_relevant_lines -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def prepare_yaml(val): @@ -54,7 +54,7 @@ def attach_stage_content(stage) -> None: attach_text(joined, "stage_yaml", yaml_type) -def attach_yaml(payload, name): +def attach_yaml(payload, name) -> None: prepared = prepare_yaml(payload) dumped = yaml.safe_dump(prepared) return attach_text(dumped, name, yaml_type) diff --git a/tavern/_core/run.py b/tavern/_core/run.py index be6bb984..8b56042e 100644 --- a/tavern/_core/run.py +++ b/tavern/_core/run.py @@ -3,9 +3,9 @@ import functools import logging import pathlib +from collections.abc import Mapping, MutableMapping from contextlib import ExitStack from copy import deepcopy -from typing import Dict, List, Mapping, MutableMapping import box @@ -27,7 +27,7 @@ from .testhelpers import delay, retry from .tincture import Tinctures, get_stage_tinctures -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def _resolve_test_stages(test_spec: Mapping, available_stages: Mapping): @@ -60,8 +60,8 @@ def _get_included_stages( tavern_box: box.Box, test_block_config: TestConfig, test_spec: Mapping, - available_stages: List[dict], -) -> List[dict]: + available_stages: list[dict], +) -> list[dict]: """ Get any stages which were included via config files which will be available for use in this test @@ -272,11 +272,11 @@ def update_stage_options(new_option): @dataclasses.dataclass(frozen=True) class _TestRunner: default_global_strictness: StrictLevel - sessions: Dict[str, PluginHelperBase] + sessions: dict[str, PluginHelperBase] test_block_config: TestConfig test_spec: Mapping - def run_stage(self, idx: int, stage, *, is_final: bool = False): + def run_stage(self, idx: int, stage, *, is_final: bool = False) -> None: tinctures = get_stage_tinctures(stage, self.test_spec) stage_config = self.test_block_config.with_strictness( @@ -305,7 +305,7 @@ def run_stage(self, idx: int, stage, *, is_final: bool = False): def wrapped_run_stage( self, stage: dict, stage_config: TestConfig, tinctures: Tinctures - ): + ) -> None: """Run one stage from the test Args: diff --git a/tavern/_core/schema/extensions.py b/tavern/_core/schema/extensions.py index 6d178e53..881156aa 100644 --- a/tavern/_core/schema/extensions.py +++ b/tavern/_core/schema/extensions.py @@ -1,6 +1,6 @@ import os import re -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from pykwalify.types import is_bool, is_float, is_int @@ -132,7 +132,9 @@ def check_usefixtures(value, rule_obj, path) -> bool: return True -def validate_grpc_status_is_valid_or_list_of_names(value: "GRPCCode", rule_obj, path): +def validate_grpc_status_is_valid_or_list_of_names( + value: "GRPCCode", rule_obj, path +) -> bool: """Validate GRPC statuses https://github.com/grpc/grpc/blob/master/doc/statuscodes.md""" # pylint: disable=unused-argument err_msg = ( @@ -153,7 +155,7 @@ def validate_grpc_status_is_valid_or_list_of_names(value: "GRPCCode", rule_obj, return True -def to_grpc_status(value: Union[str, int]): +def to_grpc_status(value: str | int): from grpc import StatusCode if isinstance(value, str): @@ -365,7 +367,7 @@ def check_is_timeout_val(v): return True -def validate_verify_bool_or_str(value: Union[bool, str], rule_obj, path) -> bool: +def validate_verify_bool_or_str(value: bool | str, rule_obj, path) -> bool: """Make sure the 'verify' key is either a bool or a str""" if not isinstance(value, (bool, str)) and not is_bool_like(value): diff --git a/tavern/_core/schema/files.py b/tavern/_core/schema/files.py index 8f801d6d..8f481035 100644 --- a/tavern/_core/schema/files.py +++ b/tavern/_core/schema/files.py @@ -3,7 +3,7 @@ import logging import os import tempfile -from typing import Dict, Mapping +from collections.abc import Mapping import pykwalify import yaml @@ -14,14 +14,14 @@ from tavern._core.plugins import load_plugins from tavern._core.schema.jsonschema import verify_jsonschema -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class SchemaCache: """Caches loaded schemas""" def __init__(self) -> None: - self._loaded: Dict[str, dict] = {} + self._loaded: dict[str, dict] = {} def _load_base_schema(self, schema_filename): try: diff --git a/tavern/_core/schema/jsonschema.py b/tavern/_core/schema/jsonschema.py index 022f8af8..d3e88149 100644 --- a/tavern/_core/schema/jsonschema.py +++ b/tavern/_core/schema/jsonschema.py @@ -1,6 +1,6 @@ import logging import re -from typing import Mapping +from collections.abc import Mapping import jsonschema from jsonschema import Draft7Validator, ValidationError @@ -35,7 +35,7 @@ read_relevant_lines, ) -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def is_str_or_bytes_or_token(checker, instance): diff --git a/tavern/_core/stage_lines.py b/tavern/_core/stage_lines.py index 0f943a54..95c3b784 100644 --- a/tavern/_core/stage_lines.py +++ b/tavern/_core/stage_lines.py @@ -1,6 +1,6 @@ import logging -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def get_stage_lines(stage): diff --git a/tavern/_core/strict_util.py b/tavern/_core/strict_util.py index 76d090d8..a4125794 100644 --- a/tavern/_core/strict_util.py +++ b/tavern/_core/strict_util.py @@ -2,12 +2,12 @@ import enum import logging import re -from typing import List, Optional, Tuple, Union +from typing import Union from tavern._core import exceptions from tavern._core.strtobool import strtobool -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class StrictSetting(enum.Enum): @@ -24,7 +24,7 @@ class StrictSetting(enum.Enum): valid_switches = ["on", "off", "list_any_order"] -def strict_setting_factory(str_setting: Optional[str]) -> StrictSetting: +def strict_setting_factory(str_setting: str | None) -> StrictSetting: """Converts from cmdline/setting file to an enum""" if str_setting is None: return StrictSetting.UNSET @@ -100,7 +100,7 @@ class StrictLevel: ) @classmethod - def from_options(cls, options: Union[List[str], str]) -> "StrictLevel": + def from_options(cls, options: list[str] | str) -> "StrictLevel": if isinstance(options, str): options = [options] elif not isinstance(options, list): @@ -135,7 +135,7 @@ def all_off(cls) -> "StrictLevel": StrictSettingKinds = Union[None, bool, StrictSetting, StrictOption] -def extract_strict_setting(strict: StrictSettingKinds) -> Tuple[bool, StrictSetting]: +def extract_strict_setting(strict: StrictSettingKinds) -> tuple[bool, StrictSetting]: """Takes either a bool, StrictOption, or a StrictSetting and return the bool representation and StrictSetting representation""" diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index 2c72b49f..44e47986 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -1,13 +1,13 @@ import logging import time +from collections.abc import Mapping from functools import wraps -from typing import Mapping from tavern._core import exceptions from tavern._core.dict_util import format_keys from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def delay(stage: Mapping, when: str, variables: Mapping) -> None: diff --git a/tavern/_core/tincture.py b/tavern/_core/tincture.py index f1fe746f..eee3a792 100644 --- a/tavern/_core/tincture.py +++ b/tavern/_core/tincture.py @@ -1,20 +1,20 @@ import collections.abc import inspect import logging -from typing import Any, List +from typing import Any from tavern._core import exceptions from tavern._core.extfunctions import get_wrapped_response_function -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class Tinctures: - def __init__(self, tinctures: List[Any]): + def __init__(self, tinctures: list[Any]) -> None: self._tinctures = tinctures - self._needs_response: List[Any] = [] + self._needs_response: list[Any] = [] - def start_tinctures(self, stage: collections.abc.Mapping): + def start_tinctures(self, stage: collections.abc.Mapping) -> None: results = [t(stage) for t in self._tinctures] self._needs_response = [] diff --git a/tavern/_plugins/grpc/client.py b/tavern/_plugins/grpc/client.py index 04d7b428..8904655c 100644 --- a/tavern/_plugins/grpc/client.py +++ b/tavern/_plugins/grpc/client.py @@ -1,8 +1,8 @@ import dataclasses import logging -import typing import warnings -from typing import Any, Dict, List, Mapping, Optional, Tuple +from collections.abc import Mapping +from typing import Any import grpc import grpc_reflection @@ -22,13 +22,13 @@ from tavern._core.dict_util import check_expected_keys from tavern._plugins.grpc.protos import _generate_proto_import, _import_grpc_module -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) with warnings.catch_warnings(): warnings.simplefilter("ignore") warnings.warn("deprecated", DeprecationWarning) # noqa: B028 -_ProtoMessageType = typing.Type[proto.message.Message] +_ProtoMessageType = type[proto.message.Message] @dataclasses.dataclass @@ -39,7 +39,7 @@ class _ChannelVals: class GRPCClient: - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: logger.debug("Initialising GRPC client with %s", kwargs) expected_blocks = { "connect": {"host", "port", "options", "timeout", "secure"}, @@ -69,7 +69,7 @@ def __init__(self, **kwargs): self.timeout = int(_connect_args.get("timeout", 5)) self.secure = bool(_connect_args.get("secure", False)) - self._options: List[Tuple[str, Any]] = [] + self._options: list[tuple[str, Any]] = [] for key, value in _connect_args.pop("options", {}).items(): if not key.startswith("grpc."): raise exceptions.GRPCServiceException( @@ -77,7 +77,7 @@ def __init__(self, **kwargs): ) self._options.append((key, value)) - self.channels: Dict[str, grpc.Channel] = {} + self.channels: dict[str, grpc.Channel] = {} # Using the default symbol database is a bit undesirable because it means that things being imported from # previous tests will affect later ones which can mask bugs. But there isn't a nice way to have a # self-contained symbol database, because then you need to transitively import all dependencies of protos and @@ -99,15 +99,15 @@ def __init__(self, **kwargs): def _register_file_descriptor( self, service_proto: grpc_reflection.v1alpha.reflection_pb2.FileDescriptorResponse, - ): + ) -> None: for file_descriptor_proto in service_proto.file_descriptor_proto: descriptor = descriptor_pb2.FileDescriptorProto() descriptor.ParseFromString(file_descriptor_proto) self.sym_db.pool.Add(descriptor) def _get_reflection_info( - self, channel, service_name: Optional[str] = None, file_by_filename=None - ): + self, channel, service_name: str | None = None, file_by_filename=None + ) -> None: logger.debug( "Getting GRPC protobuf for service %s from reflection", service_name ) @@ -123,7 +123,7 @@ def _get_reflection_info( def _get_grpc_service( self, channel: grpc.Channel, service: str, method: str - ) -> Optional[_ChannelVals]: + ) -> _ChannelVals | None: full_service_name = f"{service}/{method}" try: input_type, output_type = self.get_method_types(full_service_name) @@ -143,7 +143,7 @@ def _get_grpc_service( def get_method_types( self, full_method_name: str - ) -> Tuple[_ProtoMessageType, _ProtoMessageType]: + ) -> tuple[_ProtoMessageType, _ProtoMessageType]: """Uses the builtin symbol pool to try and find the input and output types for the given method Args: @@ -167,9 +167,7 @@ def get_method_types( return input_type, output_type - def _make_call_request( - self, host: str, full_service: str - ) -> Optional[_ChannelVals]: + def _make_call_request(self, host: str, full_service: str) -> _ChannelVals | None: full_service = full_service.replace("/", ".") service_method = full_service.rsplit(".", 1) if len(service_method) != 2: @@ -239,15 +237,15 @@ def _make_call_request( return self._get_grpc_service(channel, service, method) - def __enter__(self): + def __enter__(self) -> None: logger.debug("Connecting to GRPC") def call( self, service: str, - host: Optional[str] = None, - body: Optional[Mapping] = None, - timeout: Optional[int] = None, + host: str | None = None, + body: Mapping | None = None, + timeout: int | None = None, ) -> grpc.Future: """Makes the request and returns a future with the response.""" if host is None: @@ -282,7 +280,7 @@ def call( request, metadata=self._metadata, timeout=timeout ) - def __exit__(self, *args): + def __exit__(self, *args) -> None: logger.debug("Disconnecting from GRPC") for v in self.channels.values(): v.close() diff --git a/tavern/_plugins/grpc/protos.py b/tavern/_plugins/grpc/protos.py index b70aa238..14805d12 100644 --- a/tavern/_plugins/grpc/protos.py +++ b/tavern/_plugins/grpc/protos.py @@ -9,11 +9,10 @@ import tempfile from distutils.spawn import find_executable from importlib.machinery import ModuleSpec -from typing import List from tavern._core import exceptions -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @functools.lru_cache @@ -31,7 +30,7 @@ def find_protoc() -> str: @functools.lru_cache -def _generate_proto_import(source: str): +def _generate_proto_import(source: str) -> None: """Invokes the Protocol Compiler to generate a _pb2.py from the given .proto file. Does nothing if the output already exists and is newer than the input. @@ -101,7 +100,7 @@ def sanitise(s): _import_grpc_module(output) -def _import_grpc_module(python_module_name: str): +def _import_grpc_module(python_module_name: str) -> None: """takes an expected python module name and tries to import the relevant file, adding service to the symbol database. """ @@ -118,7 +117,7 @@ def _import_grpc_module(python_module_name: str): f"relative imports for Python grpc modules not allowed (got {python_module_name})" ) - import_specs: List[ModuleSpec] = [] + import_specs: list[ModuleSpec] = [] # Check if its already on the python path if (spec := importlib.util.find_spec(python_module_name)) is not None: diff --git a/tavern/_plugins/grpc/request.py b/tavern/_plugins/grpc/request.py index 6fe311ba..8c4ef93d 100644 --- a/tavern/_plugins/grpc/request.py +++ b/tavern/_plugins/grpc/request.py @@ -2,8 +2,7 @@ import functools import json import logging -import warnings -from typing import Mapping, Union +from collections.abc import Mapping import grpc from box import Box @@ -14,7 +13,7 @@ from tavern._plugins.grpc.client import GRPCClient from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def get_grpc_args(rspec, test_block_config): @@ -37,7 +36,7 @@ def get_grpc_args(rspec, test_block_config): @dataclasses.dataclass class WrappedFuture: - response: Union[grpc.Call, grpc.Future] + response: grpc.Call | grpc.Future service_name: str @@ -47,19 +46,9 @@ class GRPCRequest(BaseRequest): Similar to RestRequest, publishes a single message. """ - _warned = False - def __init__( self, client: GRPCClient, request_spec: Mapping, test_block_config: TestConfig - ): - if not self._warned: - warnings.warn( - "Tavern gRPC support is experimental and will be updated in a future release.", - RuntimeWarning, - stacklevel=0, - ) - GRPCRequest._warned = True - + ) -> None: expected = {"host", "service", "body"} check_expected_keys(expected, request_spec) @@ -87,5 +76,5 @@ def run(self) -> WrappedFuture: raise exceptions.GRPCRequestException from e @property - def request_vars(self): + def request_vars(self) -> Box: return Box(self._original_request_vars) diff --git a/tavern/_plugins/grpc/response.py b/tavern/_plugins/grpc/response.py index 8fbc5291..2ec1f17e 100644 --- a/tavern/_plugins/grpc/response.py +++ b/tavern/_plugins/grpc/response.py @@ -1,5 +1,6 @@ import logging -from typing import TYPE_CHECKING, Any, List, Mapping, TypedDict, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypedDict, Union import proto.message from google.protobuf import json_format @@ -16,13 +17,13 @@ if TYPE_CHECKING: from tavern._plugins.grpc.request import WrappedFuture -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -GRPCCode = Union[str, int, List[str], List[int]] +GRPCCode = Union[str, int, list[str], list[int]] -def _to_grpc_name(status: GRPCCode) -> Union[str, List[str]]: +def _to_grpc_name(status: GRPCCode) -> str | list[str]: if isinstance(status, list): return [_to_grpc_name(s) for s in status] # type:ignore @@ -46,9 +47,9 @@ def __init__( self, client: GRPCClient, name: str, - expected: Union[_GRPCExpected, Mapping], + expected: _GRPCExpected | Mapping, test_block_config: TestConfig, - ): + ) -> None: check_expected_keys({"body", "status", "details"}, expected) super().__init__(name, expected, test_block_config) @@ -60,7 +61,7 @@ def __str__(self): else: return "" - def _validate_block(self, blockname: str, block: Mapping): + def _validate_block(self, blockname: str, block: Mapping) -> None: """Validate a block of the response Args: diff --git a/tavern/_plugins/grpc/tavernhook.py b/tavern/_plugins/grpc/tavernhook.py index 56a4f3ff..77b87b1a 100644 --- a/tavern/_plugins/grpc/tavernhook.py +++ b/tavern/_plugins/grpc/tavernhook.py @@ -10,7 +10,7 @@ from .request import GRPCRequest from .response import GRPCResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) session_type = GRPCClient @@ -29,6 +29,6 @@ def get_expected_from_request(response_block, test_block_config: TestConfig, ses verifier_type = GRPCResponse response_block_name = "grpc_response" -schema_path = join(abspath(dirname(__file__)), "jsonschema.yaml") +schema_path: str = join(abspath(dirname(__file__)), "jsonschema.yaml") with open(schema_path) as schema_file: schema = yaml.load(schema_file, Loader=yaml.SafeLoader) diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index c07b0a0c..17dd14ba 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -4,10 +4,11 @@ import ssl import threading import time +from collections.abc import Mapping, MutableMapping from queue import Empty, Full, Queue -from typing import Dict, List, Mapping, MutableMapping, Optional import paho.mqtt.client as paho +from paho.mqtt.client import MQTTMessageInfo from tavern._core import exceptions from tavern._core.dict_util import check_expected_keys @@ -33,7 +34,7 @@ 15: "MQTT_ERR_QUEUE_SIZE", } -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass @@ -57,7 +58,7 @@ def check_file_exists(key, filename) -> None: def _handle_tls_args( tls_args: MutableMapping, -) -> Optional[Mapping]: +) -> Mapping | None: """Make sure TLS options are valid""" if not tls_args: @@ -74,7 +75,7 @@ def _handle_tls_args( def _handle_ssl_context_args( ssl_context_args: MutableMapping, -) -> Optional[Mapping]: +) -> Mapping | None: """Make sure SSL Context options are valid""" if not ssl_context_args: return None @@ -87,8 +88,8 @@ def _handle_ssl_context_args( def _check_and_update_common_tls_args( - tls_args: MutableMapping, check_file_keys: List[str] -): + tls_args: MutableMapping, check_file_keys: list[str] +) -> None: """Checks common args between ssl/tls args""" # could be moved to schema validation stage @@ -274,14 +275,14 @@ def __init__(self, **kwargs) -> None: self._client.tls_insecure_set(True) # Topics to subscribe to - mapping of subscription message id to subscription object - self._subscribed: Dict[int, _Subscription] = {} + self._subscribed: dict[int, _Subscription] = {} # Lock to ensure there is no race condition when subscribing self._subscribe_lock = threading.RLock() # callback self._client.on_subscribe = self._on_subscribe # Mapping of topic -> subscription id, for indexing into self._subscribed - self._subscription_mappings: Dict[str, int] = {} + self._subscription_mappings: dict[str, int] = {} self._userdata = { "_subscription_mappings": self._subscription_mappings, "_subscribed": self._subscribed, @@ -311,7 +312,7 @@ def _on_message(client, userdata, message) -> None: logger.exception("message queue full") @staticmethod - def _on_connect(client, userdata, flags, rc) -> None: + def _on_connect(client, userdata, flags, rc: int) -> None: logger.debug( "Client '%s' connected to the broker with result code '%s'", client._client_id.decode(), @@ -319,7 +320,7 @@ def _on_connect(client, userdata, flags, rc) -> None: ) @staticmethod - def _on_disconnect(client, userdata, rc) -> None: + def _on_disconnect(client, userdata, rc: int) -> None: if rc == paho.CONNACK_ACCEPTED: logger.debug( "Client '%s' successfully disconnected from the broker with result code '%s'", @@ -376,7 +377,13 @@ def message_received(self, topic: str, timeout: int = 1): return msg - def publish(self, topic, payload=None, qos=None, retain=None): + def publish( + self, + topic: str, + payload: None | bytearray | bytes | float | str = None, + qos=None, + retain=None, + ) -> MQTTMessageInfo: """publish message using paho library""" self._wait_for_subscriptions() diff --git a/tavern/_plugins/mqtt/request.py b/tavern/_plugins/mqtt/request.py index 0a9de87a..57630be5 100644 --- a/tavern/_plugins/mqtt/request.py +++ b/tavern/_plugins/mqtt/request.py @@ -1,7 +1,6 @@ import functools import json import logging -from typing import Dict from box.box import Box @@ -13,10 +12,10 @@ from tavern._plugins.mqtt.client import MQTTClient from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def get_publish_args(rspec: Dict, test_block_config: TestConfig) -> dict: +def get_publish_args(rspec: dict, test_block_config: TestConfig) -> dict: """Format mqtt request args and update using ext functions""" fspec = format_keys(rspec, test_block_config.variables) @@ -41,7 +40,7 @@ class MQTTRequest(BaseRequest): """ def __init__( - self, client: MQTTClient, rspec: Dict, test_block_config: TestConfig + self, client: MQTTClient, rspec: dict, test_block_config: TestConfig ) -> None: expected = {"topic", "payload", "json", "qos", "retain"} diff --git a/tavern/_plugins/mqtt/response.py b/tavern/_plugins/mqtt/response.py index afc60ffb..bf103b7c 100644 --- a/tavern/_plugins/mqtt/response.py +++ b/tavern/_plugins/mqtt/response.py @@ -5,14 +5,15 @@ import json import logging import time +from collections.abc import Mapping from dataclasses import dataclass -from typing import Dict, List, Mapping, Optional, Tuple, Union from paho.mqtt.client import MQTTMessage from tavern._core import exceptions from tavern._core.dict_util import check_keys_match_recursive from tavern._core.loader import ANYTHING +from tavern._core.pytest.config import TestConfig from tavern._core.pytest.newhooks import call_hook from tavern._core.report import attach_yaml from tavern._core.strict_util import StrictSetting @@ -20,13 +21,15 @@ from .client import MQTTClient -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) _default_timeout = 1 class MQTTResponse(BaseResponse): - def __init__(self, client: MQTTClient, name, expected, test_block_config) -> None: + def __init__( + self, client: MQTTClient, name: str, expected, test_block_config: TestConfig + ) -> None: super().__init__(name, expected, test_block_config) self._client = client @@ -67,8 +70,8 @@ def _await_response(self) -> dict: m: list(v) for m, v in itertools.groupby(expected, lambda x: x["topic"]) } - correct_messages: List["_ReturnedMessage"] = [] - warnings: List[str] = [] + correct_messages: list["_ReturnedMessage"] = [] + warnings: list[str] = [] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] @@ -135,8 +138,8 @@ def _await_response(self) -> dict: return saved def _await_messages_on_topic( - self, topic: str, expected: List[Dict] - ) -> Tuple[List["_ReturnedMessage"], List[str]]: + self, topic: str, expected: list[dict] + ) -> tuple[list["_ReturnedMessage"], list[str]]: """ Waits for the specific message @@ -189,7 +192,7 @@ def _await_messages_on_topic( name="rest_response", ) - found: List[int] = [] + found: list[int] = [] for i, v in enumerate(verifiers): if v.is_valid(msg): correct_messages.append(_ReturnedMessage(v.expected, msg)) @@ -246,7 +249,7 @@ def __init__(self, test_block_config, expected) -> None: # Any warnings to do with the request # eg, if a message was received but it didn't match, message had payload, etc. - self.warnings: List[str] = [] + self.warnings: list[str] = [] def is_valid(self, msg: MQTTMessage) -> bool: if time.time() > self.expires: @@ -318,7 +321,7 @@ def addwarning(w, *args, **kwargs): return False @staticmethod - def _get_payload_vals(expected: Mapping) -> Tuple[Optional[Union[str, dict]], bool]: + def _get_payload_vals(expected: Mapping) -> tuple[str | dict | None, bool]: """Gets the payload from the 'expected' block Returns: @@ -348,7 +351,7 @@ def _get_payload_vals(expected: Mapping) -> Tuple[Optional[Union[str, dict]], bo return payload, json_payload - def popwarnings(self) -> List[str]: + def popwarnings(self) -> list[str]: popped = [] while self.warnings: popped.append(self.warnings.pop(0)) diff --git a/tavern/_plugins/mqtt/tavernhook.py b/tavern/_plugins/mqtt/tavernhook.py index 9ff57245..ac06a9bb 100644 --- a/tavern/_plugins/mqtt/tavernhook.py +++ b/tavern/_plugins/mqtt/tavernhook.py @@ -1,16 +1,16 @@ import logging from os.path import abspath, dirname, join -from typing import Dict, Optional import yaml from tavern._core.dict_util import format_keys +from tavern._core.pytest.config import TestConfig from .client import MQTTClient from .request import MQTTRequest from .response import MQTTResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) session_type = MQTTClient @@ -18,8 +18,10 @@ request_block_name = "mqtt_publish" -def get_expected_from_request(response_block, test_block_config, session): - expected: Optional[Dict] = None +def get_expected_from_request( + response_block, test_block_config: TestConfig, session: MQTTClient +): + expected: dict | None = None # mqtt response is not required if response_block: @@ -40,6 +42,6 @@ def get_expected_from_request(response_block, test_block_config, session): verifier_type = MQTTResponse response_block_name = "mqtt_response" -schema_path = join(abspath(dirname(__file__)), "jsonschema.yaml") +schema_path: str = join(abspath(dirname(__file__)), "jsonschema.yaml") with open(schema_path, encoding="utf-8") as schema_file: schema = yaml.load(schema_file, Loader=yaml.SafeLoader) diff --git a/tavern/_plugins/rest/files.py b/tavern/_plugins/rest/files.py index 3f72db8a..c8ffcd5d 100644 --- a/tavern/_plugins/rest/files.py +++ b/tavern/_plugins/rest/files.py @@ -3,13 +3,13 @@ import mimetypes import os from contextlib import ExitStack -from typing import Any, List, Optional, Tuple, Union +from typing import Any from tavern._core import exceptions from tavern._core.dict_util import format_keys from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass @@ -17,12 +17,12 @@ class _Filespec: """A description of a file for a file upload, possibly as part of a multi part upload""" path: str - content_type: Optional[str] = None - content_encoding: Optional[str] = None - form_field_name: Optional[str] = None + content_type: str | None = None + content_encoding: str | None = None + form_field_name: str | None = None -def _parse_filespec(filespec: Union[str, dict]) -> _Filespec: +def _parse_filespec(filespec: str | dict) -> _Filespec: """ Get configuration for uploading file @@ -63,8 +63,8 @@ def _parse_filespec(filespec: Union[str, dict]) -> _Filespec: def guess_filespec( - filespec: Union[str, dict], stack: ExitStack, test_block_config: TestConfig -) -> Tuple[List, Optional[str]]: + filespec: str | dict, stack: ExitStack, test_block_config: TestConfig +) -> tuple[list, str | None]: """tries to guess the content type and encoding from a file. Args: @@ -133,9 +133,9 @@ def _parse_file_mapping(file_args, stack, test_block_config) -> dict: return files_to_send -def _parse_file_list(file_args, stack, test_block_config) -> List: +def _parse_file_list(file_args, stack, test_block_config) -> list: """Parses a case where there may be multiple files uploaded as part of one form field""" - files_to_send: List[Any] = [] + files_to_send: list[Any] = [] for filespec in file_args: file_spec, form_field_name = guess_filespec(filespec, stack, test_block_config) @@ -170,7 +170,7 @@ def get_file_arguments( mapping of 'files' block to pass directly to requests """ - files_to_send: Optional[Union[dict, List]] = None + files_to_send: dict | list | None = None file_args = request_args.get("files") diff --git a/tavern/_plugins/rest/request.py b/tavern/_plugins/rest/request.py index 8d81d773..3433e3db 100644 --- a/tavern/_plugins/rest/request.py +++ b/tavern/_plugins/rest/request.py @@ -2,9 +2,10 @@ import json import logging import warnings +from collections.abc import Mapping, MutableMapping from contextlib import ExitStack from itertools import filterfalse, tee -from typing import ClassVar, List, Mapping, MutableMapping, Optional +from typing import ClassVar from urllib.parse import quote_plus import requests @@ -21,7 +22,7 @@ from tavern._plugins.rest.files import get_file_arguments, guess_filespec from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def get_request_args(rspec: MutableMapping, test_block_config: TestConfig) -> dict: @@ -263,7 +264,7 @@ def _check_allow_redirects(rspec: dict, test_block_config: TestConfig): def _read_expected_cookies( session: requests.Session, rspec: Mapping, test_block_config: TestConfig -) -> Optional[dict]: +) -> dict | None: """ Read cookies to inject into request, ignoring others which are present @@ -333,7 +334,7 @@ def partition(pred, iterable): class RestRequest(BaseRequest): - optional_in_file: ClassVar[List[str]] = [ + optional_in_file: ClassVar[list[str]] = [ "json", "data", "params", diff --git a/tavern/_plugins/rest/response.py b/tavern/_plugins/rest/response.py index 83b1062d..f60634cd 100644 --- a/tavern/_plugins/rest/response.py +++ b/tavern/_plugins/rest/response.py @@ -1,7 +1,7 @@ import contextlib import json import logging -from typing import Dict, Mapping, Optional +from collections.abc import Mapping from urllib.parse import parse_qs, urlparse import requests @@ -9,20 +9,23 @@ from tavern._core import exceptions from tavern._core.dict_util import deep_dict_merge +from tavern._core.pytest.config import TestConfig from tavern._core.pytest.newhooks import call_hook from tavern._core.report import attach_yaml from tavern.response import BaseResponse, indent_err_text -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class RestResponse(BaseResponse): - def __init__(self, session, name: str, expected, test_block_config) -> None: + def __init__( + self, session, name: str, expected, test_block_config: TestConfig + ) -> None: defaults = {"status_code": 200} super().__init__(name, deep_dict_merge(defaults, expected), test_block_config) - self.status_code: Optional[int] = None + self.status_code: int | None = None def check_code(code: int) -> None: if int(code) not in _codes: @@ -75,7 +78,7 @@ def log_dict_block(block, name): logger.debug("Redirect location: %s", to_path) log_dict_block(redirect_query_params, "Redirect URL query parameters") - def _get_redirect_query_params(self, response) -> Dict[str, str]: + def _get_redirect_query_params(self, response) -> dict[str, str]: """If there was a redirect header, get any query parameters from it""" try: diff --git a/tavern/_plugins/rest/tavernhook.py b/tavern/_plugins/rest/tavernhook.py index 208e32d3..7f695e02 100644 --- a/tavern/_plugins/rest/tavernhook.py +++ b/tavern/_plugins/rest/tavernhook.py @@ -9,7 +9,7 @@ from .request import RestRequest from .response import RestResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class TavernRestPlugin(PluginHelperBase): diff --git a/tavern/core.py b/tavern/core.py index da20f273..60d48896 100644 --- a/tavern/core.py +++ b/tavern/core.py @@ -1,23 +1,20 @@ import os from contextlib import ExitStack -from typing import Union import pytest +from pytest import ExitCode from tavern._core import exceptions from tavern._core.schema.files import wrapfile -def _get_or_wrap_global_cfg( - stack: ExitStack, tavern_global_cfg: Union[dict, str] -) -> str: +def _get_or_wrap_global_cfg(stack: ExitStack, tavern_global_cfg: dict | str) -> str: """ Try to parse global configuration from given argument. Args: stack: context stack for wrapping file if a dictionary is given - tavern_global_cfg: Dictionary or string. It should be a - path to a file or a dictionary with configuration. + tavern_global_cfg: path to a file or a dictionary with configuration. Returns: path to global config file @@ -48,29 +45,29 @@ def _get_or_wrap_global_cfg( return global_filename -def run( +def run( # type:ignore in_file: str, - tavern_global_cfg=None, - tavern_mqtt_backend=None, - tavern_http_backend=None, - tavern_grpc_backend=None, - tavern_strict=None, - pytest_args=None, -): + tavern_global_cfg: dict | str | None = None, + tavern_mqtt_backend: str | None = None, + tavern_http_backend: str | None = None, + tavern_grpc_backend: str | None = None, + tavern_strict: bool | None = None, + pytest_args: list | None = None, +) -> ExitCode | int: """Run all tests contained in a file using pytest.main() Args: in_file: file to run tests on - tavern_global_cfg (str, dict): Extra global config - tavern_mqtt_backend (str, optional): name of MQTT plugin to use. If not + tavern_global_cfg: Extra global config + tavern_mqtt_backend: name of MQTT plugin to use. If not specified, uses tavern-mqtt - tavern_http_backend (str, optional): name of HTTP plugin to use. If not + tavern_http_backend: name of HTTP plugin to use. If not specified, use tavern-http - tavern_grpc_backend (str, optional): name of GRPC plugin to use. If not + tavern_grpc_backend: name of GRPC plugin to use. If not specified, use tavern-grpc - tavern_strict (bool, optional): Strictness of checking for responses. + tavern_strict: Strictness of checking for responses. See documentation for details - pytest_args (list, optional): List of extra arguments to pass directly + pytest_args: List of extra arguments to pass directly to Pytest as if they were command line arguments Returns: diff --git a/tavern/entry.py b/tavern/entry.py index b996f94c..aaab5132 100644 --- a/tavern/entry.py +++ b/tavern/entry.py @@ -2,7 +2,6 @@ import logging.config from argparse import ArgumentParser from textwrap import dedent -from typing import Dict from .core import run @@ -49,7 +48,7 @@ def main(): log_level = "INFO" # Basic logging config that will print out useful information - log_cfg: Dict = { + log_cfg: dict = { "version": 1, "formatters": { "default": { diff --git a/tavern/helpers.py b/tavern/helpers.py index b4204b65..fdd157da 100644 --- a/tavern/helpers.py +++ b/tavern/helpers.py @@ -2,7 +2,6 @@ import json import logging import re -from typing import Dict, List, Optional import jmespath import jwt @@ -14,7 +13,7 @@ from tavern._core.jmesutils import actual_validation, validate_comparison from tavern._core.schema.files import verify_pykwalify -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def check_exception_raised( @@ -68,7 +67,7 @@ def check_exception_raised( ) from e -def validate_jwt(response, jwt_key, **kwargs) -> Dict[str, Box]: +def validate_jwt(response, jwt_key, **kwargs) -> dict[str, Box]: """Make sure a jwt is valid This uses the pyjwt library to decode the jwt, so any keyword args needed @@ -118,9 +117,9 @@ def validate_regex( response: requests.Response, expression: str, *, - header: Optional[str] = None, - in_jmespath: Optional[str] = None, -) -> Dict[str, Box]: + header: str | None = None, + in_jmespath: str | None = None, +) -> dict[str, Box]: """Make sure the response matches a regex expression Args: @@ -171,7 +170,7 @@ def validate_regex( return {"regex": Box(match.groupdict())} -def validate_content(response: requests.Response, comparisons: List[str]) -> None: +def validate_content(response: requests.Response, comparisons: list[str]) -> None: """Asserts expected value with actual value using JMES path expression Args: @@ -197,7 +196,7 @@ def validate_content(response: requests.Response, comparisons: List[str]) -> Non raise exceptions.JMESError("Error validating JMES") from e -def check_jmespath_match(parsed_response, query: str, expected: Optional[str] = None): +def check_jmespath_match(parsed_response, query: str, expected: str | None = None): """ Check that the JMES path given in 'query' is present in the given response diff --git a/tavern/request.py b/tavern/request.py index a049c8eb..abfa7ded 100644 --- a/tavern/request.py +++ b/tavern/request.py @@ -6,7 +6,7 @@ from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class BaseRequest: diff --git a/tavern/response.py b/tavern/response.py index b9ffb86f..4c3350aa 100644 --- a/tavern/response.py +++ b/tavern/response.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Mapping from textwrap import indent -from typing import Any, List, Optional +from typing import Any from tavern._core import exceptions from tavern._core.dict_util import check_keys_match_recursive, recurse_access_key @@ -11,7 +11,7 @@ from tavern._core.pytest.config import TestConfig from tavern._core.strict_util import StrictOption -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def indent_err_text(err: str) -> str: @@ -26,16 +26,16 @@ def __init__(self, name: str, expected, test_block_config: TestConfig) -> None: self.name = name # all errors in this response - self.errors: List[str] = [] + self.errors: list[str] = [] - self.validate_functions: List = [] + self.validate_functions: list = [] self._check_for_validate_functions(expected) self.test_block_config = test_block_config self.expected = expected - self.response: Optional[Any] = None + self.response: Any | None = None def _str_errors(self) -> str: return "- " + "\n- ".join(self.errors) @@ -58,7 +58,7 @@ def verify(self, response): def recurse_check_key_match( self, - expected_block: Optional[Mapping], + expected_block: Mapping | None, block: Mapping, blockname: str, strict: StrictOption, @@ -224,9 +224,9 @@ def maybe_get_save_values_from_ext( def maybe_get_save_values_from_save_block( self, key: str, - save_from: Optional[Mapping], + save_from: Mapping | None, *, - outer_save_block: Optional[Mapping] = None, + outer_save_block: Mapping | None = None, ) -> dict: """Save a value from a specific block in the response. @@ -252,7 +252,7 @@ def maybe_get_save_values_from_save_block( def maybe_get_save_values_from_given_block( self, key: str, - save_from: Optional[Mapping], + save_from: Mapping | None, to_save: Mapping, ) -> dict: """Save a value from a specific block in the response. diff --git a/tests/unit/tavern_grpc/test_grpc.py b/tests/unit/tavern_grpc/test_grpc.py index 5e9903e0..ab57fda3 100644 --- a/tests/unit/tavern_grpc/test_grpc.py +++ b/tests/unit/tavern_grpc/test_grpc.py @@ -2,15 +2,16 @@ import os.path import random import sys +from collections.abc import Mapping from concurrent import futures -from typing import Any, Mapping, Optional +from typing import Any import grpc import pytest -from _pytest.mark import MarkGenerator from google.protobuf import json_format from google.protobuf.empty_pb2 import Empty from grpc_reflection.v1alpha import reflection +from pytest import MarkGenerator from tavern._core.pytest.config import TestConfig from tavern._plugins.grpc.client import GRPCClient @@ -71,7 +72,7 @@ class GRPCTestSpec: method: str req: Any - resp: Optional[Any] = None + resp: Any | None = None xfail: bool = False code: GRPCCode = grpc.StatusCode.OK.value[0] service: str = "tavern.tests.v1.DummyService" diff --git a/tests/unit/test_mqtt.py b/tests/unit/test_mqtt.py index 1d55b06f..0a91cbab 100644 --- a/tests/unit/test_mqtt.py +++ b/tests/unit/test_mqtt.py @@ -1,4 +1,3 @@ -from typing import Dict from unittest.mock import MagicMock, Mock, patch import paho.mqtt.client as paho @@ -198,7 +197,7 @@ def subscribe_success(topic, *args, **kwargs): class TestExtFunctions: @pytest.fixture() - def basic_mqtt_request_args(self) -> Dict: + def basic_mqtt_request_args(self) -> dict: return { "topic": "/a/b/c", }