diff --git a/docs/source/filter.md b/docs/source/filter.md index 5a53ffdc9c..05c4b5c818 100644 --- a/docs/source/filter.md +++ b/docs/source/filter.md @@ -21,27 +21,28 @@ The `/filter` endpoint accepts the following query parameters: - `length`: the length of the slice, for example `10` (maximum: `100`) The `where` parameter must be expressed as a comparison predicate, which can be: -- a simple predicate composed of a column name, a comparison operator, and a value +- a simple predicate composed of a column name in double quotes, a comparison operator, and a value - the comparison operators are: `=`, `<>`, `>`, `>=`, `<`, `<=` - a composite predicate composed of two or more simple predicates (optionally grouped with parentheses to indicate the order of evaluation), combined with logical operators - the logical operators are: `AND`, `OR`, `NOT` For example, the following `where` parameter value ``` -where=age>30 AND (name='Simone' OR children=0) +where="age">30 AND ("name"='Simone' OR "children"=0) ``` will filter the data to select only those rows where the float "age" column is larger than 30 and, either the string "name" column is equal to 'Simone' or the integer "children" column is equal to 0. - Note that, following SQL syntax, string values in comparison predicates must be enclosed in single quotes, - for example: 'Scarlett'. + Note that, following SQL syntax, in comparison predicates, + column names should be enclosed in double quotes ("name"), + and string values must be enclosed in single quotes ('Simone'). Additionally, if the string value contains a single quote, it must be escaped with another single quote, for example: 'O''Hara'. -The `orderby` parameter must contain the column name whose values will be sorted (in ascending order by default). -To sort the rows in descending order, use the DESC keyword, like `orderby=age DESC`. +The `orderby` parameter must contain the column name (in double quotes) whose values will be sorted (in ascending order by default). +To sort the rows in descending order, use the DESC keyword, like `orderby="age" DESC`. For example, let's filter those rows with no_answer=false in the `train` split of the `SelfRC` configuration of the `ibm/duorc` dataset restricting the results to the slice 150-151: @@ -50,7 +51,7 @@ For example, let's filter those rows with no_answer=false in the `train` split o ```python import requests headers = {"Authorization": f"Bearer {API_TOKEN}"} -API_URL = "https://datasets-server.huggingface.co/filter?dataset=ibm/duorc&config=SelfRC&split=train&where=no_answer=true&offset=150&length=2" +API_URL = "https://datasets-server.huggingface.co/filter?dataset=ibm/duorc&config=SelfRC&split=train&where="no_answer"=true&offset=150&length=2" def query(): response = requests.get(API_URL, headers=headers) return response.json() @@ -62,7 +63,7 @@ data = query() import fetch from "node-fetch"; async function query(data) { const response = await fetch( - "https://datasets-server.huggingface.co/filter?dataset=ibm/duorc&config=SelfRC&split=train&where=no_answer=true&offset=150&length=2", + "https://datasets-server.huggingface.co/filter?dataset=ibm/duorc&config=SelfRC&split=train&where="no_answer"=true&offset=150&length=2", { headers: { Authorization: `Bearer ${API_TOKEN}` }, method: "GET" @@ -78,7 +79,7 @@ query().then((response) => { ```curl -curl https://datasets-server.huggingface.co/filter?dataset=ibm/duorc&config=SelfRC&split=train&where=no_answer=true&offset=150&length=2 \ +curl https://datasets-server.huggingface.co/filter?dataset=ibm/duorc&config=SelfRC&split=train&where="no_answer"=true&offset=150&length=2 \ -X GET \ -H "Authorization: Bearer ${API_TOKEN}" ``` diff --git a/e2e/tests/test_53_filter.py b/e2e/tests/test_53_filter.py index 565a398aab..d553cea2b7 100644 --- a/e2e/tests/test_53_filter.py +++ b/e2e/tests/test_53_filter.py @@ -12,8 +12,8 @@ def test_filter_endpoint(normal_user_public_dataset: str) -> None: config, split = get_default_config_split() offset = 1 length = 2 - where = "col_4='B'" - orderby = "col_2 DESC" + where = "\"col_4\"='B'" + orderby = '"col_2" DESC' filter_response = poll_until_ready_and_assert( relative_url=( f"/filter?dataset={dataset}&config={config}&split={split}&offset={offset}&length={length}&where={where}&orderby={orderby}" diff --git a/services/search/tests/routes/test_filter.py b/services/search/tests/routes/test_filter.py index ae3391a350..09e268ce7c 100644 --- a/services/search/tests/routes/test_filter.py +++ b/services/search/tests/routes/test_filter.py @@ -4,6 +4,7 @@ import os from collections.abc import Generator from pathlib import Path +from typing import Union import duckdb import pyarrow as pa @@ -62,7 +63,7 @@ def index_file_location(ds: Dataset) -> Generator[str, None, None]: @pytest.mark.parametrize( - "parameter_name, parameter_value", [("where", "col='A'"), ("orderby", "A"), ("orderby", "A DESC")] + "parameter_name, parameter_value", [("where", "\"col\"='A'"), ("orderby", '"A"'), ("orderby", '"A" DESC')] ) def test_validate_query_parameter(parameter_name: str, parameter_value: str) -> None: validate_query_parameter(parameter_value, parameter_name) @@ -71,15 +72,15 @@ def test_validate_query_parameter(parameter_name: str, parameter_value: str) -> @pytest.mark.parametrize("sql_injection", ["; SELECT * from data", " /*", "--"]) @pytest.mark.parametrize( "parameter_name, parameter_value", - [("where", "col='A'"), ("orderby", "A"), ("orderby", "A DESC")], + [("where", "\"col\"='A'"), ("orderby", '"A"'), ("orderby", '"A" DESC')], ) def test_validate_query_parameter_raises(parameter_name: str, parameter_value: str, sql_injection: str) -> None: with pytest.raises(InvalidParameterError): validate_query_parameter(parameter_value + sql_injection, parameter_name) -@pytest.mark.parametrize("orderby", ["", "age", "age DESC"]) -@pytest.mark.parametrize("where", ["", "gender='female'"]) +@pytest.mark.parametrize("orderby", ["", '"age"', '"age" DESC']) +@pytest.mark.parametrize("where", ["", "\"gender\"='female'"]) @pytest.mark.parametrize("columns", [["name", "age"], ["name"]]) def test_execute_filter_query(columns: list[str], where: str, orderby: str, index_file_location: str) -> None: # in split-duckdb-index we always add the ROW_IDX_COLUMN column @@ -108,15 +109,15 @@ def test_execute_filter_query(columns: list[str], where: str, orderby: str, inde expected_pa_table = expected_pa_table.filter(pc.field("gender") == "female") if orderby: if orderby.endswith(" DESC"): - sorting = [(orderby.removesuffix(" DESC"), "descending")] - expected_pa_table = expected_pa_table.sort_by(sorting) + sorting: Union[str, list[tuple[str, str]]] = [(orderby.removesuffix(" DESC").strip('"'), "descending")] else: - expected_pa_table = expected_pa_table.sort_by(orderby) + sorting = orderby.strip('"') + expected_pa_table = expected_pa_table.sort_by(sorting) expected_pa_table = expected_pa_table.slice(offset, limit).select(columns) assert pa_table == expected_pa_table -@pytest.mark.parametrize("where", ["non-existing-column=30", "name=30", "name>30"]) +@pytest.mark.parametrize("where", ['"non-existing-column"=30', '"name"=30', '"name">30']) def test_execute_filter_query_raises(where: str, index_file_location: str) -> None: columns, limit, offset = ["name", "gender", "age"], 100, 0 with pytest.raises(InvalidParameterError):