Skip to content

Commit

Permalink
Allow sorting by column (#2691)
Browse files Browse the repository at this point in the history
* Test orderby in filter

* Add orderby parameter to filter endpoint

* Rename validate_where_parameter

* Validate orderby parameter

* Test validate orderby parameter

* Make where parameter optional

* Test execute_filter_query without where param

* Add e2e test of filter endpoint parameter orderby

* Add e2e test of filter endpoint without where parameter

* Test e2e filter endpoint with orderby parameter

* Add orderby parameter to docs

* Update openapi.json
  • Loading branch information
albertvillanova authored May 3, 2024
1 parent 3ace461 commit 906fec5
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 63 deletions.
4 changes: 4 additions & 0 deletions docs/source/filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The `/filter` endpoint accepts the following query parameters:
- `config`: the configuration name, for example `cola`
- `split`: the split name, for example `train`
- `where`: the filter condition
- `orderby`: the order-by clause
- `offset`: the offset of the slice, for example `150`
- `length`: the length of the slice, for example `10` (maximum: `100`)

Expand All @@ -39,6 +40,9 @@ either the string "name" column is equal to 'Simone' or the integer "children" c
for example: <code>'O''Hara'</code>.
</Tip>

The `orderby` parameter must contain the column name whose values will be sorted (in ascending order by default).
To sort the rows in descending order, use the DESC keyword, like `orderby=age DESC`.

For example, let's filter those rows with no_answer=false in the `train` split of the `SelfRC` configuration of the `ibm/duorc` dataset restricting the results to the slice 150-151:

<inferencesnippet>
Expand Down
24 changes: 9 additions & 15 deletions docs/source/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -3878,12 +3878,6 @@
"error": "Parameter 'split' is required"
}
},
"missing-where": {
"summary": "The where parameter is missing.",
"value": {
"error": "Parameter 'where' is required"
}
},
"empty-dataset": {
"summary": "The dataset parameter is empty.",
"value": {
Expand All @@ -3902,12 +3896,6 @@
"error": "Parameter 'split' is required"
}
},
"empty-where": {
"summary": "The where parameter is empty.",
"value": {
"error": "Parameter 'where' is required"
}
},
"non-integer-offset": {
"summary": "The offset must be integer.",
"value": {
Expand Down Expand Up @@ -3944,10 +3932,16 @@
"error": "Parameter 'where' contains invalid symbols"
}
},
"invalid-where": {
"summary": "The where parameter is invalid.",
"orderby-with-invalid-symbols": {
"summary": "The orderby parameter contains invalid symbols.",
"value": {
"error": "Parameter 'orderby' contains invalid symbols"
}
},
"invalid-parameter": {
"summary": "A query parameter is invalid.",
"value": {
"error": "Parameter 'where' is invalid"
"error": "A query parameter is invalid"
}
}
}
Expand Down
24 changes: 12 additions & 12 deletions docs/source/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ In this quickstart, you'll learn how to use the dataset viewer's REST API to:

Each feature is served through an endpoint summarized in the table below:

