Skip to content

Commit

Permalink
feat: QueryConstructor rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-pl committed Dec 19, 2024
1 parent 2aa9604 commit 34f3ddd
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 260 deletions.
33 changes: 12 additions & 21 deletions rdfproxy/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
import math
from typing import Generic

from rdfproxy.constructor import QueryConstructor
from rdfproxy.mapper import ModelBindingsMapper
from rdfproxy.sparql_strategies import HttpxStrategy, SPARQLStrategy
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils.models import Page, QueryParameters
from rdfproxy.utils.sparql_utils import (
calculate_offset,
construct_count_query,
construct_items_query,
)


class SPARQLModelAdapter(Generic[_TModelInstance]):
Expand All @@ -24,10 +20,12 @@ class SPARQLModelAdapter(Generic[_TModelInstance]):
SPARQLModelAdapter.query returns a Page model object with a default pagination size of 100 results.
SPARQL bindings are implicitly assigned to model fields of the same name,
explicit SPARQL binding to model field allocation is available with typing.Annotated and rdfproxy.SPARQLBinding.
explicit SPARQL binding to model field allocation is available with rdfproxy.SPARQLBinding.
Result grouping is controlled through the model,
i.e. grouping is triggered when a field of list[pydantic.BaseModel] is encountered.
See https://github.com/acdh-oeaw/rdfproxy/tree/main/examples for examples.
"""

def __init__(
Expand All @@ -44,20 +42,21 @@ def __init__(

def query(self, query_parameters: QueryParameters) -> Page[_TModelInstance]:
"""Run a query against an endpoint and return a Page model object."""
count_query: str = construct_count_query(query=self._query, model=self._model)
items_query: str = construct_items_query(
query_constructor = QueryConstructor(
query=self._query,
query_parameters=query_parameters,
model=self._model,
limit=query_parameters.size,
offset=calculate_offset(query_parameters.page, query_parameters.size),
)

items_query_bindings: Iterator[dict] = self.sparql_strategy.query(items_query)
count_query = query_constructor.get_count_query()
items_query = query_constructor.get_items_query()

items_query_bindings: Iterator[dict] = self.sparql_strategy.query(items_query)
mapper = ModelBindingsMapper(self._model, *items_query_bindings)

items: list[_TModelInstance] = mapper.get_models()
total: int = self._get_count(count_query)

count_query_bindings: Iterator[dict] = self.sparql_strategy.query(count_query)
total: int = int(next(count_query_bindings)["cnt"])
pages: int = math.ceil(total / query_parameters.size)

return Page(
Expand All @@ -67,11 +66,3 @@ def query(self, query_parameters: QueryParameters) -> Page[_TModelInstance]:
total=total,
pages=pages,
)

def _get_count(self, query: str) -> int:
"""Run a count query and return the count result.
Helper for SPARQLModelAdapter.query.
"""
result: Iterator[dict] = self.sparql_strategy.query(query)
return int(next(result)["cnt"])
124 changes: 124 additions & 0 deletions rdfproxy/constructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils.models import QueryParameters
from rdfproxy.utils.sparql_utils import (
add_solution_modifier,
get_query_projection,
inject_into_query,
remove_sparql_prefixes,
replace_query_select_clause,
)
from rdfproxy.utils.utils import (
FieldsBindingsMap,
QueryConstructorComponent as component,
compose_left,
)


class QueryConstructor:
"""The class encapsulates dynamic SPARQL query modification logic
for implementing purely SPARQL-based, deterministic pagination.
Public methods get_items_query and get_count_query are used in rdfproxy.SPARQLModelAdapter
to construct queries for retrieving arguments for Page object instantiation.
"""

def __init__(
self,
query: str,
query_parameters: QueryParameters,
model: type[_TModelInstance],
) -> None:
self.query = query
self.query_parameters = query_parameters
self.model = model

self.bindings_map = FieldsBindingsMap(model)
self.group_by: str | None = self.bindings_map.get(
model.model_config.get("group_by")
)

def get_items_query(self) -> str:
"""Construct a SPARQL items query for use in rdfproxy.SPARQLModelAdapter."""
if self.group_by is None:
return self._get_ungrouped_items_query()
return self._get_grouped_items_query()

def get_count_query(self) -> str:
"""Construct a SPARQL count query for use in rdfproxy.SPARQLModelAdapter"""
if self.group_by is None:
select_clause = "select (count(*) as ?cnt)"
else:
select_clause = f"select (count(distinct ?{self.group_by}) as ?cnt)"

return replace_query_select_clause(self.query, select_clause)

@staticmethod
def _calculate_offset(page: int, size: int) -> int:
"""Calculate the offset value for paginated SPARQL templates."""
match page:
case 1:
return 0
case 2:
return size
case _:
return size * (page - 1)

def _get_grouped_items_query(self) -> str:
"""Construct a SPARQL items query for grouped models."""
filter_clause: str | None = self._compute_filter_clause()
select_clause: str = self._compute_select_clause()
order_by_value: str = self._compute_order_by_value()
limit, offset = self._compute_limit_offset()

subquery = compose_left(
remove_sparql_prefixes,
component(replace_query_select_clause, repl=select_clause),
component(inject_into_query, injectant=filter_clause),
component(
add_solution_modifier,
order_by=order_by_value,
limit=limit,
offset=offset,
),
)(self.query)

return inject_into_query(self.query, subquery)

def _get_ungrouped_items_query(self) -> str:
"""Construct a SPARQL items query for ungrouped models."""
filter_clause: str | None = self._compute_filter_clause()
order_by_value: str = self._compute_order_by_value()
limit, offset = self._compute_limit_offset()

return compose_left(
component(inject_into_query, injectant=filter_clause),
component(
add_solution_modifier,
order_by=order_by_value,
limit=limit,
offset=offset,
),
)(self.query)

def _compute_limit_offset(self) -> tuple[int, int]:
"""Calculate limit and offset values for SPARQL-based pagination."""
limit = self.query_parameters.size
offset = self._calculate_offset(
self.query_parameters.page, self.query_parameters.size
)

return limit, offset

def _compute_filter_clause(self) -> str | None:
"""Stub: Always None for now."""
return None

def _compute_select_clause(self):
"""Stub: Static SELECT clause for now."""
return f"select distinct ?{self.group_by}"

def _compute_order_by_value(self):
"""Stub: Only basic logic for now."""
if self.group_by is None:
return get_query_projection(self.query)[0]
return f"{self.group_by}"
3 changes: 1 addition & 2 deletions rdfproxy/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel
from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance
from rdfproxy.utils.utils import (
from rdfproxy.utils.mapper_utils import (
_collect_values_from_bindings,
_get_group_by,
_get_key_from_metadata,
Expand Down Expand Up @@ -65,7 +65,6 @@ def _generate_binding_pairs(
and (x[self._contexts[0]] == kwargs[self._contexts[0]]),
self.bindings,
)

value = self._get_unique_models(group_model, applicable_bindings)

elif _is_list_type(v.annotation):
Expand Down
2 changes: 1 addition & 1 deletion rdfproxy/sparql_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, endpoint: str):

@abc.abstractmethod
def query(self, sparql_query: str) -> Iterator[dict[str, str]]:
raise NotImplementedError
raise NotImplementedError # pragma: no cover

@staticmethod
def _get_bindings_from_bindings_dict(bindings_dict: dict) -> Iterator[dict]:
Expand Down
127 changes: 127 additions & 0 deletions rdfproxy/utils/mapper_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from collections.abc import Callable, Iterable
from typing import Any, TypeGuard, get_args, get_origin

from pydantic import BaseModel
from pydantic.fields import FieldInfo
from rdfproxy.utils._exceptions import (
InvalidGroupingKeyException,
MissingModelConfigException,
)
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue


def _is_type(obj: type | None, _type: type) -> bool:
"""Check if an obj is type _type or a GenericAlias with origin _type."""
return (obj is _type) or (get_origin(obj) is _type)


def _is_list_type(obj: type | None) -> bool:
"""Check if obj is a list type."""
return _is_type(obj, list)


def _is_list_basemodel_type(obj: type | None) -> bool:
"""Check if a type is list[pydantic.BaseModel]."""
return (get_origin(obj) is list) and all(
issubclass(cls, BaseModel) for cls in get_args(obj)
)


def _collect_values_from_bindings(
binding_name: str,
bindings: Iterable[dict],
predicate: Callable[[Any], bool] = lambda x: x is not None,
) -> list:
"""Scan bindings for a key binding_name and collect unique predicate-compliant values.
Note that element order is important for testing, so a set cast won't do.
"""
values = dict.fromkeys(
value
for binding in bindings
if predicate(value := binding.get(binding_name, None))
)
return list(values)


def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any:
"""Try to get a SPARQLBinding object from a field's metadata attribute.
Helper for _generate_binding_pairs.
"""
return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default)


def _get_applicable_grouping_keys(model: type[_TModelInstance]) -> list[str]:
return [k for k, v in model.model_fields.items() if not _is_list_type(v.annotation)]


def _get_group_by(model: type[_TModelInstance]) -> str:
"""Get the name of a grouping key from a model Config class."""
try:
group_by = model.model_config["group_by"] # type: ignore
except KeyError as e:
raise MissingModelConfigException(
"Model config with 'group_by' value required "
"for field-based grouping behavior."
) from e
else:
applicable_keys = _get_applicable_grouping_keys(model=model)

if group_by not in applicable_keys:
raise InvalidGroupingKeyException(
f"Invalid grouping key '{group_by}'. "
f"Applicable grouping keys: {', '.join(applicable_keys)}."
)

if meta := model.model_fields[group_by].metadata:
if binding := next(
filter(lambda entry: isinstance(entry, SPARQLBinding), meta), None
):
return binding
return group_by


def default_model_bool_predicate(model: BaseModel) -> bool:
"""Default predicate for determining model truthiness.
Adheres to rdfproxy.utils._types.ModelBoolPredicate.
"""
return any(dict(model).values())


def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]:
return (not isinstance(iterable, str)) and all(
map(lambda i: isinstance(i, str), iterable)
)


def _get_model_bool_predicate_from_config_value(
model_bool_value: _TModelBoolValue,
) -> ModelBoolPredicate:
"""Get a model_bool predicate function given the value of the model_bool config setting."""
match model_bool_value:
case ModelBoolPredicate():
return model_bool_value
case str():
return lambda model: bool(dict(model)[model_bool_value])
case model_bool_value if _is_iterable_of_str(model_bool_value):
return lambda model: all(map(lambda k: dict(model)[k], model_bool_value))
case _: # pragma: no cover
raise TypeError(
"Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n"
f"Received {type(model_bool_value)}"
)


def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate:
"""Get the applicable model_bool predicate function given a model."""
if (model_bool_value := model.model_config.get("model_bool", None)) is None:
model_bool_predicate = default_model_bool_predicate
else:
model_bool_predicate = _get_model_bool_predicate_from_config_value(
model_bool_value
)

return model_bool_predicate
Loading

0 comments on commit 34f3ddd

Please sign in to comment.