diff --git a/sde_collections/sinequa_api.py b/sde_collections/sinequa_api.py index 868afb77..78ac4e05 100644 --- a/sde_collections/sinequa_api.py +++ b/sde_collections/sinequa_api.py @@ -1,4 +1,5 @@ import json +from collections.abc import Iterator from typing import Any import requests @@ -61,7 +62,7 @@ class Api: def __init__(self, server_name: str = None, user: str = None, password: str = None, token: str = None) -> None: self.server_name = server_name if server_name not in server_configs: - raise ValueError(f"Server name '{server_name}' is not in server_configs") + raise ValueError(f"Invalid server configuration: '{server_name}' is not a recognized server name") self.config = server_configs[server_name] self.app_name: str = self.config["app_name"] @@ -69,7 +70,6 @@ def __init__(self, server_name: str = None, user: str = None, password: str = No self.base_url: str = self.config["base_url"] self.dev_servers = ["xli", "lrm_dev", "lrm_qa"] - # Store provided values only self._provided_user = user self._provided_password = password self._provided_token = token @@ -113,7 +113,8 @@ def query(self, page: int, collection_config_folder: str | None = None, source: password = self._get_password() if not user or not password: raise ValueError( - "User and password are required for the query endpoint on the following servers: {self.dev_servers}" + f"Authentication error: Missing credentials for dev server '{self.server_name}'. " + f"Both username and password are required for servers: {', '.join(self.dev_servers)}" ) authentication = f"?Password={password}&User={user}" url = f"{url}{authentication}" @@ -135,11 +136,22 @@ def query(self, page: int, collection_config_folder: str | None = None, source: return self.process_response(url, payload) - def sql_query(self, sql: str) -> Any: - """Executes an SQL query on the configured server using token-based authentication.""" + def _execute_sql_query(self, sql: str) -> dict: + """ + Executes a SQL query against the Sinequa API. + + Args: + sql (str): The SQL query to execute + + Returns: + dict: The JSON response from the API containing 'Rows' and 'TotalRowCount' + + Raises: + ValueError: If no token is available for authentication + """ token = self._get_token() if not token: - raise ValueError("A token is required to use the SQL endpoint") + raise ValueError("Authentication error: Token is required for SQL endpoint access") url = f"{self.base_url}/api/v1/engine.sql" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"} @@ -153,42 +165,120 @@ def sql_query(self, sql: str) -> Any: return self.process_response(url, headers=headers, raw_data=raw_payload) - def get_full_texts(self, collection_config_folder: str, source: str = None) -> Any: + def _process_rows_to_records(self, rows: list) -> list[dict]: """ - Retrieves the full texts, URLs, and titles for a specified collection. + Converts raw SQL row data into structured record dictionaries. + + Args: + rows (list): List of rows, where each row is [url, full_text, title] Returns: - dict: A JSON response containing the results of the SQL query, - where each item has 'url', 'text', and 'title'. - - Example: - Calling get_full_texts("example_collection") might return: - [ - { - 'url': 'http://example.com/article1', - 'text': 'Here is the full text of the first article...', - 'title': 'Article One Title' - }, - { - 'url': 'http://example.com/article2', - 'text': 'Here is the full text of the second article...', - 'title': 'Article Two Title' - } - ] + list[dict]: List of processed records with url, full_text, and title keys + + Raises: + ValueError: If any row doesn't contain exactly 3 elements + """ + processed_records = [] + for idx, row in enumerate(rows): + if len(row) != 3: + raise ValueError( + f"Invalid row format at index {idx}: Expected exactly three elements (url, full_text, title). " + f"Received {len(row)} elements." + ) + processed_records.append({"url": row[0], "full_text": row[1], "title": row[2]}) + return processed_records + + def get_full_texts(self, collection_config_folder: str, source: str = None) -> Iterator[dict]: + """ + Retrieves and yields batches of text records from the SQL database for a given collection. + Uses pagination to handle large datasets efficiently. + + Args: + collection_config_folder (str): The collection folder to query (e.g., "EARTHDATA", "SMD") + source (str, optional): The source to query. If None, defaults to "scrapers" for dev servers + or "SDE" for other servers. + + Yields: + list[dict]: Batches of records, where each record is a dictionary containing: + { + "url": str, # The URL of the document + "full_text": str, # The full text content of the document + "title": str # The title of the document + } + + Raises: + ValueError: If the server's index is not defined in its configuration + + Example batch: + [ + { + "url": "https://example.nasa.gov/doc1", + "full_text": "This is the content of doc1...", + "title": "Document 1 Title" + }, + { + "url": "https://example.nasa.gov/doc2", + "full_text": "This is the content of doc2...", + "title": "Document 2 Title" + } + ] + + Note: + - Results are paginated in batches of 5000 records + - Each batch is processed into clean dictionaries before being yielded + - The iterator will stop when either: + 1. No more rows are returned from the query + 2. The total count of records has been reached """ if not source: source = self._get_source_name() if (index := self.config.get("index")) is None: - raise ValueError("Index not defined for this server") + raise ValueError( + f"Configuration error: Index not defined for server '{self.server_name}'. " + "Please update server configuration with the required index." + ) sql = f"SELECT url1, text, title FROM {index} WHERE collection = '/{source}/{collection_config_folder}/'" - full_text_response = self.sql_query(sql) - return self._process_full_text_response(full_text_response) + + page = 0 + page_size = 5000 + total_processed = 0 + + while True: + paginated_sql = f"{sql} SKIP {total_processed} COUNT {page_size}" + response = self._execute_sql_query(paginated_sql) + + rows = response.get("Rows", []) + if not rows: # Stop if we get an empty batch + break + + yield self._process_rows_to_records(rows) + + total_processed += len(rows) + total_count = response.get("TotalRowCount", 0) + + if total_processed >= total_count: # Stop if we've processed all records + break + + page += 1 @staticmethod - def _process_full_text_response(full_text_response: str): - return [ - {"url": url, "full_text": full_text, "title": title} for url, full_text, title in full_text_response["Rows"] - ] + def _process_full_text_response(batch_data: dict): + if "Rows" not in batch_data or not isinstance(batch_data["Rows"], list): + raise ValueError( + "Invalid response format: Expected 'Rows' key with list data in Sinequa server response. " + f"Received: {type(batch_data.get('Rows', None))}" + ) + + processed_data = [] + for idx, row in enumerate(batch_data["Rows"]): + if len(row) != 3: + raise ValueError( + f"Invalid row format at index {idx}: Expected exactly three elements (url, full_text, title). " + f"Received {len(row)} elements." + ) + url, full_text, title = row + processed_data.append({"url": url, "full_text": full_text, "title": title}) + return processed_data diff --git a/sde_collections/tasks.py b/sde_collections/tasks.py index 47c96338..8d4a4c4d 100644 --- a/sde_collections/tasks.py +++ b/sde_collections/tasks.py @@ -7,7 +7,7 @@ from django.conf import settings from django.core import management from django.core.management.commands import loaddata -from django.db import IntegrityError +from django.db import transaction from config import celery_app @@ -145,44 +145,46 @@ def resolve_title_pattern(title_pattern_id): title_pattern.apply() -@celery_app.task +@celery_app.task(soft_time_limit=600) def fetch_and_replace_full_text(collection_id, server_name): """ - Task to fetch and replace full text and metadata for all URLs associated with a specified collection - from a given server. This task deletes all existing DumpUrl entries for the collection and creates - new entries based on the latest fetched data. - - Args: - collection_id (int): The identifier for the collection in the database. - server_name (str): The name of the server. - - Returns: - str: A message indicating the result of the operation, including the number of URLs processed. + Task to fetch and replace full text and metadata for a collection. + Handles data in batches to manage memory usage. """ collection = Collection.objects.get(id=collection_id) api = Api(server_name) - documents = api.get_full_texts(collection.config_folder) - # Step 1: Delete all existing DumpUrl entries for the collection + # Step 1: Delete existing DumpUrl entries deleted_count, _ = DumpUrl.objects.filter(collection=collection).delete() - - # Step 2: Create new DumpUrl entries from the fetched documents - processed_count = 0 - for doc in documents: - try: - DumpUrl.objects.create( - url=doc["url"], - collection=collection, - scraped_text=doc.get("full_text", ""), - scraped_title=doc.get("title", ""), - ) - processed_count += 1 - except IntegrityError: - # Handle duplicate URL case if needed - print(f"Duplicate URL found, skipping: {doc['url']}") - - collection.migrate_dump_to_delta() - - print(f"Processed {processed_count} new records.") - - return f"Successfully processed {len(documents)} records and updated the database." + print(f"Deleted {deleted_count} old records.") + + # Step 2: Process data in batches + total_processed = 0 + + try: + for batch in api.get_full_texts(collection.config_folder): + # Use bulk_create for efficiency, with a transaction per batch + with transaction.atomic(): + DumpUrl.objects.bulk_create( + [ + DumpUrl( + url=record["url"], + collection=collection, + scraped_text=record["full_text"], + scraped_title=record["title"], + ) + for record in batch + ] + ) + + total_processed += len(batch) + print(f"Processed batch of {len(batch)} records. Total: {total_processed}") + + # Step 3: Migrate dump URLs to delta URLs + collection.migrate_dump_to_delta() + + return f"Successfully processed {total_processed} records and updated the database." + + except Exception as e: + print(f"Error processing records: {str(e)}") + raise diff --git a/sde_collections/tests/api_tests.py b/sde_collections/tests/api_tests.py new file mode 100644 index 00000000..88a0f44f --- /dev/null +++ b/sde_collections/tests/api_tests.py @@ -0,0 +1,162 @@ +# docker-compose -f local.yml run --rm django pytest sde_collections/tests/api_tests.py +from unittest.mock import MagicMock, patch + +import pytest +from django.utils import timezone + +from sde_collections.models.collection import WorkflowStatusChoices +from sde_collections.sinequa_api import Api +from sde_collections.tests.factories import CollectionFactory, UserFactory + + +@pytest.mark.django_db +class TestApiClass: + @pytest.fixture + def collection(self): + """Fixture to create a collection object for testing.""" + user = UserFactory() + return CollectionFactory( + curated_by=user, + curation_started=timezone.now(), + config_folder="example_config", + workflow_status=WorkflowStatusChoices.RESEARCH_IN_PROGRESS, + ) + + @pytest.fixture + def api_instance(self): + """Fixture to create an Api instance with mocked server configs.""" + with patch( + "sde_collections.sinequa_api.server_configs", + { + "test_server": { + "app_name": "test_app", + "query_name": "test_query", + "base_url": "http://testserver.com/api", + "index": "test_index", + } + }, + ): + return Api(server_name="test_server", user="test_user", password="test_pass", token="test_token") + + @patch("requests.post") + def test_process_response_success(self, mock_post, api_instance): + """Test that process_response handles successful responses.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"key": "value"} + mock_post.return_value = mock_response + + response = api_instance.process_response("http://example.com", payload={"test": "data"}) + assert response == {"key": "value"} + + @patch("requests.post") + def test_process_response_failure(self, mock_post, api_instance): + """Test that process_response raises an exception on failure.""" + mock_response = MagicMock() + mock_response.status_code = 500 + mock_post.return_value = mock_response + mock_response.raise_for_status.side_effect = Exception("Internal Server Error") + + with pytest.raises(Exception, match="Internal Server Error"): + api_instance.process_response("http://example.com", payload={"test": "data"}) + + @patch("sde_collections.sinequa_api.Api.process_response") + def test_query(self, mock_process_response, api_instance): + """Test that query sends correct payload and processes response.""" + mock_process_response.return_value = {"result": "success"} + response = api_instance.query(page=1, collection_config_folder="folder") + assert response == {"result": "success"} + + def test_process_rows_to_records(self, api_instance): + """Test processing row data into record dictionaries.""" + # Test valid input + valid_rows = [["http://example.com/1", "Text 1", "Title 1"], ["http://example.com/2", "Text 2", "Title 2"]] + expected_output = [ + {"url": "http://example.com/1", "full_text": "Text 1", "title": "Title 1"}, + {"url": "http://example.com/2", "full_text": "Text 2", "title": "Title 2"}, + ] + assert api_instance._process_rows_to_records(valid_rows) == expected_output + + # Test invalid row length + invalid_rows = [["http://example.com", "Text"]] # Missing title + with pytest.raises(ValueError, match="Invalid row format at index 0"): + api_instance._process_rows_to_records(invalid_rows) + + @patch("sde_collections.sinequa_api.Api.process_response") + def test_execute_sql_query(self, mock_process_response, api_instance): + """Test SQL query execution.""" + mock_process_response.return_value = {"Rows": [], "TotalRowCount": 0} + + # Test successful query + result = api_instance._execute_sql_query("SELECT * FROM test") + assert result == {"Rows": [], "TotalRowCount": 0} + + # Test query with missing token + api_instance._provided_token = None + with pytest.raises(ValueError, match="Token is required"): + api_instance._execute_sql_query("SELECT * FROM test") + + @patch("sde_collections.sinequa_api.Api._execute_sql_query") + def test_get_full_texts_pagination(self, mock_execute_sql, api_instance): + """Test that get_full_texts correctly handles pagination.""" + # Mock responses for two pages of results + mock_execute_sql.side_effect = [ + { + "Rows": [["http://example.com/1", "Text 1", "Title 1"], ["http://example.com/2", "Text 2", "Title 2"]], + "TotalRowCount": 3, + }, + {"Rows": [["http://example.com/3", "Text 3", "Title 3"]], "TotalRowCount": 3}, + {"Rows": [], "TotalRowCount": 3}, + ] + + # Collect all batches from the iterator + batches = list(api_instance.get_full_texts("test_folder")) + + assert len(batches) == 2 # Should have two batches + assert len(batches[0]) == 2 # First batch has 2 records + assert len(batches[1]) == 1 # Second batch has 1 record + + # Verify content of first batch + assert batches[0] == [ + {"url": "http://example.com/1", "full_text": "Text 1", "title": "Title 1"}, + {"url": "http://example.com/2", "full_text": "Text 2", "title": "Title 2"}, + ] + + # Verify content of second batch + assert batches[1] == [{"url": "http://example.com/3", "full_text": "Text 3", "title": "Title 3"}] + + def test_get_full_texts_missing_index(self, api_instance): + """Test that get_full_texts raises error when index is missing from config.""" + api_instance.config.pop("index", None) + with pytest.raises(ValueError, match="Index not defined for server"): + next(api_instance.get_full_texts("test_folder")) + + @pytest.mark.parametrize( + "server_name,expect_auth", + [ + ("xli", True), # dev server should have auth + ("production", False), # prod server should not have auth + ], + ) + @patch("requests.post") + def test_query_authentication(self, mock_post, server_name, expect_auth, api_instance): + """Test authentication handling for different server types.""" + api_instance.server_name = server_name + mock_post.return_value = MagicMock(status_code=200, json=lambda: {"result": "success"}) + + response = api_instance.query(page=1, collection_config_folder="folder") + assert response == {"result": "success"} + + called_url = mock_post.call_args[0][0] + auth_present = "?Password=test_pass&User=test_user" in called_url + assert auth_present == expect_auth + + @patch("requests.post") + def test_query_dev_server_missing_credentials(self, mock_post, api_instance): + """Test that dev servers raise error when credentials are missing.""" + api_instance.server_name = "xli" + api_instance._provided_user = None + api_instance._provided_password = None + + with pytest.raises(ValueError, match="Authentication error: Missing credentials for dev server"): + api_instance.query(page=1) diff --git a/sde_collections/tests/test_import_fulltexts.py b/sde_collections/tests/test_import_fulltexts.py index b4256bde..d39f1633 100644 --- a/sde_collections/tests/test_import_fulltexts.py +++ b/sde_collections/tests/test_import_fulltexts.py @@ -4,39 +4,69 @@ import pytest -from sde_collections.models.delta_url import CuratedUrl, DeltaUrl, DumpUrl +from sde_collections.models.delta_url import DeltaUrl, DumpUrl from sde_collections.tasks import fetch_and_replace_full_text from sde_collections.tests.factories import CollectionFactory @pytest.mark.django_db def test_fetch_and_replace_full_text(): - # Create a test collection - collection = CollectionFactory() + collection = CollectionFactory(config_folder="test_folder") - # Mock API response - mock_documents = [ + mock_batch = [ {"url": "http://example.com/1", "full_text": "Test Text 1", "title": "Test Title 1"}, {"url": "http://example.com/2", "full_text": "Test Text 2", "title": "Test Title 2"}, ] + def mock_generator(): + yield mock_batch + with patch("sde_collections.sinequa_api.Api.get_full_texts") as mock_get_full_texts: - mock_get_full_texts.return_value = mock_documents + mock_get_full_texts.return_value = mock_generator() - # Call the function fetch_and_replace_full_text(collection.id, "lrm_dev") - # Assertions assert DumpUrl.objects.filter(collection=collection).count() == 0 - assert DeltaUrl.objects.filter(collection=collection).count() == len(mock_documents) - assert CuratedUrl.objects.filter(collection=collection).count() == 0 - - for doc in mock_documents: - assert ( - DeltaUrl.objects.filter(collection=collection) - .filter( - url=doc["url"], - scraped_text=doc["full_text"], - ) - .exists() - ) + assert DeltaUrl.objects.filter(collection=collection).count() == 2 + + +@pytest.mark.django_db +def test_fetch_and_replace_full_text_large_dataset(): + """Test processing a large number of records with proper pagination and batching.""" + collection = CollectionFactory(config_folder="test_folder") + + # Create sample data - 20,000 records in total + def create_batch(start_idx, size): + return [ + {"url": f"http://example.com/{i}", "full_text": f"Test Text {i}", "title": f"Test Title {i}"} + for i in range(start_idx, start_idx + size) + ] + + # Mock the API to return data in batches of 5000 (matching actual API pagination) + def mock_batch_generator(): + batch_size = 5000 + total_records = 20000 + + for start in range(0, total_records, batch_size): + yield create_batch(start, min(batch_size, total_records - start)) + + with patch("sde_collections.sinequa_api.Api.get_full_texts") as mock_get_full_texts: + mock_get_full_texts.return_value = mock_batch_generator() + + # Execute the task + result = fetch_and_replace_full_text(collection.id, "lrm_dev") + + # Verify total number of records + assert DeltaUrl.objects.filter(collection=collection).count() == 20000 + + # Verify some random records exist and have correct data + for i in [0, 4999, 5000, 19999]: # Check boundaries and middle + url = DeltaUrl.objects.get(url=f"http://example.com/{i}") + assert url.scraped_text == f"Test Text {i}" + assert url.scraped_title == f"Test Title {i}" + + # Verify batch processing worked by checking the success message + assert "Successfully processed 20000 records" in result + + # Verify no DumpUrls remain (should all be migrated to DeltaUrls) + assert DumpUrl.objects.filter(collection=collection).count() == 0