| Endpoint | Method | Description | Query parameters |
| --------------------------- | ------ | ------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [/is-valid](./valid) | GET | Check whether a specific dataset is valid. | `dataset`: name of the dataset |
| [/splits](./splits) | GET | Get the list of configurations and splits of a dataset. | `dataset`: name of the dataset |
| [/first-rows](./first_rows) | GET | Get the first rows of a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split |
| [/rows](./rows) | GET | Get a slice of rows of a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split<br>- `offset`: offset of the slice<br>- `length`: length of the slice (maximum 100) |
| [/search](./search) | GET | Search text in a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split<br>- `query`: text to search for<br> |
| [/filter](./filter) | GET | Filter rows in a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split<br>- `where`: filter query<br>- `offset`: offset of the slice<br>- `length`: length of the slice (maximum 100) |
| [/parquet](./parquet) | GET | Get the list of parquet files of a dataset. | `dataset`: name of the dataset |
| [/size](./size) | GET | Get the size of a dataset. | `dataset`: name of the dataset |
| [/statistics](./statistics) | GET | Get statistics about a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split |
| [/croissant](./croissant) | GET | Get Croissant metadata about a dataset. | - `dataset`: name of the dataset |
| Endpoint | Method | Description | Query parameters |
|-----------------------------|--------|---------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [/is-valid](./valid) | GET | Check whether a specific dataset is valid. | `dataset`: name of the dataset |
| [/splits](./splits) | GET | Get the list of configurations and splits of a dataset. | `dataset`: name of the dataset |
| [/first-rows](./first_rows) | GET | Get the first rows of a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split |
| [/rows](./rows) | GET | Get a slice of rows of a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split<br>- `offset`: offset of the slice<br>- `length`: length of the slice (maximum 100) |
| [/search](./search) | GET | Search text in a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split<br>- `query`: text to search for<br> |
| [/filter](./filter) | GET | Filter rows in a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split<br>- `where`: filter query<br>- `orderby`: order-by clause<br>- `offset`: offset of the slice<br>- `length`: length of the slice (maximum 100) |
| [/parquet](./parquet) | GET | Get the list of parquet files of a dataset. | `dataset`: name of the dataset |
| [/size](./size) | GET | Get the size of a dataset. | `dataset`: name of the dataset |
| [/statistics](./statistics) | GET | Get statistics about a dataset split. | - `dataset`: name of the dataset<br>- `config`: name of the config<br>- `split`: name of the split |
| [/croissant](./croissant) | GET | Get Croissant metadata about a dataset. | - `dataset`: name of the dataset |

There is no installation or setup required to use the dataset viewer API.

