diff --git a/rdfproxy/__init__.py b/rdfproxy/__init__.py index 20bdaad..caaad1e 100644 --- a/rdfproxy/__init__.py +++ b/rdfproxy/__init__.py @@ -1,3 +1,4 @@ from rdfproxy.adapter import SPARQLModelAdapter # noqa: F401 +from rdfproxy.mapper import ModelBindingsMapper # noqa: F401 from rdfproxy.utils._types import SPARQLBinding # noqa: F401 from rdfproxy.utils.models import Page # noqa: F401 diff --git a/rdfproxy/adapter.py b/rdfproxy/adapter.py index 3576c86..38f2718 100644 --- a/rdfproxy/adapter.py +++ b/rdfproxy/adapter.py @@ -1,38 +1,34 @@ """SPARQLModelAdapter class for SPARQL query result set to Pydantic model conversions.""" -from collections import defaultdict from collections.abc import Iterator import math -from typing import Any, Generic, overload +from typing import Generic -from SPARQLWrapper import JSON, QueryResult, SPARQLWrapper -from rdfproxy.utils._exceptions import ( - InterdependentParametersException, - UndefinedBindingException, -) +from SPARQLWrapper import JSON, SPARQLWrapper +from rdfproxy.mapper import ModelBindingsMapper from rdfproxy.utils._types import _TModelInstance from rdfproxy.utils.models import Page -from rdfproxy.utils.sparql.sparql_templates import ungrouped_pagination_base_query -from rdfproxy.utils.sparql.sparql_utils import ( +from rdfproxy.utils.sparql_utils import ( calculate_offset, construct_count_query, - construct_grouped_count_query, - construct_grouped_pagination_query, query_with_wrapper, - temporary_query_override, -) -from rdfproxy.utils.utils import ( - get_bindings_from_query_result, - instantiate_model_from_kwargs, + ungrouped_pagination_base_query, ) class SPARQLModelAdapter(Generic[_TModelInstance]): """Adapter/Mapper for SPARQL query result set to Pydantic model conversions. - The rdfproxy.SPARQLModelAdapter class allows to run a query against an endpoint, - map a flat SPARQL query result set to a potentially nested Pydantic model and - optionally paginate and/or group the results by a SPARQL binding. + The rdfproxy.SPARQLModelAdapter class allows to run a query against an endpoint + and map a flat SPARQL query result set to a potentially nested Pydantic model. + + SPARQLModelAdapter.query returns a Page model object with a default pagination size of 100 results. + + SPARQL bindings are implicitly assigned to model fields of the same name, + explicit SPARQL binding to model field allocation is available with typing.Annotated and rdfproxy.SPARQLBinding. + + Result grouping is controlled through the model, + i.e. grouping is triggered when a field of list[pydantic.BaseModel] is encountered. """ def __init__( @@ -45,156 +41,37 @@ def __init__( SPARQLWrapper(target) if isinstance(target, str) else target ) self.sparql_wrapper.setReturnFormat(JSON) - self.sparql_wrapper.setQuery(query) - - @overload - def query(self) -> list[_TModelInstance]: ... - - @overload - def query( - self, - *, - group_by: str, - ) -> dict[str, list[_TModelInstance]]: ... - @overload def query( self, *, - page: int, - size: int, - ) -> Page[_TModelInstance]: ... - - @overload - def query( - self, - *, - page: int, - size: int, - group_by: str, - ) -> Page[_TModelInstance]: ... - - def query( - self, - *, - page: int | None = None, - size: int | None = None, - group_by: str | None = None, - ) -> ( - list[_TModelInstance] | dict[str, list[_TModelInstance]] | Page[_TModelInstance] - ): - """Run query against endpoint and map the SPARQL query result set to a Pydantic model. - - Optional pagination and/or grouping by a SPARQL binding is avaible by - supplying the group_by and/or page/size parameters. - """ - match page, size, group_by: - case None, None, None: - return self._query_collect_models() - case int(), int(), None: - return self._query_paginate_ungrouped(page=page, size=size) - case None, None, str(): - return self._query_group_by(group_by=group_by) - case int(), int(), str(): - return self._query_paginate_grouped( - page=page, size=size, group_by=group_by - ) - case (None, int(), Any()) | (int(), None, Any()): - raise InterdependentParametersException( - "Parameters 'page' and 'size' are mutually dependent." - ) - case _: - raise Exception("This should never happen.") - - def _query_generate_model_bindings_mapping( - self, query: str | None = None - ) -> Iterator[tuple[_TModelInstance, dict[str, Any]]]: - """Run query, construct model instances and generate a model-bindings mapping. - - The query parameter defaults to the initially defined query and - is run against the endpoint defined in the SPARQLModelAdapter instance. - - Note: The coupling of model instances with flat SPARQL results - allows for easier and more efficient grouping operations (see grouping functionality). - """ - if query is None: - query_result: QueryResult = self.sparql_wrapper.query() - else: - with temporary_query_override(self.sparql_wrapper): - self.sparql_wrapper.setQuery(query) - query_result: QueryResult = self.sparql_wrapper.query() - - _bindings = get_bindings_from_query_result(query_result) - - for bindings in _bindings: - model = instantiate_model_from_kwargs(self._model, **bindings) - yield model, bindings - - def _query_collect_models(self, query: str | None = None) -> list[_TModelInstance]: - """Run query against endpoint and collect model instances.""" - return [ - model - for model, _ in self._query_generate_model_bindings_mapping(query=query) - ] - - def _query_group_by( - self, group_by: str, query: str | None = None - ) -> dict[str, list[_TModelInstance]]: - """Run query against endpoint and group results by a SPARQL binding.""" - group = defaultdict(list) - - for model, bindings in self._query_generate_model_bindings_mapping(query): - try: - key = bindings[group_by] - except KeyError: - raise UndefinedBindingException( - f"SPARQL binding '{group_by}' requested for grouping " - f"not in query projection '{bindings}'." - ) - - group[str(key)].append(model) - - return group - - def _get_count(self, query: str) -> int: - """Construct a count query from the initialized query, run it and return the count result.""" - result = query_with_wrapper(query=query, sparql_wrapper=self.sparql_wrapper) - return int(next(result)["cnt"]) - - def _query_paginate_ungrouped(self, page: int, size: int) -> Page[_TModelInstance]: - """Run query with pagination according to page and size. - - The internal query is dynamically modified according to page (offset)/size (limit) - and run with SPARQLModelAdapter._query_collect_models. - """ - paginated_query = ungrouped_pagination_base_query.substitute( + page: int = 1, + size: int = 100, + ) -> Page[_TModelInstance]: + """Run a query against an endpoint and return a Page model object.""" + count_query: str = construct_count_query(self._query) + items_query: str = ungrouped_pagination_base_query.substitute( query=self._query, offset=calculate_offset(page, size), limit=size ) - count_query = construct_count_query(self._query) - items = self._query_collect_models(query=paginated_query) - total = self._get_count(count_query) - pages = math.ceil(total / size) + items_query_bindings: Iterator[dict] = query_with_wrapper( + query=items_query, sparql_wrapper=self.sparql_wrapper + ) + + mapper = ModelBindingsMapper(self._model, *items_query_bindings) + + items: list[_TModelInstance] = mapper.get_models() + total: int = self._get_count(count_query) + pages: int = math.ceil(total / size) return Page(items=items, page=page, size=size, total=total, pages=pages) - def _query_paginate_grouped( - self, page: int, size: int, group_by: str - ) -> Page[_TModelInstance]: - """Run query with pagination according to page/size and group result by a SPARQL binding. + def _get_count(self, query: str) -> int: + """Run a count query and return the count result. - The internal query is dynamically modified according to page (offset)/size (limit) - and run with SPARQLModelAdapter._query_group_by. + Helper for SPARQLModelAdapter.query. """ - grouped_paginated_query = construct_grouped_pagination_query( - query=self._query, page=page, size=size, group_by=group_by - ) - grouped_count_query = construct_grouped_count_query( - query=self._query, group_by=group_by + result: Iterator[dict] = query_with_wrapper( + query=query, sparql_wrapper=self.sparql_wrapper ) - - items = self._query_group_by(group_by=group_by, query=grouped_paginated_query) - total = self._get_count(grouped_count_query) - pages = math.ceil(total / size) - - return Page(items=items, page=page, size=size, total=total, pages=pages) + return int(next(result)["cnt"]) diff --git a/rdfproxy/mapper.py b/rdfproxy/mapper.py new file mode 100644 index 0000000..a7ba2c3 --- /dev/null +++ b/rdfproxy/mapper.py @@ -0,0 +1,87 @@ +"""ModelBindingsMapper: Functionality for mapping binding maps to a Pydantic model.""" + +from collections.abc import Iterator +from typing import Any, Generic, get_args + +from pydantic import BaseModel +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils.utils import ( + _collect_values_from_bindings, + _get_group_by, + _get_key_from_metadata, + _is_list_basemodel_type, + _is_list_type, +) + + +class ModelBindingsMapper(Generic[_TModelInstance]): + """Utility class for mapping flat bindings to a (potentially nested) Pydantic model.""" + + def __init__(self, model: type[_TModelInstance], *bindings: dict): + self.model = model + self.bindings = bindings + self._contexts = [] + + def get_models(self) -> list[_TModelInstance]: + """Generate a list of (potentially nested) Pydantic models based on (flat) bindings.""" + return self._get_unique_models(self.model, self.bindings) + + def _get_unique_models(self, model, bindings): + """Call the mapping logic and collect unique and non-empty models.""" + models = [] + for _bindings in bindings: + _model = model(**dict(self._generate_binding_pairs(model, **_bindings))) + + if any(_model.model_dump().values()) and (_model not in models): + models.append(_model) + + return models + + def _get_group_by(self, model, kwargs) -> str: + """Get the group_by value from a model and register it in self._contexts.""" + group_by: str = _get_group_by(model, kwargs) + + if group_by not in self._contexts: + self._contexts.append(group_by) + + return group_by + + def _generate_binding_pairs( + self, + model: type[BaseModel], + **kwargs, + ) -> Iterator[tuple[str, Any]]: + """Generate an Iterator[tuple] projection of the bindings needed for model instantation.""" + for k, v in model.model_fields.items(): + if _is_list_basemodel_type(v.annotation): + group_by: str = self._get_group_by(model, kwargs) + group_model, *_ = get_args(v.annotation) + + applicable_bindings = filter( + lambda x: (x[group_by] == kwargs[group_by]) + and (x[self._contexts[0]] == kwargs[self._contexts[0]]), + self.bindings, + ) + + value = self._get_unique_models(group_model, applicable_bindings) + + elif _is_list_type(v.annotation): + group_by: str = self._get_group_by(model, kwargs) + applicable_bindings = filter( + lambda x: x[group_by] == kwargs[group_by], + self.bindings, + ) + + binding_key: str = _get_key_from_metadata(v, default=k) + value = _collect_values_from_bindings(binding_key, applicable_bindings) + + elif isinstance(v.annotation, type(BaseModel)): + nested_model = v.annotation + value = nested_model( + **dict(self._generate_binding_pairs(nested_model, **kwargs)) + ) + else: + binding_key: str = _get_key_from_metadata(v, default=k) + value = kwargs.get(binding_key, v.default) + + yield k, value diff --git a/rdfproxy/utils/_exceptions.py b/rdfproxy/utils/_exceptions.py index 2a4d219..ff80a6b 100644 --- a/rdfproxy/utils/_exceptions.py +++ b/rdfproxy/utils/_exceptions.py @@ -7,3 +7,11 @@ class UndefinedBindingException(KeyError): class InterdependentParametersException(Exception): """Exceptiono for indicating that two or more parameters are interdependent.""" + + +class MissingModelConfigException(Exception): + """Exception for indicating that an expected Config class is missing in a Pydantic model definition.""" + + +class UnboundGroupingKeyException(Exception): + """Exception for indicating that no SPARQL binding corresponds to the requested grouping key.""" diff --git a/rdfproxy/utils/_types.py b/rdfproxy/utils/_types.py index 899327e..68df506 100644 --- a/rdfproxy/utils/_types.py +++ b/rdfproxy/utils/_types.py @@ -1,22 +1,13 @@ """Type definitions for rdfproxy.""" -from collections.abc import Iterable -from typing import Protocol, TypeVar, runtime_checkable +from typing import TypeVar -from SPARQLWrapper import QueryResult from pydantic import BaseModel _TModelInstance = TypeVar("_TModelInstance", bound=BaseModel) -@runtime_checkable -class _TModelConstructorCallable(Protocol[_TModelInstance]): - """Callback protocol for model constructor callables.""" - - def __call__(self, query_result: QueryResult) -> Iterable[_TModelInstance]: ... - - class SPARQLBinding(str): """SPARQLBinding type for explicit SPARQL binding to model field allocation. diff --git a/rdfproxy/utils/sparql/sparql_templates.py b/rdfproxy/utils/sparql/sparql_templates.py deleted file mode 100644 index f9b141c..0000000 --- a/rdfproxy/utils/sparql/sparql_templates.py +++ /dev/null @@ -1,10 +0,0 @@ -"""SPARQL Query templates for RDFProxy paginations.""" - -from string import Template - - -ungrouped_pagination_base_query = Template(""" -$query -limit $limit -offset $offset -""") diff --git a/rdfproxy/utils/sparql/sparql_utils.py b/rdfproxy/utils/sparql_utils.py similarity index 54% rename from rdfproxy/utils/sparql/sparql_utils.py rename to rdfproxy/utils/sparql_utils.py index 146fabf..a401a8a 100644 --- a/rdfproxy/utils/sparql/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -3,40 +3,20 @@ from collections.abc import Iterator from contextlib import contextmanager import re +from string import Template +from typing import Annotated +from typing import cast from SPARQLWrapper import QueryResult, SPARQLWrapper -from rdfproxy.utils.sparql.sparql_templates import ungrouped_pagination_base_query -from rdfproxy.utils.utils import get_bindings_from_query_result -def remove_query_prefixes(query: str) -> str: - """Remove prefix definitions from a SPARQL query. - - Prefix definitions need removing e.g. in injected subqueries. - """ - return re.sub( - pattern=r"^prefix.*", repl="", string=query, flags=re.I | re.MULTILINE - ) - - -def inject_subquery(query: str, subquery: str) -> str: - """Inject a subquery into query.""" - - def _indent_query(query: str, indent: int = 2) -> str: - """Indent a query by n spaces according to indent parameter.""" - indented_query = "".join( - [f"{' ' * indent}{line}\n" for line in query.splitlines()] - ) - return indented_query - - point: int = query.rfind("}") - partial_query: str = query[:point] - - _subquery = remove_query_prefixes(subquery) - indented_subquery: str = _indent_query(_subquery) - - new_query: str = f"{partial_query} " f"{{{indented_subquery}}}\n}}" - return new_query +ungrouped_pagination_base_query: Annotated[ + str, "SPARQL template for query pagination." +] = Template(""" +$query +limit $limit +offset $offset +""") def replace_query_select_clause(query: str, repl: str) -> str: @@ -72,26 +52,6 @@ def calculate_offset(page: int, size: int) -> int: return size * (page - 1) -def construct_grouped_pagination_query( - query: str, page: int, size: int, group_by: str -) -> str: - """Dynamically construct a query for grouped pagination. - - Based on the initial query, construct a query with limit/offset according to page/size - and with a SELECT clause that distinctly selects the group_by variable; - then inject that query into the initial query as a subquery. - """ - _paginated_query = ungrouped_pagination_base_query.substitute( - query=query, offset=calculate_offset(page, size), limit=size - ) - subquery = replace_query_select_clause( - _paginated_query, f"select distinct ?{group_by}" - ) - - grouped_pagination_query = inject_subquery(query=query, subquery=subquery) - return grouped_pagination_query - - def construct_grouped_count_query(query: str, group_by) -> str: grouped_count_query = replace_query_select_clause( query, f"select (count(distinct ?{group_by}) as ?cnt)" @@ -100,6 +60,28 @@ def construct_grouped_count_query(query: str, group_by) -> str: return grouped_count_query +def _get_bindings_from_bindings_dict(bindings_dict: dict) -> Iterator[dict]: + bindings = map( + lambda binding: {k: v["value"] for k, v in binding.items()}, + bindings_dict["results"]["bindings"], + ) + return bindings + + +def get_bindings_from_query_result(query_result: QueryResult) -> Iterator[dict]: + """Extract just the bindings from a SPARQLWrapper.QueryResult.""" + if (result_format := query_result.requestedFormat) != "json": + raise Exception( + "Only QueryResult objects with JSON format are currently supported. " + f"Received object with requestedFormat '{result_format}'." + ) + + query_json: dict = cast(dict, query_result.convert()) + bindings = _get_bindings_from_bindings_dict(query_json) + + return bindings + + @contextmanager def temporary_query_override(sparql_wrapper: SPARQLWrapper): """Context manager that allows to contextually overwrite a query in a SPARQLWrapper object.""" diff --git a/rdfproxy/utils/utils.py b/rdfproxy/utils/utils.py index 979efaa..1eb5c8b 100644 --- a/rdfproxy/utils/utils.py +++ b/rdfproxy/utils/utils.py @@ -1,90 +1,72 @@ """SPARQL/FastAPI utils.""" -from collections.abc import Iterator -from typing import Any +from collections.abc import Callable, Iterable +from typing import Any, get_args, get_origin -from SPARQLWrapper import QueryResult from pydantic import BaseModel from pydantic.fields import FieldInfo -from rdfproxy.utils._types import SPARQLBinding, _TModelInstance +from rdfproxy.utils._exceptions import ( + MissingModelConfigException, + UnboundGroupingKeyException, +) +from rdfproxy.utils._types import SPARQLBinding -def get_bindings_from_query_result(query_result: QueryResult) -> Iterator[dict]: - """Extract just the bindings from a SPARQLWrapper.QueryResult.""" - if (result_format := query_result.requestedFormat) != "json": - raise Exception( - "Only QueryResult objects with JSON format are currently supported. " - f"Received object with requestedFormat '{result_format}'." - ) +def _is_type(obj: type | None, _type: type) -> bool: + """Check if an obj is type _type or a GenericAlias with origin _type.""" + return (obj is _type) or (get_origin(obj) is _type) - query_json: dict = query_result.convert() - bindings = map( - lambda binding: {k: v["value"] for k, v in binding.items()}, - query_json["results"]["bindings"], - ) - - return bindings - - -def instantiate_model_from_kwargs( - model: type[_TModelInstance], **kwargs -) -> _TModelInstance: - """Instantiate a (potentially nested) model from (flat) kwargs. - More a more generic version of this function see upto.init_model_from_kwargs - https://github.com/lu-pl/upto?tab=readme-ov-file#init_model_from_kwargs. +def _is_list_type(obj: type | None) -> bool: + """Check if obj is a list type.""" + return _is_type(obj, list) - Example: - class SimpleModel(BaseModel): - x: int - y: int +def _is_list_basemodel_type(obj: type | None) -> bool: + """Check if a type is list[pydantic.BaseModel].""" + return (get_origin(obj) is list) and all( + issubclass(cls, BaseModel) for cls in get_args(obj) + ) - class NestedModel(BaseModel): - a: str - b: SimpleModel +def _collect_values_from_bindings( + binding_name: str, + bindings: Iterable[dict], + predicate: Callable[[Any], bool] = lambda x: x is not None, +) -> list: + """Scan bindings for a key binding_name and collect unique predicate-compliant values. + Note that element order is important for testing, so a set cast won't do. + """ + values = dict.fromkeys( + value + for binding in bindings + if predicate(value := binding.get(binding_name, None)) + ) + return list(values) - class ComplexModel(BaseModel): - p: str - q: NestedModel +def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any: + """Try to get a SPARQLBinding object from a field's metadata attribute. - model = instantiate_model_from_kwargs(ComplexModel, x=1, y=2, a="a value", p="p value") - print(model) # p='p value' q=NestedModel(a='a value', b=SimpleModel(x=1, y=2)) + Helper for _generate_binding_pairs. """ - - def _get_key_from_metadata(v: FieldInfo): - """Try to get a SPARQLBinding object from a field's metadata attribute. - - Helper for _generate_binding_pairs. - """ - try: - value = next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata)) - return value - except StopIteration: - return None - - def _generate_binding_pairs( - model: type[_TModelInstance], **kwargs - ) -> Iterator[tuple[str, Any]]: - """Get the bindings needed for model instantation. - - The function traverses model.model_fields - and constructs binding pairs by either getting values from kwargs or field defaults. - For model fields the recursive clause runs. - """ - for k, v in model.model_fields.items(): - if isinstance(v.annotation, type(BaseModel)): - value = v.annotation( - **dict(_generate_binding_pairs(v.annotation, **kwargs)) - ) - else: - binding_key = _get_key_from_metadata(v) or k - value = kwargs.get(binding_key, v.default) - - yield k, value - - bindings = dict(_generate_binding_pairs(model, **kwargs)) - return model(**bindings) + return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default) + + +def _get_group_by(model: type[BaseModel], kwargs: dict) -> str: + """Get the name of a grouping key from a model Config class.""" + try: + group_by = model.model_config["group_by"] # type: ignore + except KeyError as e: + raise MissingModelConfigException( + "Model config with 'group_by' value required " + "for field-based grouping behavior." + ) from e + else: + if group_by not in kwargs.keys(): + raise UnboundGroupingKeyException( + f"Requested grouping key '{group_by}' not in SPARQL binding projection.\n" + f"Applicable grouping keys: {', '.join(kwargs.keys())}." + ) + return group_by