From 810fada9e1c475f611d1f8f298718291f240a924 Mon Sep 17 00:00:00 2001 From: Lukas Plank Date: Fri, 18 Oct 2024 07:11:50 +0200 Subject: [PATCH 1/2] fix: adapt count query generator to correctly count grouped models Fixes #99. --- rdfproxy/adapter.py | 2 +- rdfproxy/utils/sparql_utils.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/rdfproxy/adapter.py b/rdfproxy/adapter.py index 38f2718..0b3fe48 100644 --- a/rdfproxy/adapter.py +++ b/rdfproxy/adapter.py @@ -49,7 +49,7 @@ def query( size: int = 100, ) -> Page[_TModelInstance]: """Run a query against an endpoint and return a Page model object.""" - count_query: str = construct_count_query(self._query) + count_query: str = construct_count_query(query=self._query, model=self._model) items_query: str = ungrouped_pagination_base_query.substitute( query=self._query, offset=calculate_offset(page, size), limit=size ) diff --git a/rdfproxy/utils/sparql_utils.py b/rdfproxy/utils/sparql_utils.py index a401a8a..b76cb45 100644 --- a/rdfproxy/utils/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -8,6 +8,7 @@ from typing import cast from SPARQLWrapper import QueryResult, SPARQLWrapper +from rdfproxy.utils._types import _TModelInstance ungrouped_pagination_base_query: Annotated[ @@ -35,9 +36,14 @@ def replace_query_select_clause(query: str, repl: str) -> str: return count_query -def construct_count_query(query: str) -> str: +def construct_count_query(query: str, model: type[_TModelInstance]) -> str: """Construct a generic count query from a SELECT query.""" - count_query = replace_query_select_clause(query, "select (count(*) as ?cnt)") + try: + group_by: str = model.model_config["group_by"] + count_query = construct_grouped_count_query(query, group_by) + except KeyError: + count_query = replace_query_select_clause(query, "select (count(*) as ?cnt)") + return count_query From e934c1521bc25dd03159342a1bbd8f79d1b3c964 Mon Sep 17 00:00:00 2001 From: Lukas Plank Date: Fri, 18 Oct 2024 08:59:39 +0200 Subject: [PATCH 2/2] test: implement basic example test for construct_count_query --- tests/test_construct_count_query.py | 46 +++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_construct_count_query.py diff --git a/tests/test_construct_count_query.py b/tests/test_construct_count_query.py new file mode 100644 index 0000000..94f6e2d --- /dev/null +++ b/tests/test_construct_count_query.py @@ -0,0 +1,46 @@ +"""Pytest entry point for rdfproxy.utils.sparql_utils.construct_count_query tests.""" + +from pydantic import BaseModel, ConfigDict +from rdflib import Graph +from rdflib.plugins.sparql.processor import SPARQLResult +from rdfproxy.utils.sparql_utils import construct_count_query + + +query = """ +select ?x ?y ?z +where { + values (?x ?y ?z) { + (1 2 3) + (1 22 33) + (2 222 333) + } +} +""" + +graph: Graph = Graph() + + +class Dummy(BaseModel): + pass + + +class GroupedDummy(BaseModel): + model_config = ConfigDict(group_by="x") + + +count_query_dummy = construct_count_query(query=query, model=Dummy) +count_query_grouped_dummy = construct_count_query(query=query, model=GroupedDummy) + + +def _get_cnt_value_from_sparql_result( + result: SPARQLResult, count_var: str = "cnt" +) -> int: + return int(result.bindings[0][count_var]) + + +def test_basic_construct_count_query(): + result_dummy: SPARQLResult = graph.query(count_query_dummy) + result_grouped_dummy: SPARQLResult = graph.query(count_query_grouped_dummy) + + assert _get_cnt_value_from_sparql_result(result_dummy) == 3 + assert _get_cnt_value_from_sparql_result(result_grouped_dummy) == 2