Skip to content

Commit

Permalink
Limit aggregators (#36)
Browse files Browse the repository at this point in the history
* feat: add config to limit aggregators

* fix: end of file line

* feat: better logs if bad query

* feat: reject aggregation requests for unspecific resources

* docs: add warning for aggregation requests

* feat: adapt tests

* feat: update test

* fix: fix metrics

* fix: lint

* docs: update changelog

* docs: indicate what's expected in new config

* fix: better swagger test, required pytest-mock
  • Loading branch information
Pierlou authored Nov 28, 2024
1 parent 1360f83 commit d968058
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 35 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Current (in progress)

- Handle queries with aggregators [#35](https://github.com/datagouv/api-tabular/pull/35)
- Restrain aggregators to list of specific resources [#36](https://github.com/datagouv/api-tabular/pull/36)

## 0.2.1 (2024-11-21)

Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ column_name__max
# sum
column_name__sum
```

> /!\ WARNING: aggregation requests are only available for resources that are listed in the `ALLOW_AGGREGATION` list of the config file.
> NB : passing an aggregation operator (`count`, `avg`, `min`, `max`, `sum`) returns a column that is named `<column_name>__<operator>` (for instance: `?birth__groupby&score__sum` will return a list of dicts with the keys `birth` and `score__sum`).
For instance:
Expand Down
9 changes: 6 additions & 3 deletions api_tabular/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,11 @@ async def resource_data(request):
offset = 0

try:
sql_query = build_sql_query_string(query_string, page_size, offset)
sql_query = build_sql_query_string(query_string, resource_id, page_size, offset)
except ValueError as e:
raise QueryException(400, None, "Invalid query string", f"Malformed query: {e}")

except PermissionError as e:
raise QueryException(403, None, "Unauthorized parameters", str(e))
resource = await get_resource(request.app["csession"], resource_id, ["parsing_table"])
response, total = await get_resource_data(request.app["csession"], resource, sql_query)

Expand All @@ -123,9 +124,11 @@ async def resource_data_csv(request):
query_string = request.query_string.split("&") if request.query_string else []

try:
sql_query = build_sql_query_string(query_string)
sql_query = build_sql_query_string(query_string, resource_id)
except ValueError:
raise QueryException(400, None, "Invalid query string", "Malformed query")
except PermissionError as e:
raise QueryException(403, None, "Unauthorized parameters", str(e))

resource = await get_resource(request.app["csession"], resource_id, ["parsing_table"])

Expand Down
1 change: 1 addition & 0 deletions api_tabular/config_default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ PAGE_SIZE_DEFAULT = 20
PAGE_SIZE_MAX = 50
BATCH_SIZE = 50000
DOC_PATH = "/api/doc"
ALLOW_AGGREGATION = [] # list of resource_ids
10 changes: 5 additions & 5 deletions api_tabular/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ async def metrics_data(request):
else:
offset = 0
try:
sql_query = build_sql_query_string(query_string, page_size, offset)
except ValueError:
raise QueryException(400, None, "Invalid query string", "Malformed query")
sql_query = build_sql_query_string(query_string, page_size=page_size, offset=offset)
except ValueError as e:
raise QueryException(400, None, "Invalid query string", f"Malformed query: {e}")

response, total = await get_object_data(request.app["csession"], model, sql_query)

Expand All @@ -104,8 +104,8 @@ async def metrics_data_csv(request):

try:
sql_query = build_sql_query_string(query_string)
except ValueError:
raise QueryException(400, None, "Invalid query string", "Malformed query")
except ValueError as e:
raise QueryException(400, None, "Invalid query string", f"Malformed query: {e}")

response_headers = {
"Content-Disposition": f'attachment; filename="{model}.csv"',
Expand Down
30 changes: 27 additions & 3 deletions api_tabular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,40 @@
"groupby": {
"name": "{}__groupby",
"description": "Performs `group by values` operation in column: {}",
"is_aggregator": True,
},
"count": {
"name": "{}__count",
"description": "Performs `count values` operation in column: {}",
"is_aggregator": True,
},
"avg": {
"name": "{}__avg",
"description": "Performs `mean` operation in column: {}",
"is_aggregator": True,
},
"min": {
"name": "{}__min",
"description": "Performs `minimum` operation in column: {}",
"is_aggregator": True,
},
"max": {
"name": "{}__max",
"description": "Performs `maximum` operation in column: {}",
"is_aggregator": True,
},
"sum": {
"name": "{}__sum",
"description": "Performs `sum` operation in column: {}",
"is_aggregator": True,
},
}


def is_aggregation_allowed(resource_id: str):
return resource_id in config.ALLOW_AGGREGATION


async def get_app_version() -> str:
"""Parse pyproject.toml and return the version or an error."""
try:
Expand All @@ -105,7 +115,12 @@ async def get_app_version() -> str:
return f"unknown ({str(e)})"


def build_sql_query_string(request_arg: list, page_size: int = None, offset: int = 0) -> str:
def build_sql_query_string(
request_arg: list,
resource_id: Optional[str] = None,
page_size: int = None,
offset: int = 0,
) -> str:
sql_query = []
aggregators = defaultdict(list)
sorted = False
Expand All @@ -125,6 +140,11 @@ def build_sql_query_string(request_arg: list, page_size: int = None, offset: int
else:
raise ValueError(f"argument '{arg}' could not be parsed")
if aggregators:
if resource_id and not is_aggregation_allowed(resource_id):
raise PermissionError(
f"Aggregation parameters `{'`, `'.join(aggregators.keys())}` "
f"are not allowed for resource '{resource_id}'"
)
agg_query = "select="
for operator in aggregators:
if operator == "groupby":
Expand Down Expand Up @@ -217,7 +237,7 @@ def url_for(request: Request, route: str, *args, **kwargs) -> str:
return router[route].url_for(**kwargs)


def swagger_parameters(resource_columns: dict) -> list:
def swagger_parameters(resource_columns: dict, resource_id: str) -> list:
parameters_list = [
{
"name": "page",
Expand All @@ -239,6 +259,10 @@ def swagger_parameters(resource_columns: dict) -> list:
# see cast for db here: https://github.com/datagouv/hydra/blob/main/udata_hydra/analysis/csv.py
for key, value in resource_columns.items():
for op in OPERATORS_DESCRIPTIONS:
if not is_aggregation_allowed(resource_id) and OPERATORS_DESCRIPTIONS[op].get(
"is_aggregator"
):
continue
if op in TYPE_POSSIBILITIES[value["python_type"]]:
parameters_list.extend(
[
Expand Down Expand Up @@ -354,7 +378,7 @@ def swagger_component(resource_columns: dict) -> dict:


def build_swagger_file(resource_columns: dict, rid: str) -> str:
parameters_list = swagger_parameters(resource_columns)
parameters_list = swagger_parameters(resource_columns, rid)
component_dict = swagger_component(resource_columns)
swagger_dict = {
"openapi": "3.0.3",
Expand Down
19 changes: 18 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ aioresponses = "^0.7.4"
bumpx = "^0.3.10"
pytest = "^7.2.1"
pytest-asyncio = "^0.20.3"
pytest-mock = "^3.14.0"
ruff = "^0.6.5"

[tool.ruff]
Expand Down
48 changes: 32 additions & 16 deletions tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import pytest

from api_tabular import config
from api_tabular.utils import build_sql_query_string

from .conftest import RESOURCE_ID


def test_query_build_limit():
query_str = []
result = build_sql_query_string(query_str, 12)
result = build_sql_query_string(query_str, page_size=12)
assert result == "limit=12&order=__id.asc"


def test_query_build_offset():
query_str = []
result = build_sql_query_string(query_str, 12, 12)
result = build_sql_query_string(query_str, page_size=12, offset=12)
assert result == "limit=12&offset=12&order=__id.asc"


def test_query_build_sort_asc():
query_str = ["column_name__sort=asc"]
result = build_sql_query_string(query_str, 50)
result = build_sql_query_string(query_str, page_size=50)
assert result == 'order="column_name".asc&limit=50'


Expand All @@ -39,43 +42,43 @@ def test_query_build_sort_asc_with_page_in_query():

def test_query_build_sort_desc():
query_str = ["column_name__sort=desc"]
result = build_sql_query_string(query_str, 50)
result = build_sql_query_string(query_str, page_size=50)
assert result == 'order="column_name".desc&limit=50'


def test_query_build_exact():
query_str = ["column_name__exact=BIDULE"]
result = build_sql_query_string(query_str, 50)
result = build_sql_query_string(query_str, page_size=50)
assert result == '"column_name"=eq.BIDULE&limit=50&order=__id.asc'


def test_query_build_differs():
query_str = ["column_name__differs=BIDULE"]
result = build_sql_query_string(query_str, 50)
result = build_sql_query_string(query_str, page_size=50)
assert result == '"column_name"=neq.BIDULE&limit=50&order=__id.asc'


def test_query_build_contains():
query_str = ["column_name__contains=BIDULE"]
result = build_sql_query_string(query_str, 50)
result = build_sql_query_string(query_str, page_size=50)
assert result == '"column_name"=ilike.*BIDULE*&limit=50&order=__id.asc'


def test_query_build_in():
query_str = ["column_name__in=value1,value2,value3"]
result = build_sql_query_string(query_str, 50)
result = build_sql_query_string(query_str, page_size=50)
assert result == '"column_name"=in.(value1,value2,value3)&limit=50&order=__id.asc'


def test_query_build_less():
query_str = ["column_name__less=12"]
result = build_sql_query_string(query_str, 50, 12)
result = build_sql_query_string(query_str, page_size=50, offset=12)
assert result == '"column_name"=lte.12&limit=50&offset=12&order=__id.asc'


def test_query_build_greater():
query_str = ["column_name__greater=12"]
result = build_sql_query_string(query_str, 50)
result = build_sql_query_string(query_str, page_size=50)
assert result == '"column_name"=gte.12&limit=50&order=__id.asc'


Expand All @@ -85,7 +88,7 @@ def test_query_build_multiple():
"column_name__greater=12",
"column_name__exact=BIDULE",
]
result = build_sql_query_string(query_str, 50)
result = build_sql_query_string(query_str, page_size=50)
assert (
result
== '"column_name"=eq.BIDULE&"column_name"=gte.12&"column_name"=eq.BIDULE&limit=50&order=__id.asc'
Expand All @@ -95,16 +98,29 @@ def test_query_build_multiple():
def test_query_build_multiple_with_unknown():
query_str = ["select=numnum"]
with pytest.raises(ValueError):
build_sql_query_string(query_str, 50)


def test_query_aggregators():
build_sql_query_string(query_str, page_size=50)


@pytest.mark.parametrize(
"allow_aggregation",
[
False,
True,
],
)
def test_query_aggregators(allow_aggregation, mocker):
if allow_aggregation:
mocker.patch("api_tabular.config.ALLOW_AGGREGATION", [RESOURCE_ID])
query_str = [
"column_name__groupby",
"column_name__min",
"column_name__avg",
]
results = build_sql_query_string(query_str, 50).split("&")
if not allow_aggregation:
with pytest.raises(PermissionError):
build_sql_query_string(query_str, resource_id=RESOURCE_ID, page_size=50)
return
results = build_sql_query_string(query_str, resource_id=RESOURCE_ID, page_size=50).split("&")
assert "limit=50" in results
assert "order=__id.asc" not in results # no sort if aggregators
select = [_ for _ in results if "select" in _]
Expand Down
36 changes: 29 additions & 7 deletions tests/test_swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pytest
import yaml

from api_tabular.utils import TYPE_POSSIBILITIES
from api_tabular import config
from api_tabular.utils import OPERATORS_DESCRIPTIONS, TYPE_POSSIBILITIES

from .conftest import RESOURCE_ID, TABLES_INDEX_PATTERN

Expand All @@ -17,7 +18,16 @@ async def test_swagger_endpoint(client, rmock):
assert res.status == 200


async def test_swagger_content(client, rmock):
@pytest.mark.parametrize(
"allow_aggregation",
[
False,
True,
],
)
async def test_swagger_content(client, rmock, allow_aggregation, mocker):
if allow_aggregation:
mocker.patch("api_tabular.config.ALLOW_AGGREGATION", [RESOURCE_ID])
with open("db/sample.csv", newline="") as csvfile:
spamreader = csv.reader(csvfile, delimiter=",", quotechar='"')
# getting the csv-detective output in the test file
Expand Down Expand Up @@ -49,8 +59,20 @@ async def test_swagger_content(client, rmock):
elif p == "in":
value = "value1,value2,..."
for _p in _params:
if (
f"{c}__{_p}={value}" not in params # filters
and f"{c}__{_p}" not in params # aggregators
):
raise ValueError(f"{c}__{_p} is missing in {output} output")
if allow_aggregation:
if (
f"{c}__{_p}={value}" not in params # filters
and f"{c}__{_p}" not in params # aggregators
):
raise ValueError(f"{c}__{_p} is missing in {output} output")
else:
if (
not OPERATORS_DESCRIPTIONS.get(_p, {}).get("is_aggregator")
and f"{c}__{_p}={value}" not in params # filters are in
):
raise ValueError(f"{c}__{_p} is missing in {output} output")
if (
OPERATORS_DESCRIPTIONS.get(_p, {}).get("is_aggregator")
and f"{c}__{_p}" in params # aggregators are out
):
raise ValueError(f"{c}__{_p} is in {output} output but should not")

0 comments on commit d968058

Please sign in to comment.