Skip to content

Commit

Permalink
Add typing and improve docstrings for public functions
Browse files Browse the repository at this point in the history
  • Loading branch information
slymit committed Jun 12, 2024
1 parent 7b60588 commit ec23575
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 18 deletions.
22 changes: 16 additions & 6 deletions sa_filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from collections.abc import Iterable
from inspect import signature
from itertools import chain
from typing import Union, Any

from sqlalchemy import and_, or_, not_, func
from sqlalchemy.sql import Select
from sqlalchemy.orm import Query

from .exceptions import BadFilterFormat, BadSpec
from .models import Field, auto_join, get_model_from_spec, get_default_model, \
Expand Down Expand Up @@ -199,13 +202,17 @@ def get_named_models(filters):
return models


def apply_filters(query, filter_spec, do_auto_join=True):
def apply_filters(
query: Union[Select, Query],
filter_spec: Union[Iterable[dict[str, Any]], dict[str, Any]],
do_auto_join: bool = True
) -> Union[Select, Query]:
"""Apply filters to a SQLAlchemy query or Select object.
:param query:
The statement to be processed. May be one of:
a :class:`sqlalchemy.orm.Query` instance or
a :class:`sqlalchemy.sql.expression.Select` instance.
a :class:`sqlalchemy.sql.Select` object or
a :class:`sqlalchemy.orm.Query` object.
:param filter_spec:
A dict or an iterable of dicts, where each one includes
Expand All @@ -232,10 +239,13 @@ def apply_filters(query, filter_spec, do_auto_join=True):
]
}
:param do_auto_join:
Allow or not auto join.
:returns:
The :class:`sqlalchemy.orm.Query` or
the :class:`sqlalchemy.sql.expression.Select`
instance after all the filters have been applied.
The :class:`sqlalchemy.sql.Select` object or
the :class:`sqlalchemy.orm.Query` object
after all the filters have been applied.
"""
filters = build_filters(filter_spec)

Expand Down
21 changes: 15 additions & 6 deletions sa_filters/loads.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from sqlalchemy.orm import Load
from typing import Union, Any

from sqlalchemy.sql import Select
from sqlalchemy.orm import Load, Query

from .exceptions import BadLoadFormat
from .models import Field, auto_join, get_model_from_spec, get_default_model
Expand Down Expand Up @@ -44,9 +47,15 @@ def get_named_models(loads):
return models


def apply_loads(query, load_spec):
"""Apply load restrictions to a :class:`sqlalchemy.orm.Query` instance
or a :class:`sqlalchemy.sql.expression.Select` instance.
def apply_loads(
query: Union[Select, Query],
load_spec: Union[list[dict[str, Any]], dict[str, Any], list[str]]
) -> Union[Select, Query]:
"""Apply load restrictions to a :class:`sqlalchemy.sql.Select` object
or a :class:`sqlalchemy.orm.Query` object.
:param query:
The statement to be processed.
:param load_spec:
A list of dictionaries, where each item contains the fields to load
Expand All @@ -66,8 +75,8 @@ def apply_loads(query, load_spec):
load_spec = ['id', 'name']
:returns:
The :class:`sqlalchemy.orm.Query` instance or
a :class:`sqlalchemy.sql.expression.Select` instance
The :class:`sqlalchemy.sql.Select` object or
a :class:`sqlalchemy.orm.Query` object
after the load restrictions have been applied.
"""
if (
Expand Down
3 changes: 2 additions & 1 deletion sa_filters/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def apply_pagination(
page_size: Optional[int] = None,
total_results: int = 0
) -> TupleType[Union[Select, Query], Pagination]:
"""Apply pagination to a SQLAlchemy query or Select object.
"""Apply pagination to a SQLAlchemy :class:`sqlalchemy.sql.Select` object
or a :class:`sqlalchemy.orm.Query` object.
:param stmt:
The statement to be processed.
Expand Down
21 changes: 16 additions & 5 deletions sa_filters/sorting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# -*- coding: utf-8 -*-

from typing import Union, Any

from sqlalchemy.sql import Select
from sqlalchemy.orm import Query

from .exceptions import BadSortFormat
from .models import Field, auto_join, get_model_from_spec, get_default_model

Expand Down Expand Up @@ -67,9 +72,15 @@ def get_named_models(sorts):
return models


def apply_sort(query, sort_spec):
"""Apply sorting to a :class:`sqlalchemy.orm.Query` instance or
a :class:`sqlalchemy.sql.expression.Select` instance.
def apply_sort(
query: Union[Select, Query],
sort_spec: Union[list[dict[str, Any]], dict[str, Any]]
) -> Union[Select, Query]:
"""Apply sorting to a SQLAlchemy :class:`sqlalchemy.sql.Select`
object or a :class:`sqlalchemy.orm.Query` object.
:param query:
The statement to be processed.
:param sort_spec:
A list of dictionaries, where each one of them includes
Expand Down Expand Up @@ -98,8 +109,8 @@ def apply_sort(query, sort_spec):
may be omitted from the sort spec.
:returns:
The :class:`sqlalchemy.orm.Query` instance or
the :class:`sqlalchemy.sql.expression.Select` after the provided
The :class:`sqlalchemy.sql.Select` object or
the :class:`sqlalchemy.orm.Query` object after the provided
sorting has been applied.
"""
if isinstance(sort_spec, dict):
Expand Down

0 comments on commit ec23575

Please sign in to comment.