Skip to content

Commit

Permalink
feat: redesign SPARQLModelAdapter
Browse files Browse the repository at this point in the history
Major redesign of the SPARQLModelAdapter class.

Add a code path in ModelBindingsMapper for handling list type
fields (other than list[BaseModel]) that triggers array collection behavior.

Move ModelBindingsMapper into its own module.

Closes #57, #81.
  • Loading branch information
lu-pl committed Sep 19, 2024
1 parent 0394a54 commit afc689b
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 301 deletions.
1 change: 1 addition & 0 deletions rdfproxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from rdfproxy.adapter import SPARQLModelAdapter # noqa: F401
from rdfproxy.mapper import ModelBindingsMapper # noqa: F401
from rdfproxy.utils._types import SPARQLBinding # noqa: F401
from rdfproxy.utils.models import Page # noqa: F401
195 changes: 36 additions & 159 deletions rdfproxy/adapter.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,34 @@
"""SPARQLModelAdapter class for SPARQL query result set to Pydantic model conversions."""

from collections import defaultdict
from collections.abc import Iterator
import math
from typing import Any, Generic, overload
from typing import Generic

from SPARQLWrapper import JSON, QueryResult, SPARQLWrapper
from rdfproxy.utils._exceptions import (
InterdependentParametersException,
UndefinedBindingException,
)
from SPARQLWrapper import JSON, SPARQLWrapper
from rdfproxy.mapper import ModelBindingsMapper
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils.models import Page
from rdfproxy.utils.sparql.sparql_templates import ungrouped_pagination_base_query
from rdfproxy.utils.sparql.sparql_utils import (
from rdfproxy.utils.sparql_utils import (
calculate_offset,
construct_count_query,
construct_grouped_count_query,
construct_grouped_pagination_query,
query_with_wrapper,
temporary_query_override,
)
from rdfproxy.utils.utils import (
get_bindings_from_query_result,
instantiate_model_from_kwargs,
ungrouped_pagination_base_query,
)


class SPARQLModelAdapter(Generic[_TModelInstance]):
"""Adapter/Mapper for SPARQL query result set to Pydantic model conversions.
The rdfproxy.SPARQLModelAdapter class allows to run a query against an endpoint,
map a flat SPARQL query result set to a potentially nested Pydantic model and
optionally paginate and/or group the results by a SPARQL binding.
The rdfproxy.SPARQLModelAdapter class allows to run a query against an endpoint
and map a flat SPARQL query result set to a potentially nested Pydantic model.
SPARQLModelAdapter.query returns a Page model object with a default pagination size of 100 results.
SPARQL bindings are implicitly assigned to model fields of the same name,
explicit SPARQL binding to model field allocation is available with typing.Annotated and rdfproxy.SPARQLBinding.
Result grouping is controlled through the model,
i.e. grouping is triggered when a field of list[pydantic.BaseModel] is encountered.
"""

def __init__(
Expand All @@ -45,156 +41,37 @@ def __init__(
SPARQLWrapper(target) if isinstance(target, str) else target
)
self.sparql_wrapper.setReturnFormat(JSON)
self.sparql_wrapper.setQuery(query)

@overload
def query(self) -> list[_TModelInstance]: ...

@overload
def query(
self,
*,
group_by: str,
) -> dict[str, list[_TModelInstance]]: ...

@overload
def query(
self,
*,
page: int,
size: int,
) -> Page[_TModelInstance]: ...

@overload
def query(
self,
*,
page: int,
size: int,
group_by: str,
) -> Page[_TModelInstance]: ...

def query(
self,
*,
page: int | None = None,
size: int | None = None,
group_by: str | None = None,
) -> (
list[_TModelInstance] | dict[str, list[_TModelInstance]] | Page[_TModelInstance]
):
"""Run query against endpoint and map the SPARQL query result set to a Pydantic model.
Optional pagination and/or grouping by a SPARQL binding is avaible by
supplying the group_by and/or page/size parameters.
"""
match page, size, group_by:
case None, None, None:
return self._query_collect_models()
case int(), int(), None:
return self._query_paginate_ungrouped(page=page, size=size)
case None, None, str():
return self._query_group_by(group_by=group_by)
case int(), int(), str():
return self._query_paginate_grouped(
page=page, size=size, group_by=group_by
)
case (None, int(), Any()) | (int(), None, Any()):
raise InterdependentParametersException(
"Parameters 'page' and 'size' are mutually dependent."
)
case _:
raise Exception("This should never happen.")

def _query_generate_model_bindings_mapping(
self, query: str | None = None
) -> Iterator[tuple[_TModelInstance, dict[str, Any]]]:
"""Run query, construct model instances and generate a model-bindings mapping.
The query parameter defaults to the initially defined query and
is run against the endpoint defined in the SPARQLModelAdapter instance.
Note: The coupling of model instances with flat SPARQL results
allows for easier and more efficient grouping operations (see grouping functionality).
"""
if query is None:
query_result: QueryResult = self.sparql_wrapper.query()
else:
with temporary_query_override(self.sparql_wrapper):
self.sparql_wrapper.setQuery(query)
query_result: QueryResult = self.sparql_wrapper.query()

_bindings = get_bindings_from_query_result(query_result)

for bindings in _bindings:
model = instantiate_model_from_kwargs(self._model, **bindings)
yield model, bindings

def _query_collect_models(self, query: str | None = None) -> list[_TModelInstance]:
"""Run query against endpoint and collect model instances."""
return [
model
for model, _ in self._query_generate_model_bindings_mapping(query=query)
]

def _query_group_by(
self, group_by: str, query: str | None = None
) -> dict[str, list[_TModelInstance]]:
"""Run query against endpoint and group results by a SPARQL binding."""
group = defaultdict(list)

for model, bindings in self._query_generate_model_bindings_mapping(query):
try:
key = bindings[group_by]
except KeyError:
raise UndefinedBindingException(
f"SPARQL binding '{group_by}' requested for grouping "
f"not in query projection '{bindings}'."
)

group[str(key)].append(model)

return group

def _get_count(self, query: str) -> int:
"""Construct a count query from the initialized query, run it and return the count result."""
result = query_with_wrapper(query=query, sparql_wrapper=self.sparql_wrapper)
return int(next(result)["cnt"])

def _query_paginate_ungrouped(self, page: int, size: int) -> Page[_TModelInstance]:
"""Run query with pagination according to page and size.
The internal query is dynamically modified according to page (offset)/size (limit)
and run with SPARQLModelAdapter._query_collect_models.
"""
paginated_query = ungrouped_pagination_base_query.substitute(
page: int = 1,
size: int = 100,
) -> Page[_TModelInstance]:
"""Run a query against an endpoint and return a Page model object."""
count_query: str = construct_count_query(self._query)
items_query: str = ungrouped_pagination_base_query.substitute(
query=self._query, offset=calculate_offset(page, size), limit=size
)
count_query = construct_count_query(self._query)

items = self._query_collect_models(query=paginated_query)
total = self._get_count(count_query)
pages = math.ceil(total / size)
items_query_bindings: Iterator[dict] = query_with_wrapper(
query=items_query, sparql_wrapper=self.sparql_wrapper
)

mapper = ModelBindingsMapper(self._model, *items_query_bindings)

items: list[_TModelInstance] = mapper.get_models()
total: int = self._get_count(count_query)
pages: int = math.ceil(total / size)

return Page(items=items, page=page, size=size, total=total, pages=pages)

def _query_paginate_grouped(
self, page: int, size: int, group_by: str
) -> Page[_TModelInstance]:
"""Run query with pagination according to page/size and group result by a SPARQL binding.
def _get_count(self, query: str) -> int:
"""Run a count query and return the count result.
The internal query is dynamically modified according to page (offset)/size (limit)
and run with SPARQLModelAdapter._query_group_by.
Helper for SPARQLModelAdapter.query.
"""
grouped_paginated_query = construct_grouped_pagination_query(
query=self._query, page=page, size=size, group_by=group_by
)
grouped_count_query = construct_grouped_count_query(
query=self._query, group_by=group_by
result: Iterator[dict] = query_with_wrapper(
query=query, sparql_wrapper=self.sparql_wrapper
)

items = self._query_group_by(group_by=group_by, query=grouped_paginated_query)
total = self._get_count(grouped_count_query)
pages = math.ceil(total / size)

return Page(items=items, page=page, size=size, total=total, pages=pages)
return int(next(result)["cnt"])
87 changes: 87 additions & 0 deletions rdfproxy/mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""ModelBindingsMapper: Functionality for mapping binding maps to a Pydantic model."""

from collections.abc import Iterator
from typing import Any, Generic, get_args

from pydantic import BaseModel
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils.utils import (
_collect_values_from_bindings,
_get_group_by,
_get_key_from_metadata,
_list_basemodel_p,
_list_p,
)


class ModelBindingsMapper(Generic[_TModelInstance]):
"""Utility class for mapping flat bindings to a (potentially nested) Pydantic model."""

def __init__(self, model: type[_TModelInstance], *bindings: dict):
self.model = model
self.bindings = bindings
self._contexts = []

def get_models(self) -> list[_TModelInstance]:
"""Generate a list of (potentially nested) Pydantic models based on (flat) bindings."""
return self._get_unique_models(self.model, self.bindings)

def _get_unique_models(self, model, bindings):
"""Call the mapping logic and collect unique and non-empty models."""
models = []
for _bindings in bindings:
_model = model(**dict(self._generate_binding_pairs(model, **_bindings)))

if any(_model.model_dump().values()) and (_model not in models):
models.append(_model)

return models

def _get_group_by(self, model, kwargs) -> str:
"""Get the group_by value from a model and register it in self._contexts."""
group_by: str = _get_group_by(model, kwargs)

if group_by not in self._contexts:
self._contexts.append(group_by)

return group_by

def _generate_binding_pairs(
self,
model: type[BaseModel],
**kwargs,
) -> Iterator[tuple[str, Any]]:
"""Generate an Iterator[tuple] projection of the bindings needed for model instantation."""
for k, v in model.model_fields.items():
if _list_basemodel_p(v.annotation):
group_by: str = self._get_group_by(model, kwargs)
group_model, *_ = get_args(v.annotation)

applicable_bindings = filter(
lambda x: (x[group_by] == kwargs[group_by])
and (x[self._contexts[0]] == kwargs[self._contexts[0]]),
self.bindings,
)

value = self._get_unique_models(group_model, applicable_bindings)

elif _list_p(v.annotation):
group_by: str = self._get_group_by(model, kwargs)
applicable_bindings = filter(
lambda x: x[group_by] == kwargs[group_by],
self.bindings,
)

binding_key: str = _get_key_from_metadata(v) or k
value = _collect_values_from_bindings(binding_key, applicable_bindings)

elif isinstance(v.annotation, type(BaseModel)):
nested_model = v.annotation
value = nested_model(
**dict(self._generate_binding_pairs(nested_model, **kwargs))
)
else:
binding_key: str = _get_key_from_metadata(v) or k
value = kwargs.get(binding_key, v.default)

yield k, value
8 changes: 8 additions & 0 deletions rdfproxy/utils/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@ class UndefinedBindingException(KeyError):

class InterdependentParametersException(Exception):
"""Exceptiono for indicating that two or more parameters are interdependent."""


class MissingModelConfigException(Exception):
"""Exception for indicating that an expected Config class is missing in a Pydantic model definition."""


class UnboundGroupingKeyException(Exception):
"""Exception for indicating that no SPARQL binding corresponds to the requested grouping key."""
11 changes: 1 addition & 10 deletions rdfproxy/utils/_types.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
"""Type definitions for rdfproxy."""

from collections.abc import Iterable
from typing import Protocol, TypeVar, runtime_checkable
from typing import TypeVar

from SPARQLWrapper import QueryResult
from pydantic import BaseModel


_TModelInstance = TypeVar("_TModelInstance", bound=BaseModel)


@runtime_checkable
class _TModelConstructorCallable(Protocol[_TModelInstance]):
"""Callback protocol for model constructor callables."""

def __call__(self, query_result: QueryResult) -> Iterable[_TModelInstance]: ...


class SPARQLBinding(str):
"""SPARQLBinding type for explicit SPARQL binding to model field allocation.
Expand Down
10 changes: 0 additions & 10 deletions rdfproxy/utils/sparql/sparql_templates.py

This file was deleted.

Loading

0 comments on commit afc689b

Please sign in to comment.