diff --git a/poetry.lock b/poetry.lock index f5405f9..b9320b8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -361,20 +361,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "isodate" -version = "0.6.1" -description = "An ISO 8601 date/time/duration parser and formatter" -optional = false -python-versions = "*" -files = [ - {file = "isodate-0.6.1-py2.py3-none-any.whl", hash = "sha256:0751eece944162659049d35f4f549ed815792b38793f07cf73381c1c87cbed96"}, - {file = "isodate-0.6.1.tar.gz", hash = "sha256:48c5881de7e8b0a0d648cb024c8062dc84e7b840ed81e864c7614fd3c127bde9"}, -] - -[package.dependencies] -six = "*" - [[package]] name = "jinja2" version = "3.1.4" @@ -818,24 +804,24 @@ files = [ [[package]] name = "rdflib" -version = "7.0.0" +version = "7.1.1" description = "RDFLib is a Python library for working with RDF, a simple yet powerful language for representing information." optional = false -python-versions = ">=3.8.1,<4.0.0" +python-versions = "<4.0.0,>=3.8.1" files = [ - {file = "rdflib-7.0.0-py3-none-any.whl", hash = "sha256:0438920912a642c866a513de6fe8a0001bd86ef975057d6962c79ce4771687cd"}, - {file = "rdflib-7.0.0.tar.gz", hash = "sha256:9995eb8569428059b8c1affd26b25eac510d64f5043d9ce8c84e0d0036e995ae"}, + {file = "rdflib-7.1.1-py3-none-any.whl", hash = "sha256:e590fa9a2c34ba33a667818b5a84be3fb8a4d85868f8038f17912ec84f912a25"}, + {file = "rdflib-7.1.1.tar.gz", hash = "sha256:164de86bd3564558802ca983d84f6616a4a1a420c7a17a8152f5016076b2913e"}, ] [package.dependencies] -isodate = ">=0.6.0,<0.7.0" pyparsing = ">=2.1.0,<4" [package.extras] berkeleydb = ["berkeleydb (>=18.1.0,<19.0.0)"] -html = ["html5lib (>=1.0,<2.0)"] -lxml = ["lxml (>=4.3.0,<5.0.0)"] -networkx = ["networkx (>=2.0.0,<3.0.0)"] +html = ["html5rdf (>=1.2,<2)"] +lxml = ["lxml (>=4.3,<6.0)"] +networkx = ["networkx (>=2,<4)"] +orjson = ["orjson (>=3.9.14,<4)"] [[package]] name = "rich" @@ -893,17 +879,6 @@ files = [ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, ] -[[package]] -name = "six" -version = "1.16.0" -description = "Python 2 and 3 compatibility utilities" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, - {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, -] - [[package]] name = "sniffio" version = "1.3.1" @@ -1246,4 +1221,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "e9ac3d16b289eb2f29bfa0587a17105664b9951db114dfe9d2e9a35b1265e117" +content-hash = "0a2e465322bae2eaee949d9268fb74e79735a2dd33e1d7f7b390f2b21d797124" diff --git a/pyproject.toml b/pyproject.toml index e245608..b2ef8ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ pydantic = "^2.9.2" httpx = "^0.28.1" +rdflib = "^7.1.1" [tool.poetry.group.dev.dependencies] ruff = "^0.7.0" deptry = "^0.20.0" diff --git a/rdfproxy/adapter.py b/rdfproxy/adapter.py index ad98717..ebe7da3 100644 --- a/rdfproxy/adapter.py +++ b/rdfproxy/adapter.py @@ -4,15 +4,11 @@ import math from typing import Generic +from rdfproxy.constructor import QueryConstructor from rdfproxy.mapper import ModelBindingsMapper from rdfproxy.sparql_strategies import HttpxStrategy, SPARQLStrategy from rdfproxy.utils._types import _TModelInstance from rdfproxy.utils.models import Page, QueryParameters -from rdfproxy.utils.sparql_utils import ( - calculate_offset, - construct_count_query, - construct_items_query, -) class SPARQLModelAdapter(Generic[_TModelInstance]): @@ -24,10 +20,12 @@ class SPARQLModelAdapter(Generic[_TModelInstance]): SPARQLModelAdapter.query returns a Page model object with a default pagination size of 100 results. SPARQL bindings are implicitly assigned to model fields of the same name, - explicit SPARQL binding to model field allocation is available with typing.Annotated and rdfproxy.SPARQLBinding. + explicit SPARQL binding to model field allocation is available with rdfproxy.SPARQLBinding. Result grouping is controlled through the model, i.e. grouping is triggered when a field of list[pydantic.BaseModel] is encountered. + + See https://github.com/acdh-oeaw/rdfproxy/tree/main/examples for examples. """ def __init__( @@ -44,20 +42,21 @@ def __init__( def query(self, query_parameters: QueryParameters) -> Page[_TModelInstance]: """Run a query against an endpoint and return a Page model object.""" - count_query: str = construct_count_query(query=self._query, model=self._model) - items_query: str = construct_items_query( + query_constructor = QueryConstructor( query=self._query, + query_parameters=query_parameters, model=self._model, - limit=query_parameters.size, - offset=calculate_offset(query_parameters.page, query_parameters.size), ) - items_query_bindings: Iterator[dict] = self.sparql_strategy.query(items_query) + count_query = query_constructor.get_count_query() + items_query = query_constructor.get_items_query() + items_query_bindings: Iterator[dict] = self.sparql_strategy.query(items_query) mapper = ModelBindingsMapper(self._model, *items_query_bindings) - items: list[_TModelInstance] = mapper.get_models() - total: int = self._get_count(count_query) + + count_query_bindings: Iterator[dict] = self.sparql_strategy.query(count_query) + total: int = int(next(count_query_bindings)["cnt"]) pages: int = math.ceil(total / query_parameters.size) return Page( @@ -67,11 +66,3 @@ def query(self, query_parameters: QueryParameters) -> Page[_TModelInstance]: total=total, pages=pages, ) - - def _get_count(self, query: str) -> int: - """Run a count query and return the count result. - - Helper for SPARQLModelAdapter.query. - """ - result: Iterator[dict] = self.sparql_strategy.query(query) - return int(next(result)["cnt"]) diff --git a/rdfproxy/constructor.py b/rdfproxy/constructor.py new file mode 100644 index 0000000..8054bd3 --- /dev/null +++ b/rdfproxy/constructor.py @@ -0,0 +1,124 @@ +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils.models import QueryParameters +from rdfproxy.utils.sparql_utils import ( + add_solution_modifier, + get_query_projection, + inject_into_query, + remove_sparql_prefixes, + replace_query_select_clause, +) +from rdfproxy.utils.utils import ( + FieldsBindingsMap, + QueryConstructorComponent as component, + compose_left, +) + + +class QueryConstructor: + """The class encapsulates dynamic SPARQL query modification logic + for implementing purely SPARQL-based, deterministic pagination. + + Public methods get_items_query and get_count_query are used in rdfproxy.SPARQLModelAdapter + to construct queries for retrieving arguments for Page object instantiation. + """ + + def __init__( + self, + query: str, + query_parameters: QueryParameters, + model: type[_TModelInstance], + ) -> None: + self.query = query + self.query_parameters = query_parameters + self.model = model + + self.bindings_map = FieldsBindingsMap(model) + self.group_by: str | None = self.bindings_map.get( + model.model_config.get("group_by") + ) + + def get_items_query(self) -> str: + """Construct a SPARQL items query for use in rdfproxy.SPARQLModelAdapter.""" + if self.group_by is None: + return self._get_ungrouped_items_query() + return self._get_grouped_items_query() + + def get_count_query(self) -> str: + """Construct a SPARQL count query for use in rdfproxy.SPARQLModelAdapter""" + if self.group_by is None: + select_clause = "select (count(*) as ?cnt)" + else: + select_clause = f"select (count(distinct ?{self.group_by}) as ?cnt)" + + return replace_query_select_clause(self.query, select_clause) + + @staticmethod + def _calculate_offset(page: int, size: int) -> int: + """Calculate the offset value for paginated SPARQL templates.""" + match page: + case 1: + return 0 + case 2: + return size + case _: + return size * (page - 1) + + def _get_grouped_items_query(self) -> str: + """Construct a SPARQL items query for grouped models.""" + filter_clause: str | None = self._compute_filter_clause() + select_clause: str = self._compute_select_clause() + order_by_value: str = self._compute_order_by_value() + limit, offset = self._compute_limit_offset() + + subquery = compose_left( + remove_sparql_prefixes, + component(replace_query_select_clause, repl=select_clause), + component(inject_into_query, injectant=filter_clause), + component( + add_solution_modifier, + order_by=order_by_value, + limit=limit, + offset=offset, + ), + )(self.query) + + return inject_into_query(self.query, subquery) + + def _get_ungrouped_items_query(self) -> str: + """Construct a SPARQL items query for ungrouped models.""" + filter_clause: str | None = self._compute_filter_clause() + order_by_value: str = self._compute_order_by_value() + limit, offset = self._compute_limit_offset() + + return compose_left( + component(inject_into_query, injectant=filter_clause), + component( + add_solution_modifier, + order_by=order_by_value, + limit=limit, + offset=offset, + ), + )(self.query) + + def _compute_limit_offset(self) -> tuple[int, int]: + """Calculate limit and offset values for SPARQL-based pagination.""" + limit = self.query_parameters.size + offset = self._calculate_offset( + self.query_parameters.page, self.query_parameters.size + ) + + return limit, offset + + def _compute_filter_clause(self) -> str | None: + """Stub: Always None for now.""" + return None + + def _compute_select_clause(self): + """Stub: Static SELECT clause for now.""" + return f"select distinct ?{self.group_by}" + + def _compute_order_by_value(self): + """Stub: Only basic logic for now.""" + if self.group_by is None: + return get_query_projection(self.query)[0] + return f"{self.group_by}" diff --git a/rdfproxy/mapper.py b/rdfproxy/mapper.py index da10054..7c3d56d 100644 --- a/rdfproxy/mapper.py +++ b/rdfproxy/mapper.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance -from rdfproxy.utils.utils import ( +from rdfproxy.utils.mapper_utils import ( _collect_values_from_bindings, _get_group_by, _get_key_from_metadata, @@ -65,7 +65,6 @@ def _generate_binding_pairs( and (x[self._contexts[0]] == kwargs[self._contexts[0]]), self.bindings, ) - value = self._get_unique_models(group_model, applicable_bindings) elif _is_list_type(v.annotation): diff --git a/rdfproxy/sparql_strategies.py b/rdfproxy/sparql_strategies.py index 6b61860..da26d41 100644 --- a/rdfproxy/sparql_strategies.py +++ b/rdfproxy/sparql_strategies.py @@ -2,6 +2,7 @@ import abc from collections.abc import Iterator +from typing import cast from SPARQLWrapper import JSON, QueryResult, SPARQLWrapper import httpx @@ -13,7 +14,7 @@ def __init__(self, endpoint: str): @abc.abstractmethod def query(self, sparql_query: str) -> Iterator[dict[str, str]]: - raise NotImplementedError + raise NotImplementedError # pragma: no cover @staticmethod def _get_bindings_from_bindings_dict(bindings_dict: dict) -> Iterator[dict]: @@ -35,7 +36,9 @@ def query(self, sparql_query: str) -> Iterator[dict[str, str]]: self._sparql_wrapper.setQuery(sparql_query) result: QueryResult = self._sparql_wrapper.query() - return self._get_bindings_from_bindings_dict(result.convert()) + # SPARQLWrapper.Wrapper.convert is not overloaded properly and needs casting + # https://github.com/RDFLib/sparqlwrapper/blob/master/SPARQLWrapper/Wrapper.py#L1135 + return self._get_bindings_from_bindings_dict(cast(dict, result.convert())) class HttpxStrategy(SPARQLStrategy): diff --git a/rdfproxy/utils/mapper_utils.py b/rdfproxy/utils/mapper_utils.py new file mode 100644 index 0000000..acce14e --- /dev/null +++ b/rdfproxy/utils/mapper_utils.py @@ -0,0 +1,127 @@ +from collections.abc import Callable, Iterable +from typing import Any, TypeGuard, get_args, get_origin + +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from rdfproxy.utils._exceptions import ( + InvalidGroupingKeyException, + MissingModelConfigException, +) +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue + + +def _is_type(obj: type | None, _type: type) -> bool: + """Check if an obj is type _type or a GenericAlias with origin _type.""" + return (obj is _type) or (get_origin(obj) is _type) + + +def _is_list_type(obj: type | None) -> bool: + """Check if obj is a list type.""" + return _is_type(obj, list) + + +def _is_list_basemodel_type(obj: type | None) -> bool: + """Check if a type is list[pydantic.BaseModel].""" + return (get_origin(obj) is list) and all( + issubclass(cls, BaseModel) for cls in get_args(obj) + ) + + +def _collect_values_from_bindings( + binding_name: str, + bindings: Iterable[dict], + predicate: Callable[[Any], bool] = lambda x: x is not None, +) -> list: + """Scan bindings for a key binding_name and collect unique predicate-compliant values. + + Note that element order is important for testing, so a set cast won't do. + """ + values = dict.fromkeys( + value + for binding in bindings + if predicate(value := binding.get(binding_name, None)) + ) + return list(values) + + +def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any: + """Try to get a SPARQLBinding object from a field's metadata attribute. + + Helper for _generate_binding_pairs. + """ + return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default) + + +def _get_applicable_grouping_keys(model: type[_TModelInstance]) -> list[str]: + return [k for k, v in model.model_fields.items() if not _is_list_type(v.annotation)] + + +def _get_group_by(model: type[_TModelInstance]) -> str: + """Get the name of a grouping key from a model Config class.""" + try: + group_by = model.model_config["group_by"] # type: ignore + except KeyError as e: + raise MissingModelConfigException( + "Model config with 'group_by' value required " + "for field-based grouping behavior." + ) from e + else: + applicable_keys = _get_applicable_grouping_keys(model=model) + + if group_by not in applicable_keys: + raise InvalidGroupingKeyException( + f"Invalid grouping key '{group_by}'. " + f"Applicable grouping keys: {', '.join(applicable_keys)}." + ) + + if meta := model.model_fields[group_by].metadata: + if binding := next( + filter(lambda entry: isinstance(entry, SPARQLBinding), meta), None + ): + return binding + return group_by + + +def default_model_bool_predicate(model: BaseModel) -> bool: + """Default predicate for determining model truthiness. + + Adheres to rdfproxy.utils._types.ModelBoolPredicate. + """ + return any(dict(model).values()) + + +def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]: + return (not isinstance(iterable, str)) and all( + map(lambda i: isinstance(i, str), iterable) + ) + + +def _get_model_bool_predicate_from_config_value( + model_bool_value: _TModelBoolValue, +) -> ModelBoolPredicate: + """Get a model_bool predicate function given the value of the model_bool config setting.""" + match model_bool_value: + case ModelBoolPredicate(): + return model_bool_value + case str(): + return lambda model: bool(dict(model)[model_bool_value]) + case model_bool_value if _is_iterable_of_str(model_bool_value): + return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) + case _: # pragma: no cover + raise TypeError( + "Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" + f"Received {type(model_bool_value)}" + ) + + +def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: + """Get the applicable model_bool predicate function given a model.""" + if (model_bool_value := model.model_config.get("model_bool", None)) is None: + model_bool_predicate = default_model_bool_predicate + else: + model_bool_predicate = _get_model_bool_predicate_from_config_value( + model_bool_value + ) + + return model_bool_predicate diff --git a/rdfproxy/utils/sparql_utils.py b/rdfproxy/utils/sparql_utils.py index a064898..4eafea3 100644 --- a/rdfproxy/utils/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -1,19 +1,13 @@ """Functionality for dynamic SPARQL query modifcation.""" -from collections.abc import Iterator -from contextlib import contextmanager -from functools import partial +from itertools import chain import re -from typing import cast +from typing import overload -from SPARQLWrapper import QueryResult, SPARQLWrapper +from rdflib import Variable +from rdflib.plugins.sparql.parser import parseQuery +from rdflib.plugins.sparql.parserutils import CompValue, ParseResults from rdfproxy.utils._exceptions import QueryConstructionException -from rdfproxy.utils._types import ItemsQueryConstructor, SPARQLBinding, _TModelInstance - - -def construct_ungrouped_pagination_query(query: str, limit: int, offset: int) -> str: - """Construct an ungrouped pagination query.""" - return f"{query} limit {limit} offset {offset}" def replace_query_select_clause(query: str, repl: str) -> str: @@ -23,7 +17,7 @@ def replace_query_select_clause(query: str, repl: str) -> str: ) if re.search(pattern=pattern, string=query) is None: - raise Exception("Unable to obtain SELECT clause.") + raise QueryConstructionException("Unable to obtain SELECT clause.") modified_query = re.sub( pattern=pattern, @@ -35,148 +29,90 @@ def replace_query_select_clause(query: str, repl: str) -> str: return modified_query -def _remove_sparql_prefixes(query: str) -> str: +def remove_sparql_prefixes(query: str) -> str: """Remove SPARQL prefixes from a query. This is needed for subquery injection, because subqueries cannot have prefixes. - Note that this is not generic, all prefixes are simply ut from the subquery - and do not get appended to the outer query prefixes. + Note that this is not generic, all prefixes are simply cut from the subquery + and are not resolved against the outer query prefixes. """ - prefix_pattern = re.compile(r"PREFIX\s+\w*:\s?<[^>]+>\s*", flags=re.I) + prefix_pattern = re.compile(r"PREFIX\s+\w*:\s?<[^>]+>\s*", flags=re.IGNORECASE) cleaned_query = re.sub(prefix_pattern, "", query).strip() return cleaned_query -def inject_subquery(query: str, subquery: str) -> str: - """Inject a SPARQL query with a subquery.""" +def inject_into_query(query: str, injectant: str) -> str: + """Inject some injectant (e.g. subquery or filter clause) into a query.""" if (tail := re.search(r"}[^}]*\Z", query)) is None: - raise QueryConstructionException("Unable to inject subquery.") + raise QueryConstructionException( + "Unable to inject subquery." + ) # pragma: no cover ; this will be unreachable once query checking runs tail_index: int = tail.start() - injected: str = f"{query[:tail_index]} {{{_remove_sparql_prefixes(subquery)}}} {query[tail_index:]}" - return injected + injected_query: str = f"{query[:tail_index]} {{{injectant}}} {query[tail_index:]}" + return injected_query -def construct_grouped_pagination_query( - query: str, group_by_value: str, limit: int, offset: int +def add_solution_modifier( + query: str, + *, + order_by: str | None = None, + limit: int | None = None, + offset: int | None = None, ) -> str: - """Construct a grouped pagination query.""" - _subquery_base: str = replace_query_select_clause( - query=query, repl=f"select distinct ?{group_by_value}" - ) - subquery: str = construct_ungrouped_pagination_query( - query=_subquery_base, limit=limit, offset=offset - ) - - grouped_pagination_query: str = inject_subquery(query=query, subquery=subquery) - return grouped_pagination_query + """Add optional solution modifiers in SPARQL-conformant order to a query.""" + modifiers = [] + if order_by is not None: + modifiers.append(f"order by ?{order_by}") + if limit is not None: + modifiers.append(f"limit {limit}") + if offset is not None: + modifiers.append(f"offset {offset}") -def get_items_query_constructor( - model: type[_TModelInstance], -) -> ItemsQueryConstructor: - """Get the applicable query constructor function given a model class.""" + return f"{query} {' '.join(modifiers)}".strip() - if (group_by_value := model.model_config.get("group_by", None)) is None: - return construct_ungrouped_pagination_query - elif meta := model.model_fields[group_by_value].metadata: - group_by_value = next( - filter(lambda x: isinstance(x, SPARQLBinding), meta), group_by_value - ) +@overload +def _compvalue_to_dict(comp_value: dict | CompValue) -> dict: ... - return partial(construct_grouped_pagination_query, group_by_value=group_by_value) +@overload +def _compvalue_to_dict(comp_value: list | ParseResults) -> list: ... -def construct_items_query( - query: str, model: type[_TModelInstance], limit: int, offset: int -) -> str: - """Construct a grouped pagination query.""" - items_query_constructor: ItemsQueryConstructor = get_items_query_constructor( - model=model - ) - return items_query_constructor(query=query, limit=limit, offset=offset) - - -def construct_count_query(query: str, model: type[_TModelInstance]) -> str: - """Construct a generic count query from a SELECT query.""" - try: - group_by: str = model.model_config["group_by"] - group_by_binding = next( - filter( - lambda x: isinstance(x, SPARQLBinding), - model.model_fields[group_by].metadata, - ), - group_by, - ) - count_query = construct_grouped_count_query(query, group_by_binding) - except KeyError: - count_query = replace_query_select_clause(query, "select (count(*) as ?cnt)") - - return count_query - - -def calculate_offset(page: int, size: int) -> int: - """Calculate offset value for paginated SPARQL templates.""" - match page: - case 1: - return 0 - case 2: - return size - case _: - return size * (page - 1) - - -def construct_grouped_count_query(query: str, group_by) -> str: - grouped_count_query = replace_query_select_clause( - query, f"select (count(distinct ?{group_by}) as ?cnt)" - ) - return grouped_count_query +def _compvalue_to_dict(comp_value: CompValue): + """Convert a CompValue parsing object into a Python dict/list representation. - -def _get_bindings_from_bindings_dict(bindings_dict: dict) -> Iterator[dict]: - bindings = map( - lambda binding: {k: v["value"] for k, v in binding.items()}, - bindings_dict["results"]["bindings"], - ) - return bindings - - -def get_bindings_from_query_result(query_result: QueryResult) -> Iterator[dict]: - """Extract just the bindings from a SPARQLWrapper.QueryResult.""" - if (result_format := query_result.requestedFormat) != "json": - raise Exception( - "Only QueryResult objects with JSON format are currently supported. " - f"Received object with requestedFormat '{result_format}'." - ) - - query_json: dict = cast(dict, query_result.convert()) - bindings = _get_bindings_from_bindings_dict(query_json) - - return bindings - - -@contextmanager -def temporary_query_override(sparql_wrapper: SPARQLWrapper): - """Context manager that allows to contextually overwrite a query in a SPARQLWrapper object.""" - _query_cache = sparql_wrapper.queryString - - try: - yield sparql_wrapper - finally: - sparql_wrapper.setQuery(_query_cache) + Helper for get_query_projection. + """ + if isinstance(comp_value, dict): + return {key: _compvalue_to_dict(value) for key, value in comp_value.items()} + elif isinstance(comp_value, list | ParseResults): + return [_compvalue_to_dict(item) for item in comp_value] + else: + return comp_value -def query_with_wrapper(query: str, sparql_wrapper: SPARQLWrapper) -> Iterator[dict]: - """Execute a SPARQL query using a predefined sparql_wrapper object. +def get_query_projection(query: str) -> list[Variable]: + """Parse a SPARQL SELECT query and extract the ordered bindings projection. - The query attribute of the wrapper object is temporarily overridden - and gets restored after query execution. + The first case handles explicit/literal binding projections. + The second case handles implicit/* binding projections. + The third case handles implicit/* binding projections with VALUES. """ - with temporary_query_override(sparql_wrapper=sparql_wrapper): - sparql_wrapper.setQuery(query) - result: QueryResult = sparql_wrapper.query() - - bindings: Iterator[dict] = get_bindings_from_query_result(result) - return bindings + _parse_result: CompValue = parseQuery(query)[1] + parsed_query: dict = _compvalue_to_dict(_parse_result) + + match parsed_query: + case {"projection": projection}: + return [i["var"] for i in projection] + case {"where": {"part": [{"triples": triples}]}}: + projection = dict.fromkeys( + i for i in chain.from_iterable(triples) if isinstance(i, Variable) + ) + return list(projection) + case {"where": {"part": [{"var": var}]}}: + return var + case _: # pragma: no cover + raise Exception("Unable to obtain query projection.") diff --git a/rdfproxy/utils/utils.py b/rdfproxy/utils/utils.py index e7e4c12..279d19c 100644 --- a/rdfproxy/utils/utils.py +++ b/rdfproxy/utils/utils.py @@ -1,129 +1,72 @@ """SPARQL/FastAPI utils.""" -from collections.abc import Callable, Iterable -from typing import Any, TypeGuard, get_args, get_origin - -from pydantic import BaseModel -from pydantic.fields import FieldInfo -from rdfproxy.utils._exceptions import ( - InvalidGroupingKeyException, - MissingModelConfigException, -) -from rdfproxy.utils._types import _TModelInstance -from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue - - -def _is_type(obj: type | None, _type: type) -> bool: - """Check if an obj is type _type or a GenericAlias with origin _type.""" - return (obj is _type) or (get_origin(obj) is _type) +from collections import UserDict +from collections.abc import Callable +from functools import partial +from typing import TypeVar +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import SPARQLBinding -def _is_list_type(obj: type | None) -> bool: - """Check if obj is a list type.""" - return _is_type(obj, list) +T = TypeVar("T") -def _is_list_basemodel_type(obj: type | None) -> bool: - """Check if a type is list[pydantic.BaseModel].""" - return (get_origin(obj) is list) and all( - issubclass(cls, BaseModel) for cls in get_args(obj) - ) +class FieldsBindingsMap(UserDict): + """Mapping for resolving SPARQLBinding aliases. -def _collect_values_from_bindings( - binding_name: str, - bindings: Iterable[dict], - predicate: Callable[[Any], bool] = lambda x: x is not None, -) -> list: - """Scan bindings for a key binding_name and collect unique predicate-compliant values. + Model field names are mapped to SPARQLBinding names. + The FieldsBindingsMap.reverse allows reverse lookup + (i.e. from SPARQLBindings to model fields). - Note that element order is important for testing, so a set cast won't do. + Note: It might be useful to recursively resolve aliases for nested models. """ - values = dict.fromkeys( - value - for binding in bindings - if predicate(value := binding.get(binding_name, None)) - ) - return list(values) + def __init__(self, model: type[_TModelInstance]) -> None: + self.data = self._get_field_binding_mapping(model) + self._reversed = {v: k for k, v in self.data.items()} -def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any: - """Try to get a SPARQLBinding object from a field's metadata attribute. + @property + def reverse(self) -> dict[str, str]: + """Reverse lookup map from SPARQL bindings to model fields.""" + return self._reversed - Helper for _generate_binding_pairs. - """ - return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default) + @staticmethod + def _get_field_binding_mapping(model: type[_TModelInstance]) -> dict[str, str]: + """Resolve model fields against rdfproxy.SPARQLBindings.""" + return { + k: next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), k) + for k, v in model.model_fields.items() + } -def _get_applicable_grouping_keys(model: type[_TModelInstance]) -> list[str]: - return [k for k, v in model.model_fields.items() if not _is_list_type(v.annotation)] +def compose_left(*fns: Callable[[T], T]) -> Callable[[T], T]: + """Left associative compose.""" + def _left_wrapper(*fns): + fn, *rest_fns = fns -def _get_group_by(model: type[_TModelInstance]) -> str: - """Get the name of a grouping key from a model Config class.""" - try: - group_by = model.model_config["group_by"] # type: ignore - except KeyError as e: - raise MissingModelConfigException( - "Model config with 'group_by' value required " - "for field-based grouping behavior." - ) from e - else: - applicable_keys = _get_applicable_grouping_keys(model=model) + if rest_fns: + return lambda *args, **kwargs: fn(_left_wrapper(*rest_fns)(*args, **kwargs)) + return fn - if group_by not in applicable_keys: - raise InvalidGroupingKeyException( - f"Invalid grouping key '{group_by}'. " - f"Applicable grouping keys: {', '.join(applicable_keys)}." - ) + return _left_wrapper(*reversed(fns)) - if meta := model.model_fields[group_by].metadata: - if binding := next( - filter(lambda entry: isinstance(entry, SPARQLBinding), meta), None - ): - return binding - return group_by +class QueryConstructorComponent: + """Query modification component factory. -def default_model_bool_predicate(model: BaseModel) -> bool: - """Default predicate for determining model truthiness. + Components either call the wrapped function with non-None value kwargs applied + or (if all kwargs values are None) fall back to the identity function. - Adheres to rdfproxy.utils._types.ModelBoolPredicate. + QueryConstructorComponents are used in QueryConstructor for query modification compose chains. """ - return any(dict(model).values()) - - -def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]: - return (not isinstance(iterable, str)) and all( - map(lambda i: isinstance(i, str), iterable) - ) - - -def _get_model_bool_predicate_from_config_value( - model_bool_value: _TModelBoolValue, -) -> ModelBoolPredicate: - """Get a model_bool predicate function given the value of the model_bool config setting.""" - match model_bool_value: - case ModelBoolPredicate(): - return model_bool_value - case str(): - return lambda model: bool(dict(model)[model_bool_value]) - case model_bool_value if _is_iterable_of_str(model_bool_value): - return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) - case _: - raise TypeError( - "Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" - f"Received {type(model_bool_value)}" - ) - - -def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: - """Get the applicable model_bool predicate function given a model.""" - if (model_bool_value := model.model_config.get("model_bool", None)) is None: - model_bool_predicate = default_model_bool_predicate - else: - model_bool_predicate = _get_model_bool_predicate_from_config_value( - model_bool_value - ) - - return model_bool_predicate + + def __init__(self, f: Callable[..., str], **kwargs) -> None: + self.f = f + self.kwargs = kwargs + + def __call__(self, query) -> str: + if tkwargs := {k: v for k, v in self.kwargs.items() if v is not None}: + return partial(self.f, **tkwargs)(query) + return query diff --git a/tests/data/models/dummy_model.py b/tests/data/models/dummy_model.py deleted file mode 100644 index 041bf4d..0000000 --- a/tests/data/models/dummy_model.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Simple dummy models e.g. for count query constructor testing.""" - -from pydantic import BaseModel -from rdfproxy import ConfigDict - - -class Dummy(BaseModel): - pass - - -class GroupedDummy(BaseModel): - model_config = ConfigDict(group_by="x") - - x: int diff --git a/tests/tests_adapter/test_adapter_grouped_pagination.py b/tests/tests_adapter/test_adapter_grouped_pagination.py index 7fd68a0..8ac6706 100644 --- a/tests/tests_adapter/test_adapter_grouped_pagination.py +++ b/tests/tests_adapter/test_adapter_grouped_pagination.py @@ -2,15 +2,16 @@ from typing import Annotated, Any, NamedTuple -import pytest - from pydantic import BaseModel +import pytest from rdfproxy import ( ConfigDict, + HttpxStrategy, Page, QueryParameters, SPARQLBinding, SPARQLModelAdapter, + SPARQLWrapperStrategy, ) @@ -57,28 +58,43 @@ class Parent(BaseModel): children: list[Child] -binding_adapter = SPARQLModelAdapter( - target="https://graphdb.r11.eu/repositories/RELEVEN", - query=binding_query, - model=BindingParent, -) +@pytest.fixture(params=[HttpxStrategy, SPARQLWrapperStrategy]) +def adapter(request): + return SPARQLModelAdapter( + target="https://graphdb.r11.eu/repositories/RELEVEN", + query=query, + model=Parent, + sparql_strategy=request.param, + ) + + +@pytest.fixture(params=[HttpxStrategy, SPARQLWrapperStrategy]) +def binding_adapter(request): + return SPARQLModelAdapter( + target="https://graphdb.r11.eu/repositories/RELEVEN", + query=binding_query, + model=BindingParent, + sparql_strategy=request.param, + ) -adapter = SPARQLModelAdapter( - target="https://graphdb.r11.eu/repositories/RELEVEN", - query=query, - model=Parent, -) + +@pytest.fixture(params=[HttpxStrategy, SPARQLWrapperStrategy]) +def ungrouped_adapter(request): + return SPARQLModelAdapter( + target="https://graphdb.r11.eu/repositories/RELEVEN", + query=query, + model=Child, + sparql_strategy=request.param, + ) class AdapterParameter(NamedTuple): - adapter: SPARQLModelAdapter query_parameters: dict[str, Any] expected: Page -adapter_parameters = [ +binding_adapter_parameters = [ AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 1, "size": 2}, expected=Page[BindingParent]( items=[ @@ -92,7 +108,6 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 2, "size": 2}, expected=Page[BindingParent]( items=[{"parent": "z", "children": []}], @@ -103,7 +118,6 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 1, "size": 1}, expected=Page[BindingParent]( items=[{"parent": "x", "children": [{"name": "foo"}]}], @@ -114,22 +128,21 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 2, "size": 1}, expected=Page[BindingParent]( items=[{"parent": "y", "children": []}], page=2, size=1, total=3, pages=3 ), ), AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 3, "size": 1}, expected=Page[BindingParent]( items=[{"parent": "z", "children": []}], page=3, size=1, total=3, pages=3 ), ), - # +] +# +adapter_parameters = [ AdapterParameter( - adapter=adapter, query_parameters={"page": 1, "size": 2}, expected=Page[Parent]( items=[ @@ -143,7 +156,6 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=adapter, query_parameters={"page": 2, "size": 2}, expected=Page[Parent]( items=[{"parent": "z", "children": []}], @@ -154,7 +166,6 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=adapter, query_parameters={"page": 1, "size": 1}, expected=Page[Parent]( items=[{"parent": "x", "children": [{"name": "foo"}]}], @@ -165,14 +176,12 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=adapter, query_parameters={"page": 2, "size": 1}, expected=Page[Parent]( items=[{"parent": "y", "children": []}], page=2, size=1, total=3, pages=3 ), ), AdapterParameter( - adapter=adapter, query_parameters={"page": 3, "size": 1}, expected=Page[Parent]( items=[{"parent": "z", "children": []}], page=3, size=1, total=3, pages=3 @@ -180,11 +189,45 @@ class AdapterParameter(NamedTuple): ), ] +ungrouped_adapter_parameters = [ + AdapterParameter( + query_parameters={"page": 1, "size": 100}, + expected=Page[Child]( + items=[{"name": "foo"}], page=1, size=100, total=1, pages=1 + ), + ), +] + @pytest.mark.remote @pytest.mark.parametrize( - ["adapter", "query_parameters", "expected"], adapter_parameters + ["query_parameters", "expected"], + adapter_parameters, ) def test_basic_adapter_grouped_pagination(adapter, query_parameters, expected): parameters = QueryParameters(**query_parameters) assert adapter.query(parameters) == expected + + +@pytest.mark.remote +@pytest.mark.parametrize( + ["query_parameters", "expected"], + binding_adapter_parameters, +) +def test_basic_binding_adapter_grouped_pagination( + binding_adapter, query_parameters, expected +): + parameters = QueryParameters(**query_parameters) + assert binding_adapter.query(parameters) == expected + + +@pytest.mark.xfail +@pytest.mark.remote +@pytest.mark.parametrize( + ["query_parameters", "expected"], + ungrouped_adapter_parameters, +) +def test_basic_ungrouped_pagination(ungrouped_adapter, query_parameters, expected): + """This shows a possible pagination count bug that needs investigating.""" + parameters = QueryParameters(**query_parameters) + assert ungrouped_adapter.query(parameters) == expected diff --git a/tests/data/parameters/count_query_parameters.py b/tests/tests_constructor/params/count_query_parameters.py similarity index 90% rename from tests/data/parameters/count_query_parameters.py rename to tests/tests_constructor/params/count_query_parameters.py index e7aa9ad..7d93b32 100644 --- a/tests/data/parameters/count_query_parameters.py +++ b/tests/tests_constructor/params/count_query_parameters.py @@ -1,7 +1,18 @@ -from tests.data.models.dummy_model import Dummy, GroupedDummy +from pydantic import BaseModel +from rdfproxy import ConfigDict from tests.utils._types import CountQueryParameter +class Dummy(BaseModel): + pass + + +class GroupedDummy(BaseModel): + model_config = ConfigDict(group_by="x") + + x: int + + construct_count_query_parameters = [ CountQueryParameter( query=""" diff --git a/tests/tests_constructor/test_query_constructor_items_query.py b/tests/tests_constructor/test_query_constructor_items_query.py new file mode 100644 index 0000000..79ad389 --- /dev/null +++ b/tests/tests_constructor/test_query_constructor_items_query.py @@ -0,0 +1,105 @@ +"""Basic tests for the QueryConstructor class.""" + +from typing import NamedTuple + +import pytest + +from pydantic import BaseModel +from rdfproxy.constructor import QueryConstructor +from rdfproxy.utils._types import ConfigDict +from rdfproxy.utils.models import QueryParameters + + +class UngroupedModel(BaseModel): + x: int + y: int + + +class GroupedModel(BaseModel): + model_config = ConfigDict(group_by="x") + + x: int + y: list[int] + + +class Expected(NamedTuple): + count_query: str + items_query: str + + +class QueryConstructorParameters(NamedTuple): + query: str + query_parameters: QueryParameters + model: type[BaseModel] + + expected: Expected + + +parameters = [ + # ungrouped + QueryConstructorParameters( + query="select * where {?s ?p ?o}", + query_parameters=QueryParameters(), + model=UngroupedModel, + expected=Expected( + count_query="select (count(*) as ?cnt) where {?s ?p ?o}", + items_query="select * where {?s ?p ?o} order by ?s limit 100 offset 0", + ), + ), + QueryConstructorParameters( + query="select ?p ?o where {?s ?p ?o}", + query_parameters=QueryParameters(), + model=UngroupedModel, + expected=Expected( + count_query="select (count(*) as ?cnt) where {?s ?p ?o}", + items_query="select ?p ?o where {?s ?p ?o} order by ?p limit 100 offset 0", + ), + ), + QueryConstructorParameters( + query="select * where {?s ?p ?o}", + query_parameters=QueryParameters(page=2, size=2), + model=UngroupedModel, + expected=Expected( + count_query="select (count(*) as ?cnt) where {?s ?p ?o}", + items_query="select * where {?s ?p ?o} order by ?s limit 2 offset 2", + ), + ), + # grouped + QueryConstructorParameters( + query="select * where {?x a ?y}", + query_parameters=QueryParameters(), + model=GroupedModel, + expected=Expected( + count_query="select (count(distinct ?x) as ?cnt) where {?x a ?y}", + items_query="select * where {?x a ?y {select distinct ?x where {?x a ?y} order by ?x limit 100 offset 0} }", + ), + ), + QueryConstructorParameters( + query="select ?x ?y where {?x a ?y}", + query_parameters=QueryParameters(), + model=GroupedModel, + expected=Expected( + count_query="select (count(distinct ?x) as ?cnt) where {?x a ?y}", + items_query="select ?x ?y where {?x a ?y {select distinct ?x where {?x a ?y} order by ?x limit 100 offset 0} }", + ), + ), + QueryConstructorParameters( + query="select ?x ?y where {?x a ?y}", + query_parameters=QueryParameters(page=2, size=2), + model=GroupedModel, + expected=Expected( + count_query="select (count(distinct ?x) as ?cnt) where {?x a ?y}", + items_query="select ?x ?y where {?x a ?y {select distinct ?x where {?x a ?y} order by ?x limit 2 offset 2} }", + ), + ), +] + + +@pytest.mark.parametrize(["query", "query_parameters", "model", "expected"], parameters) +def test_query_constructor_items_query(query, query_parameters, model, expected): + constructor = QueryConstructor( + query=query, query_parameters=query_parameters, model=model + ) + + assert constructor.get_count_query() == expected.count_query + assert constructor.get_items_query() == expected.items_query diff --git a/tests/data/parameters/model_bindings_mapper_model_bool_parameters.py b/tests/tests_mapper/params/model_bindings_mapper_model_bool_parameters.py similarity index 88% rename from tests/data/parameters/model_bindings_mapper_model_bool_parameters.py rename to tests/tests_mapper/params/model_bindings_mapper_model_bool_parameters.py index 5721292..a4c23b5 100644 --- a/tests/data/parameters/model_bindings_mapper_model_bool_parameters.py +++ b/tests/tests_mapper/params/model_bindings_mapper_model_bool_parameters.py @@ -43,6 +43,13 @@ class Child5(BaseModel): child: str | None = Field(default=None, exclude=True) +class Child6(BaseModel): + model_config = ConfigDict(model_bool=["name", "child"]) + + name: str | None = None + child: str | None = None + + def _create_parent_with_child(child: type[BaseModel]) -> type[BaseModel]: model = create_model( "Parent", @@ -112,4 +119,13 @@ def _create_parent_with_child(child: type[BaseModel]) -> type[BaseModel]: {"parent": "z", "children": []}, ], ), + ModelBindingsMapperParameter( + model=_create_parent_with_child(Child6), + bindings=bindings, + expected=[ + {"parent": "x", "children": [{"name": "foo", "child": "c"}]}, + {"parent": "y", "children": []}, + {"parent": "z", "children": []}, + ], + ), ] diff --git a/tests/data/parameters/model_bindings_mapper_parameters.py b/tests/tests_mapper/params/model_bindings_mapper_parameters.py similarity index 94% rename from tests/data/parameters/model_bindings_mapper_parameters.py rename to tests/tests_mapper/params/model_bindings_mapper_parameters.py index 5d04fd4..8f9b70b 100644 --- a/tests/data/parameters/model_bindings_mapper_parameters.py +++ b/tests/tests_mapper/params/model_bindings_mapper_parameters.py @@ -1,12 +1,16 @@ -from tests.data.models.author_array_collection_model import Author as ArrayAuthor -from tests.data.models.author_work_title_model import Author -from tests.data.models.basic_model import ( +from tests.tests_mapper.params.models.author_array_collection_model import ( + Author as ArrayAuthor, +) +from tests.tests_mapper.params.models.author_work_title_model import Author +from tests.tests_mapper.params.models.basic_model import ( BasicComplexModel, BasicNestedModel, BasicSimpleModel, ) -from tests.data.models.grouping_model import GroupingComplexModel -from tests.data.models.nested_grouping_model import NestedGroupingComplexModel +from tests.tests_mapper.params.models.grouping_model import GroupingComplexModel +from tests.tests_mapper.params.models.nested_grouping_model import ( + NestedGroupingComplexModel, +) from tests.utils._types import ModelBindingsMapperParameter diff --git a/tests/data/models/author_array_collection_model.py b/tests/tests_mapper/params/models/author_array_collection_model.py similarity index 100% rename from tests/data/models/author_array_collection_model.py rename to tests/tests_mapper/params/models/author_array_collection_model.py diff --git a/tests/data/models/author_work_title_model.py b/tests/tests_mapper/params/models/author_work_title_model.py similarity index 100% rename from tests/data/models/author_work_title_model.py rename to tests/tests_mapper/params/models/author_work_title_model.py diff --git a/tests/data/models/basic_model.py b/tests/tests_mapper/params/models/basic_model.py similarity index 100% rename from tests/data/models/basic_model.py rename to tests/tests_mapper/params/models/basic_model.py diff --git a/tests/data/models/grouping_model.py b/tests/tests_mapper/params/models/grouping_model.py similarity index 100% rename from tests/data/models/grouping_model.py rename to tests/tests_mapper/params/models/grouping_model.py diff --git a/tests/data/models/nested_grouping_model.py b/tests/tests_mapper/params/models/nested_grouping_model.py similarity index 100% rename from tests/data/models/nested_grouping_model.py rename to tests/tests_mapper/params/models/nested_grouping_model.py diff --git a/tests/tests_mapper/test_model_bindings_mapper.py b/tests/tests_mapper/test_model_bindings_mapper.py index 4824674..52c47ed 100644 --- a/tests/tests_mapper/test_model_bindings_mapper.py +++ b/tests/tests_mapper/test_model_bindings_mapper.py @@ -1,10 +1,10 @@ """Pytest entry point for basic rdfproxy.mapper.ModelBindingsMapper.""" -from pydantic import BaseModel import pytest -from rdfproxy.mapper import ModelBindingsMapper -from tests.data.parameters.model_bindings_mapper_parameters import ( +from pydantic import BaseModel +from rdfproxy.mapper import ModelBindingsMapper +from tests.tests_mapper.params.model_bindings_mapper_parameters import ( author_array_collection_parameters, author_work_title_parameters, basic_parameters, diff --git a/tests/tests_mapper/test_model_bindings_mapper_model_bool.py b/tests/tests_mapper/test_model_bindings_mapper_model_bool.py index f9bde75..a14077b 100644 --- a/tests/tests_mapper/test_model_bindings_mapper_model_bool.py +++ b/tests/tests_mapper/test_model_bindings_mapper_model_bool.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from rdfproxy.mapper import ModelBindingsMapper -from tests.data.parameters.model_bindings_mapper_model_bool_parameters import ( +from tests.tests_mapper.params.model_bindings_mapper_model_bool_parameters import ( parent_child_parameters, ) diff --git a/tests/unit/test_construct_count_query.py b/tests/unit/test_construct_count_query.py deleted file mode 100644 index 94f7d10..0000000 --- a/tests/unit/test_construct_count_query.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Unit tests for rdfproxy.utils.sparql_utils.construct_count_query.""" - -import pytest - -from rdflib import Graph -from rdflib.plugins.sparql.processor import SPARQLResult -from rdfproxy.utils.sparql_utils import construct_count_query -from tests.data.parameters.count_query_parameters import ( - construct_count_query_parameters, -) - - -def _get_cnt_value_from_sparql_result( - result: SPARQLResult, count_var: str = "cnt" -) -> int: - """Get the 'cnt' binding of a count query from a SPARQLResult object.""" - return int(result.bindings[0][count_var]) - - -@pytest.mark.parametrize( - ["query", "model", "expected"], construct_count_query_parameters -) -def test_basic_construct_count_query(query, model, expected): - """Check the count of a grouped model. - - The count query constructed based on a grouped value must only count - distinct values according to the grouping specified in the model. - """ - - graph: Graph = Graph() - count_query: str = construct_count_query(query, model) - query_result: SPARQLResult = graph.query(count_query) - - cnt: int = _get_cnt_value_from_sparql_result(query_result) - - assert cnt == expected diff --git a/tests/unit/test_sad_path_get_bindings_from_query_result.py b/tests/unit/test_sad_path_get_bindings_from_query_result.py deleted file mode 100644 index 2bc655c..0000000 --- a/tests/unit/test_sad_path_get_bindings_from_query_result.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Sad path tests for rdfprox.utils.sparql_utils.get_bindings_from_query_result.""" - -from unittest import mock - -import pytest - -from rdfproxy.utils.sparql_utils import get_bindings_from_query_result - - -def test_basic_sad_path_get_bindings_from_query_result(): - with mock.patch("SPARQLWrapper.QueryResult") as mock_query_result: - mock_query_result.return_value.requestedFormat = "xml" - exception_message = ( - "Only QueryResult objects with JSON format are currently supported." - ) - with pytest.raises(Exception, match=exception_message): - get_bindings_from_query_result(mock_query_result) diff --git a/tests/unit/tests_sparql_utils/test_add_solution_modifier.py b/tests/unit/tests_sparql_utils/test_add_solution_modifier.py new file mode 100644 index 0000000..a276286 --- /dev/null +++ b/tests/unit/tests_sparql_utils/test_add_solution_modifier.py @@ -0,0 +1,72 @@ +from typing import NamedTuple + +import pytest +from rdfproxy.utils.sparql_utils import add_solution_modifier + + +class AddSolutionModifierParameter(NamedTuple): + query: str + parameters: dict + expected: str + + +parameters = [ + # basics + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": None, "limit": None, "offset": None}, + expected="prefix ns: select * where {?s ?p ?o }", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": None, "limit": None, "offset": 1}, + expected="prefix ns: select * where {?s ?p ?o } offset 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": None, "limit": 1, "offset": None}, + expected="prefix ns: select * where {?s ?p ?o } limit 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": "x", "limit": None, "offset": None}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": "x", "limit": 1, "offset": 1}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x limit 1 offset 1", + ), + # order + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": 1, "limit": 1, "order_by": None}, + expected="prefix ns: select * where {?s ?p ?o } limit 1 offset 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": 1, "limit": 1, "order_by": None}, + expected="prefix ns: select * where {?s ?p ?o } limit 1 offset 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": 1, "limit": None, "order_by": "x"}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x offset 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": None, "limit": 1, "order_by": "x"}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x limit 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": 1, "limit": 1, "order_by": "x"}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x limit 1 offset 1", + ), +] + + +@pytest.mark.parametrize(["query", "parameters", "expected"], parameters) +def test_add_solution_modifier(query, parameters, expected): + modified_query = add_solution_modifier(query, **parameters) + assert modified_query == expected diff --git a/tests/unit/tests_sparql_utils/test_get_query_projection.py b/tests/unit/tests_sparql_utils/test_get_query_projection.py new file mode 100644 index 0000000..2459950 --- /dev/null +++ b/tests/unit/tests_sparql_utils/test_get_query_projection.py @@ -0,0 +1,63 @@ +"""Unit tests for sparql_utils.get_query_projection.""" + +from typing import NamedTuple + +import pytest + +from rdfproxy.utils.sparql_utils import get_query_projection + + +class QueryProjectionParameter(NamedTuple): + query: str + expected: list[str] + + +parameters = [ + # explicit projection + QueryProjectionParameter( + query="select ?s ?p ?o where {?s ?p ?o}", expected=["s", "p", "o"] + ), + QueryProjectionParameter( + query=""" + PREFIX crm: + select ?s ?o where {?s ?p ?o}""", + expected=["s", "o"], + ), + # implicit projection + QueryProjectionParameter( + query="select * where {?s ?p ?o}", + expected=["s", "p", "o"], + ), + QueryProjectionParameter( + query=""" + PREFIX crm: + select * where {?s ?p ?o}""", + expected=["s", "p", "o"], + ), + # implicit projection with values clause + QueryProjectionParameter( + query=""" + select * where { + values (?s ?p ?o) + { (1 2 3) } + } + """, + expected=["s", "p", "o"], + ), + QueryProjectionParameter( + query=""" + PREFIX crm: + select * where { + values (?s ?p ?o) + { (1 2 3) } + } + """, + expected=["s", "p", "o"], + ), +] + + +@pytest.mark.parametrize(["query", "expected"], parameters) +def test_get_query_projection(query, expected): + projection = [str(binding) for binding in get_query_projection(query)] + assert projection == expected diff --git a/tests/unit/test_inject_subquery.py b/tests/unit/tests_sparql_utils/test_inject_subquery.py similarity index 95% rename from tests/unit/test_inject_subquery.py rename to tests/unit/tests_sparql_utils/test_inject_subquery.py index e95494d..a80765c 100644 --- a/tests/unit/test_inject_subquery.py +++ b/tests/unit/tests_sparql_utils/test_inject_subquery.py @@ -3,7 +3,8 @@ from typing import NamedTuple import pytest -from rdfproxy.utils.sparql_utils import inject_subquery + +from rdfproxy.utils.sparql_utils import inject_into_query, remove_sparql_prefixes class InjectSubqueryParameter(NamedTuple): @@ -118,5 +119,6 @@ class InjectSubqueryParameter(NamedTuple): @pytest.mark.parametrize(["query", "subquery", "expected"], inject_subquery_parameters) def test_inject_subquery(query, subquery, expected): - injected = inject_subquery(query=query, subquery=subquery) + injectant = remove_sparql_prefixes(subquery) + injected = inject_into_query(query=query, injectant=injectant) assert injected == expected diff --git a/tests/unit/tests_sparql_utils/test_remove_sparql_prefixes.py b/tests/unit/tests_sparql_utils/test_remove_sparql_prefixes.py new file mode 100644 index 0000000..e26e7b0 --- /dev/null +++ b/tests/unit/tests_sparql_utils/test_remove_sparql_prefixes.py @@ -0,0 +1,61 @@ +from typing import NamedTuple + +import pytest + +from rdfproxy.utils.sparql_utils import remove_sparql_prefixes +from tests.utils.utils import normalize_query + + +class SPARQLRemovePrefixParameter(NamedTuple): + query: str + expected: str + + +parameters = [ + SPARQLRemovePrefixParameter( + query=""" + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), + SPARQLRemovePrefixParameter( + query=""" + prefix ns: + prefix other_ns: + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), + SPARQLRemovePrefixParameter( + query=""" + prefix ns: prefix other_ns: + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), + SPARQLRemovePrefixParameter( + query=""" + prefix ns: + + prefix other_ns: + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), + SPARQLRemovePrefixParameter( + query=""" + prefix ns: + + prefix other_ns: + + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), +] + + +@pytest.mark.parametrize(["query", "expected"], parameters) +def test_remove_sparql_prefixes(query, expected): + modified_query = remove_sparql_prefixes(query) + assert normalize_query(modified_query) == expected diff --git a/tests/unit/test_replace_query_select_clause.py b/tests/unit/tests_sparql_utils/test_replace_query_select_clause.py similarity index 100% rename from tests/unit/test_replace_query_select_clause.py rename to tests/unit/tests_sparql_utils/test_replace_query_select_clause.py diff --git a/tests/unit/test_sad_path_replace_query_select_clause.py b/tests/unit/tests_sparql_utils/test_sad_path_replace_query_select_clause.py similarity index 100% rename from tests/unit/test_sad_path_replace_query_select_clause.py rename to tests/unit/tests_sparql_utils/test_sad_path_replace_query_select_clause.py diff --git a/tests/unit/tests_utils/test_field_bindings_map.py b/tests/unit/tests_utils/test_field_bindings_map.py new file mode 100644 index 0000000..0364e7c --- /dev/null +++ b/tests/unit/tests_utils/test_field_bindings_map.py @@ -0,0 +1,25 @@ +"""Basic unit tests for FieldBindingsMap""" + +from typing import Annotated + +from pydantic import BaseModel +from rdfproxy.utils._types import SPARQLBinding +from rdfproxy.utils.utils import FieldsBindingsMap + + +class Point(BaseModel): + x: int + y: Annotated[int, SPARQLBinding("Y_ALIAS")] + z: Annotated[list[int], SPARQLBinding("Z_ALIAS")] + + +def test_basic_fields_bindings_map(): + mapping = FieldsBindingsMap(model=Point) + + assert mapping["x"] == "x" + assert mapping["y"] == "Y_ALIAS" + assert mapping["z"] == "Z_ALIAS" + + assert mapping.reverse["x"] == "x" + assert mapping.reverse["Y_ALIAS"] == "y" + assert mapping.reverse["Z_ALIAS"] == "z" diff --git a/tests/utils/utils.py b/tests/utils/utils.py new file mode 100644 index 0000000..deb354d --- /dev/null +++ b/tests/utils/utils.py @@ -0,0 +1,11 @@ +"""Testing utils.""" + +import re + + +def normalize_query(select_query: str) -> str: + """Normalize whitespace chars in a SPARQL query.""" + normalized_select_query = re.sub( + r"(?