Skip to content

Commit

Permalink
add automatic batch size reduction to sinequa_api
Browse files Browse the repository at this point in the history
  • Loading branch information
CarsonDavis committed Dec 7, 2024
1 parent 1b71c2d commit 2b811b6
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 34 deletions.
79 changes: 45 additions & 34 deletions sde_collections/sinequa_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,26 @@ def _process_rows_to_records(self, rows: list) -> list[dict]:
processed_records.append({"url": row[0], "full_text": row[1], "title": row[2]})
return processed_records

def get_full_texts(self, collection_config_folder: str, source: str = None) -> Iterator[dict]:
def get_full_texts(
self,
collection_config_folder: str,
source: str = None,
start_at: int = 0,
batch_size: int = 500,
min_batch_size: int = 1,
) -> Iterator[dict]:
"""
Retrieves and yields batches of text records from the SQL database for a given collection.
Uses pagination to handle large datasets efficiently.
Uses pagination to handle large datasets efficiently. If a query fails, it automatically
reduces the batch size and retries, with the ability to recover batch size after successful queries.
Args:
collection_config_folder (str): The collection folder to query (e.g., "EARTHDATA", "SMD")
collection_config_folder (str): The collection folder to query (e.g., "EARTHDATA", "CASEI")
source (str, optional): The source to query. If None, defaults to "scrapers" for dev servers
or "SDE" for other servers.
start_at (int, optional): Starting offset for records. Defaults to 0.
page_size (int, optional): Initial number of records per batch. Defaults to 500.
min_batch_size (int, optional): Minimum batch size before giving up. Defaults to 1.
Yields:
list[dict]: Batches of records, where each record is a dictionary containing:
Expand All @@ -208,29 +219,16 @@ def get_full_texts(self, collection_config_folder: str, source: str = None) -> I
Raises:
ValueError: If the server's index is not defined in its configuration
Example batch:
[
{
"url": "https://example.nasa.gov/doc1",
"full_text": "This is the content of doc1...",
"title": "Document 1 Title"
},
{
"url": "https://example.nasa.gov/doc2",
"full_text": "This is the content of doc2...",
"title": "Document 2 Title"
}
]
ValueError: If batch size reaches minimum without success
Note:
- Results are paginated in batches of 5000 records
- Results are paginated with adaptive batch sizing
- Each batch is processed into clean dictionaries before being yielded
- The iterator will stop when either:
1. No more rows are returned from the query
2. The total count of records has been reached
- Batch size will decrease on failure and can recover after successful queries
"""

if not source:
source = self._get_source_name()

Expand All @@ -240,29 +238,42 @@ def get_full_texts(self, collection_config_folder: str, source: str = None) -> I
"Please update server configuration with the required index."
)

sql = f"SELECT url1, text, title FROM {index} WHERE collection = '/{source}/{collection_config_folder}/'"
base_sql = f"SELECT url1, text, title FROM {index} WHERE collection = '/{source}/{collection_config_folder}/'"

page = 0
page_size = 5000
total_processed = 0
current_offset = start_at
current_batch_size = batch_size
total_count = None

while True:
paginated_sql = f"{sql} SKIP {total_processed} COUNT {page_size}"
response = self._execute_sql_query(paginated_sql)
sql = f"{base_sql} SKIP {current_offset} COUNT {current_batch_size}"

try:
response = self._execute_sql_query(sql)
rows = response.get("Rows", [])

if not rows: # Stop if we get an empty batch
break

if total_count is None:
total_count = response.get("TotalRowCount", 0)

rows = response.get("Rows", [])
if not rows: # Stop if we get an empty batch
break
yield self._process_rows_to_records(rows)

yield self._process_rows_to_records(rows)
current_offset += len(rows)

total_processed += len(rows)
total_count = response.get("TotalRowCount", 0)
if total_count and current_offset >= total_count: # Stop if we've processed all records
break

