diff --git a/rdfproxy/adapter.py b/rdfproxy/adapter.py index 0b3fe48..9836b9b 100644 --- a/rdfproxy/adapter.py +++ b/rdfproxy/adapter.py @@ -11,8 +11,8 @@ from rdfproxy.utils.sparql_utils import ( calculate_offset, construct_count_query, + construct_items_query, query_with_wrapper, - ungrouped_pagination_base_query, ) @@ -50,8 +50,11 @@ def query( ) -> Page[_TModelInstance]: """Run a query against an endpoint and return a Page model object.""" count_query: str = construct_count_query(query=self._query, model=self._model) - items_query: str = ungrouped_pagination_base_query.substitute( - query=self._query, offset=calculate_offset(page, size), limit=size + items_query: str = construct_items_query( + query=self._query, + model=self._model, + limit=size, + offset=calculate_offset(page, size), ) items_query_bindings: Iterator[dict] = query_with_wrapper( diff --git a/rdfproxy/utils/_types.py b/rdfproxy/utils/_types.py index 68df506..f079f50 100644 --- a/rdfproxy/utils/_types.py +++ b/rdfproxy/utils/_types.py @@ -1,6 +1,6 @@ """Type definitions for rdfproxy.""" -from typing import TypeVar +from typing import Protocol, TypeVar from pydantic import BaseModel @@ -8,6 +8,10 @@ _TModelInstance = TypeVar("_TModelInstance", bound=BaseModel) +class ItemsQueryConstructor(Protocol): + def __call__(self, query: str, limit: int, offset: int) -> str: ... + + class SPARQLBinding(str): """SPARQLBinding type for explicit SPARQL binding to model field allocation. diff --git a/rdfproxy/utils/sparql_utils.py b/rdfproxy/utils/sparql_utils.py index 5398cd2..dc78c55 100644 --- a/rdfproxy/utils/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -2,22 +2,18 @@ from collections.abc import Iterator from contextlib import contextmanager +from functools import partial import re -from string import Template -from typing import Annotated +from textwrap import indent from typing import cast from SPARQLWrapper import QueryResult, SPARQLWrapper -from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import ItemsQueryConstructor, _TModelInstance -ungrouped_pagination_base_query: Annotated[ - str, "SPARQL template for query pagination." -] = Template(""" -$query -limit $limit -offset $offset -""") +def construct_ungrouped_pagination_query(query: str, limit: int, offset: int) -> str: + """Construct an ungrouped pagination query.""" + return f"{query} limit {limit} offset {offset}" def replace_query_select_clause(query: str, repl: str) -> str: @@ -39,6 +35,54 @@ def replace_query_select_clause(query: str, repl: str) -> str: return modified_query +def inject_subquery( + query: str, subquery: str, indent_depth: int = 4, indent_char: str = " " +) -> str: + """Inject a SPARQL query with a subquery. + + Also apply some basic indentation. + """ + indent_value = indent_char * indent_depth + indented_subquery = indent(f"\n{subquery}\n", indent_value) + indented_subclause = indent(f"\n{{{indented_subquery}}}", indent_value) + return re.sub(r".*\}$", f"{indented_subclause}\n}}", query) + + +def construct_grouped_pagination_query( + query: str, group_by_value: str, limit: int, offset: int +) -> str: + """Construct a grouped pagination query.""" + _subquery_base: str = replace_query_select_clause( + query=query, repl=f"select distinct ?{group_by_value}" + ) + subquery: str = construct_ungrouped_pagination_query( + query=_subquery_base, limit=limit, offset=offset + ) + + grouped_pagination_query: str = inject_subquery(query=query, subquery=subquery) + return grouped_pagination_query + + +def get_items_query_constructor( + model: type[_TModelInstance], +) -> ItemsQueryConstructor: + """Get the applicable query constructor function given a model class.""" + + if (group_by_value := model.model_config.get("group_by"), None) is None: + return construct_ungrouped_pagination_query + return partial(construct_grouped_pagination_query, group_by_value=group_by_value) + + +def construct_items_query( + query: str, model: type[_TModelInstance], limit: int, offset: int +) -> str: + """Construct a grouped pagination query.""" + items_query_constructor: ItemsQueryConstructor = get_items_query_constructor( + model=model + ) + return items_query_constructor(query=query, limit=limit, offset=offset) + + def construct_count_query(query: str, model: type[_TModelInstance]) -> str: """Construct a generic count query from a SELECT query.""" try: diff --git a/tests/tests_adapter/test_adapter_grouped_pagination.py b/tests/tests_adapter/test_adapter_grouped_pagination.py new file mode 100644 index 0000000..f201f22 --- /dev/null +++ b/tests/tests_adapter/test_adapter_grouped_pagination.py @@ -0,0 +1,107 @@ +"""Basic tests for rdfproxy.SPARQLModelAdapter pagination with grouped models.""" + +from typing import Any, NamedTuple + +import pytest + +from pydantic import BaseModel, ConfigDict +from rdfproxy import Page, SPARQLModelAdapter + + +query = """ +select ?parent ?child ?name +where { + values (?parent ?child ?name) { + ('x' 'c' 'foo') + ('y' 'd' UNDEF) + ('y' 'e' UNDEF) + ('z' UNDEF UNDEF) + } +} +""" + + +class Child(BaseModel): + name: str | None = None + + +class Parent(BaseModel): + model_config = ConfigDict(group_by="parent") + + parent: str + children: list[Child] + + +parent_adapter = SPARQLModelAdapter( + target="https://graphdb.r11.eu/repositories/RELEVEN", + query=query, + model=Parent, +) + + +class AdapterParameter(NamedTuple): + adapter: SPARQLModelAdapter + query_parameters: dict[str, Any] + expected: Page + + +adapter_parameters = [ + AdapterParameter( + adapter=parent_adapter, + query_parameters={"page": 1, "size": 2}, + expected=Page[Parent]( + items=[ + {"parent": "x", "children": [{"name": "foo"}]}, + {"parent": "y", "children": []}, + ], + page=1, + size=2, + total=3, + pages=2, + ), + ), + AdapterParameter( + adapter=parent_adapter, + query_parameters={"page": 2, "size": 2}, + expected=Page[Parent]( + items=[{"parent": "z", "children": []}], + page=2, + size=2, + total=3, + pages=2, + ), + ), + AdapterParameter( + adapter=parent_adapter, + query_parameters={"page": 1, "size": 1}, + expected=Page[Parent]( + items=[{"parent": "x", "children": [{"name": "foo"}]}], + page=1, + size=1, + total=3, + pages=3, + ), + ), + AdapterParameter( + adapter=parent_adapter, + query_parameters={"page": 2, "size": 1}, + expected=Page[Parent]( + items=[{"parent": "y", "children": []}], page=2, size=1, total=3, pages=3 + ), + ), + AdapterParameter( + adapter=parent_adapter, + query_parameters={"page": 3, "size": 1}, + expected=Page[Parent]( + items=[{"parent": "z", "children": []}], page=3, size=1, total=3, pages=3 + ), + ), +] + + +@pytest.mark.remote +@pytest.mark.parametrize( + ["adapter", "query_parameters", "expected"], adapter_parameters +) +def test_basic_adapter_grouped_pagination(adapter, query_parameters, expected): + assert adapter.query(**query_parameters) == expected