From 838ef4fd7d0b837604bcc261f3f6c6cd9a91fb8c Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Fri, 4 Aug 2023 12:46:15 -0400 Subject: [PATCH] fixed pinecone updater --- align_data/pinecone/update_pinecone.py | 168 ++++++++----------------- 1 file changed, 53 insertions(+), 115 deletions(-) diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py index b4e0393c..6a48fc51 100644 --- a/align_data/pinecone/update_pinecone.py +++ b/align_data/pinecone/update_pinecone.py @@ -1,9 +1,15 @@ import os -from typing import Dict, List, Union +from typing import Callable, Dict, List, Tuple, Union, Generator +from pydantic import BaseModel, ValidationError, validator import numpy as np import openai +import logging +from dataclasses import dataclass +from datetime import datetime +from align_data.db.models import Article from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter +from align_data.db.session import MySQLDB from align_data.pinecone.pinecone_db_handler import PineconeDB from align_data.settings import USE_OPENAI_EMBEDDINGS, OPENAI_EMBEDDINGS_MODEL, \ @@ -11,7 +17,7 @@ SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, \ CHUNK_SIZE, MAX_NUM_AUTHORS_IN_SIGNATURE, EMBEDDING_LENGTH_BIAS -import logging + logger = logging.getLogger(__name__) @@ -43,12 +49,21 @@ def __init__( self, min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE, max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE, + length_function: Callable[[str], int] = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION, + truncate_function: Callable[[str, int], str] = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION, ): + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + self.length_function = length_function + self.truncate_function = truncate_function + self.text_splitter = ParagraphSentenceUnitTextSplitter( - min_chunk_size=min_chunk_size, - max_chunk_size=max_chunk_size, + min_chunk_size=self.min_chunk_size, + max_chunk_size=self.max_chunk_size, + length_function=self.length_function, + truncate_function=self.truncate_function ) - + self.mysql_db = MySQLDB() self.pinecone_db = PineconeDB() if USE_OPENAI_EMBEDDINGS: @@ -63,77 +78,47 @@ def __init__( model_kwargs={'device': "cuda" if torch.cuda.is_available() else "cpu"}, encode_kwargs={'show_progress_bar': False} ) - + def update(self, custom_sources: List[str] = ['all']): """ Update the given sources. If no sources are provided, updates all sources. :param custom_sources: List of sources to update. """ - - for source in custom_sources: - self.update_source(source) - - def update_source(self, source: str): - """ - Updates the entries from the given source. - - :param source: The name of the source to update. - """ - - logger.info(f"Updating {source} entries...") + with self.mysql_db.session_scope() as session: + entries_stream = self.mysql_db.stream_pinecone_updates(custom_sources) + pinecone_entries_stream = self.process_entries(entries_stream) + for pinecone_entry in pinecone_entries_stream: + self.pinecone_db.upsert_entry(pinecone_entry.dict()) + + pinecone_entry_db = session.query(Article).filter(Article.id == pinecone_entry.id).one() + pinecone_entry_db.pinecone_update_required = False + session.add(pinecone_entry_db) + session.commit() - # TODO: loop through mysql entries and update the pinecone db - - logger.info(f"Successfully updated {source} entries.") + def process_entries(self, article_stream: Generator[Article, None, None]) -> Generator[PineconeEntry, None, None]: + for article in article_stream: + try: + text_chunks = self.get_text_chunks(article) + yield PineconeEntry( + id=article.id, + source=article.source, + title=article.title, + url=article.url, + date_published=article.date_published, + authors=[author.strip() for author in article.authors.split(',') if author.strip()], + text_chunks=text_chunks, + embeddings=self.extract_embeddings(text_chunks, [article.source] * len(text_chunks)) + ) + except (ValueError, ValidationError) as e: + print(e) + pass - def batchify(self, iterable): - """ - Divides the iterable into batches of size ~CHUNK_SIZE. - - :param iterable: The iterable to divide into batches. - :returns: A generator that yields batches from the iterable. - """ - - entries_batch = [] - chunks_batch = [] - chunks_ids_batch = [] - sources_batch = [] - - for entry in iterable: - chunks, chunks_ids = self.create_chunk_ids_and_authors(entry) - - entries_batch.append(entry) - chunks_batch.extend(chunks) - chunks_ids_batch.extend(chunks_ids) - sources_batch.extend([entry['source']] * len(chunks)) - - # If this batch is large enough, yield it and start a new one. - if len(chunks_batch) >= CHUNK_SIZE: - yield self._create_batch(entries_batch, chunks_batch, chunks_ids_batch, sources_batch) - - entries_batch = [] - chunks_batch = [] - chunks_ids_batch = [] - sources_batch = [] - - # Yield any remaining items. - if entries_batch: - yield self._create_batch(entries_batch, chunks_batch, chunks_ids_batch, sources_batch) - - def create_chunk_ids_and_authors(self, entry): - signature = f"Title: {entry['title']}, Author(s): {self.get_authors_str(entry['authors'])}" - chunks = self.text_splitter.split_text(entry['text']) - chunks = [f"- {signature}\n\n{chunk}" for chunk in chunks] - chunks_ids = [f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))] - return chunks, chunks_ids - - def _create_batch(self, entries_batch, chunks_batch, chunks_ids_batch, sources_batch): - return {'entries_batch': entries_batch, 'chunks_batch': chunks_batch, 'chunks_ids_batch': chunks_ids_batch, 'sources_batch': sources_batch} - - def is_sql_entry_upserted(self, entry): - """Upserts an entry to the SQL database and returns the success status""" - return self.sql_db.upsert_entry(entry) + def get_text_chunks(self, article: Article) -> List[str]: + signature = f"Title: {article.title}, Author(s): {self.get_authors_str(article.authors)}" + text_chunks = self.text_splitter.split_text(article.text) + text_chunks = [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks] + return text_chunks def extract_embeddings(self, chunks_batch, sources_batch): if USE_OPENAI_EMBEDDINGS: @@ -141,53 +126,6 @@ def extract_embeddings(self, chunks_batch, sources_batch): else: return np.array(self.hf_embeddings.embed_documents(chunks_batch, sources_batch)) - def reset_dbs(self): - self.sql_db.create_tables(True) - self.pinecone_db.create_index(True) - - @staticmethod - def preprocess_and_validate(entry): - """Preprocesses and validates the entry data""" - try: - ARDUpdater.validate_entry(entry) - - return { - 'id': entry['id'], - 'source': entry['source'], - 'title': entry['title'], - 'text': entry['text'], - 'url': entry['url'], - 'date_published': entry['date_published'], - 'authors': entry['authors'] - } - except ValueError as e: - logger.error(f"Entry validation failed: {str(e)}", exc_info=True) - return None - - @staticmethod - def validate_entry(entry: Dict[str, Union[str, list]], char_len_lower_limit: int = 0): - metadata_types = { - 'id': str, - 'source': str, - 'title': str, - 'text': str, - 'url': str, - 'date_published': str, - 'authors': list - } - - for metadata_type, metadata_type_type in metadata_types.items(): - if not isinstance(entry.get(metadata_type), metadata_type_type): - raise ValueError(f"Entry metadata '{metadata_type}' is not of type '{metadata_type_type}' or is missing.") - - if len(entry['text']) < char_len_lower_limit: - raise ValueError(f"Entry text is too short (< {char_len_lower_limit} characters).") - - @staticmethod - def is_valid_entry(entry): - """Checks if the entry is valid""" - return entry is not None - @staticmethod def get_openai_embeddings(chunks, sources=''): embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS))