diff --git a/rdfproxy/utils/sparql_utils.py b/rdfproxy/utils/sparql_utils.py index 62bcc9e..6a4e950 100644 --- a/rdfproxy/utils/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -1,19 +1,80 @@ """Functionality for dynamic SPARQL query modifcation.""" -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from contextlib import contextmanager from functools import partial +from itertools import chain import re -from typing import cast +from typing import cast, overload from SPARQLWrapper import QueryResult, SPARQLWrapper +from rdflib import Variable +from rdflib.plugins.sparql.parser import parseQuery +from rdflib.plugins.sparql.parserutils import CompValue, ParseResults from rdfproxy.utils._exceptions import QueryConstructionException from rdfproxy.utils._types import ItemsQueryConstructor, _TModelInstance +from rdfproxy.utils.utils import _is_iterable_of_str -def construct_ungrouped_pagination_query(query: str, limit: int, offset: int) -> str: +@overload +def _compvalue_to_dict(comp_value: dict | CompValue) -> dict: ... + + +@overload +def _compvalue_to_dict(comp_value: list | ParseResults) -> list: ... + + +def _compvalue_to_dict(comp_value: CompValue): + """Convert a CompValue parsing object into a Python dict/list representation.""" + if isinstance(comp_value, dict): + return {key: _compvalue_to_dict(value) for key, value in comp_value.items()} + elif isinstance(comp_value, list | ParseResults): + return [_compvalue_to_dict(item) for item in comp_value] + else: + return comp_value + + +def get_query_projection(query: str) -> list[str]: + """Parse a SPARQL SELECT query and extract the ordered bindings projection. + + The first case handles explicit/literal binding projections. + The second case handles implicit/* binding projections. + The third case handles implicit/* binding projections with VALUES. + """ + _parse_result: CompValue = parseQuery(query)[1] + parsed_query: dict = _compvalue_to_dict(_parse_result) + + match parsed_query: + case {"projection": projection}: + return [i["var"] for i in projection] + case {"where": {"part": [{"triples": triples}]}}: + projection = dict.fromkeys( + i for i in chain.from_iterable(triples) if isinstance(i, Variable) + ) + return list(projection) + case {"where": {"part": [{"var": var}]}}: + return var + case _: + raise Exception("Unable to obtain query projection.") + + +def construct_ungrouped_pagination_query( + query: str, limit: int, offset: int, order_by: str | Iterable[str] | None = None +) -> str: """Construct an ungrouped pagination query.""" - return f"{query} limit {limit} offset {offset}" + match order_by: + case None: + order_by_variables = get_query_projection(query) + case str(): + order_by_variables = [order_by] + case order_by if _is_iterable_of_str(order_by): + order_by_variables = order_by + case _: + raise TypeError( + "order_by value must be of type str | Iterable[str] | None." + ) + + return f"{query} order by {' '.join(map(lambda x: f'?{x}', order_by_variables))} limit {limit} offset {offset}" def replace_query_select_clause(query: str, repl: str) -> str: