Skip to content

Commit

Permalink
Update pinecone (#181)
Browse files Browse the repository at this point in the history
* More pinecone statuses

* Add last checked column

* Remove invalid items from pinecone
  • Loading branch information
mruwnik authored Sep 4, 2023
1 parent d9232a4 commit bbd6d81
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 33 deletions.
15 changes: 12 additions & 3 deletions align_data/db/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import re
import json
import re
Expand All @@ -10,6 +11,7 @@
from sqlalchemy import (
JSON,
DateTime,
Enum,
ForeignKey,
String,
Boolean,
Expand Down Expand Up @@ -43,6 +45,13 @@ class Summary(Base):
article: Mapped["Article"] = relationship(back_populates="summaries")


class PineconeStatus(enum.Enum):
absent = 1
pending_removal = 2
pending_addition = 3
added = 4


class Article(Base):
__tablename__ = "articles"

Expand All @@ -66,7 +75,7 @@ class Article(Base):
status: Mapped[Optional[str]] = mapped_column(String(256))
comments: Mapped[Optional[str]] = mapped_column(LONGTEXT) # Editor comments. Can be anything

pinecone_update_required: Mapped[bool] = mapped_column(Boolean, default=False)
pinecone_status: Mapped[PineconeStatus] = mapped_column(Enum(PineconeStatus), default=PineconeStatus.absent)

summaries: Mapped[List["Summary"]] = relationship(
back_populates="article", cascade="all, delete-orphan"
Expand Down Expand Up @@ -186,8 +195,8 @@ def check_for_changes(cls, mapper, connection, target):
monitored_attributes = list(PineconeMetadata.__annotations__.keys())
monitored_attributes.remove("hash_id")

changed = any(get_history(target, attr).has_changes() for attr in monitored_attributes)
target.pinecone_update_required = changed
if any(get_history(target, attr).has_changes() for attr in monitored_attributes):
target.pinecone_status = PineconeStatus.pending_addition

def to_dict(self) -> Dict[str, Any]:
if date := self.date_published:
Expand Down
40 changes: 34 additions & 6 deletions align_data/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import create_engine, or_
from sqlalchemy.orm import Session
from align_data.settings import DB_CONNECTION_URI, MIN_CONFIDENCE
from align_data.db.models import Article
from align_data.db.models import Article, PineconeStatus


logger = logging.getLogger(__name__)
Expand All @@ -25,29 +25,43 @@ def make_session(auto_commit=False):
def get_pinecone_articles(
session: Session,
force_update: bool = False,
statuses: List[PineconeStatus] = [PineconeStatus.pending_addition],
):
return (
session.query(Article)
.filter(or_(Article.pinecone_update_required.is_(True), force_update))
.filter(or_(Article.pinecone_status.in_(statuses), force_update))
.filter(Article.is_valid)
.filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
.filter(or_(Article.confidence == None, Article.confidence >= MIN_CONFIDENCE))
)


def get_pinecone_articles_to_remove(session: Session):
return (
session.query(Article)
.filter(or_(
Article.pinecone_status == PineconeStatus.pending_removal,
Article.is_valid == False,
Article.confidence < MIN_CONFIDENCE
))
)


def get_pinecone_articles_by_sources(
session: Session,
custom_sources: List[str],
force_update: bool = False,
statuses: List[PineconeStatus] = [PineconeStatus.pending_addition],
):
return get_pinecone_articles(session, force_update).filter(Article.source.in_(custom_sources))
return get_pinecone_articles(session, force_update, statuses).filter(Article.source.in_(custom_sources))


def get_pinecone_articles_by_ids(
session: Session,
hash_ids: List[int],
hash_ids: List[str],
force_update: bool = False,
statuses: List[PineconeStatus] = [PineconeStatus.pending_addition],
):
return get_pinecone_articles(session, force_update).filter(Article.id.in_(hash_ids))
return get_pinecone_articles(session, force_update, statuses).filter(Article.id.in_(hash_ids))


def get_all_valid_article_ids(session: Session) -> List[str]:
Expand All @@ -59,3 +73,17 @@ def get_all_valid_article_ids(session: Session) -> List[str]:
.all()
)
return [item[0] for item in query_result]


def get_pinecone_to_delete_by_sources(
session: Session,
custom_sources: List[str],
):
return get_pinecone_articles_to_remove(session).filter(Article.source.in_(custom_sources))


def get_pinecone_to_delete_by_ids(
session: Session,
hash_ids: List[str],
):
return get_pinecone_articles_to_remove(session).filter(Article.id.in_(hash_ids))
114 changes: 94 additions & 20 deletions align_data/embeddings/pinecone/update_pinecone.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from datetime import datetime
import logging
from itertools import islice
from typing import Callable, List, Tuple, Generator, Iterator, Optional
from typing import Any, Callable, Iterable, List, Tuple, Generator, Iterator

from sqlalchemy.orm import Session
from pydantic import ValidationError

from align_data.embeddings.embedding_utils import get_embeddings
from align_data.db.models import Article
from align_data.db.models import Article, PineconeStatus
from align_data.db.session import (
make_session,
get_pinecone_articles_by_sources,
get_pinecone_articles_by_ids,
get_pinecone_to_delete_by_sources,
get_pinecone_to_delete_by_ids,
)
from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB
from align_data.embeddings.pinecone.pinecone_models import (
Expand All @@ -30,10 +31,17 @@
TruncateFunctionType = Callable[[str, int], str]


class PineconeUpdater:
def __init__(self):
self.text_splitter = ParagraphSentenceUnitTextSplitter()
self.pinecone_db = PineconeDB()
class PineconeAction:
batch_size = 10

def __init__(self, pinecone=None):
self.pinecone_db = pinecone or PineconeDB()

def _articles_by_source(self, session: Session, sources: List[str], force_update: bool) -> Iterable[Article]:
raise NotImplementedError

def _articles_by_id(self, session: Session, ids: List[str], force_update: bool) -> Iterable[Article]:
raise NotImplementedError

def update(self, custom_sources: List[str], force_update: bool = False):
"""
Expand All @@ -42,39 +50,61 @@ def update(self, custom_sources: List[str], force_update: bool = False):
:param custom_sources: List of sources to update.
"""
with make_session() as session:
articles_to_update = get_pinecone_articles_by_sources(
session, custom_sources, force_update
)
articles_to_update = self._articles_by_source(session, custom_sources, force_update)
logger.info('Processing %s items', articles_to_update.count())
for batch in self.batch_entries(articles_to_update):
self.save_batch(session, batch)

def update_articles_by_ids(self, hash_ids: List[int], force_update: bool = False):
"""Update the Pinecone entries of specific articles based on their hash_ids."""
with make_session() as session:
articles_to_update = get_pinecone_articles_by_ids(session, hash_ids, force_update)
articles_to_update = self._articles_by_id(session, hash_ids, force_update)
for batch in self.batch_entries(articles_to_update):
self.save_batch(session, batch)

def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry | None]]):
try:
for article, pinecone_entry in batch:
if pinecone_entry:
self.pinecone_db.upsert_entry(pinecone_entry)

