Skip to content

Commit

Permalink
fixed pinecone updater
Browse files Browse the repository at this point in the history
  • Loading branch information
henri123lemoine committed Aug 4, 2023
1 parent 61d93d4 commit 838ef4f
Showing 1 changed file with 53 additions and 115 deletions.
168 changes: 53 additions & 115 deletions align_data/pinecone/update_pinecone.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import os
from typing import Dict, List, Union
from typing import Callable, Dict, List, Tuple, Union, Generator
from pydantic import BaseModel, ValidationError, validator
import numpy as np
import openai
import logging
from dataclasses import dataclass
from datetime import datetime

from align_data.db.models import Article
from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter
from align_data.db.session import MySQLDB
from align_data.pinecone.pinecone_db_handler import PineconeDB

from align_data.settings import USE_OPENAI_EMBEDDINGS, OPENAI_EMBEDDINGS_MODEL, \
OPENAI_EMBEDDINGS_DIMS, OPENAI_EMBEDDINGS_RATE_LIMIT, \
SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, \
CHUNK_SIZE, MAX_NUM_AUTHORS_IN_SIGNATURE, EMBEDDING_LENGTH_BIAS

import logging

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -43,12 +49,21 @@ def __init__(
self,
min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE,
max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE,
length_function: Callable[[str], int] = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION,
truncate_function: Callable[[str, int], str] = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION,
):
self.min_chunk_size = min_chunk_size
self.max_chunk_size = max_chunk_size
self.length_function = length_function
self.truncate_function = truncate_function

self.text_splitter = ParagraphSentenceUnitTextSplitter(
min_chunk_size=min_chunk_size,
max_chunk_size=max_chunk_size,
min_chunk_size=self.min_chunk_size,
max_chunk_size=self.max_chunk_size,
length_function=self.length_function,
truncate_function=self.truncate_function
)

self.mysql_db = MySQLDB()
self.pinecone_db = PineconeDB()

if USE_OPENAI_EMBEDDINGS:
Expand All @@ -63,131 +78,54 @@ def __init__(
model_kwargs={'device': "cuda" if torch.cuda.is_available() else "cpu"},
encode_kwargs={'show_progress_bar': False}
)

def update(self, custom_sources: List[str] = ['all']):
"""
Update the given sources. If no sources are provided, updates all sources.
:param custom_sources: List of sources to update.
"""

for source in custom_sources:
self.update_source(source)

def update_source(self, source: str):
"""
Updates the entries from the given source.
:param source: The name of the source to update.
"""

logger.info(f"Updating {source} entries...")
with self.mysql_db.session_scope() as session:
entries_stream = self.mysql_db.stream_pinecone_updates(custom_sources)
pinecone_entries_stream = self.process_entries(entries_stream)
for pinecone_entry in pinecone_entries_stream:
self.pinecone_db.upsert_entry(pinecone_entry.dict())

pinecone_entry_db = session.query(Article).filter(Article.id == pinecone_entry.id).one()
pinecone_entry_db.pinecone_update_required = False
session.add(pinecone_entry_db)
session.commit()

# TODO: loop through mysql entries and update the pinecone db

logger.info(f"Successfully updated {source} entries.")
def process_entries(self, article_stream: Generator[Article, None, None]) -> Generator[PineconeEntry, None, None]:
for article in article_stream:
try:
text_chunks = self.get_text_chunks(article)
yield PineconeEntry(
id=article.id,
source=article.source,
title=article.title,
url=article.url,
date_published=article.date_published,
authors=[author.strip() for author in article.authors.split(',') if author.strip()],
text_chunks=text_chunks,
embeddings=self.extract_embeddings(text_chunks, [article.source] * len(text_chunks))
)
except (ValueError, ValidationError) as e:
print(e)
pass

def batchify(self, iterable):
"""
Divides the iterable into batches of size ~CHUNK_SIZE.
:param iterable: The iterable to divide into batches.
:returns: A generator that yields batches from the iterable.
"""

entries_batch = []
chunks_batch = []
chunks_ids_batch = []
sources_batch = []

for entry in iterable:
chunks, chunks_ids = self.create_chunk_ids_and_authors(entry)

entries_batch.append(entry)
chunks_batch.extend(chunks)
chunks_ids_batch.extend(chunks_ids)
sources_batch.extend([entry['source']] * len(chunks))

# If this batch is large enough, yield it and start a new one.
if len(chunks_batch) >= CHUNK_SIZE:
yield self._create_batch(entries_batch, chunks_batch, chunks_ids_batch, sources_batch)

entries_batch = []
chunks_batch = []
chunks_ids_batch = []
sources_batch = []

# Yield any remaining items.
if entries_batch:
yield self._create_batch(entries_batch, chunks_batch, chunks_ids_batch, sources_batch)

def create_chunk_ids_and_authors(self, entry):
signature = f"Title: {entry['title']}, Author(s): {self.get_authors_str(entry['authors'])}"
chunks = self.text_splitter.split_text(entry['text'])
chunks = [f"- {signature}\n\n{chunk}" for chunk in chunks]
chunks_ids = [f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))]
return chunks, chunks_ids

def _create_batch(self, entries_batch, chunks_batch, chunks_ids_batch, sources_batch):
return {'entries_batch': entries_batch, 'chunks_batch': chunks_batch, 'chunks_ids_batch': chunks_ids_batch, 'sources_batch': sources_batch}

def is_sql_entry_upserted(self, entry):
"""Upserts an entry to the SQL database and returns the success status"""
return self.sql_db.upsert_entry(entry)
def get_text_chunks(self, article: Article) -> List[str]:
signature = f"Title: {article.title}, Author(s): {self.get_authors_str(article.authors)}"
text_chunks = self.text_splitter.split_text(article.text)
text_chunks = [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks]
return text_chunks

def extract_embeddings(self, chunks_batch, sources_batch):
if USE_OPENAI_EMBEDDINGS:
return self.get_openai_embeddings(chunks_batch, sources_batch)
else:
return np.array(self.hf_embeddings.embed_documents(chunks_batch, sources_batch))

def reset_dbs(self):
self.sql_db.create_tables(True)
self.pinecone_db.create_index(True)

@staticmethod
def preprocess_and_validate(entry):
"""Preprocesses and validates the entry data"""
try:
ARDUpdater.validate_entry(entry)

return {
'id': entry['id'],
'source': entry['source'],
'title': entry['title'],
'text': entry['text'],
'url': entry['url'],
'date_published': entry['date_published'],
'authors': entry['authors']
}
except ValueError as e:
logger.error(f"Entry validation failed: {str(e)}", exc_info=True)
return None

@staticmethod
def validate_entry(entry: Dict[str, Union[str, list]], char_len_lower_limit: int = 0):
metadata_types = {
'id': str,
'source': str,
'title': str,
'text': str,
'url': str,
'date_published': str,
'authors': list
}

for metadata_type, metadata_type_type in metadata_types.items():
if not isinstance(entry.get(metadata_type), metadata_type_type):
raise ValueError(f"Entry metadata '{metadata_type}' is not of type '{metadata_type_type}' or is missing.")

if len(entry['text']) < char_len_lower_limit:
raise ValueError(f"Entry text is too short (< {char_len_lower_limit} characters).")

@staticmethod
def is_valid_entry(entry):
"""Checks if the entry is valid"""
return entry is not None

@staticmethod
def get_openai_embeddings(chunks, sources=''):
embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS))
Expand Down

0 comments on commit 838ef4f

Please sign in to comment.