diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 2f8d236f..8086a66a 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -2,6 +2,7 @@ import time import zipfile from dataclasses import dataclass, field, KW_ONLY +from itertools import islice from pathlib import Path from typing import List from sqlalchemy import select @@ -12,7 +13,7 @@ import pytz from dateutil.parser import parse, ParserError from tqdm import tqdm -from align_data.db.models import Article, Author +from align_data.db.models import Article from align_data.db.session import make_session @@ -86,7 +87,7 @@ def make_data_entry(self, data, **kwargs): data = dict(data, **kwargs) # TODO: Don't keep adding the same authors - come up with some way to reuse them # TODO: Prettify this - data['authors'] = [Author(name=name) for name in data.get('authors', [])] + data['authors'] = ','.join(data.get('authors', [])) if summary := ('summary' in data and data.pop('summary')): data['summaries'] = [summary] return Article( @@ -106,21 +107,33 @@ def to_jsonl(self, out_path=None, filename=None): for article in self.read_entries(): jsonl_writer.write(article.to_dict()) - def read_entries(self): + def read_entries(self, sort_by=None): """Iterate through all the saved entries.""" with make_session() as session: - for item in session.scalars(select(Article).where(Article.source==self.name)): + query = select(Article).where(Article.source==self.name) + if sort_by is not None: + query = query.order_by(sort_by) + for item in session.scalars(query): yield item def add_entries(self, entries): + def commit(): + try: + session.commit() + return True + except IntegrityError: + session.rollback() + with make_session() as session: - for entry in entries: - session.add(entry) - try: - session.commit() - except IntegrityError: - logger.error(f'found duplicate of {entry}') - session.rollback() + while batch := tuple(islice(entries, 20)): + session.add_all(entries) + # there might be duplicates in the batch, so if they cause + # an exception, try to commit them one by one + if not commit(): + for entry in batch: + session.add(entry) + if not commit(): + logger.error(f'found duplicate of {entry}') def setup(self): # make sure the path to the raw data exists diff --git a/align_data/db/models.py b/align_data/db/models.py index e8c4fcf9..f4c68712 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -2,9 +2,8 @@ import hashlib from datetime import datetime from typing import List, Optional -from sqlalchemy import JSON, DateTime, ForeignKey, Table, String, Column, Integer, func, Text, event +from sqlalchemy import JSON, DateTime, ForeignKey, String, func, Text, event from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy from sqlalchemy.dialects.mysql import LONGTEXT @@ -12,23 +11,6 @@ class Base(DeclarativeBase): pass -author_article = Table( - 'author_article', - Base.metadata, - Column('article_id', Integer, ForeignKey('articles.id'), primary_key=True), - Column('author_id', Integer, ForeignKey('authors.id'), primary_key=True), -) - - -class Author(Base): - - __tablename__ = "authors" - - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(String(256), nullable=False) - articles: Mapped[List["Article"]] = relationship(secondary=author_article, back_populates="authors") - - class Summary(Base): __tablename__ = "summaries" @@ -50,13 +32,13 @@ class Article(Base): url: Mapped[Optional[str]] = mapped_column(String(1028)) source: Mapped[Optional[str]] = mapped_column(String(128)) source_type: Mapped[Optional[str]] = mapped_column(String(128)) + authors: Mapped[str] = mapped_column(String(1024)) text: Mapped[Optional[str]] = mapped_column(LONGTEXT) date_published: Mapped[Optional[datetime]] meta: Mapped[Optional[JSON]] = mapped_column(JSON, name='metadata', default='{}') date_created: Mapped[datetime] = mapped_column(DateTime, default=func.now()) date_updated: Mapped[Optional[datetime]] = mapped_column(DateTime, onupdate=func.current_timestamp()) - authors: Mapped[List['Author']] = relationship(secondary=author_article, back_populates="articles") summaries: Mapped[List["Summary"]] = relationship(back_populates="article", cascade="all, delete-orphan") __id_fields = ['title', 'url'] @@ -103,7 +85,7 @@ def to_dict(self): 'source_type': self.source_type, 'text': self.text, 'date_published': date, - 'authors': [a.name for a in self.authors], + 'authors': [i.strip() for i in self.authors.split(',')] if self.authors.strip() else [], 'summaries': [s.text for s in self.summaries], **self.meta, } diff --git a/align_data/greaterwrong/greaterwrong.py b/align_data/greaterwrong/greaterwrong.py index 84974068..850b12b3 100644 --- a/align_data/greaterwrong/greaterwrong.py +++ b/align_data/greaterwrong/greaterwrong.py @@ -1,17 +1,15 @@ -from datetime import datetime, timezone -from dateutil.parser import parse +from datetime import datetime import logging import time from dataclasses import dataclass -from pathlib import Path import requests import jsonlines from bs4 import BeautifulSoup -from tqdm import tqdm from markdownify import markdownify from align_data.common.alignment_dataset import AlignmentDataset +from align_data.db.models import Article logger = logging.getLogger(__name__) @@ -139,14 +137,18 @@ def fetch_posts(self, query: str): return res.json()['data']['posts'] @property - def items_list(self): - next_date = datetime(self.start_year, 1, 1).isoformat() + 'Z' - if self.jsonl_path.exists() and self.jsonl_path.lstat().st_size: - with jsonlines.open(self.jsonl_path) as f: - for item in f: - if item['date_published'] > next_date: - next_date = item['date_published'] + def last_date_published(self): + try: + prev_item = next(self.read_entries(sort_by=Article.date_published.desc())) + if prev_item and prev_item.date_published: + return prev_item.date_published.isoformat() + 'Z' + except StopIteration: + pass + return datetime(self.start_year, 1, 1).isoformat() + 'Z' + @property + def items_list(self): + next_date = self.last_date_published logger.info('Starting from %s', next_date) while next_date: posts = self.fetch_posts(self.make_query(next_date)) diff --git a/migrations/versions/8c11b666e86f_initial_structure.py b/migrations/versions/983b5bdef5f6_initial_structure.py similarity index 68% rename from migrations/versions/8c11b666e86f_initial_structure.py rename to migrations/versions/983b5bdef5f6_initial_structure.py index 113371db..ff1ef321 100644 --- a/migrations/versions/8c11b666e86f_initial_structure.py +++ b/migrations/versions/983b5bdef5f6_initial_structure.py @@ -1,8 +1,8 @@ """initial structure -Revision ID: 8c11b666e86f +Revision ID: 983b5bdef5f6 Revises: -Create Date: 2023-07-14 15:48:49.149905 +Create Date: 2023-07-18 15:54:58.299651 """ from alembic import op @@ -10,7 +10,7 @@ from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = '8c11b666e86f' +revision = '983b5bdef5f6' down_revision = None branch_labels = None depends_on = None @@ -25,6 +25,7 @@ def upgrade() -> None: sa.Column('url', sa.String(length=1028), nullable=True), sa.Column('source', sa.String(length=128), nullable=True), sa.Column('source_type', sa.String(length=128), nullable=True), + sa.Column('authors', sa.String(length=1024), nullable=False), sa.Column('text', mysql.LONGTEXT(), nullable=True), sa.Column('date_published', sa.DateTime(), nullable=True), sa.Column('metadata', sa.JSON(), nullable=True), @@ -33,20 +34,6 @@ def upgrade() -> None: sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('hash_id') ) - op.create_table( - 'authors', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=256), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_table( - 'author_article', - sa.Column('article_id', sa.Integer(), nullable=False), - sa.Column('author_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['article_id'], ['articles.id'], ), - sa.ForeignKeyConstraint(['author_id'], ['authors.id'], ), - sa.PrimaryKeyConstraint('article_id', 'author_id') - ) op.create_table( 'summaries', sa.Column('id', sa.Integer(), nullable=False), @@ -60,6 +47,4 @@ def upgrade() -> None: def downgrade() -> None: op.drop_table('summaries') - op.drop_table('author_article') - op.drop_table('authors') op.drop_table('articles') diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/test_greater_wrong.py index c5cca825..72b32b83 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/test_greater_wrong.py @@ -112,8 +112,6 @@ def fetcher(next_date): def test_items_list_with_previous_items(dataset): dataset.ai_tags = {'tag1', 'tag2'} - with open(dataset.jsonl_path, 'w') as f: - f.write('{"date_published": "2014-12-12T01:23:45Z"}\n') def make_item(date): return { @@ -135,12 +133,14 @@ def fetcher(next_date): ] return {'results': results} + mock_items = (i for i in [Mock(date_published=datetime.fromisoformat('2014-12-12T01:23:45'))]) with patch.object(dataset, 'fetch_posts', fetcher): with patch.object(dataset, 'make_query', lambda next_date: next_date): - # All items that are older than the newest item in the jsonl file are ignored - assert list(dataset.items_list) == [ - make_item(datetime(2014, 12, 12, 1, 23, 45).replace(tzinfo=pytz.UTC) + timedelta(days=i*30)) - for i in range(1, 4) + with patch.object(dataset, 'read_entries', return_value=mock_items): + # All items that are older than the newest item in the jsonl file are ignored + assert list(dataset.items_list) == [ + make_item(datetime(2014, 12, 12, 1, 23, 45).replace(tzinfo=pytz.UTC) + timedelta(days=i*30)) + for i in range(1, 4) ]