article.pinecone_update_required = False
session.add(article)
def process_batch(self, batch: List[Tuple[Article, PineconeEntry | None]]) -> List[Article]:
raise NotImplementedError

def save_batch(self, session: Session, batch: List[Any]):
try:
session.add_all(self.process_batch(batch))
session.commit()

except Exception as e:
# Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking
logger.error(e)
session.rollback()

def batch_entries(self, article_stream: Generator[Article, None, None]) -> Iterator[List[Article]]:
while batch := tuple(islice(article_stream, self.batch_size)):
yield list(batch)


class PineconeAdder(PineconeAction):
batch_size = 10

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.text_splitter = ParagraphSentenceUnitTextSplitter()

def _articles_by_source(self, session, sources: List[str], force_update: bool):
return get_pinecone_articles_by_sources(session, sources, force_update)

def _articles_by_id(self, session, ids: List[str], force_update: bool):
return get_pinecone_articles_by_ids(session, ids, force_update)

def process_batch(self, batch: List[Tuple[Article, PineconeEntry | None]]):
for article, pinecone_entry in batch:
if pinecone_entry:
self.pinecone_db.upsert_entry(pinecone_entry)

article.pinecone_status = PineconeStatus.added
return [a for a, _ in batch]

def batch_entries(
self, article_stream: Generator[Article, None, None]
) -> Iterator[List[Tuple[Article, PineconeEntry | None]]]:
while batch := tuple(islice(article_stream, 10)):
while batch := tuple(islice(article_stream, self.batch_size)):
yield [(article, self._make_pinecone_entry(article)) for article in batch]

