From 43f445bdf2fd1515c9b036bdbce719647dbfd039 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sat, 20 Jan 2024 18:47:29 +0000 Subject: [PATCH 01/20] Cleanup --- tavern/_plugins/mqtt/client.py | 22 +++++++++++------ tavern/_plugins/mqtt/response.py | 14 +++++++---- tavern/_plugins/rest/request.py | 11 ++++----- tavern/response.py | 41 ++++++++++++++++---------------- 4 files changed, 48 insertions(+), 40 deletions(-) diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index c07b0a0c..06095728 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -5,7 +5,7 @@ import threading import time from queue import Empty, Full, Queue -from typing import Dict, List, Mapping, MutableMapping, Optional +from typing import Any, Dict, List, Mapping, MutableMapping, Optional import paho.mqtt.client as paho @@ -289,7 +289,9 @@ def __init__(self, **kwargs) -> None: self._client.user_data_set(self._userdata) @staticmethod - def _on_message(client, userdata, message) -> None: + def _on_message( + client, userdata: Mapping[str, Any], message: paho.MQTTMessage + ) -> None: """Add any messages received to the queue Todo: @@ -311,7 +313,7 @@ def _on_message(client, userdata, message) -> None: logger.exception("message queue full") @staticmethod - def _on_connect(client, userdata, flags, rc) -> None: + def _on_connect(client, userdata, flags, rc: int) -> None: logger.debug( "Client '%s' connected to the broker with result code '%s'", client._client_id.decode(), @@ -319,7 +321,7 @@ def _on_connect(client, userdata, flags, rc) -> None: ) @staticmethod - def _on_disconnect(client, userdata, rc) -> None: + def _on_disconnect(client, userdata, rc: int) -> None: if rc == paho.CONNACK_ACCEPTED: logger.debug( "Client '%s' successfully disconnected from the broker with result code '%s'", @@ -351,8 +353,8 @@ def message_received(self, topic: str, timeout: int = 1): """Check that a message is in the message queue Args: - topic (str): topic to fetch message for - timeout (int): How long to wait before signalling that the message + topic: topic to fetch message for + timeout: How long to wait before signalling that the message was not received. Returns: @@ -376,7 +378,13 @@ def message_received(self, topic: str, timeout: int = 1): return msg - def publish(self, topic, payload=None, qos=None, retain=None): + def publish( + self, + topic: str, + payload: Any, + qos: Optional[int], + retain: Optional[bool] = False, + ): """publish message using paho library""" self._wait_for_subscriptions() diff --git a/tavern/_plugins/mqtt/response.py b/tavern/_plugins/mqtt/response.py index afc60ffb..020d0945 100644 --- a/tavern/_plugins/mqtt/response.py +++ b/tavern/_plugins/mqtt/response.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from typing import Dict, List, Mapping, Optional, Tuple, Union +import requests from paho.mqtt.client import MQTTMessage from tavern._core import exceptions @@ -18,6 +19,7 @@ from tavern._core.strict_util import StrictSetting from tavern.response import BaseResponse +from ..._core.pytest.config import TestConfig from .client import MQTTClient logger = logging.getLogger(__name__) @@ -26,7 +28,9 @@ class MQTTResponse(BaseResponse): - def __init__(self, client: MQTTClient, name, expected, test_block_config) -> None: + def __init__( + self, client: MQTTClient, name: str, expected, test_block_config: TestConfig + ) -> None: super().__init__(name, expected, test_block_config) self._client = client @@ -39,7 +43,7 @@ def __str__(self): else: return "" - def verify(self, response) -> dict: + def verify(self, response: requests.Response) -> Mapping: """Ensure mqtt message has arrived Args: @@ -53,11 +57,11 @@ def verify(self, response) -> dict: finally: self._client.unsubscribe_all() - def _await_response(self) -> dict: + def _await_response(self) -> Mapping: """Actually wait for response Returns: - dict: things to save to variables for the rest of this test + things to save to variables for the rest of this test """ # Get into class with metadata attached @@ -145,7 +149,7 @@ def _await_messages_on_topic( expected: expected response for this block Returns: - tuple(msg, list): The correct message (if any) and warnings from processing the message + The correct message (if any) and warnings from processing the message """ timeout = max(m.get("timeout", _default_timeout) for m in expected) diff --git a/tavern/_plugins/rest/request.py b/tavern/_plugins/rest/request.py index 8d81d773..9504e1a0 100644 --- a/tavern/_plugins/rest/request.py +++ b/tavern/_plugins/rest/request.py @@ -355,7 +355,7 @@ def __init__( Args: session: existing session rspec: test spec - test_block_config : Any configuration for this the block of + test_block_config: Any configuration for this the block of tests Raises: @@ -394,7 +394,7 @@ def __init__( ) # Used further down, but pop it asap to avoid unwanted side effects - file_body = request_args.pop("file_body", None) + file_body: Optional[str] = request_args.pop("file_body", None) # If there was a 'cookies' key, set it in the request expected_cookies = _read_expected_cookies(session, rspec, test_block_config) @@ -441,14 +441,11 @@ def prepared_request(): self._prepared = prepared_request - def run(self): + def run(self) -> requests.Response: """Runs the prepared request and times it - Todo: - time it - Returns: - requests.Response: response object + response object """ attach_yaml( diff --git a/tavern/response.py b/tavern/response.py index b9ffb86f..f1911570 100644 --- a/tavern/response.py +++ b/tavern/response.py @@ -1,3 +1,4 @@ +import dataclasses import logging import traceback from abc import abstractmethod @@ -20,27 +21,23 @@ def indent_err_text(err: str) -> str: return indent(err, " " * 4) +@dataclasses.dataclass(repr=True) class BaseResponse: - def __init__(self, name: str, expected, test_block_config: TestConfig) -> None: - # Stage name - self.name = name + name: str + test_block_config: TestConfig + expected: Any - # all errors in this response - self.errors: List[str] = [] + validate_functions: List[Any] = dataclasses.field(init=False, default_factory=list) + response: Optional[Any] = None + errors: List[str] = dataclasses.field(init=False, default_factory=list) - self.validate_functions: List = [] - self._check_for_validate_functions(expected) - - self.test_block_config = test_block_config - - self.expected = expected - - self.response: Optional[Any] = None + def __post_init__(self): + self._check_for_validate_functions(self.expected) def _str_errors(self) -> str: return "- " + "\n- ".join(self.errors) - def _adderr(self, msg, *args, e=None) -> None: + def _adderr(self, msg: str, *args, e=None) -> None: if e: logger.exception(msg, *args) else: @@ -151,14 +148,14 @@ def check_deprecated_validate(name): # Could put in an isinstance check here check_deprecated_validate("json") - def _maybe_run_validate_functions(self, response) -> None: + def _maybe_run_validate_functions(self, response: Any) -> None: """Run validation functions if available Note: 'response' will be different depending on where this is called from Args: - response (object): Response type. This could be whatever the response type/plugin uses. + response: Response type. This could be whatever the response type/plugin uses. """ logger.debug( "Calling ext function from '%s' with response '%s'", type(self), response @@ -177,7 +174,7 @@ def _maybe_run_validate_functions(self, response) -> None: def maybe_get_save_values_from_ext( self, response: Any, read_save_from: Mapping - ) -> dict: + ) -> Mapping: """If there is an $ext function in the save block, call it and save the response Args: @@ -227,12 +224,14 @@ def maybe_get_save_values_from_save_block( save_from: Optional[Mapping], *, outer_save_block: Optional[Mapping] = None, - ) -> dict: + ) -> Mapping: """Save a value from a specific block in the response. See docs for maybe_get_save_values_from_given_block for more info - Keyword Args: + Args: + key: Name of key being used to save, for debugging + save_from: An element of the response from which values are being saved outer_save_block: Read things to save from this block instead of self.expected """ @@ -254,7 +253,7 @@ def maybe_get_save_values_from_given_block( key: str, save_from: Optional[Mapping], to_save: Mapping, - ) -> dict: + ) -> Mapping: """Save a value from a specific block in the response. This is different from maybe_get_save_values_from_ext - depends on the kind of response @@ -265,7 +264,7 @@ def maybe_get_save_values_from_given_block( to_save: block containing information about things to save Returns: - dict: dictionary of save_name: value, where save_name is the key we + mapping of save_name: value, where save_name is the key we wanted to save this value as """ From 91f9c966f5e243bdaea58e2644042b321b311a99 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 16:54:39 +0000 Subject: [PATCH 02/20] type infered --- tavern/_core/dict_util.py | 10 ++++++--- tavern/_core/extfunctions.py | 17 ++++++++------ tavern/_core/general.py | 6 ++--- tavern/_core/jmesutils.py | 2 +- tavern/_core/loader.py | 36 ++++++++++++++++-------------- tavern/_core/plugins.py | 2 +- tavern/_core/pytest/config.py | 2 +- tavern/_core/pytest/error.py | 6 ++--- tavern/_core/pytest/file.py | 32 +++++++++++++++++--------- tavern/_core/pytest/hooks.py | 8 ++++++- tavern/_core/pytest/item.py | 8 +++---- tavern/_core/pytest/newhooks.py | 29 +++++++++++++----------- tavern/_core/pytest/util.py | 14 +++++++++--- tavern/_core/report.py | 10 ++++++--- tavern/_core/run.py | 6 ++--- tavern/_core/schema/extensions.py | 4 +++- tavern/_core/schema/files.py | 2 +- tavern/_core/schema/jsonschema.py | 2 +- tavern/_core/stage_lines.py | 28 ++++++++++++++--------- tavern/_core/strict_util.py | 2 +- tavern/_core/testhelpers.py | 8 +++---- tavern/_core/tincture.py | 6 ++--- tavern/_plugins/grpc/client.py | 12 +++++----- tavern/_plugins/grpc/protos.py | 6 ++--- tavern/_plugins/grpc/request.py | 10 ++++----- tavern/_plugins/grpc/response.py | 6 ++--- tavern/_plugins/grpc/tavernhook.py | 8 ++++--- tavern/_plugins/mqtt/client.py | 13 ++++++----- tavern/_plugins/mqtt/request.py | 2 +- tavern/_plugins/mqtt/response.py | 12 ++++++---- tavern/_plugins/mqtt/tavernhook.py | 13 +++++++---- tavern/_plugins/rest/files.py | 2 +- tavern/_plugins/rest/request.py | 2 +- tavern/_plugins/rest/response.py | 7 ++++-- tavern/_plugins/rest/tavernhook.py | 8 +++++-- tavern/core.py | 3 ++- tavern/helpers.py | 22 +++++++++--------- tavern/request.py | 2 +- tavern/response.py | 4 ++-- 39 files changed, 222 insertions(+), 150 deletions(-) diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index a47f5939..0dd8dc76 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -4,6 +4,7 @@ import os import re import string +import typing from typing import Any, Dict, List, Mapping, Union import box @@ -22,7 +23,7 @@ from .formatted_str import FormattedString from .strict_util import StrictSetting, StrictSettingKinds, extract_strict_setting -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str: @@ -92,13 +93,16 @@ def _attempt_find_include(to_format: str, box_vars: box.Box): return formatter.convert_field(would_replace, conversion) # type: ignore +T = typing.TypeVar("T") + + def format_keys( - val, + val: T, variables: Mapping, *, no_double_format: bool = True, dangerously_ignore_string_format_errors: bool = False, -): +) -> T: """recursively format a dictionary with the given values Args: diff --git a/tavern/_core/extfunctions.py b/tavern/_core/extfunctions.py index ce2b972b..3018931a 100644 --- a/tavern/_core/extfunctions.py +++ b/tavern/_core/extfunctions.py @@ -1,7 +1,7 @@ import functools import importlib import logging -from typing import Any, List, Mapping, Optional +from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple from tavern._core import exceptions @@ -16,7 +16,7 @@ def is_ext_function(block: Any) -> bool: block: Any object Returns: - bool: If it is an ext function style dict + If it is an ext function style dict """ return isinstance(block, dict) and block.get("$ext", None) is not None @@ -29,17 +29,20 @@ def get_pykwalify_logger(module: Optional[str]) -> logging.Logger: trying to get the root logger which won't log correctly Args: - module (string): name of module to get logger for + module: name of module to get logger for + Returns: + logger for given module """ return logging.getLogger(module) def _getlogger() -> logging.Logger: + """Get logger for this module""" return get_pykwalify_logger("tavern._core.extfunctions") -def import_ext_function(entrypoint: str): +def import_ext_function(entrypoint: str) -> Callable: """Given a function name in the form of a setuptools entry point, try to dynamically load and return it @@ -79,7 +82,7 @@ def import_ext_function(entrypoint: str): return function -def get_wrapped_response_function(ext: Mapping): +def get_wrapped_response_function(ext: Mapping) -> Callable: """Wraps a ext function with arguments given in the test file This is similar to functools.wrap, but this makes sure that 'response' is @@ -106,7 +109,7 @@ def inner(response): return inner -def get_wrapped_create_function(ext: Mapping): +def get_wrapped_create_function(ext: Mapping) -> Callable: """Same as get_wrapped_response_function, but don't require a response""" func, args, kwargs = _get_ext_values(ext) @@ -122,7 +125,7 @@ def inner(): return inner -def _get_ext_values(ext: Mapping): +def _get_ext_values(ext: Mapping) -> Tuple[Callable, Iterable, Mapping]: if not isinstance(ext, Mapping): raise exceptions.InvalidExtFunctionError( f"ext block should be a dict, but it was a {type(ext)}" diff --git a/tavern/_core/general.py b/tavern/_core/general.py index 51984c90..25658cb6 100644 --- a/tavern/_core/general.py +++ b/tavern/_core/general.py @@ -1,15 +1,15 @@ import logging import os -from typing import List +from typing import List, Union from tavern._core.loader import load_single_document_yaml from .dict_util import deep_dict_merge -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def load_global_config(global_cfg_paths: List[os.PathLike]) -> dict: +def load_global_config(global_cfg_paths: List[Union[str, os.PathLike]]) -> dict: """Given a list of file paths to global config files, load each of them and return the joined dictionary. diff --git a/tavern/_core/jmesutils.py b/tavern/_core/jmesutils.py index 7940b8b1..35055548 100644 --- a/tavern/_core/jmesutils.py +++ b/tavern/_core/jmesutils.py @@ -61,7 +61,7 @@ def safe_length(var: Sized) -> int: return -1 -def validate_comparison(each_comparison): +def validate_comparison(each_comparison: Dict[Any, Any]): if extra := set(each_comparison.keys()) - {"jmespath", "operator", "expected"}: raise exceptions.BadSchemaError( "Invalid keys given to JMES validation function (got extra keys: {})".format( diff --git a/tavern/_core/loader.py b/tavern/_core/loader.py index a4ceccc7..2a0e53b5 100644 --- a/tavern/_core/loader.py +++ b/tavern/_core/loader.py @@ -10,8 +10,10 @@ import pytest import yaml +from _pytest.python_api import ApproxBase from yaml.composer import Composer from yaml.constructor import SafeConstructor +from yaml.nodes import Node, ScalarNode from yaml.parser import Parser from yaml.reader import Reader from yaml.resolver import Resolver @@ -21,17 +23,17 @@ from tavern._core.exceptions import BadSchemaError from tavern._core.strtobool import strtobool -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def makeuuid(loader, node): +def makeuuid(loader, node) -> str: return str(uuid.uuid4()) class RememberComposer(Composer): """A composer that doesn't forget anchors across documents""" - def compose_document(self): + def compose_document(self) -> Optional[Node]: # Drop the DOCUMENT-START event. self.get_event() @@ -140,7 +142,7 @@ def _get_include_dirs(loader): return chain(loader_list, IncludeLoader.env_path_list) -def find_include(loader, node): +def find_include(loader, node) -> str: """Locate an include file and return the abs path.""" for directory in _get_include_dirs(loader): filename = os.path.abspath( @@ -190,14 +192,14 @@ def constructor(_): raise NotImplementedError @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> "TypeSentinel": return cls() - def __str__(self): + def __str__(self) -> str: return f"" @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: node = yaml.nodes.ScalarNode(cls.yaml_tag, "", style=cls.yaml_flow_style) return node @@ -242,7 +244,7 @@ class RegexSentinel(TypeSentinel): constructor = str compiled: re.Pattern - def __str__(self): + def __str__(self) -> str: return f"" @property @@ -254,28 +256,28 @@ def passes(self, string): raise NotImplementedError @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> "RegexSentinel": return cls(re.compile(node.value)) class _RegexMatchSentinel(RegexSentinel): yaml_tag = "!re_match" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.match(string) is not None class _RegexFullMatchSentinel(RegexSentinel): yaml_tag = "!re_fullmatch" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.fullmatch(string) is not None class _RegexSearchSentinel(RegexSentinel): yaml_tag = "!re_search" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.search(string) is not None @@ -321,7 +323,7 @@ class TypeConvertToken(yaml.YAMLObject): def constructor(_): raise NotImplementedError - def __init__(self, value): + def __init__(self, value) -> None: self.value = value @classmethod @@ -338,7 +340,7 @@ def from_yaml(cls, loader, node): return converted @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: return yaml.nodes.ScalarNode( cls.yaml_tag, data.value, style=cls.yaml_flow_style ) @@ -357,7 +359,7 @@ class FloatToken(TypeConvertToken): class StrToBoolConstructor: """Using `bool` as a constructor directly will evaluate all strings to `True`.""" - def __new__(cls, s): + def __new__(cls, s: str) -> bool: return strtobool(s) @@ -407,7 +409,7 @@ class ApproxSentinel(yaml.YAMLObject, ApproxScalar): # type:ignore yaml_loader = IncludeLoader @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> ApproxBase: try: val = float(node.value) except (ValueError, TypeError) as e: @@ -420,7 +422,7 @@ def from_yaml(cls, loader, node): return pytest.approx(val) @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: return yaml.nodes.ScalarNode( "!approx", str(data.expected), style=cls.yaml_flow_style ) diff --git a/tavern/_core/plugins.py b/tavern/_core/plugins.py index b296be78..ff724797 100644 --- a/tavern/_core/plugins.py +++ b/tavern/_core/plugins.py @@ -16,7 +16,7 @@ from tavern.request import BaseRequest from tavern.response import BaseResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class PluginHelperBase: diff --git a/tavern/_core/pytest/config.py b/tavern/_core/pytest/config.py index 2e33620e..51b0b65e 100644 --- a/tavern/_core/pytest/config.py +++ b/tavern/_core/pytest/config.py @@ -6,7 +6,7 @@ from tavern._core.strict_util import StrictLevel -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass(frozen=True) diff --git a/tavern/_core/pytest/error.py b/tavern/_core/pytest/error.py index 0c7b65be..e7b43188 100644 --- a/tavern/_core/pytest/error.py +++ b/tavern/_core/pytest/error.py @@ -5,7 +5,7 @@ from typing import List, Mapping, Optional import yaml -from _pytest._code.code import FormattedExcinfo +from _pytest._code.code import FormattedExcinfo, TerminalRepr from _pytest._io import TerminalWriter from tavern._core import exceptions @@ -18,10 +18,10 @@ start_mark, ) -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -class ReprdError: +class ReprdError(TerminalRepr): def __init__(self, exce, item) -> None: self.exce = exce self.item = item diff --git a/tavern/_core/pytest/file.py b/tavern/_core/pytest/file.py index 49bbbf61..7fbaf509 100644 --- a/tavern/_core/pytest/file.py +++ b/tavern/_core/pytest/file.py @@ -2,11 +2,12 @@ import functools import itertools import logging -from typing import Dict, Iterator, List, Mapping +from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Tuple, Union import pytest import yaml from box import Box +from pytest import Mark from tavern._core import exceptions from tavern._core.dict_util import deep_dict_merge, format_keys, get_tavern_box @@ -17,20 +18,24 @@ from .item import YamlItem from .util import load_global_cfg -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -_format_without_inner = functools.partial(format_keys, no_double_format=False) +_format_without_inner: Callable[[Any, Mapping], Any] = functools.partial( + format_keys, no_double_format=False +) -def _format_test_marks(original_marks, fmt_vars, test_name): +def _format_test_marks( + original_marks: Iterable[Union[str, dict]], fmt_vars: Mapping, test_name: str +) -> Tuple[List[Mark], List[Mapping]]: """Given the 'raw' marks from the test and any available format variables, generate new marks for this test Args: - original_marks (list): Raw string from test - should correspond to either a + original_marks: Raw string from test - should correspond to either a pytest builtin mark or a custom user mark - fmt_vars (dict): dictionary containing available format variables - test_name (str): Name of test (for error logging) + fmt_vars: dictionary containing available format variables + test_name: Name of test (for error logging) Returns: tuple: first element is normal pytest mark objects, second element is all @@ -86,12 +91,17 @@ def _format_test_marks(original_marks, fmt_vars, test_name): return pytest_marks, formatted_marks -def _generate_parametrized_test_items(keys: List, vals_combination): +def _generate_parametrized_test_items( + keys: Iterable[Union[str, List, Tuple]], vals_combination: Iterable[Tuple[str, str]] +) -> Tuple[Mapping[str, Any], str]: """Generate test name from given key(s)/value(s) combination Args: keys: list of keys to format name with - vals_combination (tuple(str)): this combination of values for the key + vals_combination this combination of values for the key + + Returns: + tuple of the variables for the stage and the generated stage name """ flattened_values = [] variables = {} @@ -205,7 +215,7 @@ def unwrap_map(value): "Invalid match between numbers of keys and number of values in parametrize mark" ) from e - keys = [i["parametrize"]["key"] for i in parametrize_marks] + keys: List[str] = [i["parametrize"]["key"] for i in parametrize_marks] for vals_combination in combined: logger.debug("Generating test for %s/%s", keys, vals_combination) @@ -346,7 +356,7 @@ def collect(self) -> Iterator[YamlItem]: try: # Convert to a list so we can catch parser exceptions - all_tests = list( + all_tests: Iterable[dict] = list( yaml.load_all( self.path.open(encoding="utf-8"), Loader=IncludeLoader, # type:ignore diff --git a/tavern/_core/pytest/hooks.py b/tavern/_core/pytest/hooks.py index 9a83ee9e..5fc962c3 100644 --- a/tavern/_core/pytest/hooks.py +++ b/tavern/_core/pytest/hooks.py @@ -3,7 +3,13 @@ import os import pathlib import re +import typing from textwrap import dedent +from typing import Optional + +if typing.TYPE_CHECKING: + from .file import YamlFile + import pytest import yaml @@ -18,7 +24,7 @@ def pytest_addoption(parser: pytest.Parser) -> None: add_ini_options(parser) -def pytest_collect_file(parent, path: os.PathLike): +def pytest_collect_file(parent, path: os.PathLike) -> Optional["YamlFile"]: """On collecting files, get any files that end in .tavern.yaml or .tavern.yml as tavern test files """ diff --git a/tavern/_core/pytest/item.py b/tavern/_core/pytest/item.py index 2c073e8a..18106c85 100644 --- a/tavern/_core/pytest/item.py +++ b/tavern/_core/pytest/item.py @@ -1,11 +1,11 @@ import logging import pathlib -from typing import MutableMapping, Optional, Tuple +from typing import MutableMapping, Optional, Tuple, Union import attr import pytest import yaml -from _pytest._code.code import ExceptionInfo +from _pytest._code.code import ExceptionInfo, TerminalRepr from _pytest.nodes import Node from tavern._core import exceptions @@ -20,7 +20,7 @@ from .config import TestConfig from .util import load_global_cfg -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class YamlItem(pytest.Item): @@ -253,7 +253,7 @@ def runtest(self) -> None: def repr_failure( self, excinfo: ExceptionInfo[BaseException], style: Optional[str] = None - ): + ) -> Union[TerminalRepr, str, ReprdError]: """called when self.runtest() raises an exception. By default, will raise a custom formatted traceback if it's a tavern error. if not, will use the default diff --git a/tavern/_core/pytest/newhooks.py b/tavern/_core/pytest/newhooks.py index 8fd494b2..8f994b07 100644 --- a/tavern/_core/pytest/newhooks.py +++ b/tavern/_core/pytest/newhooks.py @@ -1,9 +1,12 @@ import logging +from typing import Any, Dict, MutableMapping -logger = logging.getLogger(__name__) +from tavern._core.pytest.config import TestConfig +logger: logging.Logger = logging.getLogger(__name__) -def pytest_tavern_beta_before_every_test_run(test_dict, variables) -> None: + +def pytest_tavern_beta_before_every_test_run(test_dict: Dict, variables: Dict) -> None: """Called: - directly after fixtures are loaded for a test @@ -15,23 +18,23 @@ def pytest_tavern_beta_before_every_test_run(test_dict, variables) -> None: Modify the test in-place if you want to do something to it. Args: - test_dict (dict): Test to run - variables (dict): Available variables + test_dict: Test to run + variables: Available variables """ -def pytest_tavern_beta_after_every_test_run(test_dict, variables) -> None: +def pytest_tavern_beta_after_every_test_run(test_dict: Dict, variables: Dict) -> None: """Called: - After test run Args: - test_dict (dict): Test to run - variables (dict): Available variables + test_dict: Test to run + variables: Available variables """ -def pytest_tavern_beta_after_every_response(expected, response) -> None: +def pytest_tavern_beta_after_every_response(expected: Any, response: Dict) -> None: """Called after every _response_ - including MQTT/HTTP/etc Note: @@ -39,24 +42,24 @@ def pytest_tavern_beta_after_every_response(expected, response) -> None: - MQTT responses will call this hook multiple times if multiple messages are received Args: - response (object): Response object. - expected (dict): Response block in stage + response: Response object. + expected: Response block in stage """ -def pytest_tavern_beta_before_every_request(request_args) -> None: +def pytest_tavern_beta_before_every_request(request_args: MutableMapping) -> None: """Called just before every request - including MQTT/HTTP/etc Note: - The request object type depends on what plugin you're using, and which kind of request it is! Args: - request_args (dict): Arguments passed to the request function. By default, this is Session.request for + request_args: Arguments passed to the request function. By default, this is Session.request for HTTP and Client.publish for MQTT """ -def call_hook(test_block_config, hookname, **kwargs) -> None: +def call_hook(test_block_config: TestConfig, hookname: str, **kwargs) -> None: """Utility to call the hooks""" try: hook = getattr(test_block_config.tavern_internal.pytest_hook_caller, hookname) diff --git a/tavern/_core/pytest/util.py b/tavern/_core/pytest/util.py index 30343114..b0595ae9 100644 --- a/tavern/_core/pytest/util.py +++ b/tavern/_core/pytest/util.py @@ -1,6 +1,7 @@ import logging from functools import lru_cache -from typing import Any, Dict +from pathlib import Path +from typing import Any, Dict, Optional, TypeVar, Union import pytest @@ -9,7 +10,7 @@ from tavern._core.pytest.config import TavernInternalConfig, TestConfig from tavern._core.strict_util import StrictLevel -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def add_parser_options(parser_addoption, with_defaults: bool = True) -> None: @@ -199,7 +200,14 @@ def _load_global_follow_redirects(pytest_config: pytest.Config): return get_option_generic(pytest_config, "tavern-always-follow-redirects", False) -def get_option_generic(pytest_config: pytest.Config, flag: str, default): +T = TypeVar("T", bound=Optional[Union[str, list, list[Path], list[str], bool]]) + + +def get_option_generic( + pytest_config: pytest.Config, + flag: str, + default: T, +) -> T: """Get a configuration option or return the default Priority order is cmdline, then ini, then default""" diff --git a/tavern/_core/report.py b/tavern/_core/report.py index 432d9abe..0962029b 100644 --- a/tavern/_core/report.py +++ b/tavern/_core/report.py @@ -1,5 +1,6 @@ import logging from textwrap import dedent +from typing import Dict, Iterable, TypeVar, Union import yaml @@ -24,10 +25,13 @@ def call(step_func): from tavern._core.formatted_str import FormattedString from tavern._core.stage_lines import get_stage_lines, read_relevant_lines -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def prepare_yaml(val): +T = TypeVar("T", bound=Union[Dict, Iterable, str]) + + +def prepare_yaml(val: T) -> T: """Sanitises the formatted string into a format safe for dumping""" formatted = val @@ -54,7 +58,7 @@ def attach_stage_content(stage) -> None: attach_text(joined, "stage_yaml", yaml_type) -def attach_yaml(payload, name): +def attach_yaml(payload, name) -> None: prepared = prepare_yaml(payload) dumped = yaml.safe_dump(prepared) return attach_text(dumped, name, yaml_type) diff --git a/tavern/_core/run.py b/tavern/_core/run.py index be6bb984..83229bbd 100644 --- a/tavern/_core/run.py +++ b/tavern/_core/run.py @@ -27,7 +27,7 @@ from .testhelpers import delay, retry from .tincture import Tinctures, get_stage_tinctures -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def _resolve_test_stages(test_spec: Mapping, available_stages: Mapping): @@ -276,7 +276,7 @@ class _TestRunner: test_block_config: TestConfig test_spec: Mapping - def run_stage(self, idx: int, stage, *, is_final: bool = False): + def run_stage(self, idx: int, stage, *, is_final: bool = False) -> None: tinctures = get_stage_tinctures(stage, self.test_spec) stage_config = self.test_block_config.with_strictness( @@ -305,7 +305,7 @@ def run_stage(self, idx: int, stage, *, is_final: bool = False): def wrapped_run_stage( self, stage: dict, stage_config: TestConfig, tinctures: Tinctures - ): + ) -> None: """Run one stage from the test Args: diff --git a/tavern/_core/schema/extensions.py b/tavern/_core/schema/extensions.py index 6d178e53..62da2088 100644 --- a/tavern/_core/schema/extensions.py +++ b/tavern/_core/schema/extensions.py @@ -132,7 +132,9 @@ def check_usefixtures(value, rule_obj, path) -> bool: return True -def validate_grpc_status_is_valid_or_list_of_names(value: "GRPCCode", rule_obj, path): +def validate_grpc_status_is_valid_or_list_of_names( + value: "GRPCCode", rule_obj, path +) -> bool: """Validate GRPC statuses https://github.com/grpc/grpc/blob/master/doc/statuscodes.md""" # pylint: disable=unused-argument err_msg = ( diff --git a/tavern/_core/schema/files.py b/tavern/_core/schema/files.py index 8f801d6d..1853de3b 100644 --- a/tavern/_core/schema/files.py +++ b/tavern/_core/schema/files.py @@ -14,7 +14,7 @@ from tavern._core.plugins import load_plugins from tavern._core.schema.jsonschema import verify_jsonschema -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class SchemaCache: diff --git a/tavern/_core/schema/jsonschema.py b/tavern/_core/schema/jsonschema.py index 022f8af8..c439a732 100644 --- a/tavern/_core/schema/jsonschema.py +++ b/tavern/_core/schema/jsonschema.py @@ -35,7 +35,7 @@ read_relevant_lines, ) -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def is_str_or_bytes_or_token(checker, instance): diff --git a/tavern/_core/stage_lines.py b/tavern/_core/stage_lines.py index 0f943a54..31359438 100644 --- a/tavern/_core/stage_lines.py +++ b/tavern/_core/stage_lines.py @@ -1,9 +1,20 @@ import logging +from typing import Optional, Protocol -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def get_stage_lines(stage): +class EmptyBlock: + line: int = 0 + name: Optional[str] = None + + +class _WithMarks(Protocol): + start_mark: EmptyBlock + end_mark: EmptyBlock + + +def get_stage_lines(stage: _WithMarks): first_line = start_mark(stage).line - 1 last_line = end_mark(stage).line line_start = first_line + 1 @@ -11,7 +22,7 @@ def get_stage_lines(stage): return first_line, last_line, line_start -def read_relevant_lines(yaml_block, first_line, last_line): +def read_relevant_lines(yaml_block: _WithMarks, first_line: int, last_line: int): """Get lines between start and end mark""" filename = get_stage_filename(yaml_block) @@ -26,23 +37,18 @@ def read_relevant_lines(yaml_block, first_line, last_line): yield line.split("#", 1)[0].rstrip() -def get_stage_filename(yaml_block): +def get_stage_filename(yaml_block: _WithMarks) -> str: return start_mark(yaml_block).name -class EmptyBlock: - line = 0 - name = None - - -def start_mark(yaml_block): +def start_mark(yaml_block: _WithMarks): try: return yaml_block.start_mark except AttributeError: return EmptyBlock -def end_mark(yaml_block): +def end_mark(yaml_block: _WithMarks): try: return yaml_block.end_mark except AttributeError: diff --git a/tavern/_core/strict_util.py b/tavern/_core/strict_util.py index 76d090d8..bc7c63ee 100644 --- a/tavern/_core/strict_util.py +++ b/tavern/_core/strict_util.py @@ -7,7 +7,7 @@ from tavern._core import exceptions from tavern._core.strtobool import strtobool -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class StrictSetting(enum.Enum): diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index 2c72b49f..062ae4e5 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -1,13 +1,13 @@ import logging import time from functools import wraps -from typing import Mapping +from typing import Callable, Mapping from tavern._core import exceptions from tavern._core.dict_util import format_keys from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def delay(stage: Mapping, when: str, variables: Mapping) -> None: @@ -28,7 +28,7 @@ def delay(stage: Mapping, when: str, variables: Mapping) -> None: time.sleep(length) -def retry(stage: Mapping, test_block_config: TestConfig): +def retry(stage: Mapping, test_block_config: TestConfig) -> Callable: """Look for retry and try to repeat the stage `retry` times. Args: @@ -101,7 +101,7 @@ def wrapped(*args, **kwargs): return retry_wrapper -def maybe_format_max_retries(max_retries, test_block_config: TestConfig) -> int: +def maybe_format_max_retries(max_retries: int, test_block_config: TestConfig) -> int: """Possibly handle max_retries validation""" # Probably a format variable, or just invalid (in which case it will fail further down) diff --git a/tavern/_core/tincture.py b/tavern/_core/tincture.py index f1fe746f..e99fe2eb 100644 --- a/tavern/_core/tincture.py +++ b/tavern/_core/tincture.py @@ -6,15 +6,15 @@ from tavern._core import exceptions from tavern._core.extfunctions import get_wrapped_response_function -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class Tinctures: - def __init__(self, tinctures: List[Any]): + def __init__(self, tinctures: List[Any]) -> None: self._tinctures = tinctures self._needs_response: List[Any] = [] - def start_tinctures(self, stage: collections.abc.Mapping): + def start_tinctures(self, stage: collections.abc.Mapping) -> None: results = [t(stage) for t in self._tinctures] self._needs_response = [] diff --git a/tavern/_plugins/grpc/client.py b/tavern/_plugins/grpc/client.py index 04d7b428..85b433d3 100644 --- a/tavern/_plugins/grpc/client.py +++ b/tavern/_plugins/grpc/client.py @@ -22,7 +22,7 @@ from tavern._core.dict_util import check_expected_keys from tavern._plugins.grpc.protos import _generate_proto_import, _import_grpc_module -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -39,7 +39,7 @@ class _ChannelVals: class GRPCClient: - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: logger.debug("Initialising GRPC client with %s", kwargs) expected_blocks = { "connect": {"host", "port", "options", "timeout", "secure"}, @@ -99,7 +99,7 @@ def __init__(self, **kwargs): def _register_file_descriptor( self, service_proto: grpc_reflection.v1alpha.reflection_pb2.FileDescriptorResponse, - ): + ) -> None: for file_descriptor_proto in service_proto.file_descriptor_proto: descriptor = descriptor_pb2.FileDescriptorProto() descriptor.ParseFromString(file_descriptor_proto) @@ -107,7 +107,7 @@ def _register_file_descriptor( def _get_reflection_info( self, channel, service_name: Optional[str] = None, file_by_filename=None - ): + ) -> None: logger.debug( "Getting GRPC protobuf for service %s from reflection", service_name ) @@ -239,7 +239,7 @@ def _make_call_request( return self._get_grpc_service(channel, service, method) - def __enter__(self): + def __enter__(self) -> None: logger.debug("Connecting to GRPC") def call( @@ -282,7 +282,7 @@ def call( request, metadata=self._metadata, timeout=timeout ) - def __exit__(self, *args): + def __exit__(self, *args) -> None: logger.debug("Disconnecting from GRPC") for v in self.channels.values(): v.close() diff --git a/tavern/_plugins/grpc/protos.py b/tavern/_plugins/grpc/protos.py index b70aa238..e34503b2 100644 --- a/tavern/_plugins/grpc/protos.py +++ b/tavern/_plugins/grpc/protos.py @@ -13,7 +13,7 @@ from tavern._core import exceptions -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @functools.lru_cache @@ -31,7 +31,7 @@ def find_protoc() -> str: @functools.lru_cache -def _generate_proto_import(source: str): +def _generate_proto_import(source: str) -> None: """Invokes the Protocol Compiler to generate a _pb2.py from the given .proto file. Does nothing if the output already exists and is newer than the input. @@ -101,7 +101,7 @@ def sanitise(s): _import_grpc_module(output) -def _import_grpc_module(python_module_name: str): +def _import_grpc_module(python_module_name: str) -> None: """takes an expected python module name and tries to import the relevant file, adding service to the symbol database. """ diff --git a/tavern/_plugins/grpc/request.py b/tavern/_plugins/grpc/request.py index 6fe311ba..727e565a 100644 --- a/tavern/_plugins/grpc/request.py +++ b/tavern/_plugins/grpc/request.py @@ -3,7 +3,7 @@ import json import logging import warnings -from typing import Mapping, Union +from typing import Dict, Mapping, Union import grpc from box import Box @@ -14,10 +14,10 @@ from tavern._plugins.grpc.client import GRPCClient from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def get_grpc_args(rspec, test_block_config): +def get_grpc_args(rspec: Dict, test_block_config: TestConfig) -> Dict: """Format GRPC request args""" fspec = format_keys(rspec, test_block_config.variables) @@ -51,7 +51,7 @@ class GRPCRequest(BaseRequest): def __init__( self, client: GRPCClient, request_spec: Mapping, test_block_config: TestConfig - ): + ) -> None: if not self._warned: warnings.warn( "Tavern gRPC support is experimental and will be updated in a future release.", @@ -87,5 +87,5 @@ def run(self) -> WrappedFuture: raise exceptions.GRPCRequestException from e @property - def request_vars(self): + def request_vars(self) -> Box: return Box(self._original_request_vars) diff --git a/tavern/_plugins/grpc/response.py b/tavern/_plugins/grpc/response.py index 8fbc5291..436c3bd2 100644 --- a/tavern/_plugins/grpc/response.py +++ b/tavern/_plugins/grpc/response.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from tavern._plugins.grpc.request import WrappedFuture -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) GRPCCode = Union[str, int, List[str], List[int]] @@ -48,7 +48,7 @@ def __init__( name: str, expected: Union[_GRPCExpected, Mapping], test_block_config: TestConfig, - ): + ) -> None: check_expected_keys({"body", "status", "details"}, expected) super().__init__(name, expected, test_block_config) @@ -60,7 +60,7 @@ def __str__(self): else: return "" - def _validate_block(self, blockname: str, block: Mapping): + def _validate_block(self, blockname: str, block: Mapping) -> None: """Validate a block of the response Args: diff --git a/tavern/_plugins/grpc/tavernhook.py b/tavern/_plugins/grpc/tavernhook.py index 56a4f3ff..29ec38df 100644 --- a/tavern/_plugins/grpc/tavernhook.py +++ b/tavern/_plugins/grpc/tavernhook.py @@ -10,7 +10,7 @@ from .request import GRPCRequest from .response import GRPCResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) session_type = GRPCClient @@ -19,7 +19,9 @@ request_block_name = "grpc_request" -def get_expected_from_request(response_block, test_block_config: TestConfig, session): +def get_expected_from_request( + response_block, test_block_config: TestConfig, session: GRPCClient +): f_expected = format_keys(response_block, test_block_config.variables) expected = f_expected @@ -29,6 +31,6 @@ def get_expected_from_request(response_block, test_block_config: TestConfig, ses verifier_type = GRPCResponse response_block_name = "grpc_response" -schema_path = join(abspath(dirname(__file__)), "jsonschema.yaml") +schema_path: str = join(abspath(dirname(__file__)), "jsonschema.yaml") with open(schema_path) as schema_file: schema = yaml.load(schema_file, Loader=yaml.SafeLoader) diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index 06095728..23583515 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Mapping, MutableMapping, Optional import paho.mqtt.client as paho +from paho.mqtt.client import MQTTMessageInfo from tavern._core import exceptions from tavern._core.dict_util import check_expected_keys @@ -33,7 +34,7 @@ 15: "MQTT_ERR_QUEUE_SIZE", } -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass @@ -88,7 +89,7 @@ def _handle_ssl_context_args( def _check_and_update_common_tls_args( tls_args: MutableMapping, check_file_keys: List[str] -): +) -> None: """Checks common args between ssl/tls args""" # could be moved to schema validation stage @@ -349,7 +350,9 @@ def _on_socket_open(client, userdata, socket) -> None: def _on_socket_close(client, userdata, socket) -> None: logger.debug("MQTT socket closed") - def message_received(self, topic: str, timeout: int = 1): + def message_received( + self, topic: str, timeout: int = 1 + ) -> Optional[paho.MQTTMessage]: """Check that a message is in the message queue Args: @@ -358,7 +361,7 @@ def message_received(self, topic: str, timeout: int = 1): was not received. Returns: - paho.MQTTMessage: whether the message was received within the timeout + the message, if one was received, otherwise None Todo: Allow regexes for topic names? Better validation for mqtt payloads @@ -384,7 +387,7 @@ def publish( payload: Any, qos: Optional[int], retain: Optional[bool] = False, - ): + ) -> MQTTMessageInfo: """publish message using paho library""" self._wait_for_subscriptions() diff --git a/tavern/_plugins/mqtt/request.py b/tavern/_plugins/mqtt/request.py index 0a9de87a..6f1705fd 100644 --- a/tavern/_plugins/mqtt/request.py +++ b/tavern/_plugins/mqtt/request.py @@ -13,7 +13,7 @@ from tavern._plugins.mqtt.client import MQTTClient from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def get_publish_args(rspec: Dict, test_block_config: TestConfig) -> dict: diff --git a/tavern/_plugins/mqtt/response.py b/tavern/_plugins/mqtt/response.py index 020d0945..fe6e94f0 100644 --- a/tavern/_plugins/mqtt/response.py +++ b/tavern/_plugins/mqtt/response.py @@ -14,22 +14,26 @@ from tavern._core import exceptions from tavern._core.dict_util import check_keys_match_recursive from tavern._core.loader import ANYTHING +from tavern._core.pytest.config import TestConfig from tavern._core.pytest.newhooks import call_hook from tavern._core.report import attach_yaml from tavern._core.strict_util import StrictSetting from tavern.response import BaseResponse -from ..._core.pytest.config import TestConfig from .client import MQTTClient -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) _default_timeout = 1 class MQTTResponse(BaseResponse): def __init__( - self, client: MQTTClient, name: str, expected, test_block_config: TestConfig + self, + client: MQTTClient, + name: str, + expected: TestConfig, + test_block_config: TestConfig, ) -> None: super().__init__(name, expected, test_block_config) @@ -326,7 +330,7 @@ def _get_payload_vals(expected: Mapping) -> Tuple[Optional[Union[str, dict]], bo """Gets the payload from the 'expected' block Returns: - tuple: First element is the expected payload, second element is whether it's + First element is the expected payload, second element is whether it's expected to be json or not """ # TODO move this check to initialisation/schema checking diff --git a/tavern/_plugins/mqtt/tavernhook.py b/tavern/_plugins/mqtt/tavernhook.py index 9ff57245..385f573a 100644 --- a/tavern/_plugins/mqtt/tavernhook.py +++ b/tavern/_plugins/mqtt/tavernhook.py @@ -1,16 +1,17 @@ import logging from os.path import abspath, dirname, join -from typing import Dict, Optional +from typing import Dict, Iterable, Optional, Union import yaml from tavern._core.dict_util import format_keys +from tavern._core.pytest.config import TestConfig from .client import MQTTClient from .request import MQTTRequest from .response import MQTTResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) session_type = MQTTClient @@ -18,7 +19,11 @@ request_block_name = "mqtt_publish" -def get_expected_from_request(response_block, test_block_config, session): +def get_expected_from_request( + response_block: Union[Dict, Iterable[Dict]], + test_block_config: TestConfig, + session: MQTTClient, +): expected: Optional[Dict] = None # mqtt response is not required @@ -40,6 +45,6 @@ def get_expected_from_request(response_block, test_block_config, session): verifier_type = MQTTResponse response_block_name = "mqtt_response" -schema_path = join(abspath(dirname(__file__)), "jsonschema.yaml") +schema_path: str = join(abspath(dirname(__file__)), "jsonschema.yaml") with open(schema_path, encoding="utf-8") as schema_file: schema = yaml.load(schema_file, Loader=yaml.SafeLoader) diff --git a/tavern/_plugins/rest/files.py b/tavern/_plugins/rest/files.py index 3f72db8a..8aba5b9e 100644 --- a/tavern/_plugins/rest/files.py +++ b/tavern/_plugins/rest/files.py @@ -9,7 +9,7 @@ from tavern._core.dict_util import format_keys from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass diff --git a/tavern/_plugins/rest/request.py b/tavern/_plugins/rest/request.py index 9504e1a0..cc127b0e 100644 --- a/tavern/_plugins/rest/request.py +++ b/tavern/_plugins/rest/request.py @@ -21,7 +21,7 @@ from tavern._plugins.rest.files import get_file_arguments, guess_filespec from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def get_request_args(rspec: MutableMapping, test_block_config: TestConfig) -> dict: diff --git a/tavern/_plugins/rest/response.py b/tavern/_plugins/rest/response.py index 83b1062d..0466e645 100644 --- a/tavern/_plugins/rest/response.py +++ b/tavern/_plugins/rest/response.py @@ -9,15 +9,18 @@ from tavern._core import exceptions from tavern._core.dict_util import deep_dict_merge +from tavern._core.pytest.config import TestConfig from tavern._core.pytest.newhooks import call_hook from tavern._core.report import attach_yaml from tavern.response import BaseResponse, indent_err_text -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class RestResponse(BaseResponse): - def __init__(self, session, name: str, expected, test_block_config) -> None: + def __init__( + self, session, name: str, expected, test_block_config: TestConfig + ) -> None: defaults = {"status_code": 200} super().__init__(name, deep_dict_merge(defaults, expected), test_block_config) diff --git a/tavern/_plugins/rest/tavernhook.py b/tavern/_plugins/rest/tavernhook.py index 208e32d3..3351559a 100644 --- a/tavern/_plugins/rest/tavernhook.py +++ b/tavern/_plugins/rest/tavernhook.py @@ -1,4 +1,5 @@ import logging +from typing import Mapping import requests @@ -6,10 +7,11 @@ from tavern._core.dict_util import format_keys from tavern._core.plugins import PluginHelperBase +from ..._core.pytest.config import TestConfig from .request import RestRequest from .response import RestResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class TavernRestPlugin(PluginHelperBase): @@ -19,7 +21,9 @@ class TavernRestPlugin(PluginHelperBase): request_block_name = "request" @staticmethod - def get_expected_from_request(response_block, test_block_config, session): + def get_expected_from_request( + response_block: Mapping, test_block_config: TestConfig, session + ): if response_block is None: raise exceptions.MissingSettingsError( "no response block specified for HTTP test stage" diff --git a/tavern/core.py b/tavern/core.py index da20f273..dcf28251 100644 --- a/tavern/core.py +++ b/tavern/core.py @@ -3,6 +3,7 @@ from typing import Union import pytest +from _pytest.config import ExitCode from tavern._core import exceptions from tavern._core.schema.files import wrapfile @@ -56,7 +57,7 @@ def run( tavern_grpc_backend=None, tavern_strict=None, pytest_args=None, -): +) -> Union[ExitCode, int]: """Run all tests contained in a file using pytest.main() Args: diff --git a/tavern/helpers.py b/tavern/helpers.py index b4204b65..aabe2c7a 100644 --- a/tavern/helpers.py +++ b/tavern/helpers.py @@ -2,7 +2,7 @@ import json import logging import re -from typing import Dict, List, Optional +from typing import Dict, Iterable, Mapping, Optional import jmespath import jwt @@ -14,7 +14,7 @@ from tavern._core.jmesutils import actual_validation, validate_comparison from tavern._core.schema.files import verify_pykwalify -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def check_exception_raised( @@ -68,7 +68,9 @@ def check_exception_raised( ) from e -def validate_jwt(response, jwt_key, **kwargs) -> Dict[str, Box]: +def validate_jwt( + response: requests.Response, jwt_key: str, **kwargs +) -> Mapping[str, Box]: """Make sure a jwt is valid This uses the pyjwt library to decode the jwt, so any keyword args needed @@ -80,12 +82,12 @@ def validate_jwt(response, jwt_key, **kwargs) -> Dict[str, Box]: it wraps this in a Box so it can also be used for future formatting Args: - response (Response): requests.Response object - jwt_key (str): key of jwt in body of request + response: requests.Response object + jwt_key: key of jwt in body of request **kwargs: Any extra arguments to pass to jwt.decode Returns: - dict: dictionary of jwt: boxed jwt claims + mapping of jwt: boxed jwt claims """ token = response.json()[jwt_key] @@ -96,12 +98,12 @@ def validate_jwt(response, jwt_key, **kwargs) -> Dict[str, Box]: return {"jwt": Box(decoded)} -def validate_pykwalify(response, schema) -> None: +def validate_pykwalify(response: requests.Response, schema: Dict) -> None: """Make sure the response matches a given schema Args: - response (Response): reqeusts.Response object - schema (dict): Schema for response + response: reqeusts Response object + schema: Schema for response """ try: to_verify = response.json() @@ -171,7 +173,7 @@ def validate_regex( return {"regex": Box(match.groupdict())} -def validate_content(response: requests.Response, comparisons: List[str]) -> None: +def validate_content(response: requests.Response, comparisons: Iterable[Dict]) -> None: """Asserts expected value with actual value using JMES path expression Args: diff --git a/tavern/request.py b/tavern/request.py index a049c8eb..abfa7ded 100644 --- a/tavern/request.py +++ b/tavern/request.py @@ -6,7 +6,7 @@ from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class BaseRequest: diff --git a/tavern/response.py b/tavern/response.py index f1911570..4653ee45 100644 --- a/tavern/response.py +++ b/tavern/response.py @@ -12,7 +12,7 @@ from tavern._core.pytest.config import TestConfig from tavern._core.strict_util import StrictOption -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def indent_err_text(err: str) -> str: @@ -31,7 +31,7 @@ class BaseResponse: response: Optional[Any] = None errors: List[str] = dataclasses.field(init=False, default_factory=list) - def __post_init__(self): + def __post_init__(self) -> None: self._check_for_validate_functions(self.expected) def _str_errors(self) -> str: From e4f2dbd6130ceb67692b31f96b467d201634b189 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 17:42:38 +0000 Subject: [PATCH 03/20] fix some issues --- tavern/_core/dict_util.py | 4 +++- tavern/_core/pytest/file.py | 5 ++++- tavern/_core/report.py | 17 ++++++----------- tavern/_core/testhelpers.py | 2 +- tavern/_plugins/grpc/request.py | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index 0dd8dc76..46984a7a 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -134,8 +134,10 @@ def format_keys( # formatted = {key: format_keys(val[key], box_vars) for key in val} for key in val: formatted[key] = format_keys_(val[key], box_vars) + + return formatted elif isinstance(val, (list, tuple)): - formatted = [format_keys_(item, box_vars) for item in val] # type: ignore + return [format_keys_(item, box_vars) for item in val] # type: ignore elif isinstance(formatted, FormattedString): logger.debug("Already formatted %s, not double-formatting", formatted) elif isinstance(val, str): diff --git a/tavern/_core/pytest/file.py b/tavern/_core/pytest/file.py index 7fbaf509..45fa426b 100644 --- a/tavern/_core/pytest/file.py +++ b/tavern/_core/pytest/file.py @@ -2,6 +2,7 @@ import functools import itertools import logging +import typing from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Tuple, Union import pytest @@ -20,7 +21,9 @@ logger: logging.Logger = logging.getLogger(__name__) -_format_without_inner: Callable[[Any, Mapping], Any] = functools.partial( +T = typing.TypeVar("T") + +_format_without_inner: Callable[[T, Mapping], T] = functools.partial( format_keys, no_double_format=False ) diff --git a/tavern/_core/report.py b/tavern/_core/report.py index 0962029b..17991eb6 100644 --- a/tavern/_core/report.py +++ b/tavern/_core/report.py @@ -1,6 +1,6 @@ import logging from textwrap import dedent -from typing import Dict, Iterable, TypeVar, Union +from typing import Dict, List, Set, Tuple, Union import yaml @@ -28,13 +28,8 @@ def call(step_func): logger: logging.Logger = logging.getLogger(__name__) -T = TypeVar("T", bound=Union[Dict, Iterable, str]) - - -def prepare_yaml(val: T) -> T: +def prepare_yaml(val: Union[Dict, Set, List, Tuple, str]) -> Union[Dict, List, str]: """Sanitises the formatted string into a format safe for dumping""" - formatted = val - if isinstance(val, dict): formatted = {} # formatted = {key: format_keys(val[key], box_vars) for key in val} @@ -43,11 +38,11 @@ def prepare_yaml(val: T) -> T: key = str(key) formatted[key] = prepare_yaml(val[key]) elif isinstance(val, (list, tuple, set)): - formatted = [prepare_yaml(item) for item in val] - elif isinstance(formatted, FormattedString): - return str(formatted) + return [prepare_yaml(item) for item in val] + elif isinstance(val, FormattedString): + return str(val) - return formatted + return val def attach_stage_content(stage) -> None: diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index 062ae4e5..c90c467d 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -38,7 +38,7 @@ def retry(stage: Mapping, test_block_config: TestConfig) -> Callable: if "max_retries" in stage: max_retries = maybe_format_max_retries( - stage.get("max_retries"), test_block_config + int(stage.get("max_retries")), test_block_config ) else: max_retries = 0 diff --git a/tavern/_plugins/grpc/request.py b/tavern/_plugins/grpc/request.py index 727e565a..706f0bb5 100644 --- a/tavern/_plugins/grpc/request.py +++ b/tavern/_plugins/grpc/request.py @@ -3,7 +3,7 @@ import json import logging import warnings -from typing import Dict, Mapping, Union +from typing import Dict, Union import grpc from box import Box @@ -50,7 +50,7 @@ class GRPCRequest(BaseRequest): _warned = False def __init__( - self, client: GRPCClient, request_spec: Mapping, test_block_config: TestConfig + self, client: GRPCClient, request_spec: Dict, test_block_config: TestConfig ) -> None: if not self._warned: warnings.warn( From 95683f3369858e7fc47970ba918ce4e62320b66a Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 17:50:55 +0000 Subject: [PATCH 04/20] fix some issues --- tavern/_core/loader.py | 4 ++-- tavern/_core/pytest/util.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tavern/_core/loader.py b/tavern/_core/loader.py index 2a0e53b5..fdb23eb6 100644 --- a/tavern/_core/loader.py +++ b/tavern/_core/loader.py @@ -6,7 +6,7 @@ import uuid from abc import abstractmethod from itertools import chain -from typing import List, Optional +from typing import List, Optional, Union import pytest import yaml @@ -432,7 +432,7 @@ def to_yaml(cls, dumper, data) -> ScalarNode: yaml.dumper.Dumper.add_representer(ApproxScalar, ApproxSentinel.to_yaml) -def load_single_document_yaml(filename: os.PathLike) -> dict: +def load_single_document_yaml(filename: Union[str, os.PathLike]) -> dict: """ Load a yaml file and expect only one document diff --git a/tavern/_core/pytest/util.py b/tavern/_core/pytest/util.py index b0595ae9..6fe6e2e0 100644 --- a/tavern/_core/pytest/util.py +++ b/tavern/_core/pytest/util.py @@ -1,7 +1,7 @@ import logging from functools import lru_cache from pathlib import Path -from typing import Any, Dict, Optional, TypeVar, Union +from typing import Any, Dict, List, Optional, TypeVar, Union import pytest @@ -200,7 +200,7 @@ def _load_global_follow_redirects(pytest_config: pytest.Config): return get_option_generic(pytest_config, "tavern-always-follow-redirects", False) -T = TypeVar("T", bound=Optional[Union[str, list, list[Path], list[str], bool]]) +T = TypeVar("T", bound=Optional[Union[str, List, List[Path], List[str], bool]]) def get_option_generic( From 87d8ce1fc66bcd811f6f4947f19714896f7e065d Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 17:54:38 +0000 Subject: [PATCH 05/20] annotation --- tavern/_plugins/mqtt/response.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tavern/_plugins/mqtt/response.py b/tavern/_plugins/mqtt/response.py index fe6e94f0..184e217a 100644 --- a/tavern/_plugins/mqtt/response.py +++ b/tavern/_plugins/mqtt/response.py @@ -110,7 +110,7 @@ def _await_response(self) -> Mapping: failures=self.errors, ) - saved = {} + saved: Dict = {} for msg in correct_messages: # Check saving things from the payload and from json From 2d1e49c7d03adff6f10a93558976b3191234562b Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 19:19:05 +0000 Subject: [PATCH 06/20] More cleanup --- scripts/smoke.bash | 4 ++-- tavern/_core/dict_util.py | 15 +++++---------- tavern/_core/extfunctions.py | 2 +- tavern/_core/pytest/file.py | 10 +++++----- tavern/_core/pytest/util.py | 4 ++-- tavern/_core/report.py | 2 +- tavern/_core/stage_lines.py | 17 +++++++++++------ tavern/_core/testhelpers.py | 8 +++++--- tavern/_plugins/rest/request.py | 4 ++-- tavern/response.py | 2 +- 10 files changed, 35 insertions(+), 33 deletions(-) diff --git a/scripts/smoke.bash b/scripts/smoke.bash index aa9762f6..bcbd5b8e 100755 --- a/scripts/smoke.bash +++ b/scripts/smoke.bash @@ -6,10 +6,10 @@ pre-commit run ruff --all-files || true pre-commit run ruff-format --all-files || true tox --parallel -c tox.ini \ - -e py3check + -e py3mypy tox --parallel -c tox.ini \ - -e py3mypy + -e py3check tox --parallel -c tox.ini \ -e py3 diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index 46984a7a..b17cffa4 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -5,7 +5,7 @@ import re import string import typing -from typing import Any, Dict, List, Mapping, Union +from typing import Any, Dict, List, Mapping, Tuple, Union import box import jmespath @@ -93,7 +93,7 @@ def _attempt_find_include(to_format: str, box_vars: box.Box): return formatter.convert_field(would_replace, conversion) # type: ignore -T = typing.TypeVar("T") +T = typing.TypeVar("T", str, Dict, List, Tuple) def format_keys( @@ -106,7 +106,7 @@ def format_keys( """recursively format a dictionary with the given values Args: - val: Input dictionary to format + val: Input thing to format variables: Dictionary of keys to format it with no_double_format: Whether to use the 'inner formatted string' class to avoid double formatting This is required if passing something via pytest-xdist, such as markers: @@ -130,14 +130,9 @@ def format_keys( box_vars = variables if isinstance(val, dict): - formatted = {} - # formatted = {key: format_keys(val[key], box_vars) for key in val} - for key in val: - formatted[key] = format_keys_(val[key], box_vars) - - return formatted + return {key: format_keys_(val[key], box_vars) for key in formatted} elif isinstance(val, (list, tuple)): - return [format_keys_(item, box_vars) for item in val] # type: ignore + return [format_keys_(item, box_vars) for item in val] elif isinstance(formatted, FormattedString): logger.debug("Already formatted %s, not double-formatting", formatted) elif isinstance(val, str): diff --git a/tavern/_core/extfunctions.py b/tavern/_core/extfunctions.py index 3018931a..c81404f9 100644 --- a/tavern/_core/extfunctions.py +++ b/tavern/_core/extfunctions.py @@ -93,7 +93,7 @@ def get_wrapped_response_function(ext: Mapping) -> Callable: extra_kwargs to pass Returns: - function: Wrapped function + Wrapped function """ func, args, kwargs = _get_ext_values(ext) diff --git a/tavern/_core/pytest/file.py b/tavern/_core/pytest/file.py index 45fa426b..9ec5c23a 100644 --- a/tavern/_core/pytest/file.py +++ b/tavern/_core/pytest/file.py @@ -60,8 +60,8 @@ def _format_test_marks( """ - pytest_marks = [] - formatted_marks = [] + pytest_marks: List[Mark] = [] + formatted_marks: List[Mapping] = [] for m in original_marks: if isinstance(m, str): @@ -107,7 +107,7 @@ def _generate_parametrized_test_items( tuple of the variables for the stage and the generated stage name """ flattened_values = [] - variables = {} + variables: Dict[str, Any] = {} # combination of keys and the values they correspond to for pair in zip(keys, vals_combination): @@ -116,7 +116,7 @@ def _generate_parametrized_test_items( # very weird looking if isinstance(key, str): variables[key] = value - flattened_values += [value] + flattened_values.append(value) else: if not isinstance(value, (list, tuple)): value = [value] @@ -130,7 +130,7 @@ def _generate_parametrized_test_items( for subkey, subvalue in zip(key, value): variables[subkey] = subvalue - flattened_values += [subvalue] + flattened_values.append(subvalue) def maybe_load_ext(v): key, value = v diff --git a/tavern/_core/pytest/util.py b/tavern/_core/pytest/util.py index 6fe6e2e0..d8f5cc3a 100644 --- a/tavern/_core/pytest/util.py +++ b/tavern/_core/pytest/util.py @@ -152,11 +152,11 @@ def _load_global_cfg(pytest_config: pytest.Config) -> TestConfig: all_paths = ini_global_cfg_paths + cmdline_global_cfg_paths global_cfg_dict = load_global_config(all_paths) + variables: Dict = {} try: loaded_variables = global_cfg_dict["variables"] except KeyError: logger.debug("Nothing to format in global config files") - variables = {} else: tavern_box = get_tavern_box() variables = format_keys(loaded_variables, tavern_box) @@ -195,7 +195,7 @@ def _load_global_strictness(pytest_config: pytest.Config) -> StrictLevel: return StrictLevel.from_options(options) -def _load_global_follow_redirects(pytest_config: pytest.Config): +def _load_global_follow_redirects(pytest_config: pytest.Config) -> bool: """Load the global 'follow redirects' setting""" return get_option_generic(pytest_config, "tavern-always-follow-redirects", False) diff --git a/tavern/_core/report.py b/tavern/_core/report.py index 17991eb6..e7d9f244 100644 --- a/tavern/_core/report.py +++ b/tavern/_core/report.py @@ -45,7 +45,7 @@ def prepare_yaml(val: Union[Dict, Set, List, Tuple, str]) -> Union[Dict, List, s return val -def attach_stage_content(stage) -> None: +def attach_stage_content(stage: Dict) -> None: first_line, last_line, _ = get_stage_lines(stage) code_lines = list(read_relevant_lines(stage, first_line, last_line)) diff --git a/tavern/_core/stage_lines.py b/tavern/_core/stage_lines.py index 31359438..6a27cd57 100644 --- a/tavern/_core/stage_lines.py +++ b/tavern/_core/stage_lines.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Protocol +from typing import Dict, Optional, Protocol, Type, Union logger: logging.Logger = logging.getLogger(__name__) @@ -10,11 +10,16 @@ class EmptyBlock: class _WithMarks(Protocol): + """Things loaded by pyyaml have these""" + start_mark: EmptyBlock end_mark: EmptyBlock -def get_stage_lines(stage: _WithMarks): +PyYamlDict = Union[Dict, _WithMarks] + + +def get_stage_lines(stage: PyYamlDict): first_line = start_mark(stage).line - 1 last_line = end_mark(stage).line line_start = first_line + 1 @@ -22,7 +27,7 @@ def get_stage_lines(stage: _WithMarks): return first_line, last_line, line_start -def read_relevant_lines(yaml_block: _WithMarks, first_line: int, last_line: int): +def read_relevant_lines(yaml_block: PyYamlDict, first_line: int, last_line: int): """Get lines between start and end mark""" filename = get_stage_filename(yaml_block) @@ -37,18 +42,18 @@ def read_relevant_lines(yaml_block: _WithMarks, first_line: int, last_line: int) yield line.split("#", 1)[0].rstrip() -def get_stage_filename(yaml_block: _WithMarks) -> str: +def get_stage_filename(yaml_block: PyYamlDict) -> Optional[str]: return start_mark(yaml_block).name -def start_mark(yaml_block: _WithMarks): +def start_mark(yaml_block: PyYamlDict) -> Union[Type[EmptyBlock], EmptyBlock]: try: return yaml_block.start_mark except AttributeError: return EmptyBlock -def end_mark(yaml_block: _WithMarks): +def end_mark(yaml_block: PyYamlDict) -> Union[Type[EmptyBlock], EmptyBlock]: try: return yaml_block.end_mark except AttributeError: diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index c90c467d..a30a4d5f 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -1,7 +1,7 @@ import logging import time from functools import wraps -from typing import Callable, Mapping +from typing import Callable, Mapping, Union from tavern._core import exceptions from tavern._core.dict_util import format_keys @@ -101,11 +101,13 @@ def wrapped(*args, **kwargs): return retry_wrapper -def maybe_format_max_retries(max_retries: int, test_block_config: TestConfig) -> int: +def maybe_format_max_retries( + max_retries: Union[str, int], test_block_config: TestConfig +) -> int: """Possibly handle max_retries validation""" # Probably a format variable, or just invalid (in which case it will fail further down) - max_retries = format_keys(max_retries, test_block_config.variables) + max_retries = format_keys(str(max_retries), test_block_config.variables) # Missing type token will mean that max_retries is still a string and will fail here # Could auto convert here as well, but keep it consistent and just fail diff --git a/tavern/_plugins/rest/request.py b/tavern/_plugins/rest/request.py index cc127b0e..a3282220 100644 --- a/tavern/_plugins/rest/request.py +++ b/tavern/_plugins/rest/request.py @@ -4,7 +4,7 @@ import warnings from contextlib import ExitStack from itertools import filterfalse, tee -from typing import ClassVar, List, Mapping, MutableMapping, Optional +from typing import ClassVar, Dict, List, Mapping, Optional from urllib.parse import quote_plus import requests @@ -24,7 +24,7 @@ logger: logging.Logger = logging.getLogger(__name__) -def get_request_args(rspec: MutableMapping, test_block_config: TestConfig) -> dict: +def get_request_args(rspec: Dict, test_block_config: TestConfig) -> dict: """Format the test spec given values inthe global config Todo: diff --git a/tavern/response.py b/tavern/response.py index 4653ee45..ad14d246 100644 --- a/tavern/response.py +++ b/tavern/response.py @@ -24,8 +24,8 @@ def indent_err_text(err: str) -> str: @dataclasses.dataclass(repr=True) class BaseResponse: name: str - test_block_config: TestConfig expected: Any + test_block_config: TestConfig validate_functions: List[Any] = dataclasses.field(init=False, default_factory=list) response: Optional[Any] = None From c4ab4d75c8c199a7daa566696336990899413e18 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 19:31:54 +0000 Subject: [PATCH 07/20] More anotations --- tavern/_core/dict_util.py | 4 +- tavern/_core/loader.py | 8 ++-- tavern/_core/pytest/error.py | 8 ++-- tavern/_core/pytest/file.py | 72 +++++++++++++++--------------- tavern/_core/pytest/util.py | 14 +++--- tavern/_core/stage_lines.py | 6 +-- tavern/_core/testhelpers.py | 6 +-- tavern/_plugins/grpc/response.py | 4 +- tavern/_plugins/rest/response.py | 2 +- tavern/_plugins/rest/tavernhook.py | 4 +- tavern/response.py | 2 +- 11 files changed, 64 insertions(+), 66 deletions(-) diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index b17cffa4..ece4c3e1 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -131,7 +131,9 @@ def format_keys( if isinstance(val, dict): return {key: format_keys_(val[key], box_vars) for key in formatted} - elif isinstance(val, (list, tuple)): + elif isinstance(val, (tuple)): + return tuple(format_keys_(item, box_vars) for item in val) + elif isinstance(val, (list)): return [format_keys_(item, box_vars) for item in val] elif isinstance(formatted, FormattedString): logger.debug("Already formatted %s, not double-formatting", formatted) diff --git a/tavern/_core/loader.py b/tavern/_core/loader.py index fdb23eb6..0a583ab9 100644 --- a/tavern/_core/loader.py +++ b/tavern/_core/loader.py @@ -35,13 +35,13 @@ class RememberComposer(Composer): def compose_document(self) -> Optional[Node]: # Drop the DOCUMENT-START event. - self.get_event() + self.get_event() # type:ignore # Compose the root node. - node = self.compose_node(None, None) + node = self.compose_node(None, None) # type:ignore # Drop the DOCUMENT-END event. - self.get_event() + self.get_event() # type:ignore # If we don't drop the anchors here, then we can keep anchors across # documents. @@ -359,7 +359,7 @@ class FloatToken(TypeConvertToken): class StrToBoolConstructor: """Using `bool` as a constructor directly will evaluate all strings to `True`.""" - def __new__(cls, s: str) -> bool: + def __new__(cls, s: str) -> bool: # type:ignore return strtobool(s) diff --git a/tavern/_core/pytest/error.py b/tavern/_core/pytest/error.py index e7b43188..1275f610 100644 --- a/tavern/_core/pytest/error.py +++ b/tavern/_core/pytest/error.py @@ -2,7 +2,7 @@ import logging import re from io import StringIO -from typing import List, Mapping, Optional +from typing import Dict, List, Optional import yaml from _pytest._code.code import FormattedExcinfo, TerminalRepr @@ -137,7 +137,7 @@ def _print_test_stage( else: tw.line(line, white=True) - def _print_formatted_stage(self, tw: TerminalWriter, stage: Mapping) -> None: + def _print_formatted_stage(self, tw: TerminalWriter, stage: Dict) -> None: """Print the 'formatted' stage that Tavern will actually use to send the request/process the response @@ -155,10 +155,10 @@ def _print_formatted_stage(self, tw: TerminalWriter, stage: Mapping) -> None: ) # Replace formatted strings with strings for dumping - formatted_stage = prepare_yaml(formatted_stage) + prepared_stage = prepare_yaml(formatted_stage) # Dump formatted stage to YAML format - formatted_lines = yaml.dump(formatted_stage, default_flow_style=False).split( + formatted_lines = yaml.dump(prepared_stage, default_flow_style=False).split( "\n" ) diff --git a/tavern/_core/pytest/file.py b/tavern/_core/pytest/file.py index 9ec5c23a..17b8a05d 100644 --- a/tavern/_core/pytest/file.py +++ b/tavern/_core/pytest/file.py @@ -94,6 +94,40 @@ def _format_test_marks( return pytest_marks, formatted_marks +def _maybe_load_ext(pair): + """Try to load ext values""" + key, value = pair + + if is_ext_function(value): + # If it is an ext function, load the new (or supplemental) value[s] + ext = value.pop("$ext") + f = get_wrapped_create_function(ext) + new_value = f() + + if len(value) == 0: + # Use only this new value + return key, new_value + elif isinstance(new_value, dict): + # Merge with some existing data. At this point 'value' is known to be a dict. + return key, deep_dict_merge(value, f()) + else: + # For example, if it's defined like + # + # - testkey: testval + # $ext: + # function: mod:func + # + # and 'mod:func' returns a string, it's impossible to 'merge' with the existing data. + logger.error("Values still in 'val': %s", value) + raise exceptions.BadSchemaError( + "There were extra key/value pairs in the 'val' for this parametrize mark, but the ext function {} returned '{}' (of type {}) that was not a dictionary. It is impossible to merge these values.".format( + ext, new_value, type(new_value) + ) + ) + + return key, value + + def _generate_parametrized_test_items( keys: Iterable[Union[str, List, Tuple]], vals_combination: Iterable[Tuple[str, str]] ) -> Tuple[Mapping[str, Any], str]: @@ -106,7 +140,7 @@ def _generate_parametrized_test_items( Returns: tuple of the variables for the stage and the generated stage name """ - flattened_values = [] + flattened_values: List[Iterable[str]] = [] variables: Dict[str, Any] = {} # combination of keys and the values they correspond to @@ -124,7 +158,7 @@ def _generate_parametrized_test_items( if len(value) != len(key): raise exceptions.BadSchemaError( "Invalid match between numbers of keys and number of values in parametrize mark ({} keys, {} values)".format( - (key), (value) + key, value ) ) @@ -132,39 +166,7 @@ def _generate_parametrized_test_items( variables[subkey] = subvalue flattened_values.append(subvalue) - def maybe_load_ext(v): - key, value = v - - if is_ext_function(value): - # If it is an ext function, load the new (or supplemental) value[s] - ext = value.pop("$ext") - f = get_wrapped_create_function(ext) - new_value = f() - - if len(value) == 0: - # Use only this new value - return key, new_value - elif isinstance(new_value, dict): - # Merge with some existing data. At this point 'value' is known to be a dict. - return key, deep_dict_merge(value, f()) - else: - # For example, if it's defined like - # - # - testkey: testval - # $ext: - # function: mod:func - # - # and 'mod:func' returns a string, it's impossible to 'merge' with the existing data. - logger.error("Values still in 'val': %s", value) - raise exceptions.BadSchemaError( - "There were extra key/value pairs in the 'val' for this parametrize mark, but the ext function {} returned '{}' (of type {}) that was not a dictionary. It is impossible to merge these values.".format( - ext, new_value, type(new_value) - ) - ) - - return key, value - - variables = dict(map(maybe_load_ext, variables.items())) + variables = dict(map(_maybe_load_ext, variables.items())) logger.debug("Variables for this combination: %s", variables) logger.debug("Values for this combination: %s", flattened_values) diff --git a/tavern/_core/pytest/util.py b/tavern/_core/pytest/util.py index d8f5cc3a..03416d2c 100644 --- a/tavern/_core/pytest/util.py +++ b/tavern/_core/pytest/util.py @@ -177,20 +177,16 @@ def _load_global_cfg(pytest_config: pytest.Config) -> TestConfig: def _load_global_backends(pytest_config: pytest.Config) -> Dict[str, Any]: """Load which backend should be used""" - backend_settings = {} - - for b in TestConfig.backends(): - backend_settings[b] = get_option_generic( - pytest_config, f"tavern-{b}-backend", None - ) - - return backend_settings + return { + b: get_option_generic(pytest_config, f"tavern-{b}-backend", None) + for b in TestConfig.backends() + } def _load_global_strictness(pytest_config: pytest.Config) -> StrictLevel: """Load the global 'strictness' setting""" - options = get_option_generic(pytest_config, "tavern-strict", []) + options: List = get_option_generic(pytest_config, "tavern-strict", []) return StrictLevel.from_options(options) diff --git a/tavern/_core/stage_lines.py b/tavern/_core/stage_lines.py index 6a27cd57..d99c5972 100644 --- a/tavern/_core/stage_lines.py +++ b/tavern/_core/stage_lines.py @@ -16,7 +16,7 @@ class _WithMarks(Protocol): end_mark: EmptyBlock -PyYamlDict = Union[Dict, _WithMarks] +PyYamlDict = Union[_WithMarks, Dict] def get_stage_lines(stage: PyYamlDict): @@ -48,13 +48,13 @@ def get_stage_filename(yaml_block: PyYamlDict) -> Optional[str]: def start_mark(yaml_block: PyYamlDict) -> Union[Type[EmptyBlock], EmptyBlock]: try: - return yaml_block.start_mark + return yaml_block.start_mark # type:ignore except AttributeError: return EmptyBlock def end_mark(yaml_block: PyYamlDict) -> Union[Type[EmptyBlock], EmptyBlock]: try: - return yaml_block.end_mark + return yaml_block.end_mark # type:ignore except AttributeError: return EmptyBlock diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index a30a4d5f..cdaf605c 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -36,10 +36,8 @@ def retry(stage: Mapping, test_block_config: TestConfig) -> Callable: test_block_config: Configuration for current test """ - if "max_retries" in stage: - max_retries = maybe_format_max_retries( - int(stage.get("max_retries")), test_block_config - ) + if r := stage.get("max_retries", None): + max_retries = maybe_format_max_retries(int(r), test_block_config) else: max_retries = 0 diff --git a/tavern/_plugins/grpc/response.py b/tavern/_plugins/grpc/response.py index 436c3bd2..b9c8849d 100644 --- a/tavern/_plugins/grpc/response.py +++ b/tavern/_plugins/grpc/response.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Any, List, Mapping, TypedDict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, TypedDict, Union import proto.message from google.protobuf import json_format @@ -92,7 +92,7 @@ def verify(self, response: "WrappedFuture") -> Mapping: logger.debug(f"grpc details: {grpc_response.details()}") # Get any keys to save - saved = {} + saved: Dict[str, Any] = {} verify_status = [StatusCode.OK.name] if status := self.expected.get("status", None): verify_status = _to_grpc_name(status) # type: ignore diff --git a/tavern/_plugins/rest/response.py b/tavern/_plugins/rest/response.py index 0466e645..9f354788 100644 --- a/tavern/_plugins/rest/response.py +++ b/tavern/_plugins/rest/response.py @@ -177,7 +177,7 @@ def verify(self, response: requests.Response) -> dict: self._maybe_run_validate_functions(response) # Get any keys to save - saved = {} + saved: Dict = {} saved.update(self.maybe_get_save_values_from_save_block("json", body)) saved.update( diff --git a/tavern/_plugins/rest/tavernhook.py b/tavern/_plugins/rest/tavernhook.py index 3351559a..4245bf46 100644 --- a/tavern/_plugins/rest/tavernhook.py +++ b/tavern/_plugins/rest/tavernhook.py @@ -1,5 +1,5 @@ import logging -from typing import Mapping +from typing import Dict import requests @@ -22,7 +22,7 @@ class TavernRestPlugin(PluginHelperBase): @staticmethod def get_expected_from_request( - response_block: Mapping, test_block_config: TestConfig, session + response_block: Dict, test_block_config: TestConfig, session ): if response_block is None: raise exceptions.MissingSettingsError( diff --git a/tavern/response.py b/tavern/response.py index ad14d246..563f468c 100644 --- a/tavern/response.py +++ b/tavern/response.py @@ -200,7 +200,7 @@ def maybe_get_save_values_from_ext( except Exception as e: self._adderr( "Error calling save function '%s':\n%s", - wrapped.func, + wrapped.func, # type:ignore indent_err_text(traceback.format_exc()), e=e, ) From 310dbb15cb4e70f72edf5059cf8af4bd7c188ddc Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 19:38:23 +0000 Subject: [PATCH 08/20] More anotations --- tavern/_core/report.py | 12 ++++++----- tavern/_core/testhelpers.py | 2 +- tavern/_plugins/mqtt/client.py | 2 +- tavern/_plugins/rest/tavernhook.py | 2 +- tavern/core.py | 32 +++++++++++++----------------- tests/unit/test_mqtt.py | 2 +- 6 files changed, 25 insertions(+), 27 deletions(-) diff --git a/tavern/_core/report.py b/tavern/_core/report.py index e7d9f244..2bfba54b 100644 --- a/tavern/_core/report.py +++ b/tavern/_core/report.py @@ -31,12 +31,14 @@ def call(step_func): def prepare_yaml(val: Union[Dict, Set, List, Tuple, str]) -> Union[Dict, List, str]: """Sanitises the formatted string into a format safe for dumping""" if isinstance(val, dict): - formatted = {} + prepared = {} # formatted = {key: format_keys(val[key], box_vars) for key in val} for key in val: if isinstance(key, FormattedString): key = str(key) - formatted[key] = prepare_yaml(val[key]) + prepared[key] = prepare_yaml(val[key]) + + return prepared elif isinstance(val, (list, tuple, set)): return [prepare_yaml(item) for item in val] elif isinstance(val, FormattedString): @@ -53,15 +55,15 @@ def attach_stage_content(stage: Dict) -> None: attach_text(joined, "stage_yaml", yaml_type) -def attach_yaml(payload, name) -> None: +def attach_yaml(payload, name: str) -> None: prepared = prepare_yaml(payload) dumped = yaml.safe_dump(prepared) return attach_text(dumped, name, yaml_type) -def attach_text(payload, name, attachment_type=None) -> None: +def attach_text(payload, name: str, attachment_type=None) -> None: return attach(payload, name=name, attachment_type=attachment_type) -def wrap_step(allure_name, partial): +def wrap_step(allure_name: str, partial): return step(allure_name)(partial) diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index cdaf605c..ddbd7294 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -105,7 +105,7 @@ def maybe_format_max_retries( """Possibly handle max_retries validation""" # Probably a format variable, or just invalid (in which case it will fail further down) - max_retries = format_keys(str(max_retries), test_block_config.variables) + max_retries = int(format_keys(str(max_retries), test_block_config.variables)) # Missing type token will mean that max_retries is still a string and will fail here # Could auto convert here as well, but keep it consistent and just fail diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index 23583515..4c30ea4e 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -385,7 +385,7 @@ def publish( self, topic: str, payload: Any, - qos: Optional[int], + qos: Optional[int] = None, retain: Optional[bool] = False, ) -> MQTTMessageInfo: """publish message using paho library""" diff --git a/tavern/_plugins/rest/tavernhook.py b/tavern/_plugins/rest/tavernhook.py index 4245bf46..3a08bf6f 100644 --- a/tavern/_plugins/rest/tavernhook.py +++ b/tavern/_plugins/rest/tavernhook.py @@ -6,8 +6,8 @@ from tavern._core import exceptions from tavern._core.dict_util import format_keys from tavern._core.plugins import PluginHelperBase +from tavern._core.pytest.config import TestConfig -from ..._core.pytest.config import TestConfig from .request import RestRequest from .response import RestResponse diff --git a/tavern/core.py b/tavern/core.py index dcf28251..addc3456 100644 --- a/tavern/core.py +++ b/tavern/core.py @@ -17,8 +17,7 @@ def _get_or_wrap_global_cfg( Args: stack: context stack for wrapping file if a dictionary is given - tavern_global_cfg: Dictionary or string. It should be a - path to a file or a dictionary with configuration. + tavern_global_cfg: path to a file or a dictionary with configuration. Returns: path to global config file @@ -26,9 +25,6 @@ def _get_or_wrap_global_cfg( Raises: InvalidSettingsError: If global config was not of the right type or a given path does not exist - - Todo: - Once python 2 is dropped, allow this to take a 'path like object' """ if isinstance(tavern_global_cfg, str): @@ -49,29 +45,29 @@ def _get_or_wrap_global_cfg( return global_filename -def run( +def run( # type:ignore in_file: str, - tavern_global_cfg=None, - tavern_mqtt_backend=None, - tavern_http_backend=None, - tavern_grpc_backend=None, - tavern_strict=None, - pytest_args=None, + tavern_global_cfg: Union[dict, str, None] = None, + tavern_mqtt_backend: Union[str, None] = None, + tavern_http_backend: Union[str, None] = None, + tavern_grpc_backend: Union[str, None] = None, + tavern_strict: Union[bool, None] = None, + pytest_args: Union[list, None] = None, ) -> Union[ExitCode, int]: """Run all tests contained in a file using pytest.main() Args: in_file: file to run tests on - tavern_global_cfg (str, dict): Extra global config - tavern_mqtt_backend (str, optional): name of MQTT plugin to use. If not + tavern_global_cfg: Extra global config + tavern_mqtt_backend: name of MQTT plugin to use. If not specified, uses tavern-mqtt - tavern_http_backend (str, optional): name of HTTP plugin to use. If not + tavern_http_backend: name of HTTP plugin to use. If not specified, use tavern-http - tavern_grpc_backend (str, optional): name of GRPC plugin to use. If not + tavern_grpc_backend: name of GRPC plugin to use. If not specified, use tavern-grpc - tavern_strict (bool, optional): Strictness of checking for responses. + tavern_strict: Strictness of checking for responses. See documentation for details - pytest_args (list, optional): List of extra arguments to pass directly + pytest_args: List of extra arguments to pass directly to Pytest as if they were command line arguments Returns: diff --git a/tests/unit/test_mqtt.py b/tests/unit/test_mqtt.py index 1d55b06f..9e8a7a04 100644 --- a/tests/unit/test_mqtt.py +++ b/tests/unit/test_mqtt.py @@ -71,7 +71,7 @@ def test_context_connection_success(self, fake_client): with fake_client as x: assert fake_client == x - def test_assert_message_published(self, fake_client): + def test_assert_message_published(self, fake_client: MQTTClient): """If it couldn't immediately publish the message, error out""" class FakeMessage: From d23f6e69eaecbf10948c3347f103c28b88769aea Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 19:52:22 +0000 Subject: [PATCH 09/20] fix errors form missing args --- tavern/_core/dict_util.py | 23 ++++++++++++----------- tavern/_core/testhelpers.py | 2 +- tavern/_plugins/mqtt/client.py | 4 ++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index ece4c3e1..9a7e546c 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -117,8 +117,6 @@ def format_keys( Returns: recursively formatted values """ - formatted = val - format_keys_ = functools.partial( format_keys, dangerously_ignore_string_format_errors=dangerously_ignore_string_format_errors, @@ -130,14 +128,15 @@ def format_keys( box_vars = variables if isinstance(val, dict): - return {key: format_keys_(val[key], box_vars) for key in formatted} - elif isinstance(val, (tuple)): + return {key: format_keys_(val[key], box_vars) for key in val} + elif isinstance(val, tuple): return tuple(format_keys_(item, box_vars) for item in val) - elif isinstance(val, (list)): + elif isinstance(val, list): return [format_keys_(item, box_vars) for item in val] - elif isinstance(formatted, FormattedString): - logger.debug("Already formatted %s, not double-formatting", formatted) + elif isinstance(val, FormattedString): + logger.debug("Already formatted %s, not double-formatting", val) elif isinstance(val, str): + formatted = val try: formatted = _check_and_format_values(val, box_vars) except exceptions.MissingFormatError: @@ -146,17 +145,19 @@ def format_keys( if no_double_format: formatted = FormattedString(formatted) # type: ignore + + return formatted elif isinstance(val, TypeConvertToken): logger.debug("Got type convert token '%s'", val) if isinstance(val, ForceIncludeToken): - formatted = _attempt_find_include(val.value, box_vars) + return _attempt_find_include(val.value, box_vars) else: value = format_keys_(val.value, box_vars) - formatted = val.constructor(value) + return val.constructor(value) else: - logger.debug("Not formatting something of type '%s'", type(formatted)) + logger.debug("Not formatting something of type '%s'", type(val)) - return formatted + return val def recurse_access_key(data, query: str): diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index ddbd7294..db132b3d 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -105,7 +105,7 @@ def maybe_format_max_retries( """Possibly handle max_retries validation""" # Probably a format variable, or just invalid (in which case it will fail further down) - max_retries = int(format_keys(str(max_retries), test_block_config.variables)) + max_retries = int(format_keys(max_retries, test_block_config.variables)) # type:ignore # Missing type token will mean that max_retries is still a string and will fail here # Could auto convert here as well, but keep it consistent and just fail diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index 4c30ea4e..7779c70e 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -384,7 +384,7 @@ def message_received( def publish( self, topic: str, - payload: Any, + payload: Optional[Any] = None, qos: Optional[int] = None, retain: Optional[bool] = False, ) -> MQTTMessageInfo: @@ -400,7 +400,7 @@ def publish( kwargs["retain"] = retain msg = self._client.publish(topic, payload, **kwargs) - if not msg.is_published: + if not msg.is_published(): raise exceptions.MQTTError( "err {:s}: {:s}".format( _err_vals.get(msg.rc, "unknown"), paho.error_string(msg.rc) From 0e2c5386cfefea2fcf378338a141011d5fb989fa Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sun, 21 Jan 2024 20:05:07 +0000 Subject: [PATCH 10/20] Published --- tavern/_core/testhelpers.py | 2 +- tavern/_plugins/mqtt/client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index db132b3d..3c294a4c 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -37,7 +37,7 @@ def retry(stage: Mapping, test_block_config: TestConfig) -> Callable: """ if r := stage.get("max_retries", None): - max_retries = maybe_format_max_retries(int(r), test_block_config) + max_retries = maybe_format_max_retries(r, test_block_config) else: max_retries = 0 diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index 7779c70e..58362136 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -400,7 +400,7 @@ def publish( kwargs["retain"] = retain msg = self._client.publish(topic, payload, **kwargs) - if not msg.is_published(): + if msg.is_published() or True: # FIXME raise exceptions.MQTTError( "err {:s}: {:s}".format( _err_vals.get(msg.rc, "unknown"), paho.error_string(msg.rc) From effc280d857d60de0644f904067c0775ba86bb94 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Tue, 23 Jan 2024 08:55:34 +0000 Subject: [PATCH 11/20] More types --- tavern/_core/dict_util.py | 47 ++++++++++++++++++++++--------------- tavern/_core/stage_lines.py | 26 ++++++++++++-------- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index 9a7e546c..5304a85f 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -5,7 +5,8 @@ import re import string import typing -from typing import Any, Dict, List, Mapping, Tuple, Union +from collections.abc import Collection +from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union import box import jmespath @@ -160,7 +161,7 @@ def format_keys( return val -def recurse_access_key(data, query: str): +def recurse_access_key(data: Union[List, Mapping], query: str) -> Any: """ Search for something in the given data using the given query. @@ -172,11 +173,11 @@ def recurse_access_key(data, query: str): 'c' Args: - data (dict, list): Data to search in - query (str): Query to run + data: Data to search in + query: Query to run Returns: - object: Whatever was found by the search + Whatever was found by the search """ try: @@ -199,7 +200,9 @@ def recurse_access_key(data, query: str): return from_jmespath -def _deprecated_recurse_access_key(current_val, keys): +def _deprecated_recurse_access_key( + current_val: Union[List, Mapping], keys: List +) -> Any: """Given a list of keys and a dictionary, recursively access the dicionary using the keys until we find the key its looking for @@ -213,15 +216,15 @@ def _deprecated_recurse_access_key(current_val, keys): 'c' Args: - current_val (dict): current dictionary we have recursed into - keys (list): list of str/int of subkeys + current_val: current dictionary we have recursed into + keys: list of str/int of subkeys Raises: IndexError: list index not found in data KeyError: dict key not found in data Returns: - str or dict: value of subkey in dict + value of subkey in dict """ logger.debug("Recursively searching for '%s' in '%s'", keys, current_val) @@ -270,12 +273,12 @@ def deep_dict_merge(initial_dct: Dict, merge_dct: Mapping) -> dict: return dct -def check_expected_keys(expected, actual) -> None: +def check_expected_keys(expected: Collection, actual: Collection) -> None: """Check that a set of expected keys is a superset of the actual keys Args: - expected (list, set, dict): keys we expect - actual (list, set, dict): keys we have got from the input + expected: keys we expect + actual: keys we have got from the input Raises: UnexpectedKeysError: If not actual <= expected @@ -293,7 +296,7 @@ def check_expected_keys(expected, actual) -> None: raise exceptions.UnexpectedKeysError(msg) -def yield_keyvals(block): +def yield_keyvals(block: Union[List, Dict]) -> Iterator[Tuple[List, str, str]]: """Return indexes, keys and expected values for matching recursive keys Given a list or dict, return a 3-tuple of the 'split' key (key split on @@ -325,10 +328,10 @@ def yield_keyvals(block): (['2'], '2', 'c') Args: - block (dict, list): input matches + block: input matches Yields: - (list, str, str): key split on dots, key, expected value + iterable of (key split on dots, key, expected value) """ if isinstance(block, dict): for joined_key, expected_val in block.items(): @@ -340,9 +343,12 @@ def yield_keyvals(block): yield [sidx], sidx, val +Checked = typing.TypeVar("Checked", Dict, Collection, str) + + def check_keys_match_recursive( - expected_val: Any, - actual_val: Any, + expected_val: Checked, + actual_val: Checked, keys: List[Union[str, int]], strict: StrictSettingKinds = True, ) -> None: @@ -447,8 +453,8 @@ def _format_err(which): raise exceptions.KeyMismatchError(msg) from e if isinstance(expected_val, dict): - akeys = set(actual_val.keys()) ekeys = set(expected_val.keys()) + akeys = set(actual_val.keys()) # type:ignore if akeys != ekeys: extra_actual_keys = akeys - ekeys @@ -485,7 +491,10 @@ def _format_err(which): for key in to_recurse: try: check_keys_match_recursive( - expected_val[key], actual_val[key], keys + [key], strict + expected_val[key], + actual_val[key], # type:ignore + keys + [key], + strict, ) except KeyError: logger.debug( diff --git a/tavern/_core/stage_lines.py b/tavern/_core/stage_lines.py index d99c5972..18023155 100644 --- a/tavern/_core/stage_lines.py +++ b/tavern/_core/stage_lines.py @@ -1,10 +1,14 @@ +import dataclasses import logging -from typing import Dict, Optional, Protocol, Type, Union +from typing import Dict, Iterable, Optional, Protocol, Tuple, Type, Union logger: logging.Logger = logging.getLogger(__name__) -class EmptyBlock: +@dataclasses.dataclass +class YamlMark: + """A pyyaml mark""" + line: int = 0 name: Optional[str] = None @@ -12,14 +16,14 @@ class EmptyBlock: class _WithMarks(Protocol): """Things loaded by pyyaml have these""" - start_mark: EmptyBlock - end_mark: EmptyBlock + start_mark: YamlMark + end_mark: YamlMark PyYamlDict = Union[_WithMarks, Dict] -def get_stage_lines(stage: PyYamlDict): +def get_stage_lines(stage: PyYamlDict) -> Tuple[int, int, int]: first_line = start_mark(stage).line - 1 last_line = end_mark(stage).line line_start = first_line + 1 @@ -27,7 +31,9 @@ def get_stage_lines(stage: PyYamlDict): return first_line, last_line, line_start -def read_relevant_lines(yaml_block: PyYamlDict, first_line: int, last_line: int): +def read_relevant_lines( + yaml_block: PyYamlDict, first_line: int, last_line: int +) -> Iterable[str]: """Get lines between start and end mark""" filename = get_stage_filename(yaml_block) @@ -46,15 +52,15 @@ def get_stage_filename(yaml_block: PyYamlDict) -> Optional[str]: return start_mark(yaml_block).name -def start_mark(yaml_block: PyYamlDict) -> Union[Type[EmptyBlock], EmptyBlock]: +def start_mark(yaml_block: PyYamlDict) -> Union[Type[YamlMark], YamlMark]: try: return yaml_block.start_mark # type:ignore except AttributeError: - return EmptyBlock + return YamlMark() -def end_mark(yaml_block: PyYamlDict) -> Union[Type[EmptyBlock], EmptyBlock]: +def end_mark(yaml_block: PyYamlDict) -> Union[Type[YamlMark], YamlMark]: try: return yaml_block.end_mark # type:ignore except AttributeError: - return EmptyBlock + return YamlMark() From 8aec025e18b4668df700718727cb977ee9145863 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Tue, 23 Jan 2024 14:01:20 +0000 Subject: [PATCH 12/20] More typing --- tavern/_core/pytest/error.py | 18 ++++++++++++------ tavern/_core/pytest/file.py | 2 +- tavern/_core/pytest/item.py | 5 +++-- tavern/_core/stage_lines.py | 12 ++++++++++-- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/tavern/_core/pytest/error.py b/tavern/_core/pytest/error.py index 1275f610..7fe619d5 100644 --- a/tavern/_core/pytest/error.py +++ b/tavern/_core/pytest/error.py @@ -1,8 +1,10 @@ +import dataclasses import json import logging import re +import typing from io import StringIO -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import yaml from _pytest._code.code import FormattedExcinfo, TerminalRepr @@ -10,6 +12,10 @@ from tavern._core import exceptions from tavern._core.dict_util import format_keys + +if typing.TYPE_CHECKING: + from tavern._core.pytest.item import YamlItem + from tavern._core.report import prepare_yaml from tavern._core.stage_lines import ( end_mark, @@ -21,19 +27,19 @@ logger: logging.Logger = logging.getLogger(__name__) +@dataclasses.dataclass class ReprdError(TerminalRepr): - def __init__(self, exce, item) -> None: - self.exce = exce - self.item = item + exce: Any + item: "YamlItem" - def _get_available_format_keys(self): + def _get_available_format_keys(self) -> Dict: """Try to get the format variables for the stage If we can't get the variable for this specific stage, just return the global config which will at least have some format variables Returns: - dict: variables for formatting test + variables for formatting test """ try: keys = self.exce._excinfo[1].test_block_config.variables diff --git a/tavern/_core/pytest/file.py b/tavern/_core/pytest/file.py index 17b8a05d..c0f690a1 100644 --- a/tavern/_core/pytest/file.py +++ b/tavern/_core/pytest/file.py @@ -41,7 +41,7 @@ def _format_test_marks( test_name: Name of test (for error logging) Returns: - tuple: first element is normal pytest mark objects, second element is all + first element is normal pytest mark objects, second element is all marks which were formatted (no matter their content) Todo: diff --git a/tavern/_core/pytest/item.py b/tavern/_core/pytest/item.py index 18106c85..6e14984a 100644 --- a/tavern/_core/pytest/item.py +++ b/tavern/_core/pytest/item.py @@ -33,11 +33,14 @@ class YamlItem(pytest.Item): Attributes: path: filename that this test came from spec: The whole dictionary of the test + global_cfg: configuration for test """ # See https://github.com/taverntesting/tavern/issues/825 _patched_yaml = False + global_cfg: TestConfig + def __init__( self, *, name: str, parent, spec: MutableMapping, path: pathlib.Path, **kwargs ) -> None: @@ -48,8 +51,6 @@ def __init__( self.path = path self.spec = spec - self.global_cfg: Optional[TestConfig] = None - if not YamlItem._patched_yaml: yaml.parser.Parser.process_empty_scalar = ( # type:ignore error_on_empty_scalar diff --git a/tavern/_core/stage_lines.py b/tavern/_core/stage_lines.py index 18023155..7c6f9722 100644 --- a/tavern/_core/stage_lines.py +++ b/tavern/_core/stage_lines.py @@ -1,6 +1,14 @@ import dataclasses import logging -from typing import Dict, Iterable, Optional, Protocol, Tuple, Type, Union +from typing import ( + Iterable, + Mapping, + Optional, + Protocol, + Tuple, + Type, + Union, +) logger: logging.Logger = logging.getLogger(__name__) @@ -20,7 +28,7 @@ class _WithMarks(Protocol): end_mark: YamlMark -PyYamlDict = Union[_WithMarks, Dict] +PyYamlDict = Union[_WithMarks, Mapping] def get_stage_lines(stage: PyYamlDict) -> Tuple[int, int, int]: From 025e036814e91e911eac72c705e01bce3155efe4 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Tue, 23 Jan 2024 14:08:46 +0000 Subject: [PATCH 13/20] fix a couple of issues --- .pre-commit-config.yaml | 2 +- tavern/_core/dict_util.py | 6 ++++++ tavern/_core/loader.py | 2 -- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee74939a..6079aece 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: hooks: - id: pycln - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.1.11" + rev: "v0.1.14" hooks: - id: ruff-format - id: ruff diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index 5304a85f..a5fe3f0d 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -115,6 +115,9 @@ def format_keys( dangerously_ignore_string_format_errors: whether to ignore any string formatting errors. This will result in broken output, only use for debugging purposes. + Raises: + MissingFormatError: if a format variable was not found in variables + Returns: recursively formatted values """ @@ -176,6 +179,9 @@ def recurse_access_key(data: Union[List, Mapping], query: str) -> Any: data: Data to search in query: Query to run + Raises: + JMESError: if there was an error parsing the query + Returns: Whatever was found by the search """ diff --git a/tavern/_core/loader.py b/tavern/_core/loader.py index 0a583ab9..690a1ee1 100644 --- a/tavern/_core/loader.py +++ b/tavern/_core/loader.py @@ -108,8 +108,6 @@ class IncludeLoader( between documents""" def __init__(self, stream): - """Initialise Loader.""" - try: self._root = os.path.split(stream.name)[0] except AttributeError: From 8d3e812654e3f1862eef0b31bf5f9eba74fc34bd Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Tue, 23 Jan 2024 15:37:54 +0000 Subject: [PATCH 14/20] Dont run mqtt tests separately - takes ages --- tox-integration.ini | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tox-integration.ini b/tox-integration.ini index 9a7b7abc..eab3b6a9 100644 --- a/tox-integration.ini +++ b/tox-integration.ini @@ -54,9 +54,4 @@ commands = components: python -c "from tavern.core import run; exit(run('test_ping.tavern.yaml', pytest_args=[ ]))" components: python -c "from tavern.core import run; exit(run('test_hello.tavern.yaml', pytest_args=[ ]))" - mqtt: tavern-ci --stdout test_mqtt.tavern.yaml --cov tavern - mqtt: python -c "from tavern.core import run; exit(run('test_mqtt.tavern.yaml', pytest_args=['--cov-append']))" - mqtt: tavern-ci --stdout test_mqtt_failures.tavern.yaml - mqtt: python -c "from tavern.core import run; exit(run('test_mqtt_failures.tavern.yaml', pytest_args=[ ]))" - docker compose stop From d85a0ac38d7b8296eae2edf7357d282a695f090a Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Tue, 23 Jan 2024 15:38:12 +0000 Subject: [PATCH 15/20] More annotations --- tavern/_core/pytest/newhooks.py | 4 ++-- tavern/_core/run.py | 2 +- tavern/_plugins/mqtt/client.py | 4 ++-- tavern/_plugins/mqtt/response.py | 19 ++++++++++--------- tavern/_plugins/rest/response.py | 15 +++++++-------- tavern/response.py | 6 +++--- 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tavern/_core/pytest/newhooks.py b/tavern/_core/pytest/newhooks.py index 8f994b07..b76955af 100644 --- a/tavern/_core/pytest/newhooks.py +++ b/tavern/_core/pytest/newhooks.py @@ -34,7 +34,7 @@ def pytest_tavern_beta_after_every_test_run(test_dict: Dict, variables: Dict) -> """ -def pytest_tavern_beta_after_every_response(expected: Any, response: Dict) -> None: +def pytest_tavern_beta_after_every_response(expected: Any, response: Any) -> None: """Called after every _response_ - including MQTT/HTTP/etc Note: @@ -42,8 +42,8 @@ def pytest_tavern_beta_after_every_response(expected: Any, response: Dict) -> No - MQTT responses will call this hook multiple times if multiple messages are received Args: - response: Response object. expected: Response block in stage + response: Response object. """ diff --git a/tavern/_core/run.py b/tavern/_core/run.py index 83229bbd..91474d38 100644 --- a/tavern/_core/run.py +++ b/tavern/_core/run.py @@ -67,10 +67,10 @@ def _get_included_stages( for use in this test Args: - available_stages: List of stages which already exist tavern_box: Available parameters for fomatting at this point test_block_config: Current test config dictionary test_spec: Specification for current test + available_stages: List of stages which already exist Returns: Fully resolved stages diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index f8e6b54a..7ac4a5a9 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -5,7 +5,7 @@ import threading import time from queue import Empty, Full, Queue -from typing import Any, Dict, List, Mapping, MutableMapping, Optional +from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Union import paho.mqtt.client as paho from paho.mqtt.client import MQTTMessageInfo @@ -351,7 +351,7 @@ def _on_socket_close(client, userdata, socket) -> None: logger.debug("MQTT socket closed") def message_received( - self, topic: str, timeout: int = 1 + self, topic: str, timeout: Union[float, int] = 1 ) -> Optional[paho.MQTTMessage]: """Check that a message is in the message queue diff --git a/tavern/_plugins/mqtt/response.py b/tavern/_plugins/mqtt/response.py index 184e217a..305e8496 100644 --- a/tavern/_plugins/mqtt/response.py +++ b/tavern/_plugins/mqtt/response.py @@ -8,7 +8,6 @@ from dataclasses import dataclass from typing import Dict, List, Mapping, Optional, Tuple, Union -import requests from paho.mqtt.client import MQTTMessage from tavern._core import exceptions @@ -17,7 +16,7 @@ from tavern._core.pytest.config import TestConfig from tavern._core.pytest.newhooks import call_hook from tavern._core.report import attach_yaml -from tavern._core.strict_util import StrictSetting +from tavern._core.strict_util import StrictOption from tavern.response import BaseResponse from .client import MQTTClient @@ -28,6 +27,8 @@ class MQTTResponse(BaseResponse): + response: MQTTMessage + def __init__( self, client: MQTTClient, @@ -39,15 +40,15 @@ def __init__( self._client = client - self.received_messages = [] # type: ignore + self.received_messages: List = [] - def __str__(self): + def __str__(self) -> str: if self.response: - return self.response.payload + return self.response.payload.decode("utf-8") else: return "" - def verify(self, response: requests.Response) -> Mapping: + def verify(self, response: MQTTMessage) -> Mapping: """Ensure mqtt message has arrived Args: @@ -236,12 +237,12 @@ def _await_messages_on_topic( class _ReturnedMessage: """An actual message returned from the API and it's matching 'expected' block.""" - expected: dict + expected: Mapping msg: MQTTMessage class _MessageVerifier: - def __init__(self, test_block_config, expected) -> None: + def __init__(self, test_block_config: TestConfig, expected: Mapping) -> None: self.expires = time.time() + expected.get("timeout", _default_timeout) self.expected = expected @@ -250,7 +251,7 @@ def __init__(self, test_block_config, expected) -> None: ) test_strictness = test_block_config.strict - self.block_strictness: StrictSetting = test_strictness.option_for("json") + self.block_strictness: StrictOption = test_strictness.option_for("json") # Any warnings to do with the request # eg, if a message was received but it didn't match, message had payload, etc. diff --git a/tavern/_plugins/rest/response.py b/tavern/_plugins/rest/response.py index 9f354788..598fe6cd 100644 --- a/tavern/_plugins/rest/response.py +++ b/tavern/_plugins/rest/response.py @@ -1,7 +1,7 @@ import contextlib import json import logging -from typing import Dict, Mapping, Optional +from typing import Any, Dict, List, Mapping, Union from urllib.parse import parse_qs, urlparse import requests @@ -18,6 +18,8 @@ class RestResponse(BaseResponse): + response: requests.Response + def __init__( self, session, name: str, expected, test_block_config: TestConfig ) -> None: @@ -25,8 +27,6 @@ def __init__( super().__init__(name, deep_dict_merge(defaults, expected), test_block_config) - self.status_code: Optional[int] = None - def check_code(code: int) -> None: if int(code) not in _codes: logger.warning("Unexpected status code '%s'", code) @@ -47,7 +47,7 @@ def __str__(self) -> str: else: return "" - def _verbose_log_response(self, response) -> None: + def _verbose_log_response(self, response: requests.Response) -> None: """Verbosely log the response object, with query params etc.""" logger.info("Response: '%s'", response) @@ -78,7 +78,7 @@ def log_dict_block(block, name): logger.debug("Redirect location: %s", to_path) log_dict_block(redirect_query_params, "Redirect URL query parameters") - def _get_redirect_query_params(self, response) -> Dict[str, str]: + def _get_redirect_query_params(self, response: requests.Response) -> Dict[str, str]: """If there was a redirect header, get any query parameters from it""" try: @@ -98,7 +98,7 @@ def _get_redirect_query_params(self, response) -> Dict[str, str]: return redirect_query_params - def _check_status_code(self, status_code, body) -> None: + def _check_status_code(self, status_code: Union[int, List[int]], body: Any) -> None: expected_code = self.expected["status_code"] if (isinstance(expected_code, int) and status_code == expected_code) or ( @@ -108,7 +108,7 @@ def _check_status_code(self, status_code, body) -> None: "Status code '%s' matched expected '%s'", status_code, expected_code ) return - elif 400 <= status_code < 500: + elif isinstance(status_code, int) and 400 <= status_code < 500: # special case if there was a bad request. This assumes that the # response would contain some kind of information as to why this # request was rejected. @@ -147,7 +147,6 @@ def verify(self, response: requests.Response) -> dict: ) self.response = response - self.status_code = response.status_code # Get things to use from the response try: diff --git a/tavern/response.py b/tavern/response.py index 563f468c..0b90246a 100644 --- a/tavern/response.py +++ b/tavern/response.py @@ -21,14 +21,14 @@ def indent_err_text(err: str) -> str: return indent(err, " " * 4) -@dataclasses.dataclass(repr=True) +@dataclasses.dataclass class BaseResponse: name: str expected: Any test_block_config: TestConfig + response: Optional[Any] = None validate_functions: List[Any] = dataclasses.field(init=False, default_factory=list) - response: Optional[Any] = None errors: List[str] = dataclasses.field(init=False, default_factory=list) def __post_init__(self) -> None: @@ -66,10 +66,10 @@ def recurse_check_key_match( Optionally use a validation library too Args: - strict: strictness setting for this block expected_block: expected data block: actual data blockname: 'name' of this block (params, mqtt, etc) for error messages + strict: strictness setting for this block """ if expected_block is None: From 1dcab85ac57806494f326cf60bc24435a29d50ed Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Wed, 24 Jan 2024 15:07:01 +0000 Subject: [PATCH 16/20] annotationS 2 --- tavern/_core/dict_util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index a5fe3f0d..f65eaeee 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -6,7 +6,7 @@ import string import typing from collections.abc import Collection -from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union import box import jmespath @@ -27,7 +27,7 @@ logger: logging.Logger = logging.getLogger(__name__) -def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str: +def _check_and_format_values(to_format: str, box_vars: Mapping[str, Any]) -> str: formatter = string.Formatter() would_format = formatter.parse(to_format) @@ -57,7 +57,7 @@ def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str: return to_format.format(**box_vars) -def _attempt_find_include(to_format: str, box_vars: box.Box): +def _attempt_find_include(to_format: str, box_vars: box.Box) -> Optional[str]: formatter = string.Formatter() would_format = list(formatter.parse(to_format)) @@ -91,7 +91,10 @@ def _attempt_find_include(to_format: str, box_vars: box.Box): would_replace = formatter.get_field(field_name, [], box_vars)[0] - return formatter.convert_field(would_replace, conversion) # type: ignore + if conversion is None: + return would_replace + + return formatter.convert_field(would_replace, conversion) T = typing.TypeVar("T", str, Dict, List, Tuple) @@ -99,7 +102,7 @@ def _attempt_find_include(to_format: str, box_vars: box.Box): def format_keys( val: T, - variables: Mapping, + variables: Union[Mapping, Box], *, no_double_format: bool = True, dangerously_ignore_string_format_errors: bool = False, From fad2bc42764076d3bef03a2f2a0e220cb7caad41 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sat, 27 Jan 2024 14:01:10 +0000 Subject: [PATCH 17/20] cleanup docstrings --- tavern/_core/extfunctions.py | 2 +- tavern/_core/plugins.py | 2 +- tavern/_core/schema/files.py | 2 +- tavern/_plugins/rest/request.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tavern/_core/extfunctions.py b/tavern/_core/extfunctions.py index c81404f9..66b5347d 100644 --- a/tavern/_core/extfunctions.py +++ b/tavern/_core/extfunctions.py @@ -51,7 +51,7 @@ def import_ext_function(entrypoint: str) -> Callable: module.submodule:function Returns: - function: function loaded from entrypoint + function loaded from entrypoint Raises: InvalidExtFunctionError: If the module or function did not exist diff --git a/tavern/_core/plugins.py b/tavern/_core/plugins.py index ff724797..e447c4fe 100644 --- a/tavern/_core/plugins.py +++ b/tavern/_core/plugins.py @@ -57,7 +57,7 @@ def is_valid_reqresp_plugin(ext: stevedore.extension.Extension) -> bool: ext: class or module plugin object Returns: - bool: Whether the plugin has everything we need to use it + Whether the plugin has everything we need to use it """ required = [ # MQTTClient, requests.Session diff --git a/tavern/_core/schema/files.py b/tavern/_core/schema/files.py index 1853de3b..a2661c37 100644 --- a/tavern/_core/schema/files.py +++ b/tavern/_core/schema/files.py @@ -66,7 +66,7 @@ def __call__(self, schema_filename, with_plugins): with_plugins (bool): Whether to load plugin schema into this schema as well Returns: - dict: loaded schema + loaded schema """ if with_plugins: diff --git a/tavern/_plugins/rest/request.py b/tavern/_plugins/rest/request.py index a3282220..06e40fbe 100644 --- a/tavern/_plugins/rest/request.py +++ b/tavern/_plugins/rest/request.py @@ -4,7 +4,7 @@ import warnings from contextlib import ExitStack from itertools import filterfalse, tee -from typing import ClassVar, Dict, List, Mapping, Optional +from typing import Callable, ClassVar, Dict, List, Mapping, Optional from urllib.parse import quote_plus import requests @@ -235,7 +235,7 @@ def _check_allow_redirects(rspec: dict, test_block_config: TestConfig): test_block_config: config available for test Returns: - bool: Whether to allow redirects for this stage or not + Whether to allow redirects for this stage or not """ # By default, don't follow redirects allow_redirects = False @@ -439,7 +439,7 @@ def prepared_request(): return session.request(**self._request_args) - self._prepared = prepared_request + self._prepared: Callable[[], requests.Response] = prepared_request def run(self) -> requests.Response: """Runs the prepared request and times it From 949adfa44c2fc87c8d422b4be250c91817bc1f62 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sat, 27 Jan 2024 14:09:12 +0000 Subject: [PATCH 18/20] annotations --- tavern/_core/schema/extensions.py | 35 ++++++++++++++++++------------ tavern/_plugins/mqtt/tavernhook.py | 2 +- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tavern/_core/schema/extensions.py b/tavern/_core/schema/extensions.py index 62da2088..44bc7075 100644 --- a/tavern/_core/schema/extensions.py +++ b/tavern/_core/schema/extensions.py @@ -1,6 +1,6 @@ import os import re -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Tuple, Type, Union from pykwalify.types import is_bool, is_float, is_int @@ -12,7 +12,13 @@ is_ext_function, ) from tavern._core.general import valid_http_methods -from tavern._core.loader import ApproxScalar, BoolToken, FloatToken, IntToken +from tavern._core.loader import ( + ApproxScalar, + BoolToken, + FloatToken, + IntToken, + TypeConvertToken, +) from tavern._core.strict_util import StrictLevel if TYPE_CHECKING: @@ -21,7 +27,9 @@ # To extend pykwalify's type validation, extend its internal functions # These return boolean values -def validate_type_and_token(validate_type, token): +def validate_type_and_token( + validate_type: Callable[[Any], bool], token: Type[TypeConvertToken] +): def validate(value): return validate_type(value) or isinstance(value, token) @@ -34,7 +42,7 @@ def validate(value): # These plug into the pykwalify extension function API -def validator_like(validate, description): +def validator_like(validate: Callable[[Any], bool], description: str): def validator(value, rule_obj, path): if validate(value): return True @@ -52,7 +60,7 @@ def validator(value, rule_obj, path): bool_variable = validator_like(is_bool_like, "bool-like") -def _validate_one_extension(input_value) -> None: +def _validate_one_extension(input_value: Mapping) -> None: expected_keys = {"function", "extra_args", "extra_kwargs"} extra = set(input_value) - expected_keys @@ -105,7 +113,7 @@ def validate_extensions(value, rule_obj, path) -> bool: return True -def validate_status_code_is_int_or_list_of_ints(value, rule_obj, path) -> bool: +def validate_status_code_is_int_or_list_of_ints(value: Mapping, rule_obj, path) -> bool: err_msg = "status_code has to be an integer or a list of integers (got {})".format( value ) @@ -120,7 +128,7 @@ def validate_status_code_is_int_or_list_of_ints(value, rule_obj, path) -> bool: return True -def check_usefixtures(value, rule_obj, path) -> bool: +def check_usefixtures(value: Mapping, rule_obj, path) -> bool: err_msg = "'usefixtures' has to be a list with at least one item" if not isinstance(value, (list, tuple)): @@ -171,11 +179,10 @@ def to_grpc_status(value: Union[str, int]): return None -def verify_oneof_id_name(value, rule_obj, path) -> bool: +def verify_oneof_id_name(value: Mapping, rule_obj, path) -> bool: """Checks that if 'name' is not present, 'id' is""" - name = value.get("name") - if not name: + if not (name := value.get("name")): if name == "": raise BadSchemaError("Name cannot be empty") @@ -246,7 +253,7 @@ def check_parametrize_marks(value, rule_obj, path) -> bool: return True -def validate_data_key(value, rule_obj, path) -> bool: +def validate_data_key(value, rule_obj, path: str) -> bool: """Validate the 'data' key in a http request From requests docs: @@ -326,7 +333,7 @@ def validate_json_with_ext(value, rule_obj, path) -> bool: return True -def check_strict_key(value, rule_obj, path) -> bool: +def check_strict_key(value: Union[List, bool], rule_obj, path) -> bool: """Make sure the 'strict' key is either a bool or a list""" if not isinstance(value, list) and not is_bool_like(value): @@ -343,7 +350,7 @@ def check_strict_key(value, rule_obj, path) -> bool: return True -def validate_timeout_tuple_or_float(value, rule_obj, path) -> bool: +def validate_timeout_tuple_or_float(value: Union[List, Tuple], rule_obj, path) -> bool: """Make sure timeout is a float/int or a tuple of floats/ints""" err_msg = "'timeout' must be either a float/int or a 2-tuple of floats/ints - got '{}' (type {})".format( @@ -398,7 +405,7 @@ def validate_cert_tuple_or_str(value, rule_obj, path) -> bool: return True -def validate_file_spec(value, rule_obj, path) -> bool: +def validate_file_spec(value: Dict, rule_obj, path) -> bool: """Validate file upload arguments""" logger = get_pykwalify_logger("tavern.schema.extensions") diff --git a/tavern/_plugins/mqtt/tavernhook.py b/tavern/_plugins/mqtt/tavernhook.py index 385f573a..bc37d267 100644 --- a/tavern/_plugins/mqtt/tavernhook.py +++ b/tavern/_plugins/mqtt/tavernhook.py @@ -23,7 +23,7 @@ def get_expected_from_request( response_block: Union[Dict, Iterable[Dict]], test_block_config: TestConfig, session: MQTTClient, -): +) -> Optional[Dict]: expected: Optional[Dict] = None # mqtt response is not required From 875e522925eedadb7dae6ff0841692689df87896 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sat, 27 Jan 2024 14:12:47 +0000 Subject: [PATCH 19/20] annotations --- tavern/_core/tincture.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tavern/_core/tincture.py b/tavern/_core/tincture.py index e99fe2eb..21280239 100644 --- a/tavern/_core/tincture.py +++ b/tavern/_core/tincture.py @@ -1,7 +1,8 @@ import collections.abc +import dataclasses import inspect import logging -from typing import Any, List +from typing import Any, Generator, List from tavern._core import exceptions from tavern._core.extfunctions import get_wrapped_response_function @@ -9,19 +10,18 @@ logger: logging.Logger = logging.getLogger(__name__) +@dataclasses.dataclass class Tinctures: - def __init__(self, tinctures: List[Any]) -> None: - self._tinctures = tinctures - self._needs_response: List[Any] = [] + tinctures: List[Any] + needs_response: List[Generator] = dataclasses.field(default_factory=list) def start_tinctures(self, stage: collections.abc.Mapping) -> None: - results = [t(stage) for t in self._tinctures] - self._needs_response = [] + results = [t(stage) for t in self.tinctures] for r in results: if inspect.isgenerator(r): # Store generator and start it - self._needs_response.append(r) + self.needs_response.append(r) next(r) def end_tinctures(self, expected: collections.abc.Mapping, response) -> None: @@ -29,14 +29,15 @@ def end_tinctures(self, expected: collections.abc.Mapping, response) -> None: Send the response object to any tinctures that want it Args: - response: The response from 'run' for the stage + expected: 'expected' from initial test - type varies depending on backend + response: The response from 'run' for the stage - type varies depending on backend """ - if self._needs_response is None: + if self.needs_response is None: raise RuntimeError( "should not be called before accumulating tinctures which need a response" ) - for n in self._needs_response: + for n in self.needs_response: try: n.send((expected, response)) except StopIteration: From 963caa7eb63435efffb2942f2dc72409b58d8411 Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sat, 27 Jan 2024 14:33:56 +0000 Subject: [PATCH 20/20] Return self in grpc client to be consistent --- tavern/_plugins/grpc/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tavern/_plugins/grpc/client.py b/tavern/_plugins/grpc/client.py index 85b433d3..8e51b110 100644 --- a/tavern/_plugins/grpc/client.py +++ b/tavern/_plugins/grpc/client.py @@ -239,8 +239,9 @@ def _make_call_request( return self._get_grpc_service(channel, service, method) - def __enter__(self) -> None: + def __enter__(self) -> "GRPCClient": logger.debug("Connecting to GRPC") + return self def call( self,