diff --git a/align_data/db/models.py b/align_data/db/models.py index ae24650d..e79da232 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -1,3 +1,4 @@ +import enum import re import json import re @@ -10,6 +11,7 @@ from sqlalchemy import ( JSON, DateTime, + Enum, ForeignKey, String, Boolean, @@ -43,6 +45,13 @@ class Summary(Base): article: Mapped["Article"] = relationship(back_populates="summaries") +class PineconeStatus(enum.Enum): + absent = 1 + pending_removal = 2 + pending_addition = 3 + added = 4 + + class Article(Base): __tablename__ = "articles" @@ -66,7 +75,7 @@ class Article(Base): status: Mapped[Optional[str]] = mapped_column(String(256)) comments: Mapped[Optional[str]] = mapped_column(LONGTEXT) # Editor comments. Can be anything - pinecone_update_required: Mapped[bool] = mapped_column(Boolean, default=False) + pinecone_status: Mapped[PineconeStatus] = mapped_column(Enum(PineconeStatus), default=PineconeStatus.absent) summaries: Mapped[List["Summary"]] = relationship( back_populates="article", cascade="all, delete-orphan" @@ -186,8 +195,8 @@ def check_for_changes(cls, mapper, connection, target): monitored_attributes = list(PineconeMetadata.__annotations__.keys()) monitored_attributes.remove("hash_id") - changed = any(get_history(target, attr).has_changes() for attr in monitored_attributes) - target.pinecone_update_required = changed + if any(get_history(target, attr).has_changes() for attr in monitored_attributes): + target.pinecone_status = PineconeStatus.pending_addition def to_dict(self) -> Dict[str, Any]: if date := self.date_published: diff --git a/align_data/db/session.py b/align_data/db/session.py index d76fc796..331de9b7 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -5,7 +5,7 @@ from sqlalchemy import create_engine, or_ from sqlalchemy.orm import Session from align_data.settings import DB_CONNECTION_URI, MIN_CONFIDENCE -from align_data.db.models import Article +from align_data.db.models import Article, PineconeStatus logger = logging.getLogger(__name__) @@ -25,12 +25,24 @@ def make_session(auto_commit=False): def get_pinecone_articles( session: Session, force_update: bool = False, + statuses: List[PineconeStatus] = [PineconeStatus.pending_addition], ): return ( session.query(Article) - .filter(or_(Article.pinecone_update_required.is_(True), force_update)) + .filter(or_(Article.pinecone_status.in_(statuses), force_update)) .filter(Article.is_valid) - .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) + .filter(or_(Article.confidence == None, Article.confidence >= MIN_CONFIDENCE)) + ) + + +def get_pinecone_articles_to_remove(session: Session): + return ( + session.query(Article) + .filter(or_( + Article.pinecone_status == PineconeStatus.pending_removal, + Article.is_valid == False, + Article.confidence < MIN_CONFIDENCE + )) ) @@ -38,16 +50,18 @@ def get_pinecone_articles_by_sources( session: Session, custom_sources: List[str], force_update: bool = False, + statuses: List[PineconeStatus] = [PineconeStatus.pending_addition], ): - return get_pinecone_articles(session, force_update).filter(Article.source.in_(custom_sources)) + return get_pinecone_articles(session, force_update, statuses).filter(Article.source.in_(custom_sources)) def get_pinecone_articles_by_ids( session: Session, - hash_ids: List[int], + hash_ids: List[str], force_update: bool = False, + statuses: List[PineconeStatus] = [PineconeStatus.pending_addition], ): - return get_pinecone_articles(session, force_update).filter(Article.id.in_(hash_ids)) + return get_pinecone_articles(session, force_update, statuses).filter(Article.id.in_(hash_ids)) def get_all_valid_article_ids(session: Session) -> List[str]: @@ -59,3 +73,17 @@ def get_all_valid_article_ids(session: Session) -> List[str]: .all() ) return [item[0] for item in query_result] + + +def get_pinecone_to_delete_by_sources( + session: Session, + custom_sources: List[str], +): + return get_pinecone_articles_to_remove(session).filter(Article.source.in_(custom_sources)) + + +def get_pinecone_to_delete_by_ids( + session: Session, + hash_ids: List[str], +): + return get_pinecone_articles_to_remove(session).filter(Article.id.in_(hash_ids)) diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index 301b67eb..321fe6b1 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -1,17 +1,18 @@ -from datetime import datetime import logging from itertools import islice -from typing import Callable, List, Tuple, Generator, Iterator, Optional +from typing import Any, Callable, Iterable, List, Tuple, Generator, Iterator from sqlalchemy.orm import Session from pydantic import ValidationError from align_data.embeddings.embedding_utils import get_embeddings -from align_data.db.models import Article +from align_data.db.models import Article, PineconeStatus from align_data.db.session import ( make_session, get_pinecone_articles_by_sources, get_pinecone_articles_by_ids, + get_pinecone_to_delete_by_sources, + get_pinecone_to_delete_by_ids, ) from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB from align_data.embeddings.pinecone.pinecone_models import ( @@ -30,10 +31,17 @@ TruncateFunctionType = Callable[[str, int], str] -class PineconeUpdater: - def __init__(self): - self.text_splitter = ParagraphSentenceUnitTextSplitter() - self.pinecone_db = PineconeDB() +class PineconeAction: + batch_size = 10 + + def __init__(self, pinecone=None): + self.pinecone_db = pinecone or PineconeDB() + + def _articles_by_source(self, session: Session, sources: List[str], force_update: bool) -> Iterable[Article]: + raise NotImplementedError + + def _articles_by_id(self, session: Session, ids: List[str], force_update: bool) -> Iterable[Article]: + raise NotImplementedError def update(self, custom_sources: List[str], force_update: bool = False): """ @@ -42,28 +50,24 @@ def update(self, custom_sources: List[str], force_update: bool = False): :param custom_sources: List of sources to update. """ with make_session() as session: - articles_to_update = get_pinecone_articles_by_sources( - session, custom_sources, force_update - ) + articles_to_update = self._articles_by_source(session, custom_sources, force_update) + logger.info('Processing %s items', articles_to_update.count()) for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) def update_articles_by_ids(self, hash_ids: List[int], force_update: bool = False): """Update the Pinecone entries of specific articles based on their hash_ids.""" with make_session() as session: - articles_to_update = get_pinecone_articles_by_ids(session, hash_ids, force_update) + articles_to_update = self._articles_by_id(session, hash_ids, force_update) for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) - def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry | None]]): - try: - for article, pinecone_entry in batch: - if pinecone_entry: - self.pinecone_db.upsert_entry(pinecone_entry) - - article.pinecone_update_required = False - session.add(article) + def process_batch(self, batch: List[Tuple[Article, PineconeEntry | None]]) -> List[Article]: + raise NotImplementedError + def save_batch(self, session: Session, batch: List[Any]): + try: + session.add_all(self.process_batch(batch)) session.commit() except Exception as e: @@ -71,10 +75,36 @@ def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry logger.error(e) session.rollback() + def batch_entries(self, article_stream: Generator[Article, None, None]) -> Iterator[List[Article]]: + while batch := tuple(islice(article_stream, self.batch_size)): + yield list(batch) + + +class PineconeAdder(PineconeAction): + batch_size = 10 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.text_splitter = ParagraphSentenceUnitTextSplitter() + + def _articles_by_source(self, session, sources: List[str], force_update: bool): + return get_pinecone_articles_by_sources(session, sources, force_update) + + def _articles_by_id(self, session, ids: List[str], force_update: bool): + return get_pinecone_articles_by_ids(session, ids, force_update) + + def process_batch(self, batch: List[Tuple[Article, PineconeEntry | None]]): + for article, pinecone_entry in batch: + if pinecone_entry: + self.pinecone_db.upsert_entry(pinecone_entry) + + article.pinecone_status = PineconeStatus.added + return [a for a, _ in batch] + def batch_entries( self, article_stream: Generator[Article, None, None] ) -> Iterator[List[Tuple[Article, PineconeEntry | None]]]: - while batch := tuple(islice(article_stream, 10)): + while batch := tuple(islice(article_stream, self.batch_size)): yield [(article, self._make_pinecone_entry(article)) for article in batch] def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: @@ -123,6 +153,50 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: raise +class PineconeDeleter(PineconeAction): + batch_size = 100 + pinecone_statuses = [PineconeStatus.pending_removal] + + def _articles_by_source(self, session, sources: List[str], _force_update: bool): + return get_pinecone_to_delete_by_sources(session, sources) + + def _articles_by_id(self, session, ids: List[str], _force_update: bool): + return get_pinecone_to_delete_by_ids(session, ids) + + def process_batch(self, batch: List[Article]): + self.pinecone_db.delete_entries([a.id for a in batch]) + logger.info('removing batch %s', len(batch)) + for article in batch: + article.pinecone_status = PineconeStatus.removed + return batch + + +class PineconeUpdater(PineconeAction): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.adder = PineconeAdder(*args, pinecone=self.pinecone_db, **kwargs) + self.remover = PineconeDeleter(*args, pinecone=self.pinecone_db, **kwargs) + + def update(self, custom_sources: List[str], force_update: bool = False): + """ + Update the given sources. If no sources are provided, updates all sources. + + :param custom_sources: List of sources to update. + """ + logger.info('Adding pinecone entries for %s', custom_sources) + self.adder.update(custom_sources, force_update) + + logger.info('Removing pinecone entries for %s', custom_sources) + self.remover.update(custom_sources, force_update) + + def update_articles_by_ids(self, hash_ids: List[int], force_update: bool = False): + """Update the Pinecone entries of specific articles based on their hash_ids.""" + logger.info('Adding pinecone entries by hash_id') + self.adder.update_articles_by_ids(hash_ids, force_update) + logger.info('Removing pinecone entries by hash_id') + self.remover.update_articles_by_ids(hash_ids, force_update) + + def get_text_chunks( article: Article, text_splitter: ParagraphSentenceUnitTextSplitter ) -> List[str]: diff --git a/migrations/versions/1866340e456a_pinecone_status.py b/migrations/versions/1866340e456a_pinecone_status.py new file mode 100644 index 00000000..8a73964a --- /dev/null +++ b/migrations/versions/1866340e456a_pinecone_status.py @@ -0,0 +1,54 @@ +"""pinecone status + +Revision ID: 1866340e456a +Revises: f5a2bcfa6b2c +Create Date: 2023-09-03 15:34:02.755588 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + + +# revision identifiers, used by Alembic. +revision = '1866340e456a' +down_revision = 'f5a2bcfa6b2c' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + ## Set the pinecone status + op.add_column('articles', sa.Column('pinecone_status', sa.String(length=32), nullable=False)) + + IS_VALID = """( + articles.status IS NULL AND + articles.text IS NOT NULL AND + articles.url IS NOT NULL AND + articles.title IS NOT NULL AND + articles.authors IS NOT NULL + )""" + op.execute(f""" + UPDATE articles SET pinecone_status = 'absent' + WHERE NOT articles.pinecone_update_required AND NOT {IS_VALID} + """) + op.execute(f""" + UPDATE articles SET pinecone_status = 'pending_removal' + WHERE articles.pinecone_update_required AND NOT {IS_VALID} + """) + op.execute(f""" + UPDATE articles SET pinecone_status = 'pending_addition' + WHERE articles.pinecone_update_required AND {IS_VALID} + """) + op.execute(f""" + UPDATE articles SET pinecone_status = 'added' + WHERE NOT articles.pinecone_update_required AND {IS_VALID} + """) + + op.drop_column('articles', 'pinecone_update_required') + + +def downgrade() -> None: + op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)) + op.execute("UPDATE articles SET articles.pinecone_update_required = (pinecone_status = 'pending_addition')") + op.drop_column('articles', 'pinecone_status') diff --git a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py index e5b9a303..92a63bb6 100644 --- a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py +++ b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py @@ -17,12 +17,8 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("articles", "pinecone_update_required") - # ### end Alembic commands ###