Skip to content

Commit

Permalink
Add mypy linter step
Browse files Browse the repository at this point in the history
  • Loading branch information
reweeden committed Jan 31, 2025
1 parent 188564b commit 2c012ea
Show file tree
Hide file tree
Showing 15 changed files with 150 additions and 55 deletions.
32 changes: 30 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
pull_request:

jobs:
lint:
flake8:
runs-on: ubuntu-latest

steps:
Expand All @@ -17,4 +17,32 @@ jobs:
- uses: TrueBrain/actions-flake8@v2
with:
flake8_version: 6.0.0
plugins: flake8-isort==6.1.1 flake8-quotes==3.4.0 flake8-commas==4.0.0
plugins: flake8-isort==6.1.1 flake8-quotes==3.4.0 flake8-commas==4.0.0

mypy:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.9

- run: |
pip install \
mypy==1.14.1 \
boto3-stubs \
h5py==3.6.0 \
jsonpath_ng==1.4.1 \
s3fs==0.4.2
- run: |
mypy \
--non-interactive \
--install-types \
--check-untyped-defs \
--disable-error-code=import-untyped \
--strict-equality \
--warn-redundant-casts \
--warn-unused-ignores \
mandible
2 changes: 1 addition & 1 deletion mandible/metadata_mapper/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def mapped(
directive_name = Mapped.directive_name
assert directive_name is not None

params = {
params: dict[str, Any] = {
"source": source,
"key": key,
}
Expand Down
3 changes: 2 additions & 1 deletion mandible/metadata_mapper/directive/reformatted.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from mandible.metadata_mapper.exception import MetadataMapperError
from mandible.metadata_mapper.format import FORMAT_REGISTRY
from mandible.metadata_mapper.types import Key

from .directive import Key, TemplateDirective, get_key
from .directive import TemplateDirective, get_key


@dataclass
Expand Down
7 changes: 5 additions & 2 deletions mandible/metadata_mapper/exception.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Optional


class MetadataMapperError(Exception):
"""A generic error raised by the MetadataMapper"""

Expand All @@ -8,7 +11,7 @@ def __init__(self, msg: str):
class TemplateError(MetadataMapperError):
"""An error that occurred while processing the metadata template."""

def __init__(self, msg: str, debug_path: str = None):
def __init__(self, msg: str, debug_path: Optional[str] = None):
super().__init__(msg)
self.debug_path = debug_path

Expand All @@ -26,7 +29,7 @@ class ContextValueError(MetadataMapperError):
def __init__(
self,
msg: str,
source_name: str = None,
source_name: Optional[str] = None,
):
super().__init__(msg)
self.source_name = source_name
Expand Down
4 changes: 2 additions & 2 deletions mandible/metadata_mapper/format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
try:
from .h5 import H5
except ImportError:
from .format import H5
from .format import H5 # type: ignore

try:
from .xml import Xml
except ImportError:
from .format import Xml
from .format import Xml # type: ignore


__all__ = (
Expand Down
39 changes: 24 additions & 15 deletions mandible/metadata_mapper/format/format.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import contextlib
import inspect
import json
import re
import zipfile
from abc import ABC, abstractmethod
from collections.abc import Iterable
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from typing import IO, Any, TypeVar
from typing import IO, Any, Generic, TypeVar

from mandible import jsonpath
from mandible.metadata_mapper.key import RAISE_EXCEPTION, Key
Expand Down Expand Up @@ -50,7 +51,7 @@ def get_value(self, file: IO[bytes], key: Key) -> Any:


@dataclass
class FileFormat(Format, ABC, register=False):
class FileFormat(Format, Generic[T], ABC, register=False):
"""A Format for querying files from a standard data file.
Simple, single format data types such as 'json' that can be queried
Expand All @@ -76,7 +77,7 @@ def get_value(self, file: IO[bytes], key: Key) -> Any:
with self.parse_data(file) as data:
return self._eval_key_wrapper(data, key)

def _eval_key_wrapper(self, data, key: Key) -> Any:
def _eval_key_wrapper(self, data: T, key: Key) -> Any:
try:
return self.eval_key(data, key)
except KeyError as e:
Expand Down Expand Up @@ -116,7 +117,7 @@ def eval_key(data: T, key: Key) -> Any:


@dataclass
class _PlaceholderBase(FileFormat, register=False):
class _PlaceholderBase(FileFormat[None], register=False):
"""
Base class for defining placeholder implementations for classes that
require extra dependencies to be installed
Expand All @@ -128,12 +129,14 @@ def __init__(self, dep: str):
)

@staticmethod
def parse_data(file: IO[bytes]) -> contextlib.AbstractContextManager[T]:
pass
def parse_data(file: IO[bytes]) -> contextlib.AbstractContextManager[None]:
# __init__ always raises
raise RuntimeError("Unreachable!")

@staticmethod
def eval_key(data: T, key: Key):
pass
def eval_key(data: None, key: Key):
# __init__ always raises
raise RuntimeError("Unreachable!")


@dataclass
Expand All @@ -151,10 +154,10 @@ def __init__(self):
# Define formats that don't require extra dependencies

@dataclass
class Json(FileFormat):
class Json(FileFormat[dict]):
@staticmethod
@contextlib.contextmanager
def parse_data(file: IO[bytes]) -> dict:
def parse_data(file: IO[bytes]) -> Generator[dict]:
yield json.load(file)

@staticmethod
Expand Down Expand Up @@ -237,20 +240,26 @@ def _matches_filters(self, zipinfo: zipfile.ZipInfo) -> bool:
return True


ZIP_INFO_ATTRS = [
name
for name, _ in inspect.getmembers(zipfile.ZipInfo, inspect.isdatadescriptor)
if not name.startswith("_")
]


@dataclass
class ZipInfo(FileFormat):
class ZipInfo(FileFormat[dict]):
"""Query Zip headers and directory information."""

@staticmethod
@contextlib.contextmanager
def parse_data(file: IO[bytes]) -> dict:
def parse_data(file: IO[bytes]) -> Generator[dict]:
with zipfile.ZipFile(file, "r") as zf:
yield {
"infolist": [
{
k: getattr(info, k)
for k in info.__slots__
if not k.startswith("_")
for k in ZIP_INFO_ATTRS
}
for info in zf.infolist()
],
Expand Down
4 changes: 2 additions & 2 deletions mandible/metadata_mapper/format/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@


@dataclass
class H5(FileFormat):
class H5(FileFormat[Any]):
@staticmethod
def parse_data(file: IO[bytes]) -> contextlib.AbstractContextManager[Any]:
return h5py.File(file, "r")

@staticmethod
def eval_key(data, key: Key) -> Any:
def eval_key(data: Any, key: Key) -> Any:
return normalize(data[key.key][()])


Expand Down
33 changes: 26 additions & 7 deletions mandible/metadata_mapper/format/xml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from typing import IO, Any
from typing import IO, Any, Union

from lxml import etree

Expand All @@ -10,16 +11,34 @@


@dataclass
class Xml(FileFormat):
class Xml(FileFormat[etree._ElementTree]):
@staticmethod
@contextlib.contextmanager
def parse_data(file: IO[bytes]) -> Any:
def parse_data(file: IO[bytes]) -> Generator[etree._ElementTree]:
yield etree.parse(file)

@staticmethod
def eval_key(data: etree.ElementTree, key: Key) -> Any:
def eval_key(data: etree._ElementTree, key: Key) -> Any:
nsmap = data.getroot().nsmap
elements = data.xpath(key.key, namespaces=nsmap)
values = [element.text for element in elements]
xpath_result = data.xpath(
key.key,
# Lxml type stubs don't handle None key for default namespaces
namespaces=nsmap, # type: ignore
)
if isinstance(xpath_result, Iterable):
values = [convert_result(item) for item in xpath_result]

return key.resolve_list_match(values)
return key.resolve_list_match(values)

# Xpath supports functions such as `count` that can result in
# `data.xpath` returning something other than a list of matches.
return xpath_result


def convert_result(result: Union[etree._Element, int, str, bytes, tuple]):
if isinstance(result, etree._Element):
return result.text
if isinstance(result, (int, str, bytes)):
return result

raise TypeError(f"Unsupported type {repr(result.__class__.__name__)}")
48 changes: 30 additions & 18 deletions mandible/metadata_mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MetadataMapper:
def __init__(
self,
template: Template,
source_provider: SourceProvider = None,
source_provider: Optional[SourceProvider] = None,
*,
directive_marker: str = "@",
):
Expand Down Expand Up @@ -72,15 +72,16 @@ def get_metadata(self, context: Context) -> Template:
def _prepare_directives(self, context: Context, sources: dict[str, Source]):
for value, debug_path in _walk_values(self.template):
if isinstance(value, dict):
directive_name = self._get_directive_name(value, debug_path)
if directive_name is None:
directive_config = self._get_directive_name(value, debug_path)
if directive_config is None:
continue

directive_name, directive_body = directive_config
directive = self._get_directive(
directive_name,
context,
sources,
value[directive_name],
directive_body,
f"{debug_path}.{directive_name}",
)
directive.prepare()
Expand All @@ -91,13 +92,14 @@ def _replace_template(
template: Template,
sources: dict[str, Source],
debug_path: str = "$",
):
) -> Template:
if isinstance(template, dict):
directive_name = self._get_directive_name(
directive_config = self._get_directive_name(
template,
debug_path,
)
if directive_name is not None:
if directive_config is not None:
directive_name, directive_body = directive_config
debug_path = f"{debug_path}.{directive_name}"
directive = self._get_directive(
directive_name,
Expand All @@ -110,7 +112,7 @@ def _replace_template(
sources,
debug_path=f"{debug_path}.{k}",
)
for k, v in template[directive_name].items()
for k, v in directive_body.items()
},
debug_path,
)
Expand Down Expand Up @@ -146,31 +148,41 @@ def _replace_template(

def _get_directive_name(
self,
value: dict,
value: dict[str, Template],
debug_path: str,
) -> Optional[str]:
directive_names = [
key for key in value
if key.startswith(self.directive_marker)
) -> Optional[tuple[str, dict[str, Template]]]:
directive_configs = [
(k, v)
for (k, v) in value.items()
if k.startswith(self.directive_marker)
]
if not directive_names:
if not directive_configs:
return None

if len(directive_names) > 1:
if len(directive_configs) > 1:
raise TemplateError(
"multiple directives found in config: "
f"{', '.join(repr(d) for d in directive_names)}",
f"{', '.join(repr(k) for k, v in directive_configs)}",
debug_path,
)

return directive_names[0]
directive_name, directive_config = directive_configs[0]

if not isinstance(directive_config, dict):
raise TemplateError(
"directive body should be type 'dict' not "
f"{repr(directive_config.__class__.__name__)}",
f"{debug_path}.{directive_name}",
)

return directive_name, directive_config

def _get_directive(
self,
directive_name: str,
context: Context,
sources: dict[str, Source],
config: dict,
config: dict[str, Template],
debug_path: str,
) -> TemplateDirective:
cls = DIRECTIVE_REGISTRY.get(directive_name[len(self.directive_marker):])
Expand Down
2 changes: 1 addition & 1 deletion mandible/metadata_mapper/source_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

T = TypeVar("T")

REGISTRY_TYPE_MAP = {
REGISTRY_TYPE_MAP: dict[str, dict[str, Any]] = {
"Format": FORMAT_REGISTRY,
"Source": SOURCE_REGISTRY,
"Storage": STORAGE_REGISTRY,
Expand Down
Loading

0 comments on commit 2c012ea

Please sign in to comment.