Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pagination on the Sinequa sql.engine Api #1104

Merged
merged 22 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 122 additions & 32 deletions sde_collections/sinequa_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections.abc import Iterator
from typing import Any

import requests
Expand Down Expand Up @@ -61,15 +62,14 @@ class Api:
def __init__(self, server_name: str = None, user: str = None, password: str = None, token: str = None) -> None:
self.server_name = server_name
if server_name not in server_configs:
raise ValueError(f"Server name '{server_name}' is not in server_configs")
raise ValueError(f"Invalid server configuration: '{server_name}' is not a recognized server name")

self.config = server_configs[server_name]
self.app_name: str = self.config["app_name"]
self.query_name: str = self.config["query_name"]
self.base_url: str = self.config["base_url"]
self.dev_servers = ["xli", "lrm_dev", "lrm_qa"]

# Store provided values only
self._provided_user = user
self._provided_password = password
self._provided_token = token
Expand Down Expand Up @@ -113,7 +113,8 @@ def query(self, page: int, collection_config_folder: str | None = None, source:
password = self._get_password()
if not user or not password:
raise ValueError(
"User and password are required for the query endpoint on the following servers: {self.dev_servers}"
f"Authentication error: Missing credentials for dev server '{self.server_name}'. "
f"Both username and password are required for servers: {', '.join(self.dev_servers)}"
)
authentication = f"?Password={password}&User={user}"
url = f"{url}{authentication}"
Expand All @@ -135,11 +136,22 @@ def query(self, page: int, collection_config_folder: str | None = None, source:

return self.process_response(url, payload)

def sql_query(self, sql: str) -> Any:
"""Executes an SQL query on the configured server using token-based authentication."""
def _execute_sql_query(self, sql: str) -> dict:
"""
Executes a SQL query against the Sinequa API.

Args:
sql (str): The SQL query to execute

Returns:
dict: The JSON response from the API containing 'Rows' and 'TotalRowCount'

Raises:
ValueError: If no token is available for authentication
"""
token = self._get_token()
if not token:
raise ValueError("A token is required to use the SQL endpoint")
raise ValueError("Authentication error: Token is required for SQL endpoint access")

url = f"{self.base_url}/api/v1/engine.sql"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
Expand All @@ -153,42 +165,120 @@ def sql_query(self, sql: str) -> Any:

return self.process_response(url, headers=headers, raw_data=raw_payload)

def get_full_texts(self, collection_config_folder: str, source: str = None) -> Any:
def _process_rows_to_records(self, rows: list) -> list[dict]:
"""
Retrieves the full texts, URLs, and titles for a specified collection.
Converts raw SQL row data into structured record dictionaries.

Args:
rows (list): List of rows, where each row is [url, full_text, title]

Returns:
dict: A JSON response containing the results of the SQL query,
where each item has 'url', 'text', and 'title'.

Example:
Calling get_full_texts("example_collection") might return:
[
{
'url': 'http://example.com/article1',
'text': 'Here is the full text of the first article...',
'title': 'Article One Title'
},
{
'url': 'http://example.com/article2',
'text': 'Here is the full text of the second article...',
'title': 'Article Two Title'
}
]
list[dict]: List of processed records with url, full_text, and title keys

Raises:
ValueError: If any row doesn't contain exactly 3 elements
"""
processed_records = []
for idx, row in enumerate(rows):
if len(row) != 3:
raise ValueError(
f"Invalid row format at index {idx}: Expected exactly three elements (url, full_text, title). "
f"Received {len(row)} elements."
)
processed_records.append({"url": row[0], "full_text": row[1], "title": row[2]})
return processed_records

def get_full_texts(self, collection_config_folder: str, source: str = None) -> Iterator[dict]:
"""
Retrieves and yields batches of text records from the SQL database for a given collection.
Uses pagination to handle large datasets efficiently.

Args:
collection_config_folder (str): The collection folder to query (e.g., "EARTHDATA", "SMD")
source (str, optional): The source to query. If None, defaults to "scrapers" for dev servers
or "SDE" for other servers.

Yields:
list[dict]: Batches of records, where each record is a dictionary containing:
{
"url": str, # The URL of the document
"full_text": str, # The full text content of the document
"title": str # The title of the document
}

Raises:
ValueError: If the server's index is not defined in its configuration

Example batch:
[
{
"url": "https://example.nasa.gov/doc1",
"full_text": "This is the content of doc1...",
"title": "Document 1 Title"
},
{
"url": "https://example.nasa.gov/doc2",
"full_text": "This is the content of doc2...",
"title": "Document 2 Title"
}
]

Note:
- Results are paginated in batches of 5000 records
- Each batch is processed into clean dictionaries before being yielded
- The iterator will stop when either:
1. No more rows are returned from the query
2. The total count of records has been reached
"""

if not source:
source = self._get_source_name()

if (index := self.config.get("index")) is None:
raise ValueError("Index not defined for this server")
raise ValueError(
f"Configuration error: Index not defined for server '{self.server_name}'. "
"Please update server configuration with the required index."
)

sql = f"SELECT url1, text, title FROM {index} WHERE collection = '/{source}/{collection_config_folder}/'"
full_text_response = self.sql_query(sql)
return self._process_full_text_response(full_text_response)

page = 0
page_size = 5000
total_processed = 0

while True:
paginated_sql = f"{sql} SKIP {total_processed} COUNT {page_size}"
response = self._execute_sql_query(paginated_sql)

rows = response.get("Rows", [])
if not rows: # Stop if we get an empty batch
break

yield self._process_rows_to_records(rows)

total_processed += len(rows)
total_count = response.get("TotalRowCount", 0)

if total_processed >= total_count: # Stop if we've processed all records
break

page += 1

@staticmethod
def _process_full_text_response(full_text_response: str):
return [
{"url": url, "full_text": full_text, "title": title} for url, full_text, title in full_text_response["Rows"]
]
def _process_full_text_response(batch_data: dict):
if "Rows" not in batch_data or not isinstance(batch_data["Rows"], list):
raise ValueError(
"Invalid response format: Expected 'Rows' key with list data in Sinequa server response. "
f"Received: {type(batch_data.get('Rows', None))}"
)

processed_data = []
for idx, row in enumerate(batch_data["Rows"]):
if len(row) != 3:
raise ValueError(
f"Invalid row format at index {idx}: Expected exactly three elements (url, full_text, title). "
f"Received {len(row)} elements."
)
url, full_text, title = row
processed_data.append({"url": url, "full_text": full_text, "title": title})
return processed_data
72 changes: 37 additions & 35 deletions sde_collections/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.conf import settings
from django.core import management
from django.core.management.commands import loaddata
from django.db import IntegrityError
from django.db import transaction

from config import celery_app

Expand Down Expand Up @@ -145,44 +145,46 @@ def resolve_title_pattern(title_pattern_id):
title_pattern.apply()


@celery_app.task
@celery_app.task(soft_time_limit=600)
def fetch_and_replace_full_text(collection_id, server_name):
"""
Task to fetch and replace full text and metadata for all URLs associated with a specified collection
from a given server. This task deletes all existing DumpUrl entries for the collection and creates
new entries based on the latest fetched data.

Args:
collection_id (int): The identifier for the collection in the database.
server_name (str): The name of the server.

Returns:
str: A message indicating the result of the operation, including the number of URLs processed.
Task to fetch and replace full text and metadata for a collection.
Handles data in batches to manage memory usage.
"""
collection = Collection.objects.get(id=collection_id)
api = Api(server_name)
documents = api.get_full_texts(collection.config_folder)

# Step 1: Delete all existing DumpUrl entries for the collection
# Step 1: Delete existing DumpUrl entries
deleted_count, _ = DumpUrl.objects.filter(collection=collection).delete()

# Step 2: Create new DumpUrl entries from the fetched documents
processed_count = 0
for doc in documents:
try:
DumpUrl.objects.create(
url=doc["url"],
collection=collection,
scraped_text=doc.get("full_text", ""),
scraped_title=doc.get("title", ""),
)
processed_count += 1
except IntegrityError:
# Handle duplicate URL case if needed
print(f"Duplicate URL found, skipping: {doc['url']}")

collection.migrate_dump_to_delta()

print(f"Processed {processed_count} new records.")

return f"Successfully processed {len(documents)} records and updated the database."
print(f"Deleted {deleted_count} old records.")

# Step 2: Process data in batches
total_processed = 0

try:
for batch in api.get_full_texts(collection.config_folder):
# Use bulk_create for efficiency, with a transaction per batch
with transaction.atomic():
DumpUrl.objects.bulk_create(
[
DumpUrl(
url=record["url"],
collection=collection,
scraped_text=record["full_text"],
scraped_title=record["title"],
)
for record in batch
]
)

total_processed += len(batch)
print(f"Processed batch of {len(batch)} records. Total: {total_processed}")

# Step 3: Migrate dump URLs to delta URLs
collection.migrate_dump_to_delta()

return f"Successfully processed {total_processed} records and updated the database."

except Exception as e:
print(f"Error processing records: {str(e)}")
raise
Loading
Loading