From 07cfe935c57a1c7e3a40e5c7284e2b27b2c708f8 Mon Sep 17 00:00:00 2001 From: Lukas Plank Date: Sat, 30 Nov 2024 19:36:54 +0100 Subject: [PATCH] fix(types): correct type annotations flagged by beartype beartype revealed several incorrect typed annotations which lead to failing tests. --- rdfproxy/utils/_types.py | 1 + rdfproxy/utils/utils.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/rdfproxy/utils/_types.py b/rdfproxy/utils/_types.py index 93b59d1..0e076b2 100644 --- a/rdfproxy/utils/_types.py +++ b/rdfproxy/utils/_types.py @@ -9,6 +9,7 @@ _TModelInstance = TypeVar("_TModelInstance", bound=BaseModel) +@runtime_checkable class ItemsQueryConstructor(Protocol): def __call__(self, query: str, limit: int, offset: int) -> str: ... diff --git a/rdfproxy/utils/utils.py b/rdfproxy/utils/utils.py index 0615d7d..ae15a12 100644 --- a/rdfproxy/utils/utils.py +++ b/rdfproxy/utils/utils.py @@ -9,20 +9,25 @@ MissingModelConfigException, UnboundGroupingKeyException, ) -from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue +from rdfproxy.utils._types import ( + ModelBoolPredicate, + SPARQLBinding, + _TModelBoolValue, + _TModelInstance, +) -def _is_type(obj: type | None, _type: type) -> bool: +def _is_type(obj: Any, _type: type) -> bool: """Check if an obj is type _type or a GenericAlias with origin _type.""" return (obj is _type) or (get_origin(obj) is _type) -def _is_list_type(obj: type | None) -> bool: +def _is_list_type(obj: Any) -> bool: """Check if obj is a list type.""" return _is_type(obj, list) -def _is_list_basemodel_type(obj: type | None) -> bool: +def _is_list_basemodel_type(obj: Any) -> bool: """Check if a type is list[pydantic.BaseModel].""" return (get_origin(obj) is list) and all( issubclass(cls, BaseModel) for cls in get_args(obj) @@ -104,7 +109,7 @@ def _get_model_bool_predicate_from_config_value( ) -def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: +def get_model_bool_predicate(model: type[_TModelInstance]) -> ModelBoolPredicate: """Get the applicable model_bool predicate function given a model.""" if (model_bool_value := model.model_config.get("model_bool", None)) is None: model_bool_predicate = default_model_bool_predicate