Skip to content

Commit

Permalink
authors as string
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Jul 18, 2023
1 parent aafefbc commit 1ac1b47
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 68 deletions.
35 changes: 24 additions & 11 deletions align_data/common/alignment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import zipfile
from dataclasses import dataclass, field, KW_ONLY
from itertools import islice
from pathlib import Path
from typing import List
from sqlalchemy import select
Expand All @@ -12,7 +13,7 @@
import pytz
from dateutil.parser import parse, ParserError
from tqdm import tqdm
from align_data.db.models import Article, Author
from align_data.db.models import Article
from align_data.db.session import make_session


Expand Down Expand Up @@ -86,7 +87,7 @@ def make_data_entry(self, data, **kwargs):
data = dict(data, **kwargs)
# TODO: Don't keep adding the same authors - come up with some way to reuse them
# TODO: Prettify this
data['authors'] = [Author(name=name) for name in data.get('authors', [])]
data['authors'] = ','.join(data.get('authors', []))
if summary := ('summary' in data and data.pop('summary')):
data['summaries'] = [summary]
return Article(
Expand All @@ -106,21 +107,33 @@ def to_jsonl(self, out_path=None, filename=None):
for article in self.read_entries():
jsonl_writer.write(article.to_dict())

def read_entries(self):
def read_entries(self, sort_by=None):
"""Iterate through all the saved entries."""
with make_session() as session:
for item in session.scalars(select(Article).where(Article.source==self.name)):
query = select(Article).where(Article.source==self.name)
if sort_by is not None:
query = query.order_by(sort_by)
for item in session.scalars(query):
yield item

def add_entries(self, entries):
def commit():
try:
session.commit()
return True
except IntegrityError:
session.rollback()

with make_session() as session:
for entry in entries:
session.add(entry)
try:
session.commit()
except IntegrityError:
logger.error(f'found duplicate of {entry}')
session.rollback()
while batch := tuple(islice(entries, 20)):
session.add_all(entries)
# there might be duplicates in the batch, so if they cause
# an exception, try to commit them one by one
if not commit():
for entry in batch:
session.add(entry)
if not commit():
logger.error(f'found duplicate of {entry}')

def setup(self):
# make sure the path to the raw data exists
Expand Down
24 changes: 3 additions & 21 deletions align_data/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,15 @@
import hashlib
from datetime import datetime
from typing import List, Optional
from sqlalchemy import JSON, DateTime, ForeignKey, Table, String, Column, Integer, func, Text, event
from sqlalchemy import JSON, DateTime, ForeignKey, String, func, Text, event
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy
from sqlalchemy.dialects.mysql import LONGTEXT


class Base(DeclarativeBase):
pass


author_article = Table(
'author_article',
Base.metadata,
Column('article_id', Integer, ForeignKey('articles.id'), primary_key=True),
Column('author_id', Integer, ForeignKey('authors.id'), primary_key=True),
)


class Author(Base):

__tablename__ = "authors"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(256), nullable=False)
articles: Mapped[List["Article"]] = relationship(secondary=author_article, back_populates="authors")


class Summary(Base):

__tablename__ = "summaries"
Expand All @@ -50,13 +32,13 @@ class Article(Base):
url: Mapped[Optional[str]] = mapped_column(String(1028))
source: Mapped[Optional[str]] = mapped_column(String(128))
source_type: Mapped[Optional[str]] = mapped_column(String(128))
authors: Mapped[str] = mapped_column(String(1024))
text: Mapped[Optional[str]] = mapped_column(LONGTEXT)
date_published: Mapped[Optional[datetime]]
meta: Mapped[Optional[JSON]] = mapped_column(JSON, name='metadata', default='{}')
date_created: Mapped[datetime] = mapped_column(DateTime, default=func.now())
date_updated: Mapped[Optional[datetime]] = mapped_column(DateTime, onupdate=func.current_timestamp())

authors: Mapped[List['Author']] = relationship(secondary=author_article, back_populates="articles")
summaries: Mapped[List["Summary"]] = relationship(back_populates="article", cascade="all, delete-orphan")

