From adc6d266fbfd9c44d2d0dc21f93e82dd6b40dfbc Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Sun, 17 Dec 2023 23:14:43 +0100 Subject: [PATCH] refresh session --- align_data/common/alignment_dataset.py | 6 +++--- align_data/db/session.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 06da73d..91db5ac 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -109,9 +109,9 @@ def commit() -> bool: session.rollback() return False - with make_session() as session: - items = iter(entries) - while batch := tuple(islice(items, self.batch_size)): + items = iter(entries) + while batch := tuple(islice(items, self.batch_size)): + with make_session() as session: self._add_batch(session, batch) # there might be duplicates in the batch, so if they cause # an exception, try to commit them one by one diff --git a/align_data/db/session.py b/align_data/db/session.py index 2e80c4b..ae26dd6 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) # We create a single engine for the entire application -engine = create_engine(DB_CONNECTION_URI, echo=False) +engine = create_engine(DB_CONNECTION_URI, echo=False, pool_pre_ping=True) @contextmanager @@ -35,13 +35,12 @@ def get_pinecone_articles( def get_pinecone_articles_to_remove(session: Session): - return ( - session.query(Article) - .filter(or_( + return session.query(Article).filter( + or_( Article.pinecone_status == PineconeStatus.pending_removal, Article.is_valid == False, - Article.confidence < MIN_CONFIDENCE - )) + Article.confidence < MIN_CONFIDENCE, + ) ) @@ -51,7 +50,9 @@ def get_pinecone_articles_by_sources( force_update: bool = False, statuses: List[PineconeStatus] = [PineconeStatus.pending_addition], ): - return get_pinecone_articles(session, force_update, statuses).filter(Article.source.in_(custom_sources)) + return get_pinecone_articles(session, force_update, statuses).filter( + Article.source.in_(custom_sources) + ) def get_pinecone_articles_by_ids(