Skip to content

Commit

Permalink
Unify ebooks and transcripts (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik authored Jul 30, 2023
1 parent ad6e75b commit 3bc0633
Show file tree
Hide file tree
Showing 16 changed files with 76 additions and 236 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/fetch-dataset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ on:
- alignmentforum
- alignment_newsletter
- arbital
- audio_transcripts
- carado.moe
- cold_takes
- deepmind_blog
Expand All @@ -49,7 +48,7 @@ on:
- importai
- jsteinhardt_blog
- lesswrong
- markdown.ebooks
- markdown
- miri
- ml_safety_newsletter
- nonarxiv_papers
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/push-datasets.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ on:
- alignment_newsletter
- arbital
- arxiv
- audio_transcripts
- carado.moe
- cold_takes
- deepmind_blog
Expand All @@ -35,7 +34,7 @@ on:
- importai
- jsteinhardt_blog
- lesswrong
- markdown.ebooks
- markdown
- miri
- ml_safety_newsletter
- nonarxiv_papers
Expand Down
2 changes: 0 additions & 2 deletions align_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import align_data.sources.reports as reports
import align_data.sources.greaterwrong as greaterwrong
import align_data.sources.stampy as stampy
import align_data.sources.audio_transcripts as audio_transcripts
import align_data.sources.alignment_newsletter as alignment_newsletter
import align_data.sources.distill as distill
import align_data.sources.gdocs as gdocs
Expand All @@ -20,7 +19,6 @@
+ reports.REPORT_REGISTRY
+ greaterwrong.GREATERWRONG_REGISTRY
+ stampy.STAMPY_REGISTRY
+ audio_transcripts.AUDIO_TRANSCRIPTS_REGISTRY
+ distill.DISTILL_REGISTRY
+ alignment_newsletter.ALIGNMENT_NEWSLETTER_REGISTRY
+ gdocs.GDOCS_REGISTRY
Expand Down
9 changes: 6 additions & 3 deletions align_data/common/alignment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class AlignmentDataset:

lazy_eval = False
"""Whether to lazy fetch items. This is nice in that it will start processing, but messes up the progress bar."""
batch_size = 20
"""The number of items to collect before flushing to the database."""

# Internal housekeeping variables
_entry_idx = 0
Expand Down Expand Up @@ -122,8 +124,9 @@ def commit():
session.rollback()

with make_session() as session:
while batch := tuple(islice(entries, 20)):
session.add_all(entries)
items = iter(entries)
while batch := tuple(islice(items, self.batch_size)):
session.add_all(batch)
# there might be duplicates in the batch, so if they cause
# an exception, try to commit them one by one
if not commit():
Expand Down Expand Up @@ -153,7 +156,7 @@ def _load_outputted_items(self):
if hasattr(Article, self.done_key):
return set(session.scalars(select(getattr(Article, self.done_key)).where(Article.source==self.name)).all())
# TODO: Properly handle this - it should create a proper SQL JSON select
return {getattr(item, self.done_key) for item in session.scalars(select(Article.meta).where(Article.source==self.name)).all()}
return {item.get(self.done_key) for item in session.scalars(select(Article.meta).where(Article.source==self.name)).all()}

def unprocessed_items(self, items=None):
"""Return a list of all items to be processed.
Expand Down
2 changes: 1 addition & 1 deletion align_data/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
from datetime import datetime
from typing import List, Optional
from sqlalchemy import JSON, DateTime, ForeignKey, String, func, Text, event, Float
from sqlalchemy import JSON, DateTime, ForeignKey, String, func, Text, event
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.dialects.mysql import LONGTEXT

Expand Down
9 changes: 8 additions & 1 deletion align_data/sources/articles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from align_data.sources.articles.datasets import PDFArticles, HTMLArticles, EbookArticles, XMLArticles
from align_data.sources.articles.datasets import (
PDFArticles, HTMLArticles, EbookArticles, XMLArticles, MarkdownArticles
)

ARTICLES_REGISTRY = [
PDFArticles(
Expand All @@ -21,4 +23,9 @@
spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4',
sheet_id='823056509'
),
MarkdownArticles(
name='markdown',
spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4',
sheet_id='1003473759'
),
]
35 changes: 27 additions & 8 deletions align_data/sources/articles/datasets.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import time
import logging
from dataclasses import dataclass
from dateutil.parser import parse
from urllib.parse import urlparse

import requests
import pypandoc
import pandas as pd
from gdown.download import download
from markdownify import markdownify

from align_data.sources.articles.pdf import fetch_pdf, read_pdf, fetch
from align_data.sources.articles.pdf import read_pdf
from align_data.sources.articles.parsers import HTML_PARSERS, extract_gdrive_contents
from align_data.sources.articles.google_cloud import fetch_markdown
from align_data.common.alignment_dataset import AlignmentDataset

logger = logging.getLogger(__name__)
Expand All @@ -24,6 +22,13 @@ class SpreadsheetDataset(AlignmentDataset):
sheet_id: str
done_key = "title"
source_filetype = None
batch_size = 1

@staticmethod
def is_val(val):
if pd.isna(val):
return None
return val

@property
def items_list(self):
Expand All @@ -40,6 +45,8 @@ def _get_text(item):

@staticmethod
def extract_authors(item):
if not SpreadsheetDataset.is_val(item.authors):
return []
return [author.strip() for author in item.authors.split(',') if author.strip()]

def process_entry(self, item):
Expand All @@ -50,21 +57,22 @@ def process_entry(self, item):

return self.make_data_entry({
'text': markdownify(text).strip(),
'url': item.url,
'title': item.title,
'url': self.is_val(item.url),
'title': self.is_val(item.title),
'source': self.name,
'source_type': item.source_type,
'source_type': self.is_val(item.source_type),
'source_filetype': self.source_filetype,
'date_published': self._get_published_date(item.date_published),
'authors': self.extract_authors(item),
'summary': None if pd.isna(item.summary) else item.summary,
'summary': self.is_val(item.summary),
})


class PDFArticles(SpreadsheetDataset):

source_filetype = 'pdf'
COOLDOWN = 1
batch_size = 1

def _get_text(self, item):
url = f'https://drive.google.com/uc?id={item.file_id}'
Expand All @@ -89,6 +97,7 @@ class EbookArticles(SpreadsheetDataset):

source_filetype = 'epub'
COOLDOWN = 10 # Add a large cooldown, as google complains a lot
batch_size = 1

def _get_text(self, item):
file_id = item.source_url.split('/')[-2]
Expand All @@ -103,3 +112,13 @@ class XMLArticles(SpreadsheetDataset):
def _get_text(self, item):
vals = extract_gdrive_contents(item.source_url)
return vals['text']


class MarkdownArticles(SpreadsheetDataset):

source_filetype = 'md'

def _get_text(self, item):
file_id = item.source_url.split('/')[-2]
vals = fetch_markdown(file_id)
return vals['text']
16 changes: 16 additions & 0 deletions align_data/sources/articles/google_cloud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import time
from collections import UserDict
from pathlib import Path

import gdown
import gspread
from google.oauth2.service_account import Credentials
from googleapiclient.discovery import build
Expand Down Expand Up @@ -116,3 +118,17 @@ def retrier(*args, **kwargs):
raise ValueError(f'Gave up after {times} tries')
return retrier
return wrapper


def fetch_markdown(file_id):
data_path = Path('data/raw/')
data_path.mkdir(parents=True, exist_ok=True)
file_name = data_path / file_id
try:
file_name = gdown.download(id=file_id, output=str(file_name), quiet=False)
return {
'text': Path(file_name).read_text(),
'data_source': 'markdown',
}
except Exception as e:
return {'error': str(e)}
9 changes: 7 additions & 2 deletions align_data/sources/articles/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import regex as re
from align_data.sources.articles.html import element_extractor, fetch, fetch_element
from align_data.sources.articles.pdf import doi_getter, fetch_pdf, get_pdf_from_page, get_arxiv_pdf
from align_data.sources.articles.google_cloud import fetch_markdown
from markdownify import MarkdownConverter
from bs4 import BeautifulSoup
from markdownify import MarkdownConverter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,8 +68,11 @@ def extract_gdrive_contents(link):
file_id = link.split('/')[-2]
url = f'https://drive.google.com/uc?id={file_id}'
res = fetch(url, 'head')
if res.status_code == 403:
logger.error('Could not fetch the file at %s - 403 returned', link)
return {'error': 'Could not read file from google drive - forbidden'}
if res.status_code >= 400:
logger.error('Could not fetch the pdf file at %s - are you sure that link is correct?', link)
logger.error('Could not fetch the file at %s - are you sure that link is correct?', link)
return {'error': 'Could not read file from google drive'}

result = {
Expand All @@ -82,6 +85,8 @@ def extract_gdrive_contents(link):
result['error'] = 'no content type'
elif content_type & {'application/octet-stream', 'application/pdf'}:
result.update(fetch_pdf(url))
elif content_type & {'text/markdown'}:
result.update(fetch_markdown(file_id))
elif content_type & {'application/epub+zip', 'application/epub'}:
result['data_source'] = 'ebook'
elif content_type & {'text/html'}:
Expand Down
9 changes: 0 additions & 9 deletions align_data/sources/audio_transcripts/__init__.py

This file was deleted.

78 changes: 0 additions & 78 deletions align_data/sources/audio_transcripts/audio_transcripts.py

This file was deleted.

12 changes: 0 additions & 12 deletions align_data/sources/ebooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
from .agentmodels import AgentModels
from .gdrive_ebooks import GDrive
from .mdebooks import MDEBooks

EBOOK_REGISTRY = [
AgentModels(
name='agentmodels',
repo='https://github.com/agentmodels/agentmodels.org.git'
),
GDrive(
name='gdrive_ebooks',
gdrive_address=
'https://drive.google.com/drive/folders/1V9-uVhUaxfWz5qw1sWLNRt0ikgSstc50'
),
MDEBooks(
name="markdown.ebooks",
gdrive_address=
'https://drive.google.com/uc?id=1diZwPT_HHAPFq-4RuiLx5poKsDu1oq1O'
),
]
4 changes: 3 additions & 1 deletion align_data/sources/ebooks/agentmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ class AgentModels(AlignmentDataset):
"""

repo: str = 'https://github.com/agentmodels/agentmodels.org.git'
done_key = "title"
done_key = "filename"

def setup(self):
super().setup()
self.base_dir = self.raw_data_path / 'agentmodels.org'
if not self.base_dir.exists() or not list(self.base_dir.glob('*')):
logger.info("Cloning repo")
Expand All @@ -36,5 +37,6 @@ def process_entry(self, filename):
'date_published': self._get_published_date(filename),
'title': 'Modeling Agents with Probabilistic Programs',
'url': f'https://agentmodels.org/chapters/{filename.stem}.html',
'filename': filename.name,
'text': filename.read_text(encoding='utf-8'),
})
Loading

0 comments on commit 3bc0633

Please sign in to comment.