Skip to content

Commit

Permalink
added confidence, added summaries to pinecone db, fixed session
Browse files Browse the repository at this point in the history
  • Loading branch information
henri123lemoine committed Aug 29, 2023
1 parent 6fa3584 commit f92d633
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
13 changes: 6 additions & 7 deletions align_data/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,32 @@ def make_session(auto_commit=False):
session.commit()


def get_pinecone(
def get_pinecone_query(
session: Session,
force_update: bool = False,
):
yield from (
return (
session.query(Article)
.filter(or_(Article.pinecone_update_required.is_(True), force_update))
.filter(Article.is_valid)
.filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
)


def get_pinecone_from_sources(
def get_pinecone_from_sources_query(
session: Session,
custom_sources: List[str],
force_update: bool = False,
):
yield from get_pinecone(session, force_update).filter(Article.source.in_(custom_sources))
return get_pinecone_query(session, force_update).filter(Article.source.in_(custom_sources))


def get_pinecone_articles_by_ids(
def get_pinecone_articles_by_ids_query(
session: Session,
hash_ids: List[int],
force_update: bool = False,
):
"""Yield Pinecone entries that require an update and match the given IDs."""
yield from get_pinecone_from_sources(session, force_update).filter(Article.id.in_(hash_ids))
return get_pinecone_from_sources_query(session, force_update).filter(Article.id.in_(hash_ids))


def get_all_valid_article_ids(session: Session) -> List[str]:
Expand Down
25 changes: 16 additions & 9 deletions align_data/embeddings/pinecone/pinecone_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class PineconeMetadata(TypedDict):
date_published: float
authors: List[str]
text: str
confidence: float | None


class PineconeEntry(BaseModel):
Expand All @@ -30,6 +31,7 @@ class PineconeEntry(BaseModel):
date_published: float
authors: List[str]
text_chunks: List[str]
confidence: float | None
embeddings: List[List[float] | None]

def __init__(self, **data):
Expand Down Expand Up @@ -62,15 +64,20 @@ def create_pinecone_vectors(self) -> List[Vector]:
Vector(
id=f"{self.hash_id}_{str(i).zfill(6)}",
values=self.embeddings[i],
metadata=PineconeMetadata(
hash_id=self.hash_id,
source=self.source,
title=self.title,
authors=self.authors,
url=self.url,
date_published=self.date_published,
text=self.text_chunks[i],
),
metadata={
key: value
for key, value in PineconeMetadata(
hash_id=self.hash_id,
source=self.source,
title=self.title,
authors=self.authors,
url=self.url,
date_published=self.date_published,
text=self.text_chunks[i],
confidence=self.confidence,
).items()
if value is not None # Filter out keys with None values
},
)
for i in range(self.chunk_num)
if self.embeddings[i] # Skips flagged chunks
Expand Down
17 changes: 11 additions & 6 deletions align_data/embeddings/pinecone/update_pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from align_data.db.models import Article
from align_data.db.session import (
make_session,
get_pinecone_articles_to_update,
get_pinecone_articles_by_ids,
get_pinecone_from_sources_query,
get_pinecone_articles_by_ids_query,
)
from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB
from align_data.embeddings.pinecone.pinecone_models import (
Expand Down Expand Up @@ -42,21 +42,21 @@ def update(self, custom_sources: List[str], force_update: bool = False):
:param custom_sources: List of sources to update.
"""
with make_session() as session:
articles_to_update_stream = get_pinecone_articles_to_update(
articles_to_update_query = get_pinecone_from_sources_query(
session, custom_sources, force_update
)
for batch in self.batch_entries(articles_to_update_stream):
for batch in self.batch_entries(articles_to_update_query):
self.save_batch(session, batch)

def update_articles_by_ids(
self, custom_sources: List[str], hash_ids: List[int], force_update: bool = False
):
"""Update the Pinecone entries of specific articles based on their hash_ids."""
with make_session() as session:
articles_to_update_stream = get_pinecone_articles_by_ids(
articles_to_update_query = get_pinecone_articles_by_ids_query(
session, hash_ids, custom_sources, force_update
)
for batch in self.batch_entries(articles_to_update_stream):
for batch in self.batch_entries(articles_to_update_query):
self.save_batch(session, batch)

def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]):
Expand Down Expand Up @@ -108,6 +108,7 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None:
authors=[author.strip() for author in article.authors.split(",") if author.strip()],
text_chunks=text_chunks,
embeddings=embeddings,
confidence=article.confidence,
)
except (
ValueError,
Expand Down Expand Up @@ -135,7 +136,11 @@ def get_text_chunks(
authors = get_authors_str(authors_lst)

signature = f"Title: {title}; Author(s): {authors}."

text_chunks = text_splitter.split_text(article.text)
for summary in article.summaries:
text_chunks += text_splitter.split_text(summary.text)

return [f'###{signature}###\n"""{text_chunk}"""' for text_chunk in text_chunks]


Expand Down

0 comments on commit f92d633

Please sign in to comment.