Expand Down
47 changes: 40 additions & 7 deletions e2e/tests/test_53_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ def test_filter_endpoint(normal_user_public_dataset: str) -> None:
offset = 1
length = 2
where = "col_4='B'"
orderby = "col_2 DESC"
filter_response = poll_until_ready_and_assert(
relative_url=(
f"/filter?dataset={dataset}&config={config}&split={split}&offset={offset}&length={length}&where={where}"
f"/filter?dataset={dataset}&config={config}&split={split}&offset={offset}&length={length}&where={where}&orderby={orderby}"
),
check_x_revision=True,
dataset=dataset,
Expand Down Expand Up @@ -44,13 +45,13 @@ def test_filter_endpoint(normal_user_public_dataset: str) -> None:
"truncated_cells": [],
}, rows[0]
assert rows[1] == {
"row_idx": 3,
"row_idx": 0,
"row": {
"col_1": "The wingman spots the pirateship coming at him and warns the Dark Lord",
"col_2": 3,
"col_3": 3.0,
"col_1": "There goes another one.",
"col_2": 0,
"col_3": 0.0,
"col_4": "B",
"col_5": None,
"col_5": True,
},
"truncated_cells": [],
}, rows[1]
Expand All @@ -66,6 +67,7 @@ def test_filter_endpoint(normal_user_public_dataset: str) -> None:
@pytest.mark.parametrize(
"where,expected_num_rows",
[
("", 4),
("col_2=3", 1),
("col_2<3", 3),
("col_2>3", 0),
Expand All @@ -79,8 +81,11 @@ def test_filter_endpoint(normal_user_public_dataset: str) -> None:
def test_filter_endpoint_parameter_where(where: str, expected_num_rows: int, normal_user_public_dataset: str) -> None:
dataset = normal_user_public_dataset
config, split = get_default_config_split()
relative_url = f"/filter?dataset={dataset}&config={config}&split={split}"
if where:
relative_url += f"&where={where}"
response = poll_until_ready_and_assert(
relative_url=f"/filter?dataset={dataset}&config={config}&split={split}&where={where}",
relative_url=relative_url,
check_x_revision=True,
dataset=dataset,
)
Expand All @@ -89,6 +94,34 @@ def test_filter_endpoint_parameter_where(where: str, expected_num_rows: int, nor
assert len(content["rows"]) == expected_num_rows


@pytest.mark.parametrize(
"orderby, expected_first_row_idx",
[
("", 1),
("col_4", 2),
("col_3 DESC", 3),
],
)
def test_filter_endpoint_parameter_orderby(
orderby: str, expected_first_row_idx: int, normal_user_public_dataset: str
) -> None:
dataset = normal_user_public_dataset
config, split = get_default_config_split()
where = "col_2>0"
relative_url = f"/filter?dataset={dataset}&config={config}&split={split}&where={where}"
if orderby:
relative_url += f"&orderby={orderby}"
response = poll_until_ready_and_assert(
relative_url=relative_url,
check_x_revision=True,
dataset=dataset,
)
content = response.json()
assert "rows" in content, response
rows = content["rows"]
assert rows[0]["row_idx"] == expected_first_row_idx, rows[0]


def test_filter_images_endpoint(normal_user_images_public_dataset: str) -> None:
dataset = normal_user_images_public_dataset
config, split = get_default_config_split()
Expand Down
32 changes: 20 additions & 12 deletions services/search/src/search/routes/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@
FILTER_QUERY = """\
SELECT {columns}
FROM data
WHERE {where}
{where}
{orderby}
LIMIT {limit}
OFFSET {offset}"""

FILTER_COUNT_QUERY = """\
SELECT COUNT(*)
FROM data
WHERE {where}"""
{where}"""

SQL_INVALID_SYMBOLS = "|".join([";", "--", r"/\*", r"\*/"])
SQL_INVALID_SYMBOLS_PATTERN = re.compile(rf"(?:{SQL_INVALID_SYMBOLS})", flags=re.IGNORECASE)
Expand Down Expand Up @@ -83,11 +84,15 @@ async def filter_endpoint(request: Request) -> Response:
dataset = get_request_parameter(request, "dataset", required=True)
config = get_request_parameter(request, "config", required=True)
split = get_request_parameter(request, "split", required=True)
where = get_request_parameter(request, "where", required=True)
validate_where_parameter(where)
where = get_request_parameter(request, "where")
validate_query_parameter(where, "where")
orderby = get_request_parameter(request, "orderby")
validate_query_parameter(orderby, "orderby")
offset = get_request_parameter_offset(request)
length = get_request_parameter_length(request)
logger.info(f"/filter, {dataset=}, {config=}, {split=}, {where=}, {offset=}, {length=}")
logger.info(
f"/filter, {dataset=}, {config=}, {split=}, {where=}, {orderby=}, {offset=}, {length=}"
)
with StepProfiler(method="filter_endpoint", step="check authentication"):
# If auth_check fails, it will raise an exception that will be caught below
await auth_check(
Expand Down Expand Up @@ -154,6 +159,7 @@ async def filter_endpoint(request: Request) -> Response:
index_file_location,
supported_columns,
where,
orderby,
length,
offset,
extensions_directory,
Expand Down Expand Up @@ -195,26 +201,28 @@ def execute_filter_query(
index_file_location: str,
columns: list[str],
where: str,
orderby: str,
limit: int,
offset: int,
extensions_directory: Optional[str] = None,
) -> tuple[int, pa.Table]:
with duckdb_connect(extensions_directory=extensions_directory, database=index_file_location) as con:
filter_query = FILTER_QUERY.format(
columns=",".join([f'"{column}"' for column in columns]),
where=where,
where=f"WHERE {where}" if where else "",
orderby=f"ORDER BY {orderby}" if orderby else "",
limit=limit,
offset=offset,
)
filter_count_query = FILTER_COUNT_QUERY.format(where=where)
filter_count_query = FILTER_COUNT_QUERY.format(where=f"WHERE {where}" if where else "")
try:
pa_table = con.sql(filter_query).arrow()
num_rows_total = con.sql(filter_count_query).fetchall()[0][0]
except duckdb.Error:
raise InvalidParameterError(message="Parameter 'where' is invalid")
except duckdb.Error as err:
raise InvalidParameterError(message="A query parameter is invalid") from err
return num_rows_total, pa_table


def validate_where_parameter(where: str) -> None:
if SQL_INVALID_SYMBOLS_PATTERN.search(where):
raise InvalidParameterError(message="Parameter 'where' contains invalid symbols")
def validate_query_parameter(parameter_value: str, parameter_name: str) -> None:
if SQL_INVALID_SYMBOLS_PATTERN.search(parameter_value):
raise InvalidParameterError(message=f"Parameter '{parameter_name}' contains invalid symbols")
Loading

0 comments on commit 906fec5

Please sign in to comment.