__id_fields = ['title', 'url']
Expand Down Expand Up @@ -103,7 +85,7 @@ def to_dict(self):
'source_type': self.source_type,
'text': self.text,
'date_published': date,
'authors': [a.name for a in self.authors],
'authors': [i.strip() for i in self.authors.split(',')] if self.authors.strip() else [],
'summaries': [s.text for s in self.summaries],
**self.meta,
}
Expand Down
24 changes: 13 additions & 11 deletions align_data/greaterwrong/greaterwrong.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from datetime import datetime, timezone
from dateutil.parser import parse
from datetime import datetime
import logging
import time
from dataclasses import dataclass
from pathlib import Path

import requests
import jsonlines
from bs4 import BeautifulSoup
from tqdm import tqdm
from markdownify import markdownify

from align_data.common.alignment_dataset import AlignmentDataset
from align_data.db.models import Article

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -139,14 +137,18 @@ def fetch_posts(self, query: str):
return res.json()['data']['posts']

@property
def items_list(self):
next_date = datetime(self.start_year, 1, 1).isoformat() + 'Z'
if self.jsonl_path.exists() and self.jsonl_path.lstat().st_size:
with jsonlines.open(self.jsonl_path) as f:
for item in f:
if item['date_published'] > next_date:
next_date = item['date_published']
def last_date_published(self):
try:
prev_item = next(self.read_entries(sort_by=Article.date_published.desc()))
if prev_item and prev_item.date_published:
return prev_item.date_published.isoformat() + 'Z'
except StopIteration:
pass
return datetime(self.start_year, 1, 1).isoformat() + 'Z'

@property
def items_list(self):
next_date = self.last_date_published
logger.info('Starting from %s', next_date)
while next_date:
posts = self.fetch_posts(self.make_query(next_date))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""initial structure
Revision ID: 8c11b666e86f
Revision ID: 983b5bdef5f6
Revises:
Create Date: 2023-07-14 15:48:49.149905
Create Date: 2023-07-18 15:54:58.299651
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision = '8c11b666e86f'
revision = '983b5bdef5f6'
down_revision = None
branch_labels = None
depends_on = None
Expand All @@ -25,6 +25,7 @@ def upgrade() -> None:
sa.Column('url', sa.String(length=1028), nullable=True),
sa.Column('source', sa.String(length=128), nullable=True),
sa.Column('source_type', sa.String(length=128), nullable=True),
sa.Column('authors', sa.String(length=1024), nullable=False),
sa.Column('text', mysql.LONGTEXT(), nullable=True),
sa.Column('date_published', sa.DateTime(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
Expand All @@ -33,20 +34,6 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('hash_id')
)
op.create_table(
'authors',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=256), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table(
'author_article',
sa.Column('article_id', sa.Integer(), nullable=False),
sa.Column('author_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['article_id'], ['articles.id'], ),
sa.ForeignKeyConstraint(['author_id'], ['authors.id'], ),
sa.PrimaryKeyConstraint('article_id', 'author_id')
)
op.create_table(
'summaries',
sa.Column('id', sa.Integer(), nullable=False),
Expand All @@ -60,6 +47,4 @@ def upgrade() -> None:

def downgrade() -> None:
op.drop_table('summaries')
op.drop_table('author_article')
op.drop_table('authors')
op.drop_table('articles')
12 changes: 6 additions & 6 deletions tests/align_data/test_greater_wrong.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def fetcher(next_date):

def test_items_list_with_previous_items(dataset):
dataset.ai_tags = {'tag1', 'tag2'}
with open(dataset.jsonl_path, 'w') as f:
f.write('{"date_published": "2014-12-12T01:23:45Z"}\n')

def make_item(date):
return {
Expand All @@ -135,12 +133,14 @@ def fetcher(next_date):
]
return {'results': results}

mock_items = (i for i in [Mock(date_published=datetime.fromisoformat('2014-12-12T01:23:45'))])
with patch.object(dataset, 'fetch_posts', fetcher):
with patch.object(dataset, 'make_query', lambda next_date: next_date):
# All items that are older than the newest item in the jsonl file are ignored
assert list(dataset.items_list) == [
make_item(datetime(2014, 12, 12, 1, 23, 45).replace(tzinfo=pytz.UTC) + timedelta(days=i*30))
for i in range(1, 4)
with patch.object(dataset, 'read_entries', return_value=mock_items):
# All items that are older than the newest item in the jsonl file are ignored
assert list(dataset.items_list) == [
make_item(datetime(2014, 12, 12, 1, 23, 45).replace(tzinfo=pytz.UTC) + timedelta(days=i*30))
for i in range(1, 4)
]


Expand Down

0 comments on commit 1ac1b47

Please sign in to comment.