From e5a9da82b306e03f05ec8ec2b3d8124478a42f69 Mon Sep 17 00:00:00 2001 From: slymit Date: Wed, 12 Jun 2024 22:41:49 +0300 Subject: [PATCH] Add typing and improve docstrings for public functions --- sa_filters/filters.py | 23 ++++++++++++++++------- sa_filters/loads.py | 21 +++++++++++++++------ sa_filters/pagination.py | 3 ++- sa_filters/sorting.py | 21 ++++++++++++++++----- 4 files changed, 49 insertions(+), 19 deletions(-) diff --git a/sa_filters/filters.py b/sa_filters/filters.py index 17f5210..d74f855 100644 --- a/sa_filters/filters.py +++ b/sa_filters/filters.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- from collections import namedtuple -from collections.abc import Iterable from inspect import signature from itertools import chain +from typing import Any, Dict, Iterable, Union from sqlalchemy import and_, or_, not_, func +from sqlalchemy.sql import Select +from sqlalchemy.orm import Query from .exceptions import BadFilterFormat, BadSpec from .models import Field, auto_join, get_model_from_spec, get_default_model, \ @@ -199,13 +201,17 @@ def get_named_models(filters): return models -def apply_filters(query, filter_spec, do_auto_join=True): +def apply_filters( + query: Union[Select, Query], + filter_spec: Union[Iterable[Dict[str, Any]], Dict[str, Any]], + do_auto_join: bool = True +) -> Union[Select, Query]: """Apply filters to a SQLAlchemy query or Select object. :param query: The statement to be processed. May be one of: - a :class:`sqlalchemy.orm.Query` instance or - a :class:`sqlalchemy.sql.expression.Select` instance. + a :class:`sqlalchemy.sql.Select` object or + a :class:`sqlalchemy.orm.Query` object. :param filter_spec: A dict or an iterable of dicts, where each one includes @@ -232,10 +238,13 @@ def apply_filters(query, filter_spec, do_auto_join=True): ] } + :param do_auto_join: + Allow or not auto join. + :returns: - The :class:`sqlalchemy.orm.Query` or - the :class:`sqlalchemy.sql.expression.Select` - instance after all the filters have been applied. + The :class:`sqlalchemy.sql.Select` object or + the :class:`sqlalchemy.orm.Query` object + after all the filters have been applied. """ filters = build_filters(filter_spec) diff --git a/sa_filters/loads.py b/sa_filters/loads.py index 42a1d63..9abf857 100644 --- a/sa_filters/loads.py +++ b/sa_filters/loads.py @@ -1,4 +1,7 @@ -from sqlalchemy.orm import Load +from typing import Any, Dict, List, Union + +from sqlalchemy.sql import Select +from sqlalchemy.orm import Load, Query from .exceptions import BadLoadFormat from .models import Field, auto_join, get_model_from_spec, get_default_model @@ -44,9 +47,15 @@ def get_named_models(loads): return models -def apply_loads(query, load_spec): - """Apply load restrictions to a :class:`sqlalchemy.orm.Query` instance - or a :class:`sqlalchemy.sql.expression.Select` instance. +def apply_loads( + query: Union[Select, Query], + load_spec: Union[List[Dict[str, Any]], Dict[str, Any], List[str]] +) -> Union[Select, Query]: + """Apply load restrictions to a :class:`sqlalchemy.sql.Select` object + or a :class:`sqlalchemy.orm.Query` object. + + :param query: + The statement to be processed. :param load_spec: A list of dictionaries, where each item contains the fields to load @@ -66,8 +75,8 @@ def apply_loads(query, load_spec): load_spec = ['id', 'name'] :returns: - The :class:`sqlalchemy.orm.Query` instance or - a :class:`sqlalchemy.sql.expression.Select` instance + The :class:`sqlalchemy.sql.Select` object or + a :class:`sqlalchemy.orm.Query` object after the load restrictions have been applied. """ if ( diff --git a/sa_filters/pagination.py b/sa_filters/pagination.py index 70433ad..0028432 100644 --- a/sa_filters/pagination.py +++ b/sa_filters/pagination.py @@ -27,7 +27,8 @@ def apply_pagination( page_size: Optional[int] = None, total_results: int = 0 ) -> TupleType[Union[Select, Query], Pagination]: - """Apply pagination to a SQLAlchemy query or Select object. + """Apply pagination to a SQLAlchemy :class:`sqlalchemy.sql.Select` object + or a :class:`sqlalchemy.orm.Query` object. :param stmt: The statement to be processed. diff --git a/sa_filters/sorting.py b/sa_filters/sorting.py index d3a3527..2bf2a16 100644 --- a/sa_filters/sorting.py +++ b/sa_filters/sorting.py @@ -1,5 +1,10 @@ # -*- coding: utf-8 -*- +from typing import Any, Dict, List, Union + +from sqlalchemy.sql import Select +from sqlalchemy.orm import Query + from .exceptions import BadSortFormat from .models import Field, auto_join, get_model_from_spec, get_default_model @@ -67,9 +72,15 @@ def get_named_models(sorts): return models -def apply_sort(query, sort_spec): - """Apply sorting to a :class:`sqlalchemy.orm.Query` instance or - a :class:`sqlalchemy.sql.expression.Select` instance. +def apply_sort( + query: Union[Select, Query], + sort_spec: Union[List[Dict[str, Any]], Dict[str, Any]] +) -> Union[Select, Query]: + """Apply sorting to a SQLAlchemy :class:`sqlalchemy.sql.Select` + object or a :class:`sqlalchemy.orm.Query` object. + + :param query: + The statement to be processed. :param sort_spec: A list of dictionaries, where each one of them includes @@ -98,8 +109,8 @@ def apply_sort(query, sort_spec): may be omitted from the sort spec. :returns: - The :class:`sqlalchemy.orm.Query` instance or - the :class:`sqlalchemy.sql.expression.Select` after the provided + The :class:`sqlalchemy.sql.Select` object or + the :class:`sqlalchemy.orm.Query` object after the provided sorting has been applied. """ if isinstance(sort_spec, dict):