From 02474edeac3c57edd8cf81a705261e053b4b689c Mon Sep 17 00:00:00 2001 From: Lukas Plank Date: Thu, 24 Oct 2024 16:33:16 +0200 Subject: [PATCH] feat: implement model_bool hook for controlling model truthiness Closes #110, closes #112. --- rdfproxy/mapper.py | 7 +++++-- rdfproxy/utils/_types.py | 13 ++++++++++++- rdfproxy/utils/utils.py | 39 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/rdfproxy/mapper.py b/rdfproxy/mapper.py index a7ba2c3..001a349 100644 --- a/rdfproxy/mapper.py +++ b/rdfproxy/mapper.py @@ -4,13 +4,14 @@ from typing import Any, Generic, get_args from pydantic import BaseModel -from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance from rdfproxy.utils.utils import ( _collect_values_from_bindings, _get_group_by, _get_key_from_metadata, _is_list_basemodel_type, _is_list_type, + get_model_bool_predicate, ) @@ -29,10 +30,12 @@ def get_models(self) -> list[_TModelInstance]: def _get_unique_models(self, model, bindings): """Call the mapping logic and collect unique and non-empty models.""" models = [] + model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(model) + for _bindings in bindings: _model = model(**dict(self._generate_binding_pairs(model, **_bindings))) - if any(_model.model_dump().values()) and (_model not in models): + if model_bool_predicate(_model) and (_model not in models): models.append(_model) return models diff --git a/rdfproxy/utils/_types.py b/rdfproxy/utils/_types.py index 68df506..07c02d6 100644 --- a/rdfproxy/utils/_types.py +++ b/rdfproxy/utils/_types.py @@ -1,6 +1,7 @@ """Type definitions for rdfproxy.""" -from typing import TypeVar +from collections.abc import Iterable +from typing import Protocol, TypeAlias, TypeVar, runtime_checkable from pydantic import BaseModel @@ -27,3 +28,13 @@ class Person(BaseModel): """ ... + + +@runtime_checkable +class ModelBoolPredicate(Protocol): + """Type for model_bool predicate functions.""" + + def __call__(self, model: BaseModel) -> bool: ... + + +_TModelBoolValue: TypeAlias = ModelBoolPredicate | str | Iterable[str] diff --git a/rdfproxy/utils/utils.py b/rdfproxy/utils/utils.py index 1eb5c8b..0875a02 100644 --- a/rdfproxy/utils/utils.py +++ b/rdfproxy/utils/utils.py @@ -9,7 +9,7 @@ MissingModelConfigException, UnboundGroupingKeyException, ) -from rdfproxy.utils._types import SPARQLBinding +from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue def _is_type(obj: type | None, _type: type) -> bool: @@ -70,3 +70,40 @@ def _get_group_by(model: type[BaseModel], kwargs: dict) -> str: f"Applicable grouping keys: {', '.join(kwargs.keys())}." ) return group_by + + +def default_model_bool_predicate(model: BaseModel) -> bool: + """Default predicate for determining model truthiness. + + Adheres to rdfproxy.utils._types.ModelBoolPredicate. + """ + return any(dict(model).values()) + + +def _get_model_bool_predicate_from_config_value( + model_bool_value: _TModelBoolValue, +) -> ModelBoolPredicate: + """Get a model_bool predicate function given the value of the model_bool config setting.""" + match model_bool_value: + case ModelBoolPredicate(): + return model_bool_value + case str(): + return lambda model: bool(dict(model)[model_bool_value]) + case Iterable(): + return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) + case _: + raise TypeError( + "Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" + f"Received {type(model_bool_value)}" + ) + + +def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: + """Get the applicable model_bool predicate function given a model.""" + model_bool_predicate: ModelBoolPredicate = ( + default_model_bool_predicate + if (model_bool_value := model.model_config.get("model_bool", None)) is None + else _get_model_bool_predicate_from_config_value(model_bool_value) + ) + + return model_bool_predicate