if total_processed >= total_count: # Stop if we've processed all records
break
except (requests.RequestException, ValueError) as e:
if current_batch_size <= min_batch_size:
raise ValueError(
f"Failed to process batch even at minimum size {min_batch_size}. " f"Last error: {str(e)}"
)

page += 1
# Halve the batch size and retry
current_batch_size = max(current_batch_size // 2, min_batch_size)
print(f"Reducing batch size to {current_batch_size} and retrying...")
continue

@staticmethod
def _process_full_text_response(batch_data: dict):
Expand Down
65 changes: 65 additions & 0 deletions sde_collections/tests/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock, patch

import pytest
import requests
from django.utils import timezone

from sde_collections.models.collection import WorkflowStatusChoices
Expand Down Expand Up @@ -160,3 +161,67 @@ def test_query_dev_server_missing_credentials(self, mock_post, api_instance):

with pytest.raises(ValueError, match="Authentication error: Missing credentials for dev server"):
api_instance.query(page=1)

@patch("sde_collections.sinequa_api.Api._execute_sql_query")
def test_get_full_texts_batch_size_reduction(self, mock_execute_sql, api_instance):
"""Test that batch size reduces appropriately on failure and continues processing."""
# Mock first query to fail, then succeed with smaller batch
mock_execute_sql.side_effect = [
requests.RequestException("Query too large"), # First attempt fails
{
"Rows": [["http://example.com/1", "Text 1", "Title 1"]],
"TotalRowCount": 1,
}, # Succeeds with smaller batch
]

batches = list(api_instance.get_full_texts("test_folder", batch_size=100, min_batch_size=1))

# Verify the batches were processed correctly after size reduction
assert len(batches) == 1
assert len(batches[0]) == 1
assert batches[0][0]["url"] == "http://example.com/1"

# Verify the calls made - first with original size, then with reduced size
assert mock_execute_sql.call_count == 2
first_call = mock_execute_sql.call_args_list[0][0][0]
second_call = mock_execute_sql.call_args_list[1][0][0]
assert "COUNT 100" in first_call
assert "COUNT 50" in second_call # Should be halved from 100

@patch("sde_collections.sinequa_api.Api._execute_sql_query")
def test_get_full_texts_minimum_batch_size(self, mock_execute_sql, api_instance):
"""Test behavior when reaching minimum batch size."""
mock_execute_sql.side_effect = requests.RequestException("Query failed")

# Start with batch_size=4, min_batch_size=1
# Should try: 4 -> 2 -> 1 -> raise error
with pytest.raises(ValueError, match="Failed to process batch even at minimum size 1"):
list(api_instance.get_full_texts("test_folder", batch_size=4, min_batch_size=1))

# Should have tried 3 times before giving up
assert mock_execute_sql.call_count == 3
calls = mock_execute_sql.call_args_list
assert "COUNT 4" in calls[0][0][0] # First try with 4
assert "COUNT 2" in calls[1][0][0] # Second try with 2
assert "COUNT 1" in calls[2][0][0] # Final try with 1

@patch("sde_collections.sinequa_api.Api._execute_sql_query")
def test_get_full_texts_batch_size_progression(self, mock_execute_sql, api_instance):
"""Test multiple batch size reductions followed by successful query."""
mock_execute_sql.side_effect = [
requests.RequestException("First failure"),
requests.RequestException("Second failure"),
{"Rows": [["http://example.com/1", "Text 1", "Title 1"]], "TotalRowCount": 1},
]

# Start with batch_size=100, should reduce to 25 before succeeding
batches = list(api_instance.get_full_texts("test_folder", batch_size=100, min_batch_size=1))

assert len(batches) == 1 # Should get one successful batch
assert mock_execute_sql.call_count == 3

calls = mock_execute_sql.call_args_list
# Verify the progression of batch sizes
assert "COUNT 100" in calls[0][0][0] # First attempt
assert "COUNT 50" in calls[1][0][0] # After first failure
assert "COUNT 25" in calls[2][0][0] # After second failure

0 comments on commit 2b811b6

Please sign in to comment.