diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee74939a..6079aece 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: hooks: - id: pycln - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.1.11" + rev: "v0.1.14" hooks: - id: ruff-format - id: ruff diff --git a/scripts/smoke.bash b/scripts/smoke.bash index aa9762f6..bcbd5b8e 100755 --- a/scripts/smoke.bash +++ b/scripts/smoke.bash @@ -6,10 +6,10 @@ pre-commit run ruff --all-files || true pre-commit run ruff-format --all-files || true tox --parallel -c tox.ini \ - -e py3check + -e py3mypy tox --parallel -c tox.ini \ - -e py3mypy + -e py3check tox --parallel -c tox.ini \ -e py3 diff --git a/tavern/_core/dict_util.py b/tavern/_core/dict_util.py index a47f5939..f65eaeee 100644 --- a/tavern/_core/dict_util.py +++ b/tavern/_core/dict_util.py @@ -4,7 +4,9 @@ import os import re import string -from typing import Any, Dict, List, Mapping, Union +import typing +from collections.abc import Collection +from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union import box import jmespath @@ -22,10 +24,10 @@ from .formatted_str import FormattedString from .strict_util import StrictSetting, StrictSettingKinds, extract_strict_setting -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str: +def _check_and_format_values(to_format: str, box_vars: Mapping[str, Any]) -> str: formatter = string.Formatter() would_format = formatter.parse(to_format) @@ -55,7 +57,7 @@ def _check_and_format_values(to_format, box_vars: Mapping[str, Any]) -> str: return to_format.format(**box_vars) -def _attempt_find_include(to_format: str, box_vars: box.Box): +def _attempt_find_include(to_format: str, box_vars: box.Box) -> Optional[str]: formatter = string.Formatter() would_format = list(formatter.parse(to_format)) @@ -89,20 +91,26 @@ def _attempt_find_include(to_format: str, box_vars: box.Box): would_replace = formatter.get_field(field_name, [], box_vars)[0] - return formatter.convert_field(would_replace, conversion) # type: ignore + if conversion is None: + return would_replace + + return formatter.convert_field(would_replace, conversion) + + +T = typing.TypeVar("T", str, Dict, List, Tuple) def format_keys( - val, - variables: Mapping, + val: T, + variables: Union[Mapping, Box], *, no_double_format: bool = True, dangerously_ignore_string_format_errors: bool = False, -): +) -> T: """recursively format a dictionary with the given values Args: - val: Input dictionary to format + val: Input thing to format variables: Dictionary of keys to format it with no_double_format: Whether to use the 'inner formatted string' class to avoid double formatting This is required if passing something via pytest-xdist, such as markers: @@ -110,11 +118,12 @@ def format_keys( dangerously_ignore_string_format_errors: whether to ignore any string formatting errors. This will result in broken output, only use for debugging purposes. + Raises: + MissingFormatError: if a format variable was not found in variables + Returns: recursively formatted values """ - formatted = val - format_keys_ = functools.partial( format_keys, dangerously_ignore_string_format_errors=dangerously_ignore_string_format_errors, @@ -126,15 +135,15 @@ def format_keys( box_vars = variables if isinstance(val, dict): - formatted = {} - # formatted = {key: format_keys(val[key], box_vars) for key in val} - for key in val: - formatted[key] = format_keys_(val[key], box_vars) - elif isinstance(val, (list, tuple)): - formatted = [format_keys_(item, box_vars) for item in val] # type: ignore - elif isinstance(formatted, FormattedString): - logger.debug("Already formatted %s, not double-formatting", formatted) + return {key: format_keys_(val[key], box_vars) for key in val} + elif isinstance(val, tuple): + return tuple(format_keys_(item, box_vars) for item in val) + elif isinstance(val, list): + return [format_keys_(item, box_vars) for item in val] + elif isinstance(val, FormattedString): + logger.debug("Already formatted %s, not double-formatting", val) elif isinstance(val, str): + formatted = val try: formatted = _check_and_format_values(val, box_vars) except exceptions.MissingFormatError: @@ -143,20 +152,22 @@ def format_keys( if no_double_format: formatted = FormattedString(formatted) # type: ignore + + return formatted elif isinstance(val, TypeConvertToken): logger.debug("Got type convert token '%s'", val) if isinstance(val, ForceIncludeToken): - formatted = _attempt_find_include(val.value, box_vars) + return _attempt_find_include(val.value, box_vars) else: value = format_keys_(val.value, box_vars) - formatted = val.constructor(value) + return val.constructor(value) else: - logger.debug("Not formatting something of type '%s'", type(formatted)) + logger.debug("Not formatting something of type '%s'", type(val)) - return formatted + return val -def recurse_access_key(data, query: str): +def recurse_access_key(data: Union[List, Mapping], query: str) -> Any: """ Search for something in the given data using the given query. @@ -168,11 +179,14 @@ def recurse_access_key(data, query: str): 'c' Args: - data (dict, list): Data to search in - query (str): Query to run + data: Data to search in + query: Query to run + + Raises: + JMESError: if there was an error parsing the query Returns: - object: Whatever was found by the search + Whatever was found by the search """ try: @@ -195,7 +209,9 @@ def recurse_access_key(data, query: str): return from_jmespath -def _deprecated_recurse_access_key(current_val, keys): +def _deprecated_recurse_access_key( + current_val: Union[List, Mapping], keys: List +) -> Any: """Given a list of keys and a dictionary, recursively access the dicionary using the keys until we find the key its looking for @@ -209,15 +225,15 @@ def _deprecated_recurse_access_key(current_val, keys): 'c' Args: - current_val (dict): current dictionary we have recursed into - keys (list): list of str/int of subkeys + current_val: current dictionary we have recursed into + keys: list of str/int of subkeys Raises: IndexError: list index not found in data KeyError: dict key not found in data Returns: - str or dict: value of subkey in dict + value of subkey in dict """ logger.debug("Recursively searching for '%s' in '%s'", keys, current_val) @@ -266,12 +282,12 @@ def deep_dict_merge(initial_dct: Dict, merge_dct: Mapping) -> dict: return dct -def check_expected_keys(expected, actual) -> None: +def check_expected_keys(expected: Collection, actual: Collection) -> None: """Check that a set of expected keys is a superset of the actual keys Args: - expected (list, set, dict): keys we expect - actual (list, set, dict): keys we have got from the input + expected: keys we expect + actual: keys we have got from the input Raises: UnexpectedKeysError: If not actual <= expected @@ -289,7 +305,7 @@ def check_expected_keys(expected, actual) -> None: raise exceptions.UnexpectedKeysError(msg) -def yield_keyvals(block): +def yield_keyvals(block: Union[List, Dict]) -> Iterator[Tuple[List, str, str]]: """Return indexes, keys and expected values for matching recursive keys Given a list or dict, return a 3-tuple of the 'split' key (key split on @@ -321,10 +337,10 @@ def yield_keyvals(block): (['2'], '2', 'c') Args: - block (dict, list): input matches + block: input matches Yields: - (list, str, str): key split on dots, key, expected value + iterable of (key split on dots, key, expected value) """ if isinstance(block, dict): for joined_key, expected_val in block.items(): @@ -336,9 +352,12 @@ def yield_keyvals(block): yield [sidx], sidx, val +Checked = typing.TypeVar("Checked", Dict, Collection, str) + + def check_keys_match_recursive( - expected_val: Any, - actual_val: Any, + expected_val: Checked, + actual_val: Checked, keys: List[Union[str, int]], strict: StrictSettingKinds = True, ) -> None: @@ -443,8 +462,8 @@ def _format_err(which): raise exceptions.KeyMismatchError(msg) from e if isinstance(expected_val, dict): - akeys = set(actual_val.keys()) ekeys = set(expected_val.keys()) + akeys = set(actual_val.keys()) # type:ignore if akeys != ekeys: extra_actual_keys = akeys - ekeys @@ -481,7 +500,10 @@ def _format_err(which): for key in to_recurse: try: check_keys_match_recursive( - expected_val[key], actual_val[key], keys + [key], strict + expected_val[key], + actual_val[key], # type:ignore + keys + [key], + strict, ) except KeyError: logger.debug( diff --git a/tavern/_core/extfunctions.py b/tavern/_core/extfunctions.py index ce2b972b..66b5347d 100644 --- a/tavern/_core/extfunctions.py +++ b/tavern/_core/extfunctions.py @@ -1,7 +1,7 @@ import functools import importlib import logging -from typing import Any, List, Mapping, Optional +from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple from tavern._core import exceptions @@ -16,7 +16,7 @@ def is_ext_function(block: Any) -> bool: block: Any object Returns: - bool: If it is an ext function style dict + If it is an ext function style dict """ return isinstance(block, dict) and block.get("$ext", None) is not None @@ -29,17 +29,20 @@ def get_pykwalify_logger(module: Optional[str]) -> logging.Logger: trying to get the root logger which won't log correctly Args: - module (string): name of module to get logger for + module: name of module to get logger for + Returns: + logger for given module """ return logging.getLogger(module) def _getlogger() -> logging.Logger: + """Get logger for this module""" return get_pykwalify_logger("tavern._core.extfunctions") -def import_ext_function(entrypoint: str): +def import_ext_function(entrypoint: str) -> Callable: """Given a function name in the form of a setuptools entry point, try to dynamically load and return it @@ -48,7 +51,7 @@ def import_ext_function(entrypoint: str): module.submodule:function Returns: - function: function loaded from entrypoint + function loaded from entrypoint Raises: InvalidExtFunctionError: If the module or function did not exist @@ -79,7 +82,7 @@ def import_ext_function(entrypoint: str): return function -def get_wrapped_response_function(ext: Mapping): +def get_wrapped_response_function(ext: Mapping) -> Callable: """Wraps a ext function with arguments given in the test file This is similar to functools.wrap, but this makes sure that 'response' is @@ -90,7 +93,7 @@ def get_wrapped_response_function(ext: Mapping): extra_kwargs to pass Returns: - function: Wrapped function + Wrapped function """ func, args, kwargs = _get_ext_values(ext) @@ -106,7 +109,7 @@ def inner(response): return inner -def get_wrapped_create_function(ext: Mapping): +def get_wrapped_create_function(ext: Mapping) -> Callable: """Same as get_wrapped_response_function, but don't require a response""" func, args, kwargs = _get_ext_values(ext) @@ -122,7 +125,7 @@ def inner(): return inner -def _get_ext_values(ext: Mapping): +def _get_ext_values(ext: Mapping) -> Tuple[Callable, Iterable, Mapping]: if not isinstance(ext, Mapping): raise exceptions.InvalidExtFunctionError( f"ext block should be a dict, but it was a {type(ext)}" diff --git a/tavern/_core/general.py b/tavern/_core/general.py index 51984c90..25658cb6 100644 --- a/tavern/_core/general.py +++ b/tavern/_core/general.py @@ -1,15 +1,15 @@ import logging import os -from typing import List +from typing import List, Union from tavern._core.loader import load_single_document_yaml from .dict_util import deep_dict_merge -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def load_global_config(global_cfg_paths: List[os.PathLike]) -> dict: +def load_global_config(global_cfg_paths: List[Union[str, os.PathLike]]) -> dict: """Given a list of file paths to global config files, load each of them and return the joined dictionary. diff --git a/tavern/_core/jmesutils.py b/tavern/_core/jmesutils.py index 7940b8b1..35055548 100644 --- a/tavern/_core/jmesutils.py +++ b/tavern/_core/jmesutils.py @@ -61,7 +61,7 @@ def safe_length(var: Sized) -> int: return -1 -def validate_comparison(each_comparison): +def validate_comparison(each_comparison: Dict[Any, Any]): if extra := set(each_comparison.keys()) - {"jmespath", "operator", "expected"}: raise exceptions.BadSchemaError( "Invalid keys given to JMES validation function (got extra keys: {})".format( diff --git a/tavern/_core/loader.py b/tavern/_core/loader.py index a4ceccc7..690a1ee1 100644 --- a/tavern/_core/loader.py +++ b/tavern/_core/loader.py @@ -6,12 +6,14 @@ import uuid from abc import abstractmethod from itertools import chain -from typing import List, Optional +from typing import List, Optional, Union import pytest import yaml +from _pytest.python_api import ApproxBase from yaml.composer import Composer from yaml.constructor import SafeConstructor +from yaml.nodes import Node, ScalarNode from yaml.parser import Parser from yaml.reader import Reader from yaml.resolver import Resolver @@ -21,25 +23,25 @@ from tavern._core.exceptions import BadSchemaError from tavern._core.strtobool import strtobool -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def makeuuid(loader, node): +def makeuuid(loader, node) -> str: return str(uuid.uuid4()) class RememberComposer(Composer): """A composer that doesn't forget anchors across documents""" - def compose_document(self): + def compose_document(self) -> Optional[Node]: # Drop the DOCUMENT-START event. - self.get_event() + self.get_event() # type:ignore # Compose the root node. - node = self.compose_node(None, None) + node = self.compose_node(None, None) # type:ignore # Drop the DOCUMENT-END event. - self.get_event() + self.get_event() # type:ignore # If we don't drop the anchors here, then we can keep anchors across # documents. @@ -106,8 +108,6 @@ class IncludeLoader( between documents""" def __init__(self, stream): - """Initialise Loader.""" - try: self._root = os.path.split(stream.name)[0] except AttributeError: @@ -140,7 +140,7 @@ def _get_include_dirs(loader): return chain(loader_list, IncludeLoader.env_path_list) -def find_include(loader, node): +def find_include(loader, node) -> str: """Locate an include file and return the abs path.""" for directory in _get_include_dirs(loader): filename = os.path.abspath( @@ -190,14 +190,14 @@ def constructor(_): raise NotImplementedError @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> "TypeSentinel": return cls() - def __str__(self): + def __str__(self) -> str: return f"" @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: node = yaml.nodes.ScalarNode(cls.yaml_tag, "", style=cls.yaml_flow_style) return node @@ -242,7 +242,7 @@ class RegexSentinel(TypeSentinel): constructor = str compiled: re.Pattern - def __str__(self): + def __str__(self) -> str: return f"" @property @@ -254,28 +254,28 @@ def passes(self, string): raise NotImplementedError @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> "RegexSentinel": return cls(re.compile(node.value)) class _RegexMatchSentinel(RegexSentinel): yaml_tag = "!re_match" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.match(string) is not None class _RegexFullMatchSentinel(RegexSentinel): yaml_tag = "!re_fullmatch" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.fullmatch(string) is not None class _RegexSearchSentinel(RegexSentinel): yaml_tag = "!re_search" - def passes(self, string): + def passes(self, string) -> bool: return self.compiled.search(string) is not None @@ -321,7 +321,7 @@ class TypeConvertToken(yaml.YAMLObject): def constructor(_): raise NotImplementedError - def __init__(self, value): + def __init__(self, value) -> None: self.value = value @classmethod @@ -338,7 +338,7 @@ def from_yaml(cls, loader, node): return converted @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: return yaml.nodes.ScalarNode( cls.yaml_tag, data.value, style=cls.yaml_flow_style ) @@ -357,7 +357,7 @@ class FloatToken(TypeConvertToken): class StrToBoolConstructor: """Using `bool` as a constructor directly will evaluate all strings to `True`.""" - def __new__(cls, s): + def __new__(cls, s: str) -> bool: # type:ignore return strtobool(s) @@ -407,7 +407,7 @@ class ApproxSentinel(yaml.YAMLObject, ApproxScalar): # type:ignore yaml_loader = IncludeLoader @classmethod - def from_yaml(cls, loader, node): + def from_yaml(cls, loader, node) -> ApproxBase: try: val = float(node.value) except (ValueError, TypeError) as e: @@ -420,7 +420,7 @@ def from_yaml(cls, loader, node): return pytest.approx(val) @classmethod - def to_yaml(cls, dumper, data): + def to_yaml(cls, dumper, data) -> ScalarNode: return yaml.nodes.ScalarNode( "!approx", str(data.expected), style=cls.yaml_flow_style ) @@ -430,7 +430,7 @@ def to_yaml(cls, dumper, data): yaml.dumper.Dumper.add_representer(ApproxScalar, ApproxSentinel.to_yaml) -def load_single_document_yaml(filename: os.PathLike) -> dict: +def load_single_document_yaml(filename: Union[str, os.PathLike]) -> dict: """ Load a yaml file and expect only one document diff --git a/tavern/_core/plugins.py b/tavern/_core/plugins.py index b296be78..e447c4fe 100644 --- a/tavern/_core/plugins.py +++ b/tavern/_core/plugins.py @@ -16,7 +16,7 @@ from tavern.request import BaseRequest from tavern.response import BaseResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class PluginHelperBase: @@ -57,7 +57,7 @@ def is_valid_reqresp_plugin(ext: stevedore.extension.Extension) -> bool: ext: class or module plugin object Returns: - bool: Whether the plugin has everything we need to use it + Whether the plugin has everything we need to use it """ required = [ # MQTTClient, requests.Session diff --git a/tavern/_core/pytest/config.py b/tavern/_core/pytest/config.py index 2e33620e..51b0b65e 100644 --- a/tavern/_core/pytest/config.py +++ b/tavern/_core/pytest/config.py @@ -6,7 +6,7 @@ from tavern._core.strict_util import StrictLevel -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass(frozen=True) diff --git a/tavern/_core/pytest/error.py b/tavern/_core/pytest/error.py index 0c7b65be..7fe619d5 100644 --- a/tavern/_core/pytest/error.py +++ b/tavern/_core/pytest/error.py @@ -1,15 +1,21 @@ +import dataclasses import json import logging import re +import typing from io import StringIO -from typing import List, Mapping, Optional +from typing import Any, Dict, List, Optional import yaml -from _pytest._code.code import FormattedExcinfo +from _pytest._code.code import FormattedExcinfo, TerminalRepr from _pytest._io import TerminalWriter from tavern._core import exceptions from tavern._core.dict_util import format_keys + +if typing.TYPE_CHECKING: + from tavern._core.pytest.item import YamlItem + from tavern._core.report import prepare_yaml from tavern._core.stage_lines import ( end_mark, @@ -18,22 +24,22 @@ start_mark, ) -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -class ReprdError: - def __init__(self, exce, item) -> None: - self.exce = exce - self.item = item +@dataclasses.dataclass +class ReprdError(TerminalRepr): + exce: Any + item: "YamlItem" - def _get_available_format_keys(self): + def _get_available_format_keys(self) -> Dict: """Try to get the format variables for the stage If we can't get the variable for this specific stage, just return the global config which will at least have some format variables Returns: - dict: variables for formatting test + variables for formatting test """ try: keys = self.exce._excinfo[1].test_block_config.variables @@ -137,7 +143,7 @@ def _print_test_stage( else: tw.line(line, white=True) - def _print_formatted_stage(self, tw: TerminalWriter, stage: Mapping) -> None: + def _print_formatted_stage(self, tw: TerminalWriter, stage: Dict) -> None: """Print the 'formatted' stage that Tavern will actually use to send the request/process the response @@ -155,10 +161,10 @@ def _print_formatted_stage(self, tw: TerminalWriter, stage: Mapping) -> None: ) # Replace formatted strings with strings for dumping - formatted_stage = prepare_yaml(formatted_stage) + prepared_stage = prepare_yaml(formatted_stage) # Dump formatted stage to YAML format - formatted_lines = yaml.dump(formatted_stage, default_flow_style=False).split( + formatted_lines = yaml.dump(prepared_stage, default_flow_style=False).split( "\n" ) diff --git a/tavern/_core/pytest/file.py b/tavern/_core/pytest/file.py index 49bbbf61..c0f690a1 100644 --- a/tavern/_core/pytest/file.py +++ b/tavern/_core/pytest/file.py @@ -2,11 +2,13 @@ import functools import itertools import logging -from typing import Dict, Iterator, List, Mapping +import typing +from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Tuple, Union import pytest import yaml from box import Box +from pytest import Mark from tavern._core import exceptions from tavern._core.dict_util import deep_dict_merge, format_keys, get_tavern_box @@ -17,23 +19,29 @@ from .item import YamlItem from .util import load_global_cfg -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -_format_without_inner = functools.partial(format_keys, no_double_format=False) +T = typing.TypeVar("T") +_format_without_inner: Callable[[T, Mapping], T] = functools.partial( + format_keys, no_double_format=False +) -def _format_test_marks(original_marks, fmt_vars, test_name): + +def _format_test_marks( + original_marks: Iterable[Union[str, dict]], fmt_vars: Mapping, test_name: str +) -> Tuple[List[Mark], List[Mapping]]: """Given the 'raw' marks from the test and any available format variables, generate new marks for this test Args: - original_marks (list): Raw string from test - should correspond to either a + original_marks: Raw string from test - should correspond to either a pytest builtin mark or a custom user mark - fmt_vars (dict): dictionary containing available format variables - test_name (str): Name of test (for error logging) + fmt_vars: dictionary containing available format variables + test_name: Name of test (for error logging) Returns: - tuple: first element is normal pytest mark objects, second element is all + first element is normal pytest mark objects, second element is all marks which were formatted (no matter their content) Todo: @@ -52,8 +60,8 @@ def _format_test_marks(original_marks, fmt_vars, test_name): """ - pytest_marks = [] - formatted_marks = [] + pytest_marks: List[Mark] = [] + formatted_marks: List[Mapping] = [] for m in original_marks: if isinstance(m, str): @@ -86,15 +94,54 @@ def _format_test_marks(original_marks, fmt_vars, test_name): return pytest_marks, formatted_marks -def _generate_parametrized_test_items(keys: List, vals_combination): +def _maybe_load_ext(pair): + """Try to load ext values""" + key, value = pair + + if is_ext_function(value): + # If it is an ext function, load the new (or supplemental) value[s] + ext = value.pop("$ext") + f = get_wrapped_create_function(ext) + new_value = f() + + if len(value) == 0: + # Use only this new value + return key, new_value + elif isinstance(new_value, dict): + # Merge with some existing data. At this point 'value' is known to be a dict. + return key, deep_dict_merge(value, f()) + else: + # For example, if it's defined like + # + # - testkey: testval + # $ext: + # function: mod:func + # + # and 'mod:func' returns a string, it's impossible to 'merge' with the existing data. + logger.error("Values still in 'val': %s", value) + raise exceptions.BadSchemaError( + "There were extra key/value pairs in the 'val' for this parametrize mark, but the ext function {} returned '{}' (of type {}) that was not a dictionary. It is impossible to merge these values.".format( + ext, new_value, type(new_value) + ) + ) + + return key, value + + +def _generate_parametrized_test_items( + keys: Iterable[Union[str, List, Tuple]], vals_combination: Iterable[Tuple[str, str]] +) -> Tuple[Mapping[str, Any], str]: """Generate test name from given key(s)/value(s) combination Args: keys: list of keys to format name with - vals_combination (tuple(str)): this combination of values for the key + vals_combination this combination of values for the key + + Returns: + tuple of the variables for the stage and the generated stage name """ - flattened_values = [] - variables = {} + flattened_values: List[Iterable[str]] = [] + variables: Dict[str, Any] = {} # combination of keys and the values they correspond to for pair in zip(keys, vals_combination): @@ -103,7 +150,7 @@ def _generate_parametrized_test_items(keys: List, vals_combination): # very weird looking if isinstance(key, str): variables[key] = value - flattened_values += [value] + flattened_values.append(value) else: if not isinstance(value, (list, tuple)): value = [value] @@ -111,47 +158,15 @@ def _generate_parametrized_test_items(keys: List, vals_combination): if len(value) != len(key): raise exceptions.BadSchemaError( "Invalid match between numbers of keys and number of values in parametrize mark ({} keys, {} values)".format( - (key), (value) + key, value ) ) for subkey, subvalue in zip(key, value): variables[subkey] = subvalue - flattened_values += [subvalue] - - def maybe_load_ext(v): - key, value = v - - if is_ext_function(value): - # If it is an ext function, load the new (or supplemental) value[s] - ext = value.pop("$ext") - f = get_wrapped_create_function(ext) - new_value = f() - - if len(value) == 0: - # Use only this new value - return key, new_value - elif isinstance(new_value, dict): - # Merge with some existing data. At this point 'value' is known to be a dict. - return key, deep_dict_merge(value, f()) - else: - # For example, if it's defined like - # - # - testkey: testval - # $ext: - # function: mod:func - # - # and 'mod:func' returns a string, it's impossible to 'merge' with the existing data. - logger.error("Values still in 'val': %s", value) - raise exceptions.BadSchemaError( - "There were extra key/value pairs in the 'val' for this parametrize mark, but the ext function {} returned '{}' (of type {}) that was not a dictionary. It is impossible to merge these values.".format( - ext, new_value, type(new_value) - ) - ) - - return key, value + flattened_values.append(subvalue) - variables = dict(map(maybe_load_ext, variables.items())) + variables = dict(map(_maybe_load_ext, variables.items())) logger.debug("Variables for this combination: %s", variables) logger.debug("Values for this combination: %s", flattened_values) @@ -205,7 +220,7 @@ def unwrap_map(value): "Invalid match between numbers of keys and number of values in parametrize mark" ) from e - keys = [i["parametrize"]["key"] for i in parametrize_marks] + keys: List[str] = [i["parametrize"]["key"] for i in parametrize_marks] for vals_combination in combined: logger.debug("Generating test for %s/%s", keys, vals_combination) @@ -346,7 +361,7 @@ def collect(self) -> Iterator[YamlItem]: try: # Convert to a list so we can catch parser exceptions - all_tests = list( + all_tests: Iterable[dict] = list( yaml.load_all( self.path.open(encoding="utf-8"), Loader=IncludeLoader, # type:ignore diff --git a/tavern/_core/pytest/hooks.py b/tavern/_core/pytest/hooks.py index 9a83ee9e..5fc962c3 100644 --- a/tavern/_core/pytest/hooks.py +++ b/tavern/_core/pytest/hooks.py @@ -3,7 +3,13 @@ import os import pathlib import re +import typing from textwrap import dedent +from typing import Optional + +if typing.TYPE_CHECKING: + from .file import YamlFile + import pytest import yaml @@ -18,7 +24,7 @@ def pytest_addoption(parser: pytest.Parser) -> None: add_ini_options(parser) -def pytest_collect_file(parent, path: os.PathLike): +def pytest_collect_file(parent, path: os.PathLike) -> Optional["YamlFile"]: """On collecting files, get any files that end in .tavern.yaml or .tavern.yml as tavern test files """ diff --git a/tavern/_core/pytest/item.py b/tavern/_core/pytest/item.py index 2c073e8a..6e14984a 100644 --- a/tavern/_core/pytest/item.py +++ b/tavern/_core/pytest/item.py @@ -1,11 +1,11 @@ import logging import pathlib -from typing import MutableMapping, Optional, Tuple +from typing import MutableMapping, Optional, Tuple, Union import attr import pytest import yaml -from _pytest._code.code import ExceptionInfo +from _pytest._code.code import ExceptionInfo, TerminalRepr from _pytest.nodes import Node from tavern._core import exceptions @@ -20,7 +20,7 @@ from .config import TestConfig from .util import load_global_cfg -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class YamlItem(pytest.Item): @@ -33,11 +33,14 @@ class YamlItem(pytest.Item): Attributes: path: filename that this test came from spec: The whole dictionary of the test + global_cfg: configuration for test """ # See https://github.com/taverntesting/tavern/issues/825 _patched_yaml = False + global_cfg: TestConfig + def __init__( self, *, name: str, parent, spec: MutableMapping, path: pathlib.Path, **kwargs ) -> None: @@ -48,8 +51,6 @@ def __init__( self.path = path self.spec = spec - self.global_cfg: Optional[TestConfig] = None - if not YamlItem._patched_yaml: yaml.parser.Parser.process_empty_scalar = ( # type:ignore error_on_empty_scalar @@ -253,7 +254,7 @@ def runtest(self) -> None: def repr_failure( self, excinfo: ExceptionInfo[BaseException], style: Optional[str] = None - ): + ) -> Union[TerminalRepr, str, ReprdError]: """called when self.runtest() raises an exception. By default, will raise a custom formatted traceback if it's a tavern error. if not, will use the default diff --git a/tavern/_core/pytest/newhooks.py b/tavern/_core/pytest/newhooks.py index 8fd494b2..b76955af 100644 --- a/tavern/_core/pytest/newhooks.py +++ b/tavern/_core/pytest/newhooks.py @@ -1,9 +1,12 @@ import logging +from typing import Any, Dict, MutableMapping -logger = logging.getLogger(__name__) +from tavern._core.pytest.config import TestConfig +logger: logging.Logger = logging.getLogger(__name__) -def pytest_tavern_beta_before_every_test_run(test_dict, variables) -> None: + +def pytest_tavern_beta_before_every_test_run(test_dict: Dict, variables: Dict) -> None: """Called: - directly after fixtures are loaded for a test @@ -15,23 +18,23 @@ def pytest_tavern_beta_before_every_test_run(test_dict, variables) -> None: Modify the test in-place if you want to do something to it. Args: - test_dict (dict): Test to run - variables (dict): Available variables + test_dict: Test to run + variables: Available variables """ -def pytest_tavern_beta_after_every_test_run(test_dict, variables) -> None: +def pytest_tavern_beta_after_every_test_run(test_dict: Dict, variables: Dict) -> None: """Called: - After test run Args: - test_dict (dict): Test to run - variables (dict): Available variables + test_dict: Test to run + variables: Available variables """ -def pytest_tavern_beta_after_every_response(expected, response) -> None: +def pytest_tavern_beta_after_every_response(expected: Any, response: Any) -> None: """Called after every _response_ - including MQTT/HTTP/etc Note: @@ -39,24 +42,24 @@ def pytest_tavern_beta_after_every_response(expected, response) -> None: - MQTT responses will call this hook multiple times if multiple messages are received Args: - response (object): Response object. - expected (dict): Response block in stage + expected: Response block in stage + response: Response object. """ -def pytest_tavern_beta_before_every_request(request_args) -> None: +def pytest_tavern_beta_before_every_request(request_args: MutableMapping) -> None: """Called just before every request - including MQTT/HTTP/etc Note: - The request object type depends on what plugin you're using, and which kind of request it is! Args: - request_args (dict): Arguments passed to the request function. By default, this is Session.request for + request_args: Arguments passed to the request function. By default, this is Session.request for HTTP and Client.publish for MQTT """ -def call_hook(test_block_config, hookname, **kwargs) -> None: +def call_hook(test_block_config: TestConfig, hookname: str, **kwargs) -> None: """Utility to call the hooks""" try: hook = getattr(test_block_config.tavern_internal.pytest_hook_caller, hookname) diff --git a/tavern/_core/pytest/util.py b/tavern/_core/pytest/util.py index 30343114..03416d2c 100644 --- a/tavern/_core/pytest/util.py +++ b/tavern/_core/pytest/util.py @@ -1,6 +1,7 @@ import logging from functools import lru_cache -from typing import Any, Dict +from pathlib import Path +from typing import Any, Dict, List, Optional, TypeVar, Union import pytest @@ -9,7 +10,7 @@ from tavern._core.pytest.config import TavernInternalConfig, TestConfig from tavern._core.strict_util import StrictLevel -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def add_parser_options(parser_addoption, with_defaults: bool = True) -> None: @@ -151,11 +152,11 @@ def _load_global_cfg(pytest_config: pytest.Config) -> TestConfig: all_paths = ini_global_cfg_paths + cmdline_global_cfg_paths global_cfg_dict = load_global_config(all_paths) + variables: Dict = {} try: loaded_variables = global_cfg_dict["variables"] except KeyError: logger.debug("Nothing to format in global config files") - variables = {} else: tavern_box = get_tavern_box() variables = format_keys(loaded_variables, tavern_box) @@ -176,30 +177,33 @@ def _load_global_cfg(pytest_config: pytest.Config) -> TestConfig: def _load_global_backends(pytest_config: pytest.Config) -> Dict[str, Any]: """Load which backend should be used""" - backend_settings = {} - - for b in TestConfig.backends(): - backend_settings[b] = get_option_generic( - pytest_config, f"tavern-{b}-backend", None - ) - - return backend_settings + return { + b: get_option_generic(pytest_config, f"tavern-{b}-backend", None) + for b in TestConfig.backends() + } def _load_global_strictness(pytest_config: pytest.Config) -> StrictLevel: """Load the global 'strictness' setting""" - options = get_option_generic(pytest_config, "tavern-strict", []) + options: List = get_option_generic(pytest_config, "tavern-strict", []) return StrictLevel.from_options(options) -def _load_global_follow_redirects(pytest_config: pytest.Config): +def _load_global_follow_redirects(pytest_config: pytest.Config) -> bool: """Load the global 'follow redirects' setting""" return get_option_generic(pytest_config, "tavern-always-follow-redirects", False) -def get_option_generic(pytest_config: pytest.Config, flag: str, default): +T = TypeVar("T", bound=Optional[Union[str, List, List[Path], List[str], bool]]) + + +def get_option_generic( + pytest_config: pytest.Config, + flag: str, + default: T, +) -> T: """Get a configuration option or return the default Priority order is cmdline, then ini, then default""" diff --git a/tavern/_core/report.py b/tavern/_core/report.py index 432d9abe..2bfba54b 100644 --- a/tavern/_core/report.py +++ b/tavern/_core/report.py @@ -1,5 +1,6 @@ import logging from textwrap import dedent +from typing import Dict, List, Set, Tuple, Union import yaml @@ -24,29 +25,29 @@ def call(step_func): from tavern._core.formatted_str import FormattedString from tavern._core.stage_lines import get_stage_lines, read_relevant_lines -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def prepare_yaml(val): +def prepare_yaml(val: Union[Dict, Set, List, Tuple, str]) -> Union[Dict, List, str]: """Sanitises the formatted string into a format safe for dumping""" - formatted = val - if isinstance(val, dict): - formatted = {} + prepared = {} # formatted = {key: format_keys(val[key], box_vars) for key in val} for key in val: if isinstance(key, FormattedString): key = str(key) - formatted[key] = prepare_yaml(val[key]) + prepared[key] = prepare_yaml(val[key]) + + return prepared elif isinstance(val, (list, tuple, set)): - formatted = [prepare_yaml(item) for item in val] - elif isinstance(formatted, FormattedString): - return str(formatted) + return [prepare_yaml(item) for item in val] + elif isinstance(val, FormattedString): + return str(val) - return formatted + return val -def attach_stage_content(stage) -> None: +def attach_stage_content(stage: Dict) -> None: first_line, last_line, _ = get_stage_lines(stage) code_lines = list(read_relevant_lines(stage, first_line, last_line)) @@ -54,15 +55,15 @@ def attach_stage_content(stage) -> None: attach_text(joined, "stage_yaml", yaml_type) -def attach_yaml(payload, name): +def attach_yaml(payload, name: str) -> None: prepared = prepare_yaml(payload) dumped = yaml.safe_dump(prepared) return attach_text(dumped, name, yaml_type) -def attach_text(payload, name, attachment_type=None) -> None: +def attach_text(payload, name: str, attachment_type=None) -> None: return attach(payload, name=name, attachment_type=attachment_type) -def wrap_step(allure_name, partial): +def wrap_step(allure_name: str, partial): return step(allure_name)(partial) diff --git a/tavern/_core/run.py b/tavern/_core/run.py index be6bb984..91474d38 100644 --- a/tavern/_core/run.py +++ b/tavern/_core/run.py @@ -27,7 +27,7 @@ from .testhelpers import delay, retry from .tincture import Tinctures, get_stage_tinctures -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def _resolve_test_stages(test_spec: Mapping, available_stages: Mapping): @@ -67,10 +67,10 @@ def _get_included_stages( for use in this test Args: - available_stages: List of stages which already exist tavern_box: Available parameters for fomatting at this point test_block_config: Current test config dictionary test_spec: Specification for current test + available_stages: List of stages which already exist Returns: Fully resolved stages @@ -276,7 +276,7 @@ class _TestRunner: test_block_config: TestConfig test_spec: Mapping - def run_stage(self, idx: int, stage, *, is_final: bool = False): + def run_stage(self, idx: int, stage, *, is_final: bool = False) -> None: tinctures = get_stage_tinctures(stage, self.test_spec) stage_config = self.test_block_config.with_strictness( @@ -305,7 +305,7 @@ def run_stage(self, idx: int, stage, *, is_final: bool = False): def wrapped_run_stage( self, stage: dict, stage_config: TestConfig, tinctures: Tinctures - ): + ) -> None: """Run one stage from the test Args: diff --git a/tavern/_core/schema/extensions.py b/tavern/_core/schema/extensions.py index 6d178e53..44bc7075 100644 --- a/tavern/_core/schema/extensions.py +++ b/tavern/_core/schema/extensions.py @@ -1,6 +1,6 @@ import os import re -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Tuple, Type, Union from pykwalify.types import is_bool, is_float, is_int @@ -12,7 +12,13 @@ is_ext_function, ) from tavern._core.general import valid_http_methods -from tavern._core.loader import ApproxScalar, BoolToken, FloatToken, IntToken +from tavern._core.loader import ( + ApproxScalar, + BoolToken, + FloatToken, + IntToken, + TypeConvertToken, +) from tavern._core.strict_util import StrictLevel if TYPE_CHECKING: @@ -21,7 +27,9 @@ # To extend pykwalify's type validation, extend its internal functions # These return boolean values -def validate_type_and_token(validate_type, token): +def validate_type_and_token( + validate_type: Callable[[Any], bool], token: Type[TypeConvertToken] +): def validate(value): return validate_type(value) or isinstance(value, token) @@ -34,7 +42,7 @@ def validate(value): # These plug into the pykwalify extension function API -def validator_like(validate, description): +def validator_like(validate: Callable[[Any], bool], description: str): def validator(value, rule_obj, path): if validate(value): return True @@ -52,7 +60,7 @@ def validator(value, rule_obj, path): bool_variable = validator_like(is_bool_like, "bool-like") -def _validate_one_extension(input_value) -> None: +def _validate_one_extension(input_value: Mapping) -> None: expected_keys = {"function", "extra_args", "extra_kwargs"} extra = set(input_value) - expected_keys @@ -105,7 +113,7 @@ def validate_extensions(value, rule_obj, path) -> bool: return True -def validate_status_code_is_int_or_list_of_ints(value, rule_obj, path) -> bool: +def validate_status_code_is_int_or_list_of_ints(value: Mapping, rule_obj, path) -> bool: err_msg = "status_code has to be an integer or a list of integers (got {})".format( value ) @@ -120,7 +128,7 @@ def validate_status_code_is_int_or_list_of_ints(value, rule_obj, path) -> bool: return True -def check_usefixtures(value, rule_obj, path) -> bool: +def check_usefixtures(value: Mapping, rule_obj, path) -> bool: err_msg = "'usefixtures' has to be a list with at least one item" if not isinstance(value, (list, tuple)): @@ -132,7 +140,9 @@ def check_usefixtures(value, rule_obj, path) -> bool: return True -def validate_grpc_status_is_valid_or_list_of_names(value: "GRPCCode", rule_obj, path): +def validate_grpc_status_is_valid_or_list_of_names( + value: "GRPCCode", rule_obj, path +) -> bool: """Validate GRPC statuses https://github.com/grpc/grpc/blob/master/doc/statuscodes.md""" # pylint: disable=unused-argument err_msg = ( @@ -169,11 +179,10 @@ def to_grpc_status(value: Union[str, int]): return None -def verify_oneof_id_name(value, rule_obj, path) -> bool: +def verify_oneof_id_name(value: Mapping, rule_obj, path) -> bool: """Checks that if 'name' is not present, 'id' is""" - name = value.get("name") - if not name: + if not (name := value.get("name")): if name == "": raise BadSchemaError("Name cannot be empty") @@ -244,7 +253,7 @@ def check_parametrize_marks(value, rule_obj, path) -> bool: return True -def validate_data_key(value, rule_obj, path) -> bool: +def validate_data_key(value, rule_obj, path: str) -> bool: """Validate the 'data' key in a http request From requests docs: @@ -324,7 +333,7 @@ def validate_json_with_ext(value, rule_obj, path) -> bool: return True -def check_strict_key(value, rule_obj, path) -> bool: +def check_strict_key(value: Union[List, bool], rule_obj, path) -> bool: """Make sure the 'strict' key is either a bool or a list""" if not isinstance(value, list) and not is_bool_like(value): @@ -341,7 +350,7 @@ def check_strict_key(value, rule_obj, path) -> bool: return True -def validate_timeout_tuple_or_float(value, rule_obj, path) -> bool: +def validate_timeout_tuple_or_float(value: Union[List, Tuple], rule_obj, path) -> bool: """Make sure timeout is a float/int or a tuple of floats/ints""" err_msg = "'timeout' must be either a float/int or a 2-tuple of floats/ints - got '{}' (type {})".format( @@ -396,7 +405,7 @@ def validate_cert_tuple_or_str(value, rule_obj, path) -> bool: return True -def validate_file_spec(value, rule_obj, path) -> bool: +def validate_file_spec(value: Dict, rule_obj, path) -> bool: """Validate file upload arguments""" logger = get_pykwalify_logger("tavern.schema.extensions") diff --git a/tavern/_core/schema/files.py b/tavern/_core/schema/files.py index 8f801d6d..a2661c37 100644 --- a/tavern/_core/schema/files.py +++ b/tavern/_core/schema/files.py @@ -14,7 +14,7 @@ from tavern._core.plugins import load_plugins from tavern._core.schema.jsonschema import verify_jsonschema -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class SchemaCache: @@ -66,7 +66,7 @@ def __call__(self, schema_filename, with_plugins): with_plugins (bool): Whether to load plugin schema into this schema as well Returns: - dict: loaded schema + loaded schema """ if with_plugins: diff --git a/tavern/_core/schema/jsonschema.py b/tavern/_core/schema/jsonschema.py index 022f8af8..c439a732 100644 --- a/tavern/_core/schema/jsonschema.py +++ b/tavern/_core/schema/jsonschema.py @@ -35,7 +35,7 @@ read_relevant_lines, ) -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def is_str_or_bytes_or_token(checker, instance): diff --git a/tavern/_core/stage_lines.py b/tavern/_core/stage_lines.py index 0f943a54..7c6f9722 100644 --- a/tavern/_core/stage_lines.py +++ b/tavern/_core/stage_lines.py @@ -1,9 +1,37 @@ +import dataclasses import logging +from typing import ( + Iterable, + Mapping, + Optional, + Protocol, + Tuple, + Type, + Union, +) -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def get_stage_lines(stage): +@dataclasses.dataclass +class YamlMark: + """A pyyaml mark""" + + line: int = 0 + name: Optional[str] = None + + +class _WithMarks(Protocol): + """Things loaded by pyyaml have these""" + + start_mark: YamlMark + end_mark: YamlMark + + +PyYamlDict = Union[_WithMarks, Mapping] + + +def get_stage_lines(stage: PyYamlDict) -> Tuple[int, int, int]: first_line = start_mark(stage).line - 1 last_line = end_mark(stage).line line_start = first_line + 1 @@ -11,7 +39,9 @@ def get_stage_lines(stage): return first_line, last_line, line_start -def read_relevant_lines(yaml_block, first_line, last_line): +def read_relevant_lines( + yaml_block: PyYamlDict, first_line: int, last_line: int +) -> Iterable[str]: """Get lines between start and end mark""" filename = get_stage_filename(yaml_block) @@ -26,24 +56,19 @@ def read_relevant_lines(yaml_block, first_line, last_line): yield line.split("#", 1)[0].rstrip() -def get_stage_filename(yaml_block): +def get_stage_filename(yaml_block: PyYamlDict) -> Optional[str]: return start_mark(yaml_block).name -class EmptyBlock: - line = 0 - name = None - - -def start_mark(yaml_block): +def start_mark(yaml_block: PyYamlDict) -> Union[Type[YamlMark], YamlMark]: try: - return yaml_block.start_mark + return yaml_block.start_mark # type:ignore except AttributeError: - return EmptyBlock + return YamlMark() -def end_mark(yaml_block): +def end_mark(yaml_block: PyYamlDict) -> Union[Type[YamlMark], YamlMark]: try: - return yaml_block.end_mark + return yaml_block.end_mark # type:ignore except AttributeError: - return EmptyBlock + return YamlMark() diff --git a/tavern/_core/strict_util.py b/tavern/_core/strict_util.py index 76d090d8..bc7c63ee 100644 --- a/tavern/_core/strict_util.py +++ b/tavern/_core/strict_util.py @@ -7,7 +7,7 @@ from tavern._core import exceptions from tavern._core.strtobool import strtobool -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class StrictSetting(enum.Enum): diff --git a/tavern/_core/testhelpers.py b/tavern/_core/testhelpers.py index 2c72b49f..3c294a4c 100644 --- a/tavern/_core/testhelpers.py +++ b/tavern/_core/testhelpers.py @@ -1,13 +1,13 @@ import logging import time from functools import wraps -from typing import Mapping +from typing import Callable, Mapping, Union from tavern._core import exceptions from tavern._core.dict_util import format_keys from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def delay(stage: Mapping, when: str, variables: Mapping) -> None: @@ -28,7 +28,7 @@ def delay(stage: Mapping, when: str, variables: Mapping) -> None: time.sleep(length) -def retry(stage: Mapping, test_block_config: TestConfig): +def retry(stage: Mapping, test_block_config: TestConfig) -> Callable: """Look for retry and try to repeat the stage `retry` times. Args: @@ -36,10 +36,8 @@ def retry(stage: Mapping, test_block_config: TestConfig): test_block_config: Configuration for current test """ - if "max_retries" in stage: - max_retries = maybe_format_max_retries( - stage.get("max_retries"), test_block_config - ) + if r := stage.get("max_retries", None): + max_retries = maybe_format_max_retries(r, test_block_config) else: max_retries = 0 @@ -101,11 +99,13 @@ def wrapped(*args, **kwargs): return retry_wrapper -def maybe_format_max_retries(max_retries, test_block_config: TestConfig) -> int: +def maybe_format_max_retries( + max_retries: Union[str, int], test_block_config: TestConfig +) -> int: """Possibly handle max_retries validation""" # Probably a format variable, or just invalid (in which case it will fail further down) - max_retries = format_keys(max_retries, test_block_config.variables) + max_retries = int(format_keys(max_retries, test_block_config.variables)) # type:ignore # Missing type token will mean that max_retries is still a string and will fail here # Could auto convert here as well, but keep it consistent and just fail diff --git a/tavern/_core/tincture.py b/tavern/_core/tincture.py index f1fe746f..21280239 100644 --- a/tavern/_core/tincture.py +++ b/tavern/_core/tincture.py @@ -1,27 +1,27 @@ import collections.abc +import dataclasses import inspect import logging -from typing import Any, List +from typing import Any, Generator, List from tavern._core import exceptions from tavern._core.extfunctions import get_wrapped_response_function -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) +@dataclasses.dataclass class Tinctures: - def __init__(self, tinctures: List[Any]): - self._tinctures = tinctures - self._needs_response: List[Any] = [] + tinctures: List[Any] + needs_response: List[Generator] = dataclasses.field(default_factory=list) - def start_tinctures(self, stage: collections.abc.Mapping): - results = [t(stage) for t in self._tinctures] - self._needs_response = [] + def start_tinctures(self, stage: collections.abc.Mapping) -> None: + results = [t(stage) for t in self.tinctures] for r in results: if inspect.isgenerator(r): # Store generator and start it - self._needs_response.append(r) + self.needs_response.append(r) next(r) def end_tinctures(self, expected: collections.abc.Mapping, response) -> None: @@ -29,14 +29,15 @@ def end_tinctures(self, expected: collections.abc.Mapping, response) -> None: Send the response object to any tinctures that want it Args: - response: The response from 'run' for the stage + expected: 'expected' from initial test - type varies depending on backend + response: The response from 'run' for the stage - type varies depending on backend """ - if self._needs_response is None: + if self.needs_response is None: raise RuntimeError( "should not be called before accumulating tinctures which need a response" ) - for n in self._needs_response: + for n in self.needs_response: try: n.send((expected, response)) except StopIteration: diff --git a/tavern/_plugins/grpc/client.py b/tavern/_plugins/grpc/client.py index 04d7b428..8e51b110 100644 --- a/tavern/_plugins/grpc/client.py +++ b/tavern/_plugins/grpc/client.py @@ -22,7 +22,7 @@ from tavern._core.dict_util import check_expected_keys from tavern._plugins.grpc.protos import _generate_proto_import, _import_grpc_module -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -39,7 +39,7 @@ class _ChannelVals: class GRPCClient: - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: logger.debug("Initialising GRPC client with %s", kwargs) expected_blocks = { "connect": {"host", "port", "options", "timeout", "secure"}, @@ -99,7 +99,7 @@ def __init__(self, **kwargs): def _register_file_descriptor( self, service_proto: grpc_reflection.v1alpha.reflection_pb2.FileDescriptorResponse, - ): + ) -> None: for file_descriptor_proto in service_proto.file_descriptor_proto: descriptor = descriptor_pb2.FileDescriptorProto() descriptor.ParseFromString(file_descriptor_proto) @@ -107,7 +107,7 @@ def _register_file_descriptor( def _get_reflection_info( self, channel, service_name: Optional[str] = None, file_by_filename=None - ): + ) -> None: logger.debug( "Getting GRPC protobuf for service %s from reflection", service_name ) @@ -239,8 +239,9 @@ def _make_call_request( return self._get_grpc_service(channel, service, method) - def __enter__(self): + def __enter__(self) -> "GRPCClient": logger.debug("Connecting to GRPC") + return self def call( self, @@ -282,7 +283,7 @@ def call( request, metadata=self._metadata, timeout=timeout ) - def __exit__(self, *args): + def __exit__(self, *args) -> None: logger.debug("Disconnecting from GRPC") for v in self.channels.values(): v.close() diff --git a/tavern/_plugins/grpc/protos.py b/tavern/_plugins/grpc/protos.py index b70aa238..e34503b2 100644 --- a/tavern/_plugins/grpc/protos.py +++ b/tavern/_plugins/grpc/protos.py @@ -13,7 +13,7 @@ from tavern._core import exceptions -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @functools.lru_cache @@ -31,7 +31,7 @@ def find_protoc() -> str: @functools.lru_cache -def _generate_proto_import(source: str): +def _generate_proto_import(source: str) -> None: """Invokes the Protocol Compiler to generate a _pb2.py from the given .proto file. Does nothing if the output already exists and is newer than the input. @@ -101,7 +101,7 @@ def sanitise(s): _import_grpc_module(output) -def _import_grpc_module(python_module_name: str): +def _import_grpc_module(python_module_name: str) -> None: """takes an expected python module name and tries to import the relevant file, adding service to the symbol database. """ diff --git a/tavern/_plugins/grpc/request.py b/tavern/_plugins/grpc/request.py index 6fe311ba..706f0bb5 100644 --- a/tavern/_plugins/grpc/request.py +++ b/tavern/_plugins/grpc/request.py @@ -3,7 +3,7 @@ import json import logging import warnings -from typing import Mapping, Union +from typing import Dict, Union import grpc from box import Box @@ -14,10 +14,10 @@ from tavern._plugins.grpc.client import GRPCClient from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def get_grpc_args(rspec, test_block_config): +def get_grpc_args(rspec: Dict, test_block_config: TestConfig) -> Dict: """Format GRPC request args""" fspec = format_keys(rspec, test_block_config.variables) @@ -50,8 +50,8 @@ class GRPCRequest(BaseRequest): _warned = False def __init__( - self, client: GRPCClient, request_spec: Mapping, test_block_config: TestConfig - ): + self, client: GRPCClient, request_spec: Dict, test_block_config: TestConfig + ) -> None: if not self._warned: warnings.warn( "Tavern gRPC support is experimental and will be updated in a future release.", @@ -87,5 +87,5 @@ def run(self) -> WrappedFuture: raise exceptions.GRPCRequestException from e @property - def request_vars(self): + def request_vars(self) -> Box: return Box(self._original_request_vars) diff --git a/tavern/_plugins/grpc/response.py b/tavern/_plugins/grpc/response.py index 8fbc5291..b9c8849d 100644 --- a/tavern/_plugins/grpc/response.py +++ b/tavern/_plugins/grpc/response.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, Any, List, Mapping, TypedDict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, TypedDict, Union import proto.message from google.protobuf import json_format @@ -16,7 +16,7 @@ if TYPE_CHECKING: from tavern._plugins.grpc.request import WrappedFuture -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) GRPCCode = Union[str, int, List[str], List[int]] @@ -48,7 +48,7 @@ def __init__( name: str, expected: Union[_GRPCExpected, Mapping], test_block_config: TestConfig, - ): + ) -> None: check_expected_keys({"body", "status", "details"}, expected) super().__init__(name, expected, test_block_config) @@ -60,7 +60,7 @@ def __str__(self): else: return "" - def _validate_block(self, blockname: str, block: Mapping): + def _validate_block(self, blockname: str, block: Mapping) -> None: """Validate a block of the response Args: @@ -92,7 +92,7 @@ def verify(self, response: "WrappedFuture") -> Mapping: logger.debug(f"grpc details: {grpc_response.details()}") # Get any keys to save - saved = {} + saved: Dict[str, Any] = {} verify_status = [StatusCode.OK.name] if status := self.expected.get("status", None): verify_status = _to_grpc_name(status) # type: ignore diff --git a/tavern/_plugins/grpc/tavernhook.py b/tavern/_plugins/grpc/tavernhook.py index 56a4f3ff..29ec38df 100644 --- a/tavern/_plugins/grpc/tavernhook.py +++ b/tavern/_plugins/grpc/tavernhook.py @@ -10,7 +10,7 @@ from .request import GRPCRequest from .response import GRPCResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) session_type = GRPCClient @@ -19,7 +19,9 @@ request_block_name = "grpc_request" -def get_expected_from_request(response_block, test_block_config: TestConfig, session): +def get_expected_from_request( + response_block, test_block_config: TestConfig, session: GRPCClient +): f_expected = format_keys(response_block, test_block_config.variables) expected = f_expected @@ -29,6 +31,6 @@ def get_expected_from_request(response_block, test_block_config: TestConfig, ses verifier_type = GRPCResponse response_block_name = "grpc_response" -schema_path = join(abspath(dirname(__file__)), "jsonschema.yaml") +schema_path: str = join(abspath(dirname(__file__)), "jsonschema.yaml") with open(schema_path) as schema_file: schema = yaml.load(schema_file, Loader=yaml.SafeLoader) diff --git a/tavern/_plugins/mqtt/client.py b/tavern/_plugins/mqtt/client.py index b23a0ab1..7ac4a5a9 100644 --- a/tavern/_plugins/mqtt/client.py +++ b/tavern/_plugins/mqtt/client.py @@ -5,9 +5,10 @@ import threading import time from queue import Empty, Full, Queue -from typing import Dict, List, Mapping, MutableMapping, Optional +from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Union import paho.mqtt.client as paho +from paho.mqtt.client import MQTTMessageInfo from tavern._core import exceptions from tavern._core.dict_util import check_expected_keys @@ -33,7 +34,7 @@ 15: "MQTT_ERR_QUEUE_SIZE", } -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass @@ -88,7 +89,7 @@ def _handle_ssl_context_args( def _check_and_update_common_tls_args( tls_args: MutableMapping, check_file_keys: List[str] -): +) -> None: """Checks common args between ssl/tls args""" # could be moved to schema validation stage @@ -289,7 +290,9 @@ def __init__(self, **kwargs) -> None: self._client.user_data_set(self._userdata) @staticmethod - def _on_message(client, userdata, message) -> None: + def _on_message( + client, userdata: Mapping[str, Any], message: paho.MQTTMessage + ) -> None: """Add any messages received to the queue Todo: @@ -311,7 +314,7 @@ def _on_message(client, userdata, message) -> None: logger.exception("message queue full") @staticmethod - def _on_connect(client, userdata, flags, rc) -> None: + def _on_connect(client, userdata, flags, rc: int) -> None: logger.debug( "Client '%s' connected to the broker with result code '%s'", client._client_id.decode(), @@ -319,7 +322,7 @@ def _on_connect(client, userdata, flags, rc) -> None: ) @staticmethod - def _on_disconnect(client, userdata, rc) -> None: + def _on_disconnect(client, userdata, rc: int) -> None: if rc == paho.CONNACK_ACCEPTED: logger.debug( "Client '%s' successfully disconnected from the broker with result code '%s'", @@ -347,16 +350,18 @@ def _on_socket_open(client, userdata, socket) -> None: def _on_socket_close(client, userdata, socket) -> None: logger.debug("MQTT socket closed") - def message_received(self, topic: str, timeout: int = 1): + def message_received( + self, topic: str, timeout: Union[float, int] = 1 + ) -> Optional[paho.MQTTMessage]: """Check that a message is in the message queue Args: - topic (str): topic to fetch message for - timeout (int): How long to wait before signalling that the message + topic: topic to fetch message for + timeout: How long to wait before signalling that the message was not received. Returns: - paho.MQTTMessage: whether the message was received within the timeout + the message, if one was received, otherwise None Todo: Allow regexes for topic names? Better validation for mqtt payloads @@ -377,8 +382,12 @@ def message_received(self, topic: str, timeout: int = 1): return msg def publish( - self, topic: str, payload=None, qos=None, retain=None - ) -> paho.MQTTMessageInfo: + self, + topic: str, + payload: Optional[Any] = None, + qos: Optional[int] = None, + retain: Optional[bool] = False, + ) -> MQTTMessageInfo: """publish message using paho library""" self._wait_for_subscriptions() diff --git a/tavern/_plugins/mqtt/request.py b/tavern/_plugins/mqtt/request.py index 0a9de87a..6f1705fd 100644 --- a/tavern/_plugins/mqtt/request.py +++ b/tavern/_plugins/mqtt/request.py @@ -13,7 +13,7 @@ from tavern._plugins.mqtt.client import MQTTClient from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def get_publish_args(rspec: Dict, test_block_config: TestConfig) -> dict: diff --git a/tavern/_plugins/mqtt/response.py b/tavern/_plugins/mqtt/response.py index afc60ffb..305e8496 100644 --- a/tavern/_plugins/mqtt/response.py +++ b/tavern/_plugins/mqtt/response.py @@ -13,33 +13,42 @@ from tavern._core import exceptions from tavern._core.dict_util import check_keys_match_recursive from tavern._core.loader import ANYTHING +from tavern._core.pytest.config import TestConfig from tavern._core.pytest.newhooks import call_hook from tavern._core.report import attach_yaml -from tavern._core.strict_util import StrictSetting +from tavern._core.strict_util import StrictOption from tavern.response import BaseResponse from .client import MQTTClient -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) _default_timeout = 1 class MQTTResponse(BaseResponse): - def __init__(self, client: MQTTClient, name, expected, test_block_config) -> None: + response: MQTTMessage + + def __init__( + self, + client: MQTTClient, + name: str, + expected: TestConfig, + test_block_config: TestConfig, + ) -> None: super().__init__(name, expected, test_block_config) self._client = client - self.received_messages = [] # type: ignore + self.received_messages: List = [] - def __str__(self): + def __str__(self) -> str: if self.response: - return self.response.payload + return self.response.payload.decode("utf-8") else: return "" - def verify(self, response) -> dict: + def verify(self, response: MQTTMessage) -> Mapping: """Ensure mqtt message has arrived Args: @@ -53,11 +62,11 @@ def verify(self, response) -> dict: finally: self._client.unsubscribe_all() - def _await_response(self) -> dict: + def _await_response(self) -> Mapping: """Actually wait for response Returns: - dict: things to save to variables for the rest of this test + things to save to variables for the rest of this test """ # Get into class with metadata attached @@ -102,7 +111,7 @@ def _await_response(self) -> dict: failures=self.errors, ) - saved = {} + saved: Dict = {} for msg in correct_messages: # Check saving things from the payload and from json @@ -145,7 +154,7 @@ def _await_messages_on_topic( expected: expected response for this block Returns: - tuple(msg, list): The correct message (if any) and warnings from processing the message + The correct message (if any) and warnings from processing the message """ timeout = max(m.get("timeout", _default_timeout) for m in expected) @@ -228,12 +237,12 @@ def _await_messages_on_topic( class _ReturnedMessage: """An actual message returned from the API and it's matching 'expected' block.""" - expected: dict + expected: Mapping msg: MQTTMessage class _MessageVerifier: - def __init__(self, test_block_config, expected) -> None: + def __init__(self, test_block_config: TestConfig, expected: Mapping) -> None: self.expires = time.time() + expected.get("timeout", _default_timeout) self.expected = expected @@ -242,7 +251,7 @@ def __init__(self, test_block_config, expected) -> None: ) test_strictness = test_block_config.strict - self.block_strictness: StrictSetting = test_strictness.option_for("json") + self.block_strictness: StrictOption = test_strictness.option_for("json") # Any warnings to do with the request # eg, if a message was received but it didn't match, message had payload, etc. @@ -322,7 +331,7 @@ def _get_payload_vals(expected: Mapping) -> Tuple[Optional[Union[str, dict]], bo """Gets the payload from the 'expected' block Returns: - tuple: First element is the expected payload, second element is whether it's + First element is the expected payload, second element is whether it's expected to be json or not """ # TODO move this check to initialisation/schema checking diff --git a/tavern/_plugins/mqtt/tavernhook.py b/tavern/_plugins/mqtt/tavernhook.py index 9ff57245..bc37d267 100644 --- a/tavern/_plugins/mqtt/tavernhook.py +++ b/tavern/_plugins/mqtt/tavernhook.py @@ -1,16 +1,17 @@ import logging from os.path import abspath, dirname, join -from typing import Dict, Optional +from typing import Dict, Iterable, Optional, Union import yaml from tavern._core.dict_util import format_keys +from tavern._core.pytest.config import TestConfig from .client import MQTTClient from .request import MQTTRequest from .response import MQTTResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) session_type = MQTTClient @@ -18,7 +19,11 @@ request_block_name = "mqtt_publish" -def get_expected_from_request(response_block, test_block_config, session): +def get_expected_from_request( + response_block: Union[Dict, Iterable[Dict]], + test_block_config: TestConfig, + session: MQTTClient, +) -> Optional[Dict]: expected: Optional[Dict] = None # mqtt response is not required @@ -40,6 +45,6 @@ def get_expected_from_request(response_block, test_block_config, session): verifier_type = MQTTResponse response_block_name = "mqtt_response" -schema_path = join(abspath(dirname(__file__)), "jsonschema.yaml") +schema_path: str = join(abspath(dirname(__file__)), "jsonschema.yaml") with open(schema_path, encoding="utf-8") as schema_file: schema = yaml.load(schema_file, Loader=yaml.SafeLoader) diff --git a/tavern/_plugins/rest/files.py b/tavern/_plugins/rest/files.py index 3f72db8a..8aba5b9e 100644 --- a/tavern/_plugins/rest/files.py +++ b/tavern/_plugins/rest/files.py @@ -9,7 +9,7 @@ from tavern._core.dict_util import format_keys from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) @dataclasses.dataclass diff --git a/tavern/_plugins/rest/request.py b/tavern/_plugins/rest/request.py index 8d81d773..06e40fbe 100644 --- a/tavern/_plugins/rest/request.py +++ b/tavern/_plugins/rest/request.py @@ -4,7 +4,7 @@ import warnings from contextlib import ExitStack from itertools import filterfalse, tee -from typing import ClassVar, List, Mapping, MutableMapping, Optional +from typing import Callable, ClassVar, Dict, List, Mapping, Optional from urllib.parse import quote_plus import requests @@ -21,10 +21,10 @@ from tavern._plugins.rest.files import get_file_arguments, guess_filespec from tavern.request import BaseRequest -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) -def get_request_args(rspec: MutableMapping, test_block_config: TestConfig) -> dict: +def get_request_args(rspec: Dict, test_block_config: TestConfig) -> dict: """Format the test spec given values inthe global config Todo: @@ -235,7 +235,7 @@ def _check_allow_redirects(rspec: dict, test_block_config: TestConfig): test_block_config: config available for test Returns: - bool: Whether to allow redirects for this stage or not + Whether to allow redirects for this stage or not """ # By default, don't follow redirects allow_redirects = False @@ -355,7 +355,7 @@ def __init__( Args: session: existing session rspec: test spec - test_block_config : Any configuration for this the block of + test_block_config: Any configuration for this the block of tests Raises: @@ -394,7 +394,7 @@ def __init__( ) # Used further down, but pop it asap to avoid unwanted side effects - file_body = request_args.pop("file_body", None) + file_body: Optional[str] = request_args.pop("file_body", None) # If there was a 'cookies' key, set it in the request expected_cookies = _read_expected_cookies(session, rspec, test_block_config) @@ -439,16 +439,13 @@ def prepared_request(): return session.request(**self._request_args) - self._prepared = prepared_request + self._prepared: Callable[[], requests.Response] = prepared_request - def run(self): + def run(self) -> requests.Response: """Runs the prepared request and times it - Todo: - time it - Returns: - requests.Response: response object + response object """ attach_yaml( diff --git a/tavern/_plugins/rest/response.py b/tavern/_plugins/rest/response.py index 83b1062d..598fe6cd 100644 --- a/tavern/_plugins/rest/response.py +++ b/tavern/_plugins/rest/response.py @@ -1,7 +1,7 @@ import contextlib import json import logging -from typing import Dict, Mapping, Optional +from typing import Any, Dict, List, Mapping, Union from urllib.parse import parse_qs, urlparse import requests @@ -9,21 +9,24 @@ from tavern._core import exceptions from tavern._core.dict_util import deep_dict_merge +from tavern._core.pytest.config import TestConfig from tavern._core.pytest.newhooks import call_hook from tavern._core.report import attach_yaml from tavern.response import BaseResponse, indent_err_text -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class RestResponse(BaseResponse): - def __init__(self, session, name: str, expected, test_block_config) -> None: + response: requests.Response + + def __init__( + self, session, name: str, expected, test_block_config: TestConfig + ) -> None: defaults = {"status_code": 200} super().__init__(name, deep_dict_merge(defaults, expected), test_block_config) - self.status_code: Optional[int] = None - def check_code(code: int) -> None: if int(code) not in _codes: logger.warning("Unexpected status code '%s'", code) @@ -44,7 +47,7 @@ def __str__(self) -> str: else: return "" - def _verbose_log_response(self, response) -> None: + def _verbose_log_response(self, response: requests.Response) -> None: """Verbosely log the response object, with query params etc.""" logger.info("Response: '%s'", response) @@ -75,7 +78,7 @@ def log_dict_block(block, name): logger.debug("Redirect location: %s", to_path) log_dict_block(redirect_query_params, "Redirect URL query parameters") - def _get_redirect_query_params(self, response) -> Dict[str, str]: + def _get_redirect_query_params(self, response: requests.Response) -> Dict[str, str]: """If there was a redirect header, get any query parameters from it""" try: @@ -95,7 +98,7 @@ def _get_redirect_query_params(self, response) -> Dict[str, str]: return redirect_query_params - def _check_status_code(self, status_code, body) -> None: + def _check_status_code(self, status_code: Union[int, List[int]], body: Any) -> None: expected_code = self.expected["status_code"] if (isinstance(expected_code, int) and status_code == expected_code) or ( @@ -105,7 +108,7 @@ def _check_status_code(self, status_code, body) -> None: "Status code '%s' matched expected '%s'", status_code, expected_code ) return - elif 400 <= status_code < 500: + elif isinstance(status_code, int) and 400 <= status_code < 500: # special case if there was a bad request. This assumes that the # response would contain some kind of information as to why this # request was rejected. @@ -144,7 +147,6 @@ def verify(self, response: requests.Response) -> dict: ) self.response = response - self.status_code = response.status_code # Get things to use from the response try: @@ -174,7 +176,7 @@ def verify(self, response: requests.Response) -> dict: self._maybe_run_validate_functions(response) # Get any keys to save - saved = {} + saved: Dict = {} saved.update(self.maybe_get_save_values_from_save_block("json", body)) saved.update( diff --git a/tavern/_plugins/rest/tavernhook.py b/tavern/_plugins/rest/tavernhook.py index 208e32d3..3a08bf6f 100644 --- a/tavern/_plugins/rest/tavernhook.py +++ b/tavern/_plugins/rest/tavernhook.py @@ -1,15 +1,17 @@ import logging +from typing import Dict import requests from tavern._core import exceptions from tavern._core.dict_util import format_keys from tavern._core.plugins import PluginHelperBase +from tavern._core.pytest.config import TestConfig from .request import RestRequest from .response import RestResponse -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class TavernRestPlugin(PluginHelperBase): @@ -19,7 +21,9 @@ class TavernRestPlugin(PluginHelperBase): request_block_name = "request" @staticmethod - def get_expected_from_request(response_block, test_block_config, session): + def get_expected_from_request( + response_block: Dict, test_block_config: TestConfig, session + ): if response_block is None: raise exceptions.MissingSettingsError( "no response block specified for HTTP test stage" diff --git a/tavern/core.py b/tavern/core.py index da20f273..addc3456 100644 --- a/tavern/core.py +++ b/tavern/core.py @@ -3,6 +3,7 @@ from typing import Union import pytest +from _pytest.config import ExitCode from tavern._core import exceptions from tavern._core.schema.files import wrapfile @@ -16,8 +17,7 @@ def _get_or_wrap_global_cfg( Args: stack: context stack for wrapping file if a dictionary is given - tavern_global_cfg: Dictionary or string. It should be a - path to a file or a dictionary with configuration. + tavern_global_cfg: path to a file or a dictionary with configuration. Returns: path to global config file @@ -25,9 +25,6 @@ def _get_or_wrap_global_cfg( Raises: InvalidSettingsError: If global config was not of the right type or a given path does not exist - - Todo: - Once python 2 is dropped, allow this to take a 'path like object' """ if isinstance(tavern_global_cfg, str): @@ -48,29 +45,29 @@ def _get_or_wrap_global_cfg( return global_filename -def run( +def run( # type:ignore in_file: str, - tavern_global_cfg=None, - tavern_mqtt_backend=None, - tavern_http_backend=None, - tavern_grpc_backend=None, - tavern_strict=None, - pytest_args=None, -): + tavern_global_cfg: Union[dict, str, None] = None, + tavern_mqtt_backend: Union[str, None] = None, + tavern_http_backend: Union[str, None] = None, + tavern_grpc_backend: Union[str, None] = None, + tavern_strict: Union[bool, None] = None, + pytest_args: Union[list, None] = None, +) -> Union[ExitCode, int]: """Run all tests contained in a file using pytest.main() Args: in_file: file to run tests on - tavern_global_cfg (str, dict): Extra global config - tavern_mqtt_backend (str, optional): name of MQTT plugin to use. If not + tavern_global_cfg: Extra global config + tavern_mqtt_backend: name of MQTT plugin to use. If not specified, uses tavern-mqtt - tavern_http_backend (str, optional): name of HTTP plugin to use. If not + tavern_http_backend: name of HTTP plugin to use. If not specified, use tavern-http - tavern_grpc_backend (str, optional): name of GRPC plugin to use. If not + tavern_grpc_backend: name of GRPC plugin to use. If not specified, use tavern-grpc - tavern_strict (bool, optional): Strictness of checking for responses. + tavern_strict: Strictness of checking for responses. See documentation for details - pytest_args (list, optional): List of extra arguments to pass directly + pytest_args: List of extra arguments to pass directly to Pytest as if they were command line arguments Returns: diff --git a/tavern/helpers.py b/tavern/helpers.py index b4204b65..aabe2c7a 100644 --- a/tavern/helpers.py +++ b/tavern/helpers.py @@ -2,7 +2,7 @@ import json import logging import re -from typing import Dict, List, Optional +from typing import Dict, Iterable, Mapping, Optional import jmespath import jwt @@ -14,7 +14,7 @@ from tavern._core.jmesutils import actual_validation, validate_comparison from tavern._core.schema.files import verify_pykwalify -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def check_exception_raised( @@ -68,7 +68,9 @@ def check_exception_raised( ) from e -def validate_jwt(response, jwt_key, **kwargs) -> Dict[str, Box]: +def validate_jwt( + response: requests.Response, jwt_key: str, **kwargs +) -> Mapping[str, Box]: """Make sure a jwt is valid This uses the pyjwt library to decode the jwt, so any keyword args needed @@ -80,12 +82,12 @@ def validate_jwt(response, jwt_key, **kwargs) -> Dict[str, Box]: it wraps this in a Box so it can also be used for future formatting Args: - response (Response): requests.Response object - jwt_key (str): key of jwt in body of request + response: requests.Response object + jwt_key: key of jwt in body of request **kwargs: Any extra arguments to pass to jwt.decode Returns: - dict: dictionary of jwt: boxed jwt claims + mapping of jwt: boxed jwt claims """ token = response.json()[jwt_key] @@ -96,12 +98,12 @@ def validate_jwt(response, jwt_key, **kwargs) -> Dict[str, Box]: return {"jwt": Box(decoded)} -def validate_pykwalify(response, schema) -> None: +def validate_pykwalify(response: requests.Response, schema: Dict) -> None: """Make sure the response matches a given schema Args: - response (Response): reqeusts.Response object - schema (dict): Schema for response + response: reqeusts Response object + schema: Schema for response """ try: to_verify = response.json() @@ -171,7 +173,7 @@ def validate_regex( return {"regex": Box(match.groupdict())} -def validate_content(response: requests.Response, comparisons: List[str]) -> None: +def validate_content(response: requests.Response, comparisons: Iterable[Dict]) -> None: """Asserts expected value with actual value using JMES path expression Args: diff --git a/tavern/request.py b/tavern/request.py index a049c8eb..abfa7ded 100644 --- a/tavern/request.py +++ b/tavern/request.py @@ -6,7 +6,7 @@ from tavern._core.pytest.config import TestConfig -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) class BaseRequest: diff --git a/tavern/response.py b/tavern/response.py index b9ffb86f..0b90246a 100644 --- a/tavern/response.py +++ b/tavern/response.py @@ -1,3 +1,4 @@ +import dataclasses import logging import traceback from abc import abstractmethod @@ -11,7 +12,7 @@ from tavern._core.pytest.config import TestConfig from tavern._core.strict_util import StrictOption -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) def indent_err_text(err: str) -> str: @@ -20,27 +21,23 @@ def indent_err_text(err: str) -> str: return indent(err, " " * 4) +@dataclasses.dataclass class BaseResponse: - def __init__(self, name: str, expected, test_block_config: TestConfig) -> None: - # Stage name - self.name = name + name: str + expected: Any + test_block_config: TestConfig + response: Optional[Any] = None - # all errors in this response - self.errors: List[str] = [] + validate_functions: List[Any] = dataclasses.field(init=False, default_factory=list) + errors: List[str] = dataclasses.field(init=False, default_factory=list) - self.validate_functions: List = [] - self._check_for_validate_functions(expected) - - self.test_block_config = test_block_config - - self.expected = expected - - self.response: Optional[Any] = None + def __post_init__(self) -> None: + self._check_for_validate_functions(self.expected) def _str_errors(self) -> str: return "- " + "\n- ".join(self.errors) - def _adderr(self, msg, *args, e=None) -> None: + def _adderr(self, msg: str, *args, e=None) -> None: if e: logger.exception(msg, *args) else: @@ -69,10 +66,10 @@ def recurse_check_key_match( Optionally use a validation library too Args: - strict: strictness setting for this block expected_block: expected data block: actual data blockname: 'name' of this block (params, mqtt, etc) for error messages + strict: strictness setting for this block """ if expected_block is None: @@ -151,14 +148,14 @@ def check_deprecated_validate(name): # Could put in an isinstance check here check_deprecated_validate("json") - def _maybe_run_validate_functions(self, response) -> None: + def _maybe_run_validate_functions(self, response: Any) -> None: """Run validation functions if available Note: 'response' will be different depending on where this is called from Args: - response (object): Response type. This could be whatever the response type/plugin uses. + response: Response type. This could be whatever the response type/plugin uses. """ logger.debug( "Calling ext function from '%s' with response '%s'", type(self), response @@ -177,7 +174,7 @@ def _maybe_run_validate_functions(self, response) -> None: def maybe_get_save_values_from_ext( self, response: Any, read_save_from: Mapping - ) -> dict: + ) -> Mapping: """If there is an $ext function in the save block, call it and save the response Args: @@ -203,7 +200,7 @@ def maybe_get_save_values_from_ext( except Exception as e: self._adderr( "Error calling save function '%s':\n%s", - wrapped.func, + wrapped.func, # type:ignore indent_err_text(traceback.format_exc()), e=e, ) @@ -227,12 +224,14 @@ def maybe_get_save_values_from_save_block( save_from: Optional[Mapping], *, outer_save_block: Optional[Mapping] = None, - ) -> dict: + ) -> Mapping: """Save a value from a specific block in the response. See docs for maybe_get_save_values_from_given_block for more info - Keyword Args: + Args: + key: Name of key being used to save, for debugging + save_from: An element of the response from which values are being saved outer_save_block: Read things to save from this block instead of self.expected """ @@ -254,7 +253,7 @@ def maybe_get_save_values_from_given_block( key: str, save_from: Optional[Mapping], to_save: Mapping, - ) -> dict: + ) -> Mapping: """Save a value from a specific block in the response. This is different from maybe_get_save_values_from_ext - depends on the kind of response @@ -265,7 +264,7 @@ def maybe_get_save_values_from_given_block( to_save: block containing information about things to save Returns: - dict: dictionary of save_name: value, where save_name is the key we + mapping of save_name: value, where save_name is the key we wanted to save this value as """ diff --git a/tests/unit/test_mqtt.py b/tests/unit/test_mqtt.py index 4b83bc2c..75ff51eb 100644 --- a/tests/unit/test_mqtt.py +++ b/tests/unit/test_mqtt.py @@ -22,7 +22,7 @@ def test_host_required(): @pytest.fixture(name="fake_client") def fix_fake_client(): - args = {"connect": {"host": "localhost"}} + args = {"connect": {"host": "localhost", "timeout": 0.6}} mqtt_client = MQTTClient(**args) @@ -87,7 +87,7 @@ def wait_for_publish(self, timeout=None): with pytest.raises(exceptions.MQTTError): fake_client.publish("abc", "123") - def test_assert_message_published_failure(self, fake_client): + def test_assert_message_published_failure(self, fake_client: MQTTClient): """If it couldn't publish the message, error out""" class FakeMessage(paho.MQTTMessageInfo): diff --git a/tox-integration.ini b/tox-integration.ini index 9a7b7abc..eab3b6a9 100644 --- a/tox-integration.ini +++ b/tox-integration.ini @@ -54,9 +54,4 @@ commands = components: python -c "from tavern.core import run; exit(run('test_ping.tavern.yaml', pytest_args=[ ]))" components: python -c "from tavern.core import run; exit(run('test_hello.tavern.yaml', pytest_args=[ ]))" - mqtt: tavern-ci --stdout test_mqtt.tavern.yaml --cov tavern - mqtt: python -c "from tavern.core import run; exit(run('test_mqtt.tavern.yaml', pytest_args=['--cov-append']))" - mqtt: tavern-ci --stdout test_mqtt_failures.tavern.yaml - mqtt: python -c "from tavern.core import run; exit(run('test_mqtt_failures.tavern.yaml', pytest_args=[ ]))" - docker compose stop