From 1360f8337a882c0b3a4cf6e38975fedd346006ae Mon Sep 17 00:00:00 2001 From: Pierlou Ramade <48205215+Pierlou@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:18:42 +0100 Subject: [PATCH] Add aggregators (#35) * fix: better syntax for sort * feat: upgrade postgrest and allow aggregates * feat: add aggregators * refactor: loop for parameters descriptions * feat: add tests for aggregators * fix: restore previous behaviour * fix: lint * docs: update changelog * docs: update readme * docs: add missing hint types * refactor: remove default __id side sort to allow sort with aggregation * refactor: return 400 if argument could not be parsed, stricter than before * refactor: adapt tests * fix: lint --- CHANGELOG.md | 2 +- README.md | 46 +++++++- api_tabular/app.py | 4 +- api_tabular/utils.py | 262 +++++++++++++++++++++++++++--------------- docker-compose.yml | 3 +- tests/test_api.py | 19 ++- tests/test_query.py | 36 +++++- tests/test_swagger.py | 5 +- 8 files changed, 263 insertions(+), 114 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c08e1a4..d542c9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## Current (in progress) -- Nothing yet +- Handle queries with aggregators [#35](https://github.com/datagouv/api-tabular/pull/35) ## 0.2.1 (2024-11-21) diff --git a/README.md b/README.md index b03897d..16c69be 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ curl http://localhost:8005/api/resources/aaaaaaaa-1111-bbbb-2222-cccccccccccc/da } ``` -This endpoint can be queried with the following operators as query string (replacing `column_name` with the name of an actual column): +This endpoint can be queried with the following operators as query string (replacing `column_name` with the name of an actual column), if the column type allows it (see the swagger for each column's allowed parameter): ``` # sort by column @@ -142,7 +142,26 @@ column_name__strictly_less=value # strictly greater column_name__strictly_greater=value + +# group by values +column_name__groupby + +# count values +column_name__count + +# mean / average +column_name__avg + +# minimum +column_name__min + +# maximum +column_name__max + +# sum +column_name__sum ``` +> NB : passing an aggregation operator (`count`, `avg`, `min`, `max`, `sum`) returns a column that is named `__` (for instance: `?birth__groupby&score__sum` will return a list of dicts with the keys `birth` and `score__sum`). For instance: ```shell @@ -185,6 +204,31 @@ returns } ``` +With filters and aggregators (filtering is always done **before** aggregation, no matter the order in the parameters): +```shell +curl http://localhost:8005/api/resources/aaaaaaaa-1111-bbbb-2222-cccccccccccc/data/?decompte__groupby&birth__less=1996&score__avg +``` +i.e. `decompte` and average of `score` for all rows where `birth<="1996"`, grouped by `decompte`, returns +```json +{ + "data": [ + { + "decompte": 55, + "score__avg": 0.7123333333333334 + }, + { + "decompte": 27, + "score__avg": 0.6068888888888889 + }, + { + "decompte": 23, + "score__avg": 0.4603333333333334 + }, + ... + ] +} +``` + Pagination is made through queries with `page` and `page_size`: ```shell curl http://localhost:8005/api/resources/aaaaaaaa-1111-bbbb-2222-cccccccccccc/data/?page=2&page_size=30 diff --git a/api_tabular/app.py b/api_tabular/app.py index 4ade643..516455c 100644 --- a/api_tabular/app.py +++ b/api_tabular/app.py @@ -96,8 +96,8 @@ async def resource_data(request): try: sql_query = build_sql_query_string(query_string, page_size, offset) - except ValueError: - raise QueryException(400, None, "Invalid query string", "Malformed query") + except ValueError as e: + raise QueryException(400, None, "Invalid query string", f"Malformed query: {e}") resource = await get_resource(request.app["csession"], resource_id, ["parsing_table"]) response, total = await get_resource_data(request.app["csession"], resource, sql_query) diff --git a/api_tabular/utils.py b/api_tabular/utils.py index 145b774..c9bbf1d 100644 --- a/api_tabular/utils.py +++ b/api_tabular/utils.py @@ -1,3 +1,6 @@ +from collections import defaultdict +from typing import Optional + import tomllib import yaml from aiohttp.web_request import Request @@ -6,13 +9,37 @@ from api_tabular import config TYPE_POSSIBILITIES = { - "string": ["compare", "contains", "differs", "exact", "in", "sort"], - "float": ["compare", "differs", "exact", "in", "sort"], - "int": ["compare", "differs", "exact", "in", "sort"], - "bool": ["differs", "exact", "in", "sort"], - "date": ["compare", "contains", "differs", "exact", "in", "sort"], - "datetime": ["compare", "contains", "differs", "exact", "in", "sort"], - "json": ["contains", "exact", "in"], + "string": ["compare", "contains", "differs", "exact", "in", "sort", "groupby", "count"], + "float": [ + "compare", + "differs", + "exact", + "in", + "sort", + "groupby", + "count", + "avg", + "max", + "min", + "sum", + ], + "int": [ + "compare", + "differs", + "exact", + "in", + "sort", + "groupby", + "count", + "avg", + "max", + "min", + "sum", + ], + "bool": ["differs", "exact", "in", "sort", "groupby", "count"], + "date": ["compare", "contains", "differs", "exact", "in", "sort", "groupby", "count"], + "datetime": ["compare", "contains", "differs", "exact", "in", "sort", "groupby", "count"], + "json": ["contains", "differs", "exact", "in", "groupby", "count"], } MAP_TYPES = { @@ -22,6 +49,49 @@ "float": "number", } +OPERATORS_DESCRIPTIONS = { + "exact": { + "name": "{}__exact=value", + "description": "Exact match in column: {}", + }, + "differs": { + "name": "{}__differs=value", + "description": "Differs from in column: {}", + }, + "contains": { + "name": "{}__contains=value", + "description": "String contains in column: {}", + }, + "in": { + "name": "{}__in=value1,value2,...", + "description": "Value in list in column: {}", + }, + "groupby": { + "name": "{}__groupby", + "description": "Performs `group by values` operation in column: {}", + }, + "count": { + "name": "{}__count", + "description": "Performs `count values` operation in column: {}", + }, + "avg": { + "name": "{}__avg", + "description": "Performs `mean` operation in column: {}", + }, + "min": { + "name": "{}__min", + "description": "Performs `minimum` operation in column: {}", + }, + "max": { + "name": "{}__max", + "description": "Performs `maximum` operation in column: {}", + }, + "sum": { + "name": "{}__sum", + "description": "Performs `sum` operation in column: {}", + }, +} + async def get_app_version() -> str: """Parse pyproject.toml and return the version or an error.""" @@ -37,47 +107,90 @@ async def get_app_version() -> str: def build_sql_query_string(request_arg: list, page_size: int = None, offset: int = 0) -> str: sql_query = [] + aggregators = defaultdict(list) sorted = False for arg in request_arg: - argument, value = arg.split("=") - if "__" in argument: - *column_split, comparator = argument.split("__") - normalized_comparator = comparator.lower() - # handling headers with "__" and special characters - # we're escaping the " because they are the encapsulators of the label - column = '"{}"'.format("__".join(column_split).replace('"', '\\"')) - - if normalized_comparator == "sort": - if value == "asc": - sql_query.append(f"order={column}.asc,__id.asc") - elif value == "desc": - sql_query.append(f"order={column}.desc,__id.asc") - sorted = True - elif normalized_comparator == "exact": - sql_query.append(f"{column}=eq.{value}") - elif normalized_comparator == "differs": - sql_query.append(f"{column}=neq.{value}") - elif normalized_comparator == "contains": - sql_query.append(f"{column}=ilike.*{value}*") - elif normalized_comparator == "in": - sql_query.append(f"{column}=in.({value})") - elif normalized_comparator == "less": - sql_query.append(f"{column}=lte.{value}") - elif normalized_comparator == "greater": - sql_query.append(f"{column}=gte.{value}") - elif normalized_comparator == "strictly_less": - sql_query.append(f"{column}=lt.{value}") - elif normalized_comparator == "strictly_greater": - sql_query.append(f"{column}=gt.{value}") + _split = arg.split("=") + # filters are expected to have the syntax `__=` + if len(_split) == 2: + _filter, _sorted = add_filter(*_split) + if _filter: + sorted = sorted or _sorted + sql_query.append(_filter) + # aggregators are expected to have the syntax `__` + elif len(_split) == 1: + column, operator = add_aggregator(_split[0]) + if column: + aggregators[operator].append(column) + else: + raise ValueError(f"argument '{arg}' could not be parsed") + if aggregators: + agg_query = "select=" + for operator in aggregators: + if operator == "groupby": + agg_query += f"{','.join(aggregators[operator])}," + else: + for column in aggregators[operator]: + # aggregated columns are named `__` + # we pop the heading and trailing " that were added upstream + # and put them around the new column name + agg_query += f'"{column[1:-1]}__{operator}":{column}.{operator}(),' + # we pop the trailing comma (it's always there, by construction) + sql_query.append(agg_query[:-1]) if page_size: sql_query.append(f"limit={page_size}") if offset >= 1: sql_query.append(f"offset={offset}") - if not sorted: + if not sorted and not aggregators: sql_query.append("order=__id.asc") return "&".join(sql_query) +def get_column_and_operator(argument: str) -> tuple[str, str]: + *column_split, comparator = argument.split("__") + normalized_comparator = comparator.lower() + # handling headers with "__" and special characters + # we're escaping the " because they are the encapsulators of the label + column = '"{}"'.format("__".join(column_split).replace('"', '\\"')) + return column, normalized_comparator + + +def add_filter(argument: str, value: str) -> tuple[Optional[str], bool]: + if argument in ["page", "page_size"]: # processed differently + return None, False + if "__" in argument: + column, normalized_comparator = get_column_and_operator(argument) + if normalized_comparator == "sort": + q = f"order={column}.{value}" + return q, True + elif normalized_comparator == "exact": + return f"{column}=eq.{value}", False + elif normalized_comparator == "differs": + return f"{column}=neq.{value}", False + elif normalized_comparator == "contains": + return f"{column}=ilike.*{value}*", False + elif normalized_comparator == "in": + return f"{column}=in.({value})", False + elif normalized_comparator == "less": + return f"{column}=lte.{value}", False + elif normalized_comparator == "greater": + return f"{column}=gte.{value}", False + elif normalized_comparator == "strictly_less": + return f"{column}=lt.{value}", False + elif normalized_comparator == "strictly_greater": + return f"{column}=gt.{value}", False + raise ValueError(f"argument '{argument}={value}' could not be parsed") + + +def add_aggregator(argument: str) -> tuple[str, str]: + operator = None + if "__" in argument: + column, operator = get_column_and_operator(argument) + if operator in ["avg", "count", "max", "min", "sum", "groupby"]: + return column, operator + raise ValueError(f"argument '{argument}' could not be parsed") + + def process_total(res: Response) -> int: # the Content-Range looks like this: '0-49/21777' # see https://docs.postgrest.org/en/stable/references/api/pagination_count.html @@ -86,25 +199,25 @@ def process_total(res: Response) -> int: return int(str_total) -def external_url(url): +def external_url(url) -> str: return f"{config.SCHEME}://{config.SERVER_NAME}{url}" -def build_link_with_page(request: Request, query_string: str, page: int, page_size: int): +def build_link_with_page(request: Request, query_string: str, page: int, page_size: int) -> str: q = [string for string in query_string if not string.startswith("page")] q.extend([f"page={page}", f"page_size={page_size}"]) rebuilt_q = "&".join(q) return external_url(f"{request.path}?{rebuilt_q}") -def url_for(request: Request, route: str, *args, **kwargs): +def url_for(request: Request, route: str, *args, **kwargs) -> str: router = request.app.router if kwargs.pop("_external", None): return external_url(router[route].url_for(**kwargs)) return router[route].url_for(**kwargs) -def swagger_parameters(resource_columns): +def swagger_parameters(resource_columns: dict) -> list: parameters_list = [ { "name": "page", @@ -125,42 +238,19 @@ def swagger_parameters(resource_columns): # see metier_to_python here: https://github.com/datagouv/csv-detective/blob/master/csv_detective/explore_csv.py # see cast for db here: https://github.com/datagouv/hydra/blob/main/udata_hydra/analysis/csv.py for key, value in resource_columns.items(): - if "exact" in TYPE_POSSIBILITIES[value["python_type"]]: - parameters_list.extend( - [ - { - "name": f"{key}__exact=value", - "in": "query", - "description": f"Exact match in column: {key}", - "required": False, - "schema": {"type": "string"}, - }, - ] - ) - if "differs" in TYPE_POSSIBILITIES[value["python_type"]]: - parameters_list.extend( - [ - { - "name": f"{key}__differs=value", - "in": "query", - "description": f"Differs from in column: {key}", - "required": False, - "schema": {"type": "string"}, - }, - ] - ) - if "in" in TYPE_POSSIBILITIES[value["python_type"]]: - parameters_list.extend( - [ - { - "name": f"{key}__in=value1,value2,...", - "in": "query", - "description": f"Value in list in column: {key}", - "required": False, - "schema": {"type": "string"}, - }, - ] - ) + for op in OPERATORS_DESCRIPTIONS: + if op in TYPE_POSSIBILITIES[value["python_type"]]: + parameters_list.extend( + [ + { + "name": OPERATORS_DESCRIPTIONS[op]["name"].format(key), + "in": "query", + "description": OPERATORS_DESCRIPTIONS[op]["description"].format(key), + "required": False, + "schema": {"type": "string"}, + }, + ] + ) if "sort" in TYPE_POSSIBILITIES[value["python_type"]]: parameters_list.extend( [ @@ -180,18 +270,6 @@ def swagger_parameters(resource_columns): }, ] ) - if "contains" in TYPE_POSSIBILITIES[value["python_type"]]: - parameters_list.extend( - [ - { - "name": f"{key}__contains=value", - "in": "query", - "description": f"String contains in column: {key}", - "required": False, - "schema": {"type": "string"}, - }, - ] - ) if "compare" in TYPE_POSSIBILITIES[value["python_type"]]: parameters_list.extend( [ @@ -228,7 +306,7 @@ def swagger_parameters(resource_columns): return parameters_list -def swagger_component(resource_columns): +def swagger_component(resource_columns: dict) -> dict: resource_prop_dict = {} for key, value in resource_columns.items(): type = MAP_TYPES.get(value["python_type"], "string") @@ -275,7 +353,7 @@ def swagger_component(resource_columns): return component_dict -def build_swagger_file(resource_columns, rid): +def build_swagger_file(resource_columns: dict, rid: str) -> str: parameters_list = swagger_parameters(resource_columns) component_dict = swagger_component(resource_columns) swagger_dict = { diff --git a/docker-compose.yml b/docker-compose.yml index ec759e1..a10a21c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3' services: postgrest: - image: "postgrest/postgrest:v10.1.2" + image: "postgrest/postgrest:v12.2.3" environment: # connect to hydra database-csv - PGRST_DB_URI=postgres://csvapi:csvapi@postgres-test:5432/csvapi @@ -10,6 +10,7 @@ services: - PGRST_SERVER_PORT=8080 - PGRST_DB_ANON_ROLE=csvapi - PGRST_DB_SCHEMA=csvapi + - PGRST_DB_AGGREGATES_ENABLED=true ports: - 8080:8080 postgres-test: diff --git a/tests/test_api.py b/tests/test_api.py index 9f8f912..22518aa 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -137,8 +137,8 @@ async def test_api_resource_data_with_args_error(client, rmock): "errors": [ { "code": None, - "detail": "Malformed query", "title": "Invalid query string", + "detail": "Malformed query: argument 'TESTCOLUM_NAME__EXACT=BIDULEpage=1' could not be parsed", } ] } @@ -211,16 +211,15 @@ async def test_api_with_unsupported_args(client, rmock): headers={"Content-Range": "0-10/10"}, ) res = await client.get(f"/api/resources/{RESOURCE_ID}/data/?limit=1&select=numnum") - assert res.status == 200 + assert res.status == 400 body = { - "data": {"such": "data"}, - "links": { - "next": None, - "prev": None, - "profile": external_url("/api/resources/aaaaaaaa-1111-bbbb-2222-cccccccccccc/profile/"), - "swagger": external_url("/api/resources/aaaaaaaa-1111-bbbb-2222-cccccccccccc/swagger/"), - }, - "meta": {"page": 1, "page_size": 20, "total": 10}, + "errors": [ + { + "code": None, + "title": "Invalid query string", + "detail": "Malformed query: argument 'limit=1' could not be parsed", + }, + ], } assert await res.json() == body diff --git a/tests/test_query.py b/tests/test_query.py index 0e1775d..f8bab6f 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,3 +1,5 @@ +import pytest + from api_tabular.utils import build_sql_query_string @@ -16,13 +18,13 @@ def test_query_build_offset(): def test_query_build_sort_asc(): query_str = ["column_name__sort=asc"] result = build_sql_query_string(query_str, 50) - assert result == 'order="column_name".asc,__id.asc&limit=50' + assert result == 'order="column_name".asc&limit=50' def test_query_build_sort_asc_without_limit(): query_str = ["column_name__sort=asc"] result = build_sql_query_string(query_str) - assert result == 'order="column_name".asc,__id.asc' + assert result == 'order="column_name".asc' def test_query_build_sort_asc_with_page_in_query(): @@ -32,13 +34,13 @@ def test_query_build_sort_asc_with_page_in_query(): "page_size=20", ] result = build_sql_query_string(query_str) - assert result == 'order="column_name".asc,__id.asc' + assert result == 'order="column_name".asc' def test_query_build_sort_desc(): query_str = ["column_name__sort=desc"] result = build_sql_query_string(query_str, 50) - assert result == 'order="column_name".desc,__id.asc&limit=50' + assert result == 'order="column_name".desc&limit=50' def test_query_build_exact(): @@ -92,5 +94,27 @@ def test_query_build_multiple(): def test_query_build_multiple_with_unknown(): query_str = ["select=numnum"] - result = build_sql_query_string(query_str, 50) - assert result == "limit=50&order=__id.asc" + with pytest.raises(ValueError): + build_sql_query_string(query_str, 50) + + +def test_query_aggregators(): + query_str = [ + "column_name__groupby", + "column_name__min", + "column_name__avg", + ] + results = build_sql_query_string(query_str, 50).split("&") + assert "limit=50" in results + assert "order=__id.asc" not in results # no sort if aggregators + select = [_ for _ in results if "select" in _] + assert len(select) == 1 + params = select[0].replace("select=", "").split(",") + assert all( + _ in params + for _ in [ + '"column_name"', + '"column_name__min":"column_name".min()', + '"column_name__avg":"column_name".avg()', + ] + ) diff --git a/tests/test_swagger.py b/tests/test_swagger.py index 3291b74..a383aa0 100644 --- a/tests/test_swagger.py +++ b/tests/test_swagger.py @@ -49,5 +49,8 @@ async def test_swagger_content(client, rmock): elif p == "in": value = "value1,value2,..." for _p in _params: - if f"{c}__{_p}={value}" not in params: + if ( + f"{c}__{_p}={value}" not in params # filters + and f"{c}__{_p}" not in params # aggregators + ): raise ValueError(f"{c}__{_p} is missing in {output} output")