From 77b1ec328b321d099bea36aa1579f8abc3da601d Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 26 Nov 2024 01:33:50 -0600 Subject: [PATCH] Fixes #1096 --- sde_collections/sinequa_api.py | 14 +++++++++++-- sde_collections/tests/api_tests.py | 33 +++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/sde_collections/sinequa_api.py b/sde_collections/sinequa_api.py index c16277d5..d0713062 100644 --- a/sde_collections/sinequa_api.py +++ b/sde_collections/sinequa_api.py @@ -231,5 +231,15 @@ def get_full_texts(self, collection_config_folder: str, source: str = None, coll return self.sql_query(sql, collection) @staticmethod - def _process_full_text_response(batch_data: str): - return [{"url": url, "full_text": full_text, "title": title} for url, full_text, title in batch_data["Rows"]] + def _process_full_text_response(batch_data: dict): + if 'Rows' not in batch_data or not isinstance(batch_data['Rows'], list): + raise ValueError("Expected 'Rows' key with a list of data.") + + processed_data = [] + for row in batch_data['Rows']: + # Ensure each row has exactly three elements (url, full_text, title) + if len(row) != 3: + raise ValueError("Each row must contain exactly three elements (url, full_text, title).") + url, full_text, title = row + processed_data.append({"url": url, "full_text": full_text, "title": title}) + return processed_data diff --git a/sde_collections/tests/api_tests.py b/sde_collections/tests/api_tests.py index 0a7a9245..85db82a8 100644 --- a/sde_collections/tests/api_tests.py +++ b/sde_collections/tests/api_tests.py @@ -147,7 +147,7 @@ def test_api_init(self, server_name, user, password, expected): @patch("requests.post") def test_query_dev_server_authentication(self, mock_post, api_instance): """Test query on dev servers requiring authentication.""" - api_instance.server_name = "xli" # Setting a dev server + api_instance.server_name = "xli" mock_post.return_value = MagicMock(status_code=200, json=lambda: {"result": "success"}) response = api_instance.query(page=1, collection_config_folder="folder") @@ -168,3 +168,34 @@ def test_sql_query_pagination(self, mock_process_response, api_instance, collect result = api_instance.sql_query("SELECT * FROM test_index", collection) assert result == "All 6 records have been processed and updated." + + def test_process_full_text_response(self, api_instance): + """Test that _process_full_text_response correctly processes the data.""" + batch_data = {"Rows": [ + ["http://example.com", "Example text", "Example title"], + ["http://example.net", "Another text", "Another title"] + ]} + expected_output = [ + {"url": "http://example.com", "full_text": "Example text", "title": "Example title"}, + {"url": "http://example.net", "full_text": "Another text", "title": "Another title"} + ] + result = api_instance._process_full_text_response(batch_data) + assert result == expected_output + + def test_process_full_text_response_with_invalid_data(self, api_instance): + """Test that _process_full_text_response raises an error with invalid data.""" + # Test for missing 'Rows' key + invalid_data_no_rows = {} # No 'Rows' key + with pytest.raises(ValueError, match="Expected 'Rows' key with a list of data"): + api_instance._process_full_text_response(invalid_data_no_rows) + + # Test for incorrect row length + invalid_data_wrong_length = {"Rows": [["http://example.com", "Example text"]]} # Missing 'title' + with pytest.raises(ValueError, match="Each row must contain exactly three elements"): + api_instance._process_full_text_response(invalid_data_wrong_length) + + @patch("sde_collections.sinequa_api.Api._get_token", return_value=None) + def test_sql_query_missing_token(self, mock_get_token, api_instance, collection): + """Test that sql_query raises an error when no token is provided.""" + with pytest.raises(ValueError, match="A token is required to use the SQL endpoint"): + api_instance.sql_query("SELECT * FROM test_table", collection)