From 6e1923127802e877f2f4c40acb764cb1fc9fbffc Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:57:31 +0200 Subject: [PATCH] Fix missing index column in /filter (#1935) * Fix missing index column in /filter * Fix test * Fix e2e test --- e2e/tests/test_53_filter.py | 8 +++----- services/search/src/search/routes/filter.py | 5 ++++- services/search/tests/routes/test_filter.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/e2e/tests/test_53_filter.py b/e2e/tests/test_53_filter.py index 8c7503f801..7130ce58a7 100644 --- a/e2e/tests/test_53_filter.py +++ b/e2e/tests/test_53_filter.py @@ -15,7 +15,7 @@ def test_filter_endpoint( headers = auth_headers[auth] offset = 1 length = 2 - where = "col_4 = 'B'" + where = "col_4='B'" filter_response = poll_until_ready_and_assert( relative_url=( f"/filter?dataset={dataset}&config={config}&split={split}&offset={offset}&length={length}&where={where}" @@ -47,7 +47,7 @@ def test_filter_endpoint( "truncated_cells": [], }, rows[0] assert rows[1] == { - "row_idx": 2, + "row_idx": 3, "row": { "col_1": "The wingman spots the pirateship coming at him and warns the Dark Lord", "col_2": 3, @@ -77,15 +77,13 @@ def test_filter_endpoint( ("col_2<3 OR col_4='B'", 4), ], ) -def test_where_parameter_in_filter_endpoint( +def test_filter_endpoint_parameter_where( where: str, expected_num_rows: int, hf_public_dataset_repo_csv_data: str ) -> None: dataset = hf_public_dataset_repo_csv_data config, split = get_default_config_split() response = poll_until_ready_and_assert( relative_url=f"/filter?dataset={dataset}&config={config}&split={split}&where={where}", - expected_status_code=200, - expected_error_code=None, check_x_revision=True, ) content = response.json() diff --git a/services/search/src/search/routes/filter.py b/services/search/src/search/routes/filter.py index e58a81dc27..dc4298dabf 100644 --- a/services/search/src/search/routes/filter.py +++ b/services/search/src/search/routes/filter.py @@ -222,7 +222,10 @@ def execute_filter_query( ) -> tuple[int, pa.Table]: with duckdb_connect(database=index_file_location) as con: filter_query = FILTER_QUERY.format( - columns=",".join([f'"{column}"' for column in columns]), where=where, limit=limit, offset=offset + columns=",".join([f'"{column}"' for column in [ROW_IDX_COLUMN] + columns]), + where=where, + limit=limit, + offset=offset, ) filter_count_query = FILTER_COUNT_QUERY.format(where=where) try: diff --git a/services/search/tests/routes/test_filter.py b/services/search/tests/routes/test_filter.py index fbe90b1ff4..5d7da5a603 100644 --- a/services/search/tests/routes/test_filter.py +++ b/services/search/tests/routes/test_filter.py @@ -64,7 +64,7 @@ def test_execute_filter_query(index_file_location: str) -> None: index_file_location=index_file_location, columns=columns, where=where, limit=limit, offset=offset ) assert num_rows_total == 2 - assert pa_table == pa.Table.from_pydict({"name": ["Simone"], "age": [30]}) + assert pa_table == pa.Table.from_pydict({"__hf_index_id": [3], "name": ["Simone"], "age": [30]}) @pytest.mark.parametrize("where", ["non-existing-column=30", "name=30", "name>30"])