diff --git a/rdfproxy/utils/sparql_utils.py b/rdfproxy/utils/sparql_utils.py index 0fbf613..7090e85 100644 --- a/rdfproxy/utils/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -3,17 +3,67 @@ from collections.abc import Iterator from contextlib import contextmanager from functools import partial +from itertools import chain import re from textwrap import indent -from typing import cast +from typing import cast, overload from SPARQLWrapper import QueryResult, SPARQLWrapper +from rdflib import Variable +from rdflib.plugins.sparql.parser import parseQuery +from rdflib.plugins.sparql.parserutils import CompValue, ParseResults from rdfproxy.utils._types import ItemsQueryConstructor, _TModelInstance -def construct_ungrouped_pagination_query(query: str, limit: int, offset: int) -> str: +@overload +def _compvalue_to_dict(comp_value: dict | CompValue) -> dict: ... + + +@overload +def _compvalue_to_dict(comp_value: list | ParseResults) -> list: ... + + +def _compvalue_to_dict(comp_value: CompValue): + """Convert a CompValue parsing object into a Python dict/list representation.""" + if isinstance(comp_value, dict): + return {key: _compvalue_to_dict(value) for key, value in comp_value.items()} + elif isinstance(comp_value, list | ParseResults): + return [_compvalue_to_dict(item) for item in comp_value] + else: + return comp_value + + +def get_query_projection(query: str) -> list[str]: + """Parse a SPARQL SELECT query and extract the ordered bindings projection. + + The first case handles explicit/literal binding projections. + The second case handles implicit/* binding projections. + The third case handles implicit/* binding projections with VALUES. + """ + _parse_result: CompValue = parseQuery(query)[1] + parsed_query: dict = _compvalue_to_dict(_parse_result) + + match parsed_query: + case {"projection": projection}: + return [i["var"] for i in projection] + case {"where": {"part": [{"triples": triples}]}}: + projection = dict.fromkeys( + i for i in chain.from_iterable(triples) if isinstance(i, Variable) + ) + return list(projection) + case {"where": {"part": [{"var": var}]}}: + return var + case _: + raise Exception("Unable to obtain query projection.") + + +def construct_ungrouped_pagination_query( + query: str, limit: int, offset: int, order_by: str | None = None +) -> str: """Construct an ungrouped pagination query.""" - return f"{query} limit {limit} offset {offset}" + order_by = get_query_projection(query)[0] if order_by is None else order_by + + return f"{query} order by ?{order_by} limit {limit} offset {offset}" def replace_query_select_clause(query: str, repl: str) -> str: