From 2c012eacf95a3f16cdb6e55a97f5108d4136f59d Mon Sep 17 00:00:00 2001 From: Rohan Weeden Date: Fri, 24 Jan 2025 17:30:54 -0500 Subject: [PATCH] Add mypy linter step --- .github/workflows/lint.yml | 32 ++++++++++++- mandible/metadata_mapper/builder.py | 2 +- .../metadata_mapper/directive/reformatted.py | 3 +- mandible/metadata_mapper/exception.py | 7 ++- mandible/metadata_mapper/format/__init__.py | 4 +- mandible/metadata_mapper/format/format.py | 39 +++++++++------ mandible/metadata_mapper/format/h5.py | 4 +- mandible/metadata_mapper/format/xml.py | 33 ++++++++++--- mandible/metadata_mapper/mapper.py | 48 ++++++++++++------- mandible/metadata_mapper/source_provider.py | 2 +- mandible/metadata_mapper/storage/__init__.py | 4 +- mandible/metadata_mapper/storage/cmr_query.py | 2 +- mandible/metadata_mapper/storage/storage.py | 3 +- .../integration_tests/test_metadata_mapper.py | 20 ++++++++ tests/test_format.py | 2 + 15 files changed, 150 insertions(+), 55 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9913b44..9ef526d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -6,7 +6,7 @@ on: pull_request: jobs: - lint: + flake8: runs-on: ubuntu-latest steps: @@ -17,4 +17,32 @@ jobs: - uses: TrueBrain/actions-flake8@v2 with: flake8_version: 6.0.0 - plugins: flake8-isort==6.1.1 flake8-quotes==3.4.0 flake8-commas==4.0.0 + plugins: flake8-isort==6.1.1 flake8-quotes==3.4.0 flake8-commas==4.0.0 + + mypy: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.9 + + - run: | + pip install \ + mypy==1.14.1 \ + boto3-stubs \ + h5py==3.6.0 \ + jsonpath_ng==1.4.1 \ + s3fs==0.4.2 + + - run: | + mypy \ + --non-interactive \ + --install-types \ + --check-untyped-defs \ + --disable-error-code=import-untyped \ + --strict-equality \ + --warn-redundant-casts \ + --warn-unused-ignores \ + mandible diff --git a/mandible/metadata_mapper/builder.py b/mandible/metadata_mapper/builder.py index 90a6d58..509b645 100644 --- a/mandible/metadata_mapper/builder.py +++ b/mandible/metadata_mapper/builder.py @@ -103,7 +103,7 @@ def mapped( directive_name = Mapped.directive_name assert directive_name is not None - params = { + params: dict[str, Any] = { "source": source, "key": key, } diff --git a/mandible/metadata_mapper/directive/reformatted.py b/mandible/metadata_mapper/directive/reformatted.py index 593b9f2..49cd7f0 100644 --- a/mandible/metadata_mapper/directive/reformatted.py +++ b/mandible/metadata_mapper/directive/reformatted.py @@ -4,8 +4,9 @@ from mandible.metadata_mapper.exception import MetadataMapperError from mandible.metadata_mapper.format import FORMAT_REGISTRY +from mandible.metadata_mapper.types import Key -from .directive import Key, TemplateDirective, get_key +from .directive import TemplateDirective, get_key @dataclass diff --git a/mandible/metadata_mapper/exception.py b/mandible/metadata_mapper/exception.py index f4d9ccc..5878b35 100644 --- a/mandible/metadata_mapper/exception.py +++ b/mandible/metadata_mapper/exception.py @@ -1,3 +1,6 @@ +from typing import Optional + + class MetadataMapperError(Exception): """A generic error raised by the MetadataMapper""" @@ -8,7 +11,7 @@ def __init__(self, msg: str): class TemplateError(MetadataMapperError): """An error that occurred while processing the metadata template.""" - def __init__(self, msg: str, debug_path: str = None): + def __init__(self, msg: str, debug_path: Optional[str] = None): super().__init__(msg) self.debug_path = debug_path @@ -26,7 +29,7 @@ class ContextValueError(MetadataMapperError): def __init__( self, msg: str, - source_name: str = None, + source_name: Optional[str] = None, ): super().__init__(msg) self.source_name = source_name diff --git a/mandible/metadata_mapper/format/__init__.py b/mandible/metadata_mapper/format/__init__.py index b57a46c..0f3899d 100644 --- a/mandible/metadata_mapper/format/__init__.py +++ b/mandible/metadata_mapper/format/__init__.py @@ -11,12 +11,12 @@ try: from .h5 import H5 except ImportError: - from .format import H5 + from .format import H5 # type: ignore try: from .xml import Xml except ImportError: - from .format import Xml + from .format import Xml # type: ignore __all__ = ( diff --git a/mandible/metadata_mapper/format/format.py b/mandible/metadata_mapper/format/format.py index 9bf54bc..168853d 100644 --- a/mandible/metadata_mapper/format/format.py +++ b/mandible/metadata_mapper/format/format.py @@ -1,11 +1,12 @@ import contextlib +import inspect import json import re import zipfile from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Generator, Iterable from dataclasses import dataclass -from typing import IO, Any, TypeVar +from typing import IO, Any, Generic, TypeVar from mandible import jsonpath from mandible.metadata_mapper.key import RAISE_EXCEPTION, Key @@ -50,7 +51,7 @@ def get_value(self, file: IO[bytes], key: Key) -> Any: @dataclass -class FileFormat(Format, ABC, register=False): +class FileFormat(Format, Generic[T], ABC, register=False): """A Format for querying files from a standard data file. Simple, single format data types such as 'json' that can be queried @@ -76,7 +77,7 @@ def get_value(self, file: IO[bytes], key: Key) -> Any: with self.parse_data(file) as data: return self._eval_key_wrapper(data, key) - def _eval_key_wrapper(self, data, key: Key) -> Any: + def _eval_key_wrapper(self, data: T, key: Key) -> Any: try: return self.eval_key(data, key) except KeyError as e: @@ -116,7 +117,7 @@ def eval_key(data: T, key: Key) -> Any: @dataclass -class _PlaceholderBase(FileFormat, register=False): +class _PlaceholderBase(FileFormat[None], register=False): """ Base class for defining placeholder implementations for classes that require extra dependencies to be installed @@ -128,12 +129,14 @@ def __init__(self, dep: str): ) @staticmethod - def parse_data(file: IO[bytes]) -> contextlib.AbstractContextManager[T]: - pass + def parse_data(file: IO[bytes]) -> contextlib.AbstractContextManager[None]: + # __init__ always raises + raise RuntimeError("Unreachable!") @staticmethod - def eval_key(data: T, key: Key): - pass + def eval_key(data: None, key: Key): + # __init__ always raises + raise RuntimeError("Unreachable!") @dataclass @@ -151,10 +154,10 @@ def __init__(self): # Define formats that don't require extra dependencies @dataclass -class Json(FileFormat): +class Json(FileFormat[dict]): @staticmethod @contextlib.contextmanager - def parse_data(file: IO[bytes]) -> dict: + def parse_data(file: IO[bytes]) -> Generator[dict]: yield json.load(file) @staticmethod @@ -237,20 +240,26 @@ def _matches_filters(self, zipinfo: zipfile.ZipInfo) -> bool: return True +ZIP_INFO_ATTRS = [ + name + for name, _ in inspect.getmembers(zipfile.ZipInfo, inspect.isdatadescriptor) + if not name.startswith("_") +] + + @dataclass -class ZipInfo(FileFormat): +class ZipInfo(FileFormat[dict]): """Query Zip headers and directory information.""" @staticmethod @contextlib.contextmanager - def parse_data(file: IO[bytes]) -> dict: + def parse_data(file: IO[bytes]) -> Generator[dict]: with zipfile.ZipFile(file, "r") as zf: yield { "infolist": [ { k: getattr(info, k) - for k in info.__slots__ - if not k.startswith("_") + for k in ZIP_INFO_ATTRS } for info in zf.infolist() ], diff --git a/mandible/metadata_mapper/format/h5.py b/mandible/metadata_mapper/format/h5.py index ac2365b..6ebce94 100644 --- a/mandible/metadata_mapper/format/h5.py +++ b/mandible/metadata_mapper/format/h5.py @@ -11,13 +11,13 @@ @dataclass -class H5(FileFormat): +class H5(FileFormat[Any]): @staticmethod def parse_data(file: IO[bytes]) -> contextlib.AbstractContextManager[Any]: return h5py.File(file, "r") @staticmethod - def eval_key(data, key: Key) -> Any: + def eval_key(data: Any, key: Key) -> Any: return normalize(data[key.key][()]) diff --git a/mandible/metadata_mapper/format/xml.py b/mandible/metadata_mapper/format/xml.py index 1dbee9f..3e09af2 100644 --- a/mandible/metadata_mapper/format/xml.py +++ b/mandible/metadata_mapper/format/xml.py @@ -1,6 +1,7 @@ import contextlib +from collections.abc import Generator, Iterable from dataclasses import dataclass -from typing import IO, Any +from typing import IO, Any, Union from lxml import etree @@ -10,16 +11,34 @@ @dataclass -class Xml(FileFormat): +class Xml(FileFormat[etree._ElementTree]): @staticmethod @contextlib.contextmanager - def parse_data(file: IO[bytes]) -> Any: + def parse_data(file: IO[bytes]) -> Generator[etree._ElementTree]: yield etree.parse(file) @staticmethod - def eval_key(data: etree.ElementTree, key: Key) -> Any: + def eval_key(data: etree._ElementTree, key: Key) -> Any: nsmap = data.getroot().nsmap - elements = data.xpath(key.key, namespaces=nsmap) - values = [element.text for element in elements] + xpath_result = data.xpath( + key.key, + # Lxml type stubs don't handle None key for default namespaces + namespaces=nsmap, # type: ignore + ) + if isinstance(xpath_result, Iterable): + values = [convert_result(item) for item in xpath_result] - return key.resolve_list_match(values) + return key.resolve_list_match(values) + + # Xpath supports functions such as `count` that can result in + # `data.xpath` returning something other than a list of matches. + return xpath_result + + +def convert_result(result: Union[etree._Element, int, str, bytes, tuple]): + if isinstance(result, etree._Element): + return result.text + if isinstance(result, (int, str, bytes)): + return result + + raise TypeError(f"Unsupported type {repr(result.__class__.__name__)}") diff --git a/mandible/metadata_mapper/mapper.py b/mandible/metadata_mapper/mapper.py index a84dc62..d198d0a 100644 --- a/mandible/metadata_mapper/mapper.py +++ b/mandible/metadata_mapper/mapper.py @@ -16,7 +16,7 @@ class MetadataMapper: def __init__( self, template: Template, - source_provider: SourceProvider = None, + source_provider: Optional[SourceProvider] = None, *, directive_marker: str = "@", ): @@ -72,15 +72,16 @@ def get_metadata(self, context: Context) -> Template: def _prepare_directives(self, context: Context, sources: dict[str, Source]): for value, debug_path in _walk_values(self.template): if isinstance(value, dict): - directive_name = self._get_directive_name(value, debug_path) - if directive_name is None: + directive_config = self._get_directive_name(value, debug_path) + if directive_config is None: continue + directive_name, directive_body = directive_config directive = self._get_directive( directive_name, context, sources, - value[directive_name], + directive_body, f"{debug_path}.{directive_name}", ) directive.prepare() @@ -91,13 +92,14 @@ def _replace_template( template: Template, sources: dict[str, Source], debug_path: str = "$", - ): + ) -> Template: if isinstance(template, dict): - directive_name = self._get_directive_name( + directive_config = self._get_directive_name( template, debug_path, ) - if directive_name is not None: + if directive_config is not None: + directive_name, directive_body = directive_config debug_path = f"{debug_path}.{directive_name}" directive = self._get_directive( directive_name, @@ -110,7 +112,7 @@ def _replace_template( sources, debug_path=f"{debug_path}.{k}", ) - for k, v in template[directive_name].items() + for k, v in directive_body.items() }, debug_path, ) @@ -146,31 +148,41 @@ def _replace_template( def _get_directive_name( self, - value: dict, + value: dict[str, Template], debug_path: str, - ) -> Optional[str]: - directive_names = [ - key for key in value - if key.startswith(self.directive_marker) + ) -> Optional[tuple[str, dict[str, Template]]]: + directive_configs = [ + (k, v) + for (k, v) in value.items() + if k.startswith(self.directive_marker) ] - if not directive_names: + if not directive_configs: return None - if len(directive_names) > 1: + if len(directive_configs) > 1: raise TemplateError( "multiple directives found in config: " - f"{', '.join(repr(d) for d in directive_names)}", + f"{', '.join(repr(k) for k, v in directive_configs)}", debug_path, ) - return directive_names[0] + directive_name, directive_config = directive_configs[0] + + if not isinstance(directive_config, dict): + raise TemplateError( + "directive body should be type 'dict' not " + f"{repr(directive_config.__class__.__name__)}", + f"{debug_path}.{directive_name}", + ) + + return directive_name, directive_config def _get_directive( self, directive_name: str, context: Context, sources: dict[str, Source], - config: dict, + config: dict[str, Template], debug_path: str, ) -> TemplateDirective: cls = DIRECTIVE_REGISTRY.get(directive_name[len(self.directive_marker):]) diff --git a/mandible/metadata_mapper/source_provider.py b/mandible/metadata_mapper/source_provider.py index b3a2a47..cb76c0d 100644 --- a/mandible/metadata_mapper/source_provider.py +++ b/mandible/metadata_mapper/source_provider.py @@ -11,7 +11,7 @@ T = TypeVar("T") -REGISTRY_TYPE_MAP = { +REGISTRY_TYPE_MAP: dict[str, dict[str, Any]] = { "Format": FORMAT_REGISTRY, "Source": SOURCE_REGISTRY, "Storage": STORAGE_REGISTRY, diff --git a/mandible/metadata_mapper/storage/__init__.py b/mandible/metadata_mapper/storage/__init__.py index 298866a..cee45fe 100644 --- a/mandible/metadata_mapper/storage/__init__.py +++ b/mandible/metadata_mapper/storage/__init__.py @@ -11,12 +11,12 @@ try: from .cmr_query import CmrQuery except ImportError: - from .storage import CmrQuery + from .storage import CmrQuery # type: ignore try: from .http_request import HttpRequest except ImportError: - from .storage import HttpRequest + from .storage import HttpRequest # type: ignore __all__ = ( diff --git a/mandible/metadata_mapper/storage/cmr_query.py b/mandible/metadata_mapper/storage/cmr_query.py index 9f80605..d5d95bf 100644 --- a/mandible/metadata_mapper/storage/cmr_query.py +++ b/mandible/metadata_mapper/storage/cmr_query.py @@ -18,7 +18,7 @@ class CmrQuery(HttpRequest): format: str = "" token: Optional[str] = None - def __post_init__(self, url: str): + def __post_init__(self, url: Optional[str]): if url: raise ValueError( "do not set 'url' directly, use 'base_url' and 'path' instead", diff --git a/mandible/metadata_mapper/storage/storage.py b/mandible/metadata_mapper/storage/storage.py index 3feb5b6..1ea69e4 100644 --- a/mandible/metadata_mapper/storage/storage.py +++ b/mandible/metadata_mapper/storage/storage.py @@ -47,7 +47,8 @@ def __init__(self, dep: str): ) def open_file(self, context: Context) -> IO[bytes]: - pass + # __init__ always raises + raise RuntimeError("Unreachable!") @dataclass diff --git a/tests/integration_tests/test_metadata_mapper.py b/tests/integration_tests/test_metadata_mapper.py index ca01498..94521a5 100644 --- a/tests/integration_tests/test_metadata_mapper.py +++ b/tests/integration_tests/test_metadata_mapper.py @@ -287,6 +287,26 @@ def test_invalid_directive(context): mapper.get_metadata(context) +def test_invalid_directive_config_type(context): + mapper = MetadataMapper( + template={ + "foo": { + "@mapped": 100, + }, + }, + source_provider=ConfigSourceProvider({}), + ) + + with pytest.raises( + MetadataMapperError, + match=( + r"failed to process template at \$\.foo\.@mapped: " + "directive body should be type 'dict' not 'int'" + ), + ): + mapper.get_metadata(context) + + def test_multiple_directives(context): mapper = MetadataMapper( template={ diff --git a/tests/test_format.py b/tests/test_format.py index 58a6db4..1e69d31 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -310,6 +310,7 @@ def test_xml(): Key("./nested/qux"), Key("./list/v", return_list=True), Key("./list/v", return_first=True), + Key("count(./list/v)"), ], ) == { Key("/root/foo"): "foo value", @@ -318,6 +319,7 @@ def test_xml(): Key("./nested/qux"): "qux nested value", Key("./list/v", return_list=True): ["list", "value"], Key("./list/v", return_first=True): "list", + Key("count(./list/v)"): 2, }