diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py index 956c299..5e912f9 100644 --- a/align_data/embeddings/pinecone/pinecone_db_handler.py +++ b/align_data/embeddings/pinecone/pinecone_db_handler.py @@ -1,9 +1,11 @@ # dataset/pinecone_db_handler.py +import time import logging from typing import List, Tuple from pinecone import Pinecone from pinecone.core.client.models import ScoredVector +from urllib3.exceptions import ProtocolError from align_data.embeddings.embedding_utils import get_embedding from align_data.embeddings.pinecone.pinecone_models import ( @@ -22,6 +24,22 @@ logger = logging.getLogger(__name__) +def with_retry(n=3, exceptions=(Exception,)): + def retrier_wrapper(f): + def wrapper(*args, **kwargs): + for i in range(n): + try: + return f(*args, **kwargs) + except exceptions as e: + logger.error(f'Got exception while retrying: {e}') + except Exception as e: + breakpoint() + time.sleep(2 ** i) + raise TimeoutError(f'Gave up after {n} tries') + return wrapper + return retrier_wrapper + + class PineconeDB: def __init__( self, @@ -49,10 +67,12 @@ def __init__( index_stats_response = self.index.describe_index_stats() logger.info(f"{self.index_name}:\n{index_stats_response}") - def upsert_entry( - self, pinecone_entry: PineconeEntry, upsert_size: int = 100, show_progress: bool = True - ): - vectors = pinecone_entry.create_pinecone_vectors() + @with_retry(exceptions=(ProtocolError,)) + def _get_vectors(self, entry): + return entry.create_pinecone_vectors() + + @with_retry(exceptions=(ProtocolError,)) + def _upsert(self, vectors, upsert_size: int = 100, show_progress: bool = True): self.index.upsert( vectors=vectors, batch_size=upsert_size, @@ -60,6 +80,12 @@ def upsert_entry( show_progress=show_progress, ) + def upsert_entry( + self, pinecone_entry: PineconeEntry, upsert_size: int = 100, show_progress: bool = True + ): + vectors = self._get_vectors(pinecone_entry) + self._upsert(vectors, upsert_size, show_progress) + def query_vector( self, query: List[float], @@ -111,8 +137,22 @@ def query_text( **kwargs, ) + def _find_items(self, ids): + @with_retry() + def get_item(id_): + return list(self.index.list(prefix=id_, namespace=PINECONE_NAMESPACE)) + + return [i for id_ in ids for i in get_item(id_)] + + @with_retry() + def _del_items(self, ids): + self.index.delete(ids=ids, namespace=PINECONE_NAMESPACE) + + @with_retry() def delete_entries(self, ids): - self.index.delete(filter={"hash_id": {"$in": ids}}) + items = self._find_items(ids) + if items: + self._del_items(items) def create_index(self, replace_current_index: bool = True): if replace_current_index: diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index 75f64f6..dc16b3d 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -17,8 +17,6 @@ from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB from align_data.embeddings.pinecone.pinecone_models import ( PineconeEntry, - MissingFieldsError, - MissingEmbeddingModelError, ) from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter @@ -93,12 +91,13 @@ def _articles_by_id(self, session, ids: List[str], force_update: bool): return get_pinecone_articles_by_ids(session, ids, force_update) def process_batch(self, batch: List[Tuple[Article, PineconeEntry | None]]): + logger.info(f'Processing batch of {len(batch)} items') for article, pinecone_entry in batch: if pinecone_entry: self.pinecone_db.upsert_entry(pinecone_entry) article.pinecone_status = PineconeStatus.added - return [a for a, _ in batch] + return [a for a, _ in batch] def batch_entries( self, article_stream: Generator[Article, None, None] @@ -149,6 +148,7 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: return None except Exception as e: + breakpoint() logger.error(e) raise