diff --git a/rdfproxy/adapter.py b/rdfproxy/adapter.py index ad98717..ebe7da3 100644 --- a/rdfproxy/adapter.py +++ b/rdfproxy/adapter.py @@ -4,15 +4,11 @@ import math from typing import Generic +from rdfproxy.constructor import QueryConstructor from rdfproxy.mapper import ModelBindingsMapper from rdfproxy.sparql_strategies import HttpxStrategy, SPARQLStrategy from rdfproxy.utils._types import _TModelInstance from rdfproxy.utils.models import Page, QueryParameters -from rdfproxy.utils.sparql_utils import ( - calculate_offset, - construct_count_query, - construct_items_query, -) class SPARQLModelAdapter(Generic[_TModelInstance]): @@ -24,10 +20,12 @@ class SPARQLModelAdapter(Generic[_TModelInstance]): SPARQLModelAdapter.query returns a Page model object with a default pagination size of 100 results. SPARQL bindings are implicitly assigned to model fields of the same name, - explicit SPARQL binding to model field allocation is available with typing.Annotated and rdfproxy.SPARQLBinding. + explicit SPARQL binding to model field allocation is available with rdfproxy.SPARQLBinding. Result grouping is controlled through the model, i.e. grouping is triggered when a field of list[pydantic.BaseModel] is encountered. + + See https://github.com/acdh-oeaw/rdfproxy/tree/main/examples for examples. """ def __init__( @@ -44,20 +42,21 @@ def __init__( def query(self, query_parameters: QueryParameters) -> Page[_TModelInstance]: """Run a query against an endpoint and return a Page model object.""" - count_query: str = construct_count_query(query=self._query, model=self._model) - items_query: str = construct_items_query( + query_constructor = QueryConstructor( query=self._query, + query_parameters=query_parameters, model=self._model, - limit=query_parameters.size, - offset=calculate_offset(query_parameters.page, query_parameters.size), ) - items_query_bindings: Iterator[dict] = self.sparql_strategy.query(items_query) + count_query = query_constructor.get_count_query() + items_query = query_constructor.get_items_query() + items_query_bindings: Iterator[dict] = self.sparql_strategy.query(items_query) mapper = ModelBindingsMapper(self._model, *items_query_bindings) - items: list[_TModelInstance] = mapper.get_models() - total: int = self._get_count(count_query) + + count_query_bindings: Iterator[dict] = self.sparql_strategy.query(count_query) + total: int = int(next(count_query_bindings)["cnt"]) pages: int = math.ceil(total / query_parameters.size) return Page( @@ -67,11 +66,3 @@ def query(self, query_parameters: QueryParameters) -> Page[_TModelInstance]: total=total, pages=pages, ) - - def _get_count(self, query: str) -> int: - """Run a count query and return the count result. - - Helper for SPARQLModelAdapter.query. - """ - result: Iterator[dict] = self.sparql_strategy.query(query) - return int(next(result)["cnt"]) diff --git a/rdfproxy/constructor.py b/rdfproxy/constructor.py new file mode 100644 index 0000000..8054bd3 --- /dev/null +++ b/rdfproxy/constructor.py @@ -0,0 +1,124 @@ +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils.models import QueryParameters +from rdfproxy.utils.sparql_utils import ( + add_solution_modifier, + get_query_projection, + inject_into_query, + remove_sparql_prefixes, + replace_query_select_clause, +) +from rdfproxy.utils.utils import ( + FieldsBindingsMap, + QueryConstructorComponent as component, + compose_left, +) + + +class QueryConstructor: + """The class encapsulates dynamic SPARQL query modification logic + for implementing purely SPARQL-based, deterministic pagination. + + Public methods get_items_query and get_count_query are used in rdfproxy.SPARQLModelAdapter + to construct queries for retrieving arguments for Page object instantiation. + """ + + def __init__( + self, + query: str, + query_parameters: QueryParameters, + model: type[_TModelInstance], + ) -> None: + self.query = query + self.query_parameters = query_parameters + self.model = model + + self.bindings_map = FieldsBindingsMap(model) + self.group_by: str | None = self.bindings_map.get( + model.model_config.get("group_by") + ) + + def get_items_query(self) -> str: + """Construct a SPARQL items query for use in rdfproxy.SPARQLModelAdapter.""" + if self.group_by is None: + return self._get_ungrouped_items_query() + return self._get_grouped_items_query() + + def get_count_query(self) -> str: + """Construct a SPARQL count query for use in rdfproxy.SPARQLModelAdapter""" + if self.group_by is None: + select_clause = "select (count(*) as ?cnt)" + else: + select_clause = f"select (count(distinct ?{self.group_by}) as ?cnt)" + + return replace_query_select_clause(self.query, select_clause) + + @staticmethod + def _calculate_offset(page: int, size: int) -> int: + """Calculate the offset value for paginated SPARQL templates.""" + match page: + case 1: + return 0 + case 2: + return size + case _: + return size * (page - 1) + + def _get_grouped_items_query(self) -> str: + """Construct a SPARQL items query for grouped models.""" + filter_clause: str | None = self._compute_filter_clause() + select_clause: str = self._compute_select_clause() + order_by_value: str = self._compute_order_by_value() + limit, offset = self._compute_limit_offset() + + subquery = compose_left( + remove_sparql_prefixes, + component(replace_query_select_clause, repl=select_clause), + component(inject_into_query, injectant=filter_clause), + component( + add_solution_modifier, + order_by=order_by_value, + limit=limit, + offset=offset, + ), + )(self.query) + + return inject_into_query(self.query, subquery) + + def _get_ungrouped_items_query(self) -> str: + """Construct a SPARQL items query for ungrouped models.""" + filter_clause: str | None = self._compute_filter_clause() + order_by_value: str = self._compute_order_by_value() + limit, offset = self._compute_limit_offset() + + return compose_left( + component(inject_into_query, injectant=filter_clause), + component( + add_solution_modifier, + order_by=order_by_value, + limit=limit, + offset=offset, + ), + )(self.query) + + def _compute_limit_offset(self) -> tuple[int, int]: + """Calculate limit and offset values for SPARQL-based pagination.""" + limit = self.query_parameters.size + offset = self._calculate_offset( + self.query_parameters.page, self.query_parameters.size + ) + + return limit, offset + + def _compute_filter_clause(self) -> str | None: + """Stub: Always None for now.""" + return None + + def _compute_select_clause(self): + """Stub: Static SELECT clause for now.""" + return f"select distinct ?{self.group_by}" + + def _compute_order_by_value(self): + """Stub: Only basic logic for now.""" + if self.group_by is None: + return get_query_projection(self.query)[0] + return f"{self.group_by}" diff --git a/rdfproxy/mapper.py b/rdfproxy/mapper.py index da10054..7c3d56d 100644 --- a/rdfproxy/mapper.py +++ b/rdfproxy/mapper.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance -from rdfproxy.utils.utils import ( +from rdfproxy.utils.mapper_utils import ( _collect_values_from_bindings, _get_group_by, _get_key_from_metadata, @@ -65,7 +65,6 @@ def _generate_binding_pairs( and (x[self._contexts[0]] == kwargs[self._contexts[0]]), self.bindings, ) - value = self._get_unique_models(group_model, applicable_bindings) elif _is_list_type(v.annotation): diff --git a/rdfproxy/sparql_strategies.py b/rdfproxy/sparql_strategies.py index 6b61860..da26d41 100644 --- a/rdfproxy/sparql_strategies.py +++ b/rdfproxy/sparql_strategies.py @@ -2,6 +2,7 @@ import abc from collections.abc import Iterator +from typing import cast from SPARQLWrapper import JSON, QueryResult, SPARQLWrapper import httpx @@ -13,7 +14,7 @@ def __init__(self, endpoint: str): @abc.abstractmethod def query(self, sparql_query: str) -> Iterator[dict[str, str]]: - raise NotImplementedError + raise NotImplementedError # pragma: no cover @staticmethod def _get_bindings_from_bindings_dict(bindings_dict: dict) -> Iterator[dict]: @@ -35,7 +36,9 @@ def query(self, sparql_query: str) -> Iterator[dict[str, str]]: self._sparql_wrapper.setQuery(sparql_query) result: QueryResult = self._sparql_wrapper.query() - return self._get_bindings_from_bindings_dict(result.convert()) + # SPARQLWrapper.Wrapper.convert is not overloaded properly and needs casting + # https://github.com/RDFLib/sparqlwrapper/blob/master/SPARQLWrapper/Wrapper.py#L1135 + return self._get_bindings_from_bindings_dict(cast(dict, result.convert())) class HttpxStrategy(SPARQLStrategy): diff --git a/rdfproxy/utils/mapper_utils.py b/rdfproxy/utils/mapper_utils.py new file mode 100644 index 0000000..acce14e --- /dev/null +++ b/rdfproxy/utils/mapper_utils.py @@ -0,0 +1,127 @@ +from collections.abc import Callable, Iterable +from typing import Any, TypeGuard, get_args, get_origin + +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from rdfproxy.utils._exceptions import ( + InvalidGroupingKeyException, + MissingModelConfigException, +) +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue + + +def _is_type(obj: type | None, _type: type) -> bool: + """Check if an obj is type _type or a GenericAlias with origin _type.""" + return (obj is _type) or (get_origin(obj) is _type) + + +def _is_list_type(obj: type | None) -> bool: + """Check if obj is a list type.""" + return _is_type(obj, list) + + +def _is_list_basemodel_type(obj: type | None) -> bool: + """Check if a type is list[pydantic.BaseModel].""" + return (get_origin(obj) is list) and all( + issubclass(cls, BaseModel) for cls in get_args(obj) + ) + + +def _collect_values_from_bindings( + binding_name: str, + bindings: Iterable[dict], + predicate: Callable[[Any], bool] = lambda x: x is not None, +) -> list: + """Scan bindings for a key binding_name and collect unique predicate-compliant values. + + Note that element order is important for testing, so a set cast won't do. + """ + values = dict.fromkeys( + value + for binding in bindings + if predicate(value := binding.get(binding_name, None)) + ) + return list(values) + + +def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any: + """Try to get a SPARQLBinding object from a field's metadata attribute. + + Helper for _generate_binding_pairs. + """ + return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default) + + +def _get_applicable_grouping_keys(model: type[_TModelInstance]) -> list[str]: + return [k for k, v in model.model_fields.items() if not _is_list_type(v.annotation)] + + +def _get_group_by(model: type[_TModelInstance]) -> str: + """Get the name of a grouping key from a model Config class.""" + try: + group_by = model.model_config["group_by"] # type: ignore + except KeyError as e: + raise MissingModelConfigException( + "Model config with 'group_by' value required " + "for field-based grouping behavior." + ) from e + else: + applicable_keys = _get_applicable_grouping_keys(model=model) + + if group_by not in applicable_keys: + raise InvalidGroupingKeyException( + f"Invalid grouping key '{group_by}'. " + f"Applicable grouping keys: {', '.join(applicable_keys)}." + ) + + if meta := model.model_fields[group_by].metadata: + if binding := next( + filter(lambda entry: isinstance(entry, SPARQLBinding), meta), None + ): + return binding + return group_by + + +def default_model_bool_predicate(model: BaseModel) -> bool: + """Default predicate for determining model truthiness. + + Adheres to rdfproxy.utils._types.ModelBoolPredicate. + """ + return any(dict(model).values()) + + +def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]: + return (not isinstance(iterable, str)) and all( + map(lambda i: isinstance(i, str), iterable) + ) + + +def _get_model_bool_predicate_from_config_value( + model_bool_value: _TModelBoolValue, +) -> ModelBoolPredicate: + """Get a model_bool predicate function given the value of the model_bool config setting.""" + match model_bool_value: + case ModelBoolPredicate(): + return model_bool_value + case str(): + return lambda model: bool(dict(model)[model_bool_value]) + case model_bool_value if _is_iterable_of_str(model_bool_value): + return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) + case _: # pragma: no cover + raise TypeError( + "Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" + f"Received {type(model_bool_value)}" + ) + + +def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: + """Get the applicable model_bool predicate function given a model.""" + if (model_bool_value := model.model_config.get("model_bool", None)) is None: + model_bool_predicate = default_model_bool_predicate + else: + model_bool_predicate = _get_model_bool_predicate_from_config_value( + model_bool_value + ) + + return model_bool_predicate diff --git a/rdfproxy/utils/sparql_utils.py b/rdfproxy/utils/sparql_utils.py index a064898..4eafea3 100644 --- a/rdfproxy/utils/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -1,19 +1,13 @@ """Functionality for dynamic SPARQL query modifcation.""" -from collections.abc import Iterator -from contextlib import contextmanager -from functools import partial +from itertools import chain import re -from typing import cast +from typing import overload -from SPARQLWrapper import QueryResult, SPARQLWrapper +from rdflib import Variable +from rdflib.plugins.sparql.parser import parseQuery +from rdflib.plugins.sparql.parserutils import CompValue, ParseResults from rdfproxy.utils._exceptions import QueryConstructionException -from rdfproxy.utils._types import ItemsQueryConstructor, SPARQLBinding, _TModelInstance - - -def construct_ungrouped_pagination_query(query: str, limit: int, offset: int) -> str: - """Construct an ungrouped pagination query.""" - return f"{query} limit {limit} offset {offset}" def replace_query_select_clause(query: str, repl: str) -> str: @@ -23,7 +17,7 @@ def replace_query_select_clause(query: str, repl: str) -> str: ) if re.search(pattern=pattern, string=query) is None: - raise Exception("Unable to obtain SELECT clause.") + raise QueryConstructionException("Unable to obtain SELECT clause.") modified_query = re.sub( pattern=pattern, @@ -35,148 +29,90 @@ def replace_query_select_clause(query: str, repl: str) -> str: return modified_query -def _remove_sparql_prefixes(query: str) -> str: +def remove_sparql_prefixes(query: str) -> str: """Remove SPARQL prefixes from a query. This is needed for subquery injection, because subqueries cannot have prefixes. - Note that this is not generic, all prefixes are simply ut from the subquery - and do not get appended to the outer query prefixes. + Note that this is not generic, all prefixes are simply cut from the subquery + and are not resolved against the outer query prefixes. """ - prefix_pattern = re.compile(r"PREFIX\s+\w*:\s?<[^>]+>\s*", flags=re.I) + prefix_pattern = re.compile(r"PREFIX\s+\w*:\s?<[^>]+>\s*", flags=re.IGNORECASE) cleaned_query = re.sub(prefix_pattern, "", query).strip() return cleaned_query -def inject_subquery(query: str, subquery: str) -> str: - """Inject a SPARQL query with a subquery.""" +def inject_into_query(query: str, injectant: str) -> str: + """Inject some injectant (e.g. subquery or filter clause) into a query.""" if (tail := re.search(r"}[^}]*\Z", query)) is None: - raise QueryConstructionException("Unable to inject subquery.") + raise QueryConstructionException( + "Unable to inject subquery." + ) # pragma: no cover ; this will be unreachable once query checking runs tail_index: int = tail.start() - injected: str = f"{query[:tail_index]} {{{_remove_sparql_prefixes(subquery)}}} {query[tail_index:]}" - return injected + injected_query: str = f"{query[:tail_index]} {{{injectant}}} {query[tail_index:]}" + return injected_query -def construct_grouped_pagination_query( - query: str, group_by_value: str, limit: int, offset: int +def add_solution_modifier( + query: str, + *, + order_by: str | None = None, + limit: int | None = None, + offset: int | None = None, ) -> str: - """Construct a grouped pagination query.""" - _subquery_base: str = replace_query_select_clause( - query=query, repl=f"select distinct ?{group_by_value}" - ) - subquery: str = construct_ungrouped_pagination_query( - query=_subquery_base, limit=limit, offset=offset - ) - - grouped_pagination_query: str = inject_subquery(query=query, subquery=subquery) - return grouped_pagination_query + """Add optional solution modifiers in SPARQL-conformant order to a query.""" + modifiers = [] + if order_by is not None: + modifiers.append(f"order by ?{order_by}") + if limit is not None: + modifiers.append(f"limit {limit}") + if offset is not None: + modifiers.append(f"offset {offset}") -def get_items_query_constructor( - model: type[_TModelInstance], -) -> ItemsQueryConstructor: - """Get the applicable query constructor function given a model class.""" + return f"{query} {' '.join(modifiers)}".strip() - if (group_by_value := model.model_config.get("group_by", None)) is None: - return construct_ungrouped_pagination_query - elif meta := model.model_fields[group_by_value].metadata: - group_by_value = next( - filter(lambda x: isinstance(x, SPARQLBinding), meta), group_by_value - ) +@overload +def _compvalue_to_dict(comp_value: dict | CompValue) -> dict: ... - return partial(construct_grouped_pagination_query, group_by_value=group_by_value) +@overload +def _compvalue_to_dict(comp_value: list | ParseResults) -> list: ... -def construct_items_query( - query: str, model: type[_TModelInstance], limit: int, offset: int -) -> str: - """Construct a grouped pagination query.""" - items_query_constructor: ItemsQueryConstructor = get_items_query_constructor( - model=model - ) - return items_query_constructor(query=query, limit=limit, offset=offset) - - -def construct_count_query(query: str, model: type[_TModelInstance]) -> str: - """Construct a generic count query from a SELECT query.""" - try: - group_by: str = model.model_config["group_by"] - group_by_binding = next( - filter( - lambda x: isinstance(x, SPARQLBinding), - model.model_fields[group_by].metadata, - ), - group_by, - ) - count_query = construct_grouped_count_query(query, group_by_binding) - except KeyError: - count_query = replace_query_select_clause(query, "select (count(*) as ?cnt)") - - return count_query - - -def calculate_offset(page: int, size: int) -> int: - """Calculate offset value for paginated SPARQL templates.""" - match page: - case 1: - return 0 - case 2: - return size - case _: - return size * (page - 1) - - -def construct_grouped_count_query(query: str, group_by) -> str: - grouped_count_query = replace_query_select_clause( - query, f"select (count(distinct ?{group_by}) as ?cnt)" - ) - return grouped_count_query +def _compvalue_to_dict(comp_value: CompValue): + """Convert a CompValue parsing object into a Python dict/list representation. - -def _get_bindings_from_bindings_dict(bindings_dict: dict) -> Iterator[dict]: - bindings = map( - lambda binding: {k: v["value"] for k, v in binding.items()}, - bindings_dict["results"]["bindings"], - ) - return bindings - - -def get_bindings_from_query_result(query_result: QueryResult) -> Iterator[dict]: - """Extract just the bindings from a SPARQLWrapper.QueryResult.""" - if (result_format := query_result.requestedFormat) != "json": - raise Exception( - "Only QueryResult objects with JSON format are currently supported. " - f"Received object with requestedFormat '{result_format}'." - ) - - query_json: dict = cast(dict, query_result.convert()) - bindings = _get_bindings_from_bindings_dict(query_json) - - return bindings - - -@contextmanager -def temporary_query_override(sparql_wrapper: SPARQLWrapper): - """Context manager that allows to contextually overwrite a query in a SPARQLWrapper object.""" - _query_cache = sparql_wrapper.queryString - - try: - yield sparql_wrapper - finally: - sparql_wrapper.setQuery(_query_cache) + Helper for get_query_projection. + """ + if isinstance(comp_value, dict): + return {key: _compvalue_to_dict(value) for key, value in comp_value.items()} + elif isinstance(comp_value, list | ParseResults): + return [_compvalue_to_dict(item) for item in comp_value] + else: + return comp_value -def query_with_wrapper(query: str, sparql_wrapper: SPARQLWrapper) -> Iterator[dict]: - """Execute a SPARQL query using a predefined sparql_wrapper object. +def get_query_projection(query: str) -> list[Variable]: + """Parse a SPARQL SELECT query and extract the ordered bindings projection. - The query attribute of the wrapper object is temporarily overridden - and gets restored after query execution. + The first case handles explicit/literal binding projections. + The second case handles implicit/* binding projections. + The third case handles implicit/* binding projections with VALUES. """ - with temporary_query_override(sparql_wrapper=sparql_wrapper): - sparql_wrapper.setQuery(query) - result: QueryResult = sparql_wrapper.query() - - bindings: Iterator[dict] = get_bindings_from_query_result(result) - return bindings + _parse_result: CompValue = parseQuery(query)[1] + parsed_query: dict = _compvalue_to_dict(_parse_result) + + match parsed_query: + case {"projection": projection}: + return [i["var"] for i in projection] + case {"where": {"part": [{"triples": triples}]}}: + projection = dict.fromkeys( + i for i in chain.from_iterable(triples) if isinstance(i, Variable) + ) + return list(projection) + case {"where": {"part": [{"var": var}]}}: + return var + case _: # pragma: no cover + raise Exception("Unable to obtain query projection.") diff --git a/rdfproxy/utils/utils.py b/rdfproxy/utils/utils.py index e7e4c12..279d19c 100644 --- a/rdfproxy/utils/utils.py +++ b/rdfproxy/utils/utils.py @@ -1,129 +1,72 @@ """SPARQL/FastAPI utils.""" -from collections.abc import Callable, Iterable -from typing import Any, TypeGuard, get_args, get_origin - -from pydantic import BaseModel -from pydantic.fields import FieldInfo -from rdfproxy.utils._exceptions import ( - InvalidGroupingKeyException, - MissingModelConfigException, -) -from rdfproxy.utils._types import _TModelInstance -from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue - - -def _is_type(obj: type | None, _type: type) -> bool: - """Check if an obj is type _type or a GenericAlias with origin _type.""" - return (obj is _type) or (get_origin(obj) is _type) +from collections import UserDict +from collections.abc import Callable +from functools import partial +from typing import TypeVar +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import SPARQLBinding -def _is_list_type(obj: type | None) -> bool: - """Check if obj is a list type.""" - return _is_type(obj, list) +T = TypeVar("T") -def _is_list_basemodel_type(obj: type | None) -> bool: - """Check if a type is list[pydantic.BaseModel].""" - return (get_origin(obj) is list) and all( - issubclass(cls, BaseModel) for cls in get_args(obj) - ) +class FieldsBindingsMap(UserDict): + """Mapping for resolving SPARQLBinding aliases. -def _collect_values_from_bindings( - binding_name: str, - bindings: Iterable[dict], - predicate: Callable[[Any], bool] = lambda x: x is not None, -) -> list: - """Scan bindings for a key binding_name and collect unique predicate-compliant values. + Model field names are mapped to SPARQLBinding names. + The FieldsBindingsMap.reverse allows reverse lookup + (i.e. from SPARQLBindings to model fields). - Note that element order is important for testing, so a set cast won't do. + Note: It might be useful to recursively resolve aliases for nested models. """ - values = dict.fromkeys( - value - for binding in bindings - if predicate(value := binding.get(binding_name, None)) - ) - return list(values) + def __init__(self, model: type[_TModelInstance]) -> None: + self.data = self._get_field_binding_mapping(model) + self._reversed = {v: k for k, v in self.data.items()} -def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any: - """Try to get a SPARQLBinding object from a field's metadata attribute. + @property + def reverse(self) -> dict[str, str]: + """Reverse lookup map from SPARQL bindings to model fields.""" + return self._reversed - Helper for _generate_binding_pairs. - """ - return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default) + @staticmethod + def _get_field_binding_mapping(model: type[_TModelInstance]) -> dict[str, str]: + """Resolve model fields against rdfproxy.SPARQLBindings.""" + return { + k: next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), k) + for k, v in model.model_fields.items() + } -def _get_applicable_grouping_keys(model: type[_TModelInstance]) -> list[str]: - return [k for k, v in model.model_fields.items() if not _is_list_type(v.annotation)] +def compose_left(*fns: Callable[[T], T]) -> Callable[[T], T]: + """Left associative compose.""" + def _left_wrapper(*fns): + fn, *rest_fns = fns -def _get_group_by(model: type[_TModelInstance]) -> str: - """Get the name of a grouping key from a model Config class.""" - try: - group_by = model.model_config["group_by"] # type: ignore - except KeyError as e: - raise MissingModelConfigException( - "Model config with 'group_by' value required " - "for field-based grouping behavior." - ) from e - else: - applicable_keys = _get_applicable_grouping_keys(model=model) + if rest_fns: + return lambda *args, **kwargs: fn(_left_wrapper(*rest_fns)(*args, **kwargs)) + return fn - if group_by not in applicable_keys: - raise InvalidGroupingKeyException( - f"Invalid grouping key '{group_by}'. " - f"Applicable grouping keys: {', '.join(applicable_keys)}." - ) + return _left_wrapper(*reversed(fns)) - if meta := model.model_fields[group_by].metadata: - if binding := next( - filter(lambda entry: isinstance(entry, SPARQLBinding), meta), None - ): - return binding - return group_by +class QueryConstructorComponent: + """Query modification component factory. -def default_model_bool_predicate(model: BaseModel) -> bool: - """Default predicate for determining model truthiness. + Components either call the wrapped function with non-None value kwargs applied + or (if all kwargs values are None) fall back to the identity function. - Adheres to rdfproxy.utils._types.ModelBoolPredicate. + QueryConstructorComponents are used in QueryConstructor for query modification compose chains. """ - return any(dict(model).values()) - - -def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]: - return (not isinstance(iterable, str)) and all( - map(lambda i: isinstance(i, str), iterable) - ) - - -def _get_model_bool_predicate_from_config_value( - model_bool_value: _TModelBoolValue, -) -> ModelBoolPredicate: - """Get a model_bool predicate function given the value of the model_bool config setting.""" - match model_bool_value: - case ModelBoolPredicate(): - return model_bool_value - case str(): - return lambda model: bool(dict(model)[model_bool_value]) - case model_bool_value if _is_iterable_of_str(model_bool_value): - return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) - case _: - raise TypeError( - "Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" - f"Received {type(model_bool_value)}" - ) - - -def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: - """Get the applicable model_bool predicate function given a model.""" - if (model_bool_value := model.model_config.get("model_bool", None)) is None: - model_bool_predicate = default_model_bool_predicate - else: - model_bool_predicate = _get_model_bool_predicate_from_config_value( - model_bool_value - ) - - return model_bool_predicate + + def __init__(self, f: Callable[..., str], **kwargs) -> None: + self.f = f + self.kwargs = kwargs + + def __call__(self, query) -> str: + if tkwargs := {k: v for k, v in self.kwargs.items() if v is not None}: + return partial(self.f, **tkwargs)(query) + return query