def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None:
Expand Down Expand Up @@ -123,6 +153,50 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None:
raise


class PineconeDeleter(PineconeAction):
batch_size = 100
pinecone_statuses = [PineconeStatus.pending_removal]

def _articles_by_source(self, session, sources: List[str], _force_update: bool):
return get_pinecone_to_delete_by_sources(session, sources)

def _articles_by_id(self, session, ids: List[str], _force_update: bool):
return get_pinecone_to_delete_by_ids(session, ids)

def process_batch(self, batch: List[Article]):
self.pinecone_db.delete_entries([a.id for a in batch])
logger.info('removing batch %s', len(batch))
for article in batch:
article.pinecone_status = PineconeStatus.removed
return batch


class PineconeUpdater(PineconeAction):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.adder = PineconeAdder(*args, pinecone=self.pinecone_db, **kwargs)
self.remover = PineconeDeleter(*args, pinecone=self.pinecone_db, **kwargs)

def update(self, custom_sources: List[str], force_update: bool = False):
"""
Update the given sources. If no sources are provided, updates all sources.
:param custom_sources: List of sources to update.
"""
logger.info('Adding pinecone entries for %s', custom_sources)
self.adder.update(custom_sources, force_update)

logger.info('Removing pinecone entries for %s', custom_sources)
self.remover.update(custom_sources, force_update)

def update_articles_by_ids(self, hash_ids: List[int], force_update: bool = False):
"""Update the Pinecone entries of specific articles based on their hash_ids."""
logger.info('Adding pinecone entries by hash_id')
self.adder.update_articles_by_ids(hash_ids, force_update)
logger.info('Removing pinecone entries by hash_id')
self.remover.update_articles_by_ids(hash_ids, force_update)


def get_text_chunks(
article: Article, text_splitter: ParagraphSentenceUnitTextSplitter
) -> List[str]:
Expand Down
54 changes: 54 additions & 0 deletions migrations/versions/1866340e456a_pinecone_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""pinecone status
Revision ID: 1866340e456a
Revises: f5a2bcfa6b2c
Create Date: 2023-09-03 15:34:02.755588
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql


# revision identifiers, used by Alembic.
revision = '1866340e456a'
down_revision = 'f5a2bcfa6b2c'
branch_labels = None
depends_on = None


def upgrade() -> None:
## Set the pinecone status
op.add_column('articles', sa.Column('pinecone_status', sa.String(length=32), nullable=False))

IS_VALID = """(
articles.status IS NULL AND
articles.text IS NOT NULL AND
articles.url IS NOT NULL AND
articles.title IS NOT NULL AND
articles.authors IS NOT NULL
)"""
op.execute(f"""
UPDATE articles SET pinecone_status = 'absent'
WHERE NOT articles.pinecone_update_required AND NOT {IS_VALID}
""")
op.execute(f"""
UPDATE articles SET pinecone_status = 'pending_removal'
WHERE articles.pinecone_update_required AND NOT {IS_VALID}
""")
op.execute(f"""
UPDATE articles SET pinecone_status = 'pending_addition'
WHERE articles.pinecone_update_required AND {IS_VALID}
""")
op.execute(f"""
UPDATE articles SET pinecone_status = 'added'
WHERE NOT articles.pinecone_update_required AND {IS_VALID}
""")

op.drop_column('articles', 'pinecone_update_required')


def downgrade() -> None:
op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False))
op.execute("UPDATE articles SET articles.pinecone_update_required = (pinecone_status = 'pending_addition')")
op.drop_column('articles', 'pinecone_status')
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("articles", "pinecone_update_required")
# ### end Alembic commands ###

0 comments on commit bbd6d81

Please sign in to comment.