-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: rewrite query construction functionality
The change introduces a QueryConstructor class that encapsulates all SPARQL query construction functionality. This also leads to a significant cleanup of rdfproxy.utils.sparql_utils module (utils in general) and the SPARQLModelAdapter class. Query result ordering for ungrouped models is now implemented to default to the first binding of the projection as ORDER BY value. This might still be discussed in the future, but the decision seems reasonable at this point. Closes #128. Closes #134. Closes #168.
- Loading branch information
Showing
7 changed files
with
384 additions
and
261 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from rdfproxy.utils._types import _TModelInstance | ||
from rdfproxy.utils.models import QueryParameters | ||
from rdfproxy.utils.sparql_utils import ( | ||
add_solution_modifier, | ||
get_query_projection, | ||
inject_into_query, | ||
remove_sparql_prefixes, | ||
replace_query_select_clause, | ||
) | ||
from rdfproxy.utils.utils import ( | ||
FieldsBindingsMap, | ||
QueryConstructorComponent as component, | ||
compose_left, | ||
) | ||
|
||
|
||
class QueryConstructor: | ||
"""The class encapsulates dynamic SPARQL query modification logic | ||
for implementing purely SPARQL-based, deterministic pagination. | ||
Public methods get_items_query and get_count_query are used in rdfproxy.SPARQLModelAdapter | ||
to construct queries for retrieving arguments for Page object instantiation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
query: str, | ||
query_parameters: QueryParameters, | ||
model: type[_TModelInstance], | ||
) -> None: | ||
self.query = query | ||
self.query_parameters = query_parameters | ||
self.model = model | ||
|
||
self.bindings_map = FieldsBindingsMap(model) | ||
self.group_by: str | None = self.bindings_map.get( | ||
model.model_config.get("group_by") | ||
) | ||
|
||
def get_items_query(self) -> str: | ||
"""Construct a SPARQL items query for use in rdfproxy.SPARQLModelAdapter.""" | ||
if self.group_by is None: | ||
return self._get_ungrouped_items_query() | ||
return self._get_grouped_items_query() | ||
|
||
def get_count_query(self) -> str: | ||
"""Construct a SPARQL count query for use in rdfproxy.SPARQLModelAdapter""" | ||
if self.group_by is None: | ||
select_clause = "select (count(*) as ?cnt)" | ||
else: | ||
select_clause = f"select (count(distinct ?{self.group_by}) as ?cnt)" | ||
|
||
return replace_query_select_clause(self.query, select_clause) | ||
|
||
@staticmethod | ||
def _calculate_offset(page: int, size: int) -> int: | ||
"""Calculate the offset value for paginated SPARQL templates.""" | ||
match page: | ||
case 1: | ||
return 0 | ||
case 2: | ||
return size | ||
case _: | ||
return size * (page - 1) | ||
|
||
def _get_grouped_items_query(self) -> str: | ||
"""Construct a SPARQL items query for grouped models.""" | ||
filter_clause: str | None = self._compute_filter_clause() | ||
select_clause: str = self._compute_select_clause() | ||
order_by_value: str = self._compute_order_by_value() | ||
limit, offset = self._compute_limit_offset() | ||
|
||
subquery = compose_left( | ||
remove_sparql_prefixes, | ||
component(replace_query_select_clause, repl=select_clause), | ||
component(inject_into_query, injectant=filter_clause), | ||
component( | ||
add_solution_modifier, | ||
order_by=order_by_value, | ||
limit=limit, | ||
offset=offset, | ||
), | ||
)(self.query) | ||
|
||
return inject_into_query(self.query, subquery) | ||
|
||
def _get_ungrouped_items_query(self) -> str: | ||
"""Construct a SPARQL items query for ungrouped models.""" | ||
filter_clause: str | None = self._compute_filter_clause() | ||
order_by_value: str = self._compute_order_by_value() | ||
limit, offset = self._compute_limit_offset() | ||
|
||
return compose_left( | ||
component(inject_into_query, injectant=filter_clause), | ||
component( | ||
add_solution_modifier, | ||
order_by=order_by_value, | ||
limit=limit, | ||
offset=offset, | ||
), | ||
)(self.query) | ||
|
||
def _compute_limit_offset(self) -> tuple[int, int]: | ||
"""Calculate limit and offset values for SPARQL-based pagination.""" | ||
limit = self.query_parameters.size | ||
offset = self._calculate_offset( | ||
self.query_parameters.page, self.query_parameters.size | ||
) | ||
|
||
return limit, offset | ||
|
||
def _compute_filter_clause(self) -> str | None: | ||
"""Stub: Always None for now.""" | ||
return None | ||
|
||
def _compute_select_clause(self): | ||
"""Stub: Static SELECT clause for now.""" | ||
return f"select distinct ?{self.group_by}" | ||
|
||
def _compute_order_by_value(self): | ||
"""Stub: Only basic logic for now.""" | ||
if self.group_by is None: | ||
return get_query_projection(self.query)[0] | ||
return f"{self.group_by}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from collections.abc import Callable, Iterable | ||
from typing import Any, TypeGuard, get_args, get_origin | ||
|
||
from pydantic import BaseModel | ||
from pydantic.fields import FieldInfo | ||
from rdfproxy.utils._exceptions import ( | ||
InvalidGroupingKeyException, | ||
MissingModelConfigException, | ||
) | ||
from rdfproxy.utils._types import _TModelInstance | ||
from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue | ||
|
||
|
||
def _is_type(obj: type | None, _type: type) -> bool: | ||
"""Check if an obj is type _type or a GenericAlias with origin _type.""" | ||
return (obj is _type) or (get_origin(obj) is _type) | ||
|
||
|
||
def _is_list_type(obj: type | None) -> bool: | ||
"""Check if obj is a list type.""" | ||
return _is_type(obj, list) | ||
|
||
|
||
def _is_list_basemodel_type(obj: type | None) -> bool: | ||
"""Check if a type is list[pydantic.BaseModel].""" | ||
return (get_origin(obj) is list) and all( | ||
issubclass(cls, BaseModel) for cls in get_args(obj) | ||
) | ||
|
||
|
||
def _collect_values_from_bindings( | ||
binding_name: str, | ||
bindings: Iterable[dict], | ||
predicate: Callable[[Any], bool] = lambda x: x is not None, | ||
) -> list: | ||
"""Scan bindings for a key binding_name and collect unique predicate-compliant values. | ||
Note that element order is important for testing, so a set cast won't do. | ||
""" | ||
values = dict.fromkeys( | ||
value | ||
for binding in bindings | ||
if predicate(value := binding.get(binding_name, None)) | ||
) | ||
return list(values) | ||
|
||
|
||
def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any: | ||
"""Try to get a SPARQLBinding object from a field's metadata attribute. | ||
Helper for _generate_binding_pairs. | ||
""" | ||
return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default) | ||
|
||
|
||
def _get_applicable_grouping_keys(model: type[_TModelInstance]) -> list[str]: | ||
return [k for k, v in model.model_fields.items() if not _is_list_type(v.annotation)] | ||
|
||
|
||
def _get_group_by(model: type[_TModelInstance]) -> str: | ||
"""Get the name of a grouping key from a model Config class.""" | ||
try: | ||
group_by = model.model_config["group_by"] # type: ignore | ||
except KeyError as e: | ||
raise MissingModelConfigException( | ||
"Model config with 'group_by' value required " | ||
"for field-based grouping behavior." | ||
) from e | ||
else: | ||
applicable_keys = _get_applicable_grouping_keys(model=model) | ||
|
||
if group_by not in applicable_keys: | ||
raise InvalidGroupingKeyException( | ||
f"Invalid grouping key '{group_by}'. " | ||
f"Applicable grouping keys: {', '.join(applicable_keys)}." | ||
) | ||
|
||
if meta := model.model_fields[group_by].metadata: | ||
if binding := next( | ||
filter(lambda entry: isinstance(entry, SPARQLBinding), meta), None | ||
): | ||
return binding | ||
return group_by | ||
|
||
|
||
def default_model_bool_predicate(model: BaseModel) -> bool: | ||
"""Default predicate for determining model truthiness. | ||
Adheres to rdfproxy.utils._types.ModelBoolPredicate. | ||
""" | ||
return any(dict(model).values()) | ||
|
||
|
||
def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]: | ||
return (not isinstance(iterable, str)) and all( | ||
map(lambda i: isinstance(i, str), iterable) | ||
) | ||
|
||
|
||
def _get_model_bool_predicate_from_config_value( | ||
model_bool_value: _TModelBoolValue, | ||
) -> ModelBoolPredicate: | ||
"""Get a model_bool predicate function given the value of the model_bool config setting.""" | ||
match model_bool_value: | ||
case ModelBoolPredicate(): | ||
return model_bool_value | ||
case str(): | ||
return lambda model: bool(dict(model)[model_bool_value]) | ||
case model_bool_value if _is_iterable_of_str(model_bool_value): | ||
return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) | ||
case _: # pragma: no cover | ||
raise TypeError( | ||
"Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" | ||
f"Received {type(model_bool_value)}" | ||
) | ||
|
||
|
||
def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: | ||
"""Get the applicable model_bool predicate function given a model.""" | ||
if (model_bool_value := model.model_config.get("model_bool", None)) is None: | ||
model_bool_predicate = default_model_bool_predicate | ||
else: | ||
model_bool_predicate = _get_model_bool_predicate_from_config_value( | ||
model_bool_value | ||
) | ||
|
||
return model_bool_predicate |
Oops, something went wrong.