Skip to content

Commit

Permalink
fix(types): correct type annotations flagged by beartype
Browse files Browse the repository at this point in the history
beartype revealed several incorrect typed annotations which lead to
failing tests.
  • Loading branch information
lu-pl committed Nov 30, 2024
1 parent 3595245 commit 07cfe93
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
1 change: 1 addition & 0 deletions rdfproxy/utils/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_TModelInstance = TypeVar("_TModelInstance", bound=BaseModel)


@runtime_checkable
class ItemsQueryConstructor(Protocol):
def __call__(self, query: str, limit: int, offset: int) -> str: ...

Expand Down
15 changes: 10 additions & 5 deletions rdfproxy/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,25 @@
MissingModelConfigException,
UnboundGroupingKeyException,
)
from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue
from rdfproxy.utils._types import (
ModelBoolPredicate,
SPARQLBinding,
_TModelBoolValue,
_TModelInstance,
)


def _is_type(obj: type | None, _type: type) -> bool:
def _is_type(obj: Any, _type: type) -> bool:
"""Check if an obj is type _type or a GenericAlias with origin _type."""
return (obj is _type) or (get_origin(obj) is _type)


def _is_list_type(obj: type | None) -> bool:
def _is_list_type(obj: Any) -> bool:
"""Check if obj is a list type."""
return _is_type(obj, list)


def _is_list_basemodel_type(obj: type | None) -> bool:
def _is_list_basemodel_type(obj: Any) -> bool:
"""Check if a type is list[pydantic.BaseModel]."""
return (get_origin(obj) is list) and all(
issubclass(cls, BaseModel) for cls in get_args(obj)
Expand Down Expand Up @@ -104,7 +109,7 @@ def _get_model_bool_predicate_from_config_value(
)


def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate:
def get_model_bool_predicate(model: type[_TModelInstance]) -> ModelBoolPredicate:
"""Get the applicable model_bool predicate function given a model."""
if (model_bool_value := model.model_config.get("model_bool", None)) is None:
model_bool_predicate = default_model_bool_predicate
Expand Down

0 comments on commit 07cfe93

Please sign in to comment.