diff --git a/align_data/__init__.py b/align_data/__init__.py
index a602f121..9f6c9893 100644
--- a/align_data/__init__.py
+++ b/align_data/__init__.py
@@ -25,6 +25,7 @@
ALL_DATASETS = sorted([dataset.name for dataset in DATASET_REGISTRY])
DATASET_MAP = {dataset.name: dataset for dataset in DATASET_REGISTRY}
+
def get_dataset(name):
try:
return DATASET_MAP[name]
diff --git a/align_data/analysis/analyse_jsonl_data.py b/align_data/analysis/analyse_jsonl_data.py
index bb38e887..b8c5103a 100644
--- a/align_data/analysis/analyse_jsonl_data.py
+++ b/align_data/analysis/analyse_jsonl_data.py
@@ -4,6 +4,7 @@
from collections import defaultdict
+
def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"):
"""
Checks if the given date string matches the expected format.
@@ -15,20 +16,25 @@ def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"):
except ValueError:
return False
+
def validate_data(data_dict):
"""
- Processes each dictionary element in the jsonl file.
+ Processes each dictionary element in the jsonl file.
"""
if not is_valid_date_format(data_dict):
- raise ValueError(f"Invalid date format for source: {data_dict['source']}, title: {data_dict['title'][:30]}, date_pub: {data_dict['date_published']}")
+ raise ValueError(
+ f"Invalid date format for source: {data_dict['source']}, title: {data_dict['title'][:30]}, date_pub: {data_dict['date_published']}"
+ )
# TODO: add more checks here
+
def check_for_duplicates(data_dict, seen_urls):
- id = data_dict.get('id')
+ id = data_dict.get("id")
seen_urls[id].append(data_dict)
- #TODO: Add more validation logic here
- return seen_urls
+ # TODO: Add more validation logic here
+ return seen_urls
+
def get_data_dict_str(data_dict):
"""
@@ -36,16 +42,18 @@ def get_data_dict_str(data_dict):
"""
return f"source: {data_dict['source']}, title: {data_dict['title'][:50]}, date_pub: {data_dict['date_published']}, url: {data_dict['url']}\n"
+
def files_iterator(data_dir):
"""
- Goes through the data directory, opens every jsonl file sequentially,
+ Goes through the data directory, opens every jsonl file sequentially,
and yields every element (which is a dictionary) in the jsonl file.
"""
- for path in Path(data_dir).glob('*.jsonl'):
+ for path in Path(data_dir).glob("*.jsonl"):
with jsonlines.open(path) as f:
for line in f:
yield line
+
def process_jsonl_files(data_dir):
seen_urls = defaultdict(list) # holds all seen urls
for data_dict in files_iterator(data_dir):
@@ -57,23 +65,29 @@ def process_jsonl_files(data_dir):
except Exception as e:
print(f"Unexpected error: {e}")
dup_count = 0
-
+
for id, duplicates in seen_urls.items():
if len(duplicates) > 1:
- list_of_duplicates = '\n'.join(get_data_dict_str(duplicate) for duplicate in duplicates)
- print(f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n")
+ list_of_duplicates = "\n".join(
+ get_data_dict_str(duplicate) for duplicate in duplicates
+ )
+ print(
+ f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n"
+ )
dup_count += 1
print(f"Total number of duplicate ids found: {dup_count}")
+
def delete_all_txt_and_jsonl(data_dir):
"""
Deletes all txt and jsonl files in the given directory.
"""
- for path in Path(data_dir).glob('*.txt'):
+ for path in Path(data_dir).glob("*.txt"):
path.unlink()
- for path in Path(data_dir).glob('*.jsonl'):
+ for path in Path(data_dir).glob("*.jsonl"):
path.unlink()
+
if __name__ == "__main__":
process_jsonl_files("data/")
- #delete_all_txt_and_jsonl("data/")
+ # delete_all_txt_and_jsonl("data/")
diff --git a/align_data/analysis/count_tokens.py b/align_data/analysis/count_tokens.py
index a601f7be..cd099c68 100644
--- a/align_data/analysis/count_tokens.py
+++ b/align_data/analysis/count_tokens.py
@@ -2,11 +2,15 @@
import jsonlines
import logging
from typing import Tuple
+
logger = logging.getLogger(__name__)
-def count_token(merged_dataset_path : str = "data/merged_dataset/alignment_texts.jsonl") -> Tuple[int , int , int]:
+
+def count_token(
+ merged_dataset_path: str = "data/merged_dataset/alignment_texts.jsonl",
+) -> Tuple[int, int, int]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
- total_token_count , total_word_count , total_character_count = 0 , 0 , 0
+ total_token_count, total_word_count, total_character_count = 0, 0, 0
with jsonlines.open(merged_dataset_path) as reader:
for obj in reader:
@@ -18,6 +22,4 @@ def count_token(merged_dataset_path : str = "data/merged_dataset/alignment_texts
logger.info(f"Total token count: {total_token_count}")
logger.info(f"Total word count: {total_word_count}")
logger.info(f"Total character count: {total_character_count}")
- return total_token_count , total_word_count , total_character_count
-
-
+ return total_token_count, total_word_count, total_character_count
diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py
index 13819fef..d8676b57 100644
--- a/align_data/common/alignment_dataset.py
+++ b/align_data/common/alignment_dataset.py
@@ -38,10 +38,10 @@ class AlignmentDataset:
_: KW_ONLY
- files_path = Path('')
+ files_path = Path("")
"""The path where data can be found. Usually a folder"""
- done_key = 'id'
+ done_key = "id"
"""The key of the entry to use as the id when checking if already processed."""
COOLDOWN = 0
@@ -58,30 +58,30 @@ class AlignmentDataset:
_outputted_items = set()
"""A set of the ids of all previously processed items"""
_: KW_ONLY
- id_fields: List[str] = field(default_factory=lambda: ['url', 'title'])
+ id_fields: List[str] = field(default_factory=lambda: ["url", "title"])
"""A list of fields to use as the id of the entry. If not set, will use ['url', 'title']"""
def __str__(self) -> str:
return self.name
- def __post_init__(self, data_path=Path(__file__).parent / '../../data/'):
+ def __post_init__(self, data_path=Path(__file__).parent / "../../data/"):
self.data_path = data_path
- self.raw_data_path = self.data_path / 'raw'
+ self.raw_data_path = self.data_path / "raw"
# set the default place to look for data
self.files_path = self.raw_data_path / self.name
def _add_authors(self, article: Article, authors: List[str]) -> Article:
# TODO: Don't keep adding the same authors - come up with some way to reuse them
- article.authors = ','.join(authors)
+ article.authors = ",".join(authors)
if len(article.authors) > 1024:
- article.authors = ','.join(article.authors[:1024].split(',')[:-1])
+ article.authors = ",".join(article.authors[:1024].split(",")[:-1])
return article
def make_data_entry(self, data, **kwargs) -> Article:
data = dict(data, **kwargs)
- summary = data.pop('summary', None)
- authors = data.pop('authors', [])
+ summary = data.pop("summary", None)
+ authors = data.pop("authors", [])
article = Article(
id_fields=self.id_fields,
@@ -95,13 +95,13 @@ def make_data_entry(self, data, **kwargs) -> Article:
def to_jsonl(self, out_path=None, filename=None) -> Path:
if not out_path:
- out_path=Path(__file__).parent / '../../data/'
+ out_path = Path(__file__).parent / "../../data/"
if not filename:
filename = f"{self.name}.jsonl"
filename = Path(out_path) / filename
- with jsonlines.open(filename, 'w') as jsonl_writer:
+ with jsonlines.open(filename, "w") as jsonl_writer:
for article in self.read_entries():
jsonl_writer.write(article.to_dict())
return filename.resolve()
@@ -109,7 +109,7 @@ def to_jsonl(self, out_path=None, filename=None) -> Path:
def read_entries(self, sort_by=None):
"""Iterate through all the saved entries."""
with make_session() as session:
- query = select(Article).where(Article.source==self.name)
+ query = select(Article).where(Article.source == self.name)
if sort_by is not None:
query = query.order_by(sort_by)
for item in session.scalars(query):
@@ -136,8 +136,8 @@ def commit():
for entry in batch:
session.add(entry)
if not commit():
- logger.error(f'found duplicate of {entry}')
-
+ logger.error(f"found duplicate of {entry}")
+
def setup(self):
self._outputted_items = self._load_outputted_items()
@@ -160,9 +160,14 @@ def _load_outputted_items(self) -> Set[str]:
# This doesn't filter by self.name. The good thing about that is that it should handle a lot more
# duplicates. The bad thing is that this could potentially return a massive amount of data if there
# are lots of items.
- return set(session.scalars(select(getattr(Article, self.done_key))).all())
+ return set(
+ session.scalars(select(getattr(Article, self.done_key))).all()
+ )
# TODO: Properly handle this - it should create a proper SQL JSON select
- return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()}
+ return {
+ item.get(self.done_key)
+ for item in session.scalars(select(Article.meta)).all()
+ }
def unprocessed_items(self, items=None) -> Iterable:
"""Return a list of all items to be processed.
@@ -213,7 +218,6 @@ def _get_published_date(self, date) -> Optional[datetime]:
class SummaryDataset(AlignmentDataset):
-
def unprocessed_items(self, items=None) -> Iterable:
# This breaks the possible lazy loading of the items. Should be fine...
items = list(super().unprocessed_items(items))
@@ -221,7 +225,10 @@ def unprocessed_items(self, items=None) -> Iterable:
urls = map(self.get_item_key, items)
with make_session() as session:
self.articles = {
- a.url: a for a in session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls))
+ a.url: a
+ for a in session.query(Article)
+ .options(joinedload(Article.summaries))
+ .filter(Article.url.in_(urls))
if a.url
}
@@ -230,7 +237,13 @@ def unprocessed_items(self, items=None) -> Iterable:
def _load_outputted_items(self) -> Set[str]:
"""Load the output file (if it exists) in order to know which items have already been output."""
with make_session() as session:
- return set(session.scalars(select(Article.url).join(Article.summaries).filter(Summary.source == self.name)))
+ return set(
+ session.scalars(
+ select(Article.url)
+ .join(Article.summaries)
+ .filter(Summary.source == self.name)
+ )
+ )
def _add_batch(self, session, batch):
def merge(item):
diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py
index e374dfc9..9e3799f3 100644
--- a/align_data/common/html_dataset.py
+++ b/align_data/common/html_dataset.py
@@ -16,11 +16,13 @@
logger = logging.getLogger(__name__)
+
@dataclass
class HTMLDataset(AlignmentDataset):
"""
Fetches articles from a different blog by collecting links to articles from an index page.
"""
+
url: str
done_key = "url"
@@ -29,9 +31,9 @@ class HTMLDataset(AlignmentDataset):
source_key: str = None
summary_key: str = None
- item_selector = 'article'
- title_selector = 'article h1'
- text_selector = 'article'
+ item_selector = "article"
+ title_selector = "article h1"
+ text_selector = "article"
source_type = "blog"
ignored_selectors = []
@@ -64,16 +66,18 @@ def process_entry(self, article):
if not text:
return None
- return self.make_data_entry({
- "text": text,
- "url": article_url,
- "title": title,
- "source": self.name,
- "source_type": "blog",
- "date_published": date_published,
- "authors": self.extract_authors(contents),
- **self._extra_values(contents),
- })
+ return self.make_data_entry(
+ {
+ "text": text,
+ "url": article_url,
+ "title": title,
+ "source": self.name,
+ "source_type": "blog",
+ "date_published": date_published,
+ "authors": self.extract_authors(contents),
+ **self._extra_values(contents),
+ }
+ )
def _get_contents(self, url):
logger.info("Fetching {}".format(url))
@@ -93,8 +97,8 @@ def _get_text(self, contents):
def _find_date(self, items):
for i in items:
- if re.match('\w+ \d{1,2}, \d{4}', i.text):
- return datetime.strptime(i.text, '%b %d, %Y').replace(tzinfo=pytz.UTC)
+ if re.match("\w+ \d{1,2}, \d{4}", i.text):
+ return datetime.strptime(i.text, "%b %d, %Y").replace(tzinfo=pytz.UTC)
def _extract_markdown(self, element):
return element and markdownify(str(element)).strip()
@@ -102,35 +106,35 @@ def _extract_markdown(self, element):
@dataclass
class RSSDataset(HTMLDataset):
- date_format = '%a, %d %b %Y %H:%M:%S %z'
+ date_format = "%a, %d %b %Y %H:%M:%S %z"
def get_item_key(self, item):
return item
@property
def feed_url(self):
- return f'{self.url}/rss.xml'
+ return f"{self.url}/rss.xml"
def extract_authors(self, item):
- if 'authors' in item:
- return [a['name'] for a in item['authors'] if a.get('name')]
+ if "authors" in item:
+ return [a["name"] for a in item["authors"] if a.get("name")]
return self.authors
@staticmethod
def _get_title(item):
- return item['title']
+ return item["title"]
def _get_published_date(self, item):
- date_published = item.get('published') or item.get('pubDate')
+ date_published = item.get("published") or item.get("pubDate")
return super()._get_published_date(date_published)
def _get_text(self, item):
- text = item.get('content') and item['content'][0].get('value')
+ text = item.get("content") and item["content"][0].get("value")
return self._extract_markdown(text)
def _get_contents(self, url):
item = self.items[url]
- if 'content' in item:
+ if "content" in item:
return item
logger.info("Fetching {}".format(url))
@@ -145,5 +149,5 @@ def _get_contents(self, url):
def items_list(self):
logger.info(f"Fetching entries from {self.feed_url}")
feed = feedparser.parse(self.feed_url)
- self.items = {item['link']: item for item in feed['entries']}
+ self.items = {item["link"]: item for item in feed["entries"]}
return list(self.items.keys())
diff --git a/align_data/db/models.py b/align_data/db/models.py
index 06bee3dd..029c3d1c 100644
--- a/align_data/db/models.py
+++ b/align_data/db/models.py
@@ -4,7 +4,17 @@
import hashlib
from datetime import datetime
from typing import List, Optional
-from sqlalchemy import JSON, DateTime, ForeignKey, String, Boolean, Text, Float, func, event
+from sqlalchemy import (
+ JSON,
+ DateTime,
+ ForeignKey,
+ String,
+ Boolean,
+ Text,
+ Float,
+ func,
+ event,
+)
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.dialects.mysql import LONGTEXT
from align_data.settings import PINECONE_METADATA_KEYS
@@ -18,7 +28,6 @@ class Base(DeclarativeBase):
class Summary(Base):
-
__tablename__ = "summaries"
id: Mapped[int] = mapped_column(primary_key=True)
@@ -30,28 +39,33 @@ class Summary(Base):
class Article(Base):
-
__tablename__ = "articles"
- _id: Mapped[int] = mapped_column('id', primary_key=True)
- id: Mapped[str] = mapped_column('hash_id', String(32), unique=True, nullable=False)
+ _id: Mapped[int] = mapped_column("id", primary_key=True)
+ id: Mapped[str] = mapped_column("hash_id", String(32), unique=True, nullable=False)
title: Mapped[Optional[str]] = mapped_column(String(1028))
url: Mapped[Optional[str]] = mapped_column(String(1028))
source: Mapped[Optional[str]] = mapped_column(String(128))
source_type: Mapped[Optional[str]] = mapped_column(String(128))
authors: Mapped[str] = mapped_column(String(1024))
text: Mapped[Optional[str]] = mapped_column(LONGTEXT)
- confidence: Mapped[Optional[float]] # Describes the confidence in how good this article is, as a value <0, 1>
+ confidence: Mapped[
+ Optional[float]
+ ] # Describes the confidence in how good this article is, as a value <0, 1>
date_published: Mapped[Optional[datetime]]
- meta: Mapped[Optional[JSON]] = mapped_column(JSON, name='metadata', default='{}')
+ meta: Mapped[Optional[JSON]] = mapped_column(JSON, name="metadata", default="{}")
date_created: Mapped[datetime] = mapped_column(DateTime, default=func.now())
- date_updated: Mapped[Optional[datetime]] = mapped_column(DateTime, onupdate=func.current_timestamp())
-
+ date_updated: Mapped[Optional[datetime]] = mapped_column(
+ DateTime, onupdate=func.current_timestamp()
+ )
+
pinecone_update_required: Mapped[bool] = mapped_column(Boolean, default=False)
-
- summaries: Mapped[List["Summary"]] = relationship(back_populates="article", cascade="all, delete-orphan")
- __id_fields = ['url', 'title']
+ summaries: Mapped[List["Summary"]] = relationship(
+ back_populates="article", cascade="all, delete-orphan"
+ )
+
+ __id_fields = ["url", "title"]
def __init__(self, *args, id_fields, **kwargs):
self.__id_fields = id_fields
@@ -59,32 +73,39 @@ def __init__(self, *args, id_fields, **kwargs):
def __repr__(self) -> str:
return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})"
-
+
def is_metadata_keys_equal(self, other):
if not isinstance(other, Article):
- raise TypeError(f"Expected an instance of Article, got {type(other).__name__}")
+ raise TypeError(
+ f"Expected an instance of Article, got {type(other).__name__}"
+ )
return not any(
- getattr(self, key, None) != getattr(other, key, None) # entry_id is implicitly ignored
+ getattr(self, key, None)
+ != getattr(other, key, None) # entry_id is implicitly ignored
for key in PINECONE_METADATA_KEYS
)
def generate_id_string(self) -> str:
- return ''.join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8")
+ return "".join(str(getattr(self, field)) for field in self.__id_fields).encode(
+ "utf-8"
+ )
def verify_fields(self):
missing = [field for field in self.__id_fields if not getattr(self, field)]
- assert not missing, f'Entry is missing the following fields: {missing}'
-
+ assert not missing, f"Entry is missing the following fields: {missing}"
+
def verify_id(self):
assert self.id is not None, "Entry is missing id"
id_string = self.generate_id_string()
id_from_fields = hashlib.md5(id_string).hexdigest()
- assert self.id == id_from_fields, f"Entry id {self.id} does not match id from id_fields, {id_from_fields}"
+ assert (
+ self.id == id_from_fields
+ ), f"Entry id {self.id} does not match id from id_fields, {id_from_fields}"
def update(self, other):
for field in self.__table__.columns.keys():
- if field not in ['id', 'hash_id', 'metadata'] and getattr(other, field):
+ if field not in ["id", "hash_id", "metadata"] and getattr(other, field):
setattr(self, field, getattr(other, field))
self.meta.update({k: v for k, v in other.meta.items() if k and v})
@@ -114,21 +135,21 @@ def to_dict(self):
authors = []
if self.authors and self.authors.strip():
- authors = [i.strip() for i in self.authors.split(',')]
+ authors = [i.strip() for i in self.authors.split(",")]
return {
- 'id': self.id,
- 'title': self.title,
- 'url': self.url,
- 'source': self.source,
- 'source_type': self.source_type,
- 'text': self.text,
- 'date_published': date,
- 'authors': authors,
- 'summaries': [s.text for s in (self.summaries or [])],
+ "id": self.id,
+ "title": self.title,
+ "url": self.url,
+ "source": self.source,
+ "source_type": self.source_type,
+ "text": self.text,
+ "date_published": date,
+ "authors": authors,
+ "summaries": [s.text for s in (self.summaries or [])],
**(self.meta or {}),
}
-event.listen(Article, 'before_insert', Article.before_write)
-event.listen(Article, 'before_update', Article.before_write)
+event.listen(Article, "before_insert", Article.before_write)
+event.listen(Article, "before_update", Article.before_write)
diff --git a/align_data/db/session.py b/align_data/db/session.py
index f3eb6468..ace0ff8a 100644
--- a/align_data/db/session.py
+++ b/align_data/db/session.py
@@ -24,6 +24,4 @@ def stream_pinecone_updates(session, custom_sources: List[str]):
"""Yield Pinecone entries that require an update."""
yield from session.query(Article).filter(
Article.pinecone_update_required.is_(True)
- ).filter(
- Article.source.in_(custom_sources)
- ).yield_per(1000)
\ No newline at end of file
+ ).filter(Article.source.in_(custom_sources)).yield_per(1000)
diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py
index 4168cb70..d8f565df 100644
--- a/align_data/pinecone/pinecone_db_handler.py
+++ b/align_data/pinecone/pinecone_db_handler.py
@@ -5,7 +5,14 @@
import pinecone
-from align_data.settings import PINECONE_INDEX_NAME, PINECONE_VALUES_DIMS, PINECONE_METRIC, PINECONE_METADATA_KEYS, PINECONE_API_KEY, PINECONE_ENVIRONMENT
+from align_data.settings import (
+ PINECONE_INDEX_NAME,
+ PINECONE_VALUES_DIMS,
+ PINECONE_METRIC,
+ PINECONE_METADATA_KEYS,
+ PINECONE_API_KEY,
+ PINECONE_ENVIRONMENT,
+)
logger = logging.getLogger(__name__)
@@ -25,58 +32,60 @@ def __init__(
self.values_dims = values_dims
self.metric = metric
self.metadata_keys = metadata_keys
-
+
pinecone.init(
- api_key = PINECONE_API_KEY,
- environment = PINECONE_ENVIRONMENT,
+ api_key=PINECONE_API_KEY,
+ environment=PINECONE_ENVIRONMENT,
)
-
+
if create_index:
self.create_index()
-
+
self.index = pinecone.Index(index_name=self.index_name)
-
+
if log_index_stats:
index_stats_response = self.index.describe_index_stats()
logger.info(f"{self.index_name}:\n{index_stats_response}")
-
+
def upsert_entry(self, entry: Dict, upsert_size=100):
self.index.upsert(
vectors=list(
zip(
- [f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(entry['text_chunks']))],
- entry['embeddings'].tolist(),
+ [
+ f"{entry['id']}_{str(i).zfill(6)}"
+ for i in range(len(entry["text_chunks"]))
+ ],
+ entry["embeddings"].tolist(),
[
{
- 'entry_id': entry['id'],
- 'source': entry['source'],
- 'title': entry['title'],
- 'authors': entry['authors'],
- 'text': text_chunk,
- } for text_chunk in entry['text_chunks']
- ]
+ "entry_id": entry["id"],
+ "source": entry["source"],
+ "title": entry["title"],
+ "authors": entry["authors"],
+ "text": text_chunk,
+ }
+ for text_chunk in entry["text_chunks"]
+ ],
)
),
- batch_size=upsert_size
+ batch_size=upsert_size,
)
-
+
def delete_entries(self, ids):
- self.index.delete(
- filter={"entry_id": {"$in": ids}}
- )
+ self.index.delete(filter={"entry_id": {"$in": ids}})
def create_index(self, replace_current_index: bool = True):
if replace_current_index:
self.delete_index()
-
+
pinecone.create_index(
name=self.index_name,
dimension=self.values_dims,
metric=self.metric,
- metadata_config = {"indexed": self.metadata_keys},
+ metadata_config={"indexed": self.metadata_keys},
)
def delete_index(self):
if self.index_name in pinecone.list_indexes():
logger.info(f"Deleting index '{self.index_name}'.")
- pinecone.delete_index(self.index_name)
\ No newline at end of file
+ pinecone.delete_index(self.index_name)
diff --git a/align_data/pinecone/text_splitter.py b/align_data/pinecone/text_splitter.py
index 76bb29b8..c732c99c 100644
--- a/align_data/pinecone/text_splitter.py
+++ b/align_data/pinecone/text_splitter.py
@@ -7,29 +7,33 @@
class ParagraphSentenceUnitTextSplitter(TextSplitter):
"""A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc).
-
+
@param min_chunk_size: The minimum number of units in a chunk.
@param max_chunk_size: The maximum number of units in a chunk.
@param length_function: A function that returns the length of a string in units.
@param truncate_function: A function that truncates a string to a given unit length.
"""
-
+
DEFAULT_MIN_CHUNK_SIZE = 900
DEFAULT_MAX_CHUNK_SIZE = 1100
DEFAULT_LENGTH_FUNCTION = lambda string: len(string)
- DEFAULT_TRUNCATE_FUNCTION = lambda string, length, from_end=False: string[-length:] if from_end else string[:length]
+ DEFAULT_TRUNCATE_FUNCTION = (
+ lambda string, length, from_end=False: string[-length:]
+ if from_end
+ else string[:length]
+ )
def __init__(
- self,
+ self,
min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE,
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE,
length_function: Callable[[str], int] = DEFAULT_LENGTH_FUNCTION,
truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION,
- **kwargs: Any
+ **kwargs: Any,
):
super().__init__(**kwargs)
self.min_chunk_size = min_chunk_size
- self.max_chunk_size = max_chunk_size
+ self.max_chunk_size = max_chunk_size
self._length_function = length_function
self._truncate_function = truncate_function
@@ -43,26 +47,30 @@ def split_text(self, text: str) -> List[str]:
current_block += "\n\n" + paragraph
block_length = self._length_function(current_block)
- if block_length > self.max_chunk_size: # current block is too large, truncate it
- current_block = self._handle_large_paragraph(current_block, blocks, paragraph)
+ if (
+ block_length > self.max_chunk_size
+ ): # current block is too large, truncate it
+ current_block = self._handle_large_paragraph(
+ current_block, blocks, paragraph
+ )
elif block_length >= self.min_chunk_size:
blocks.append(current_block)
current_block = ""
else: # current block is too small, continue appending to it
continue
-
+
blocks = self._handle_remaining_text(current_block, blocks)
return [block.strip() for block in blocks]
def _handle_large_paragraph(self, current_block, blocks, paragraph):
# Undo adding the whole paragraph
- current_block = current_block[:-(len(paragraph)+2)] # +2 accounts for "\n\n"
+ current_block = current_block[: -(len(paragraph) + 2)] # +2 accounts for "\n\n"
sentences = sent_tokenize(paragraph)
for sentence in sentences:
current_block += f" {sentence}"
-
+
block_length = self._length_function(current_block)
if block_length < self.min_chunk_size:
continue
@@ -70,19 +78,23 @@ def _handle_large_paragraph(self, current_block, blocks, paragraph):
blocks.append(current_block)
current_block = ""
else:
- current_block = self._truncate_large_block(current_block, blocks, sentence)
-
+ current_block = self._truncate_large_block(
+ current_block, blocks, sentence
+ )
+
return current_block
def _truncate_large_block(self, current_block, blocks, sentence):
while self._length_function(current_block) > self.max_chunk_size:
# Truncate current_block to max size, set remaining sentence as next sentence
- truncated_block = self._truncate_function(current_block, self.max_chunk_size)
+ truncated_block = self._truncate_function(
+ current_block, self.max_chunk_size
+ )
blocks.append(truncated_block)
- remaining_sentence = current_block[len(truncated_block):].lstrip()
+ remaining_sentence = current_block[len(truncated_block) :].lstrip()
current_block = sentence = remaining_sentence
-
+
return current_block
def _handle_remaining_text(self, current_block, blocks):
@@ -93,13 +105,17 @@ def _handle_remaining_text(self, current_block, blocks):
if len_current_block < self.min_chunk_size:
# it needs to take the last min_chunk_size-len_current_block units from the previous block
previous_block = blocks[-1]
- required_units = self.min_chunk_size - len_current_block # calculate the required units
+ required_units = (
+ self.min_chunk_size - len_current_block
+ ) # calculate the required units
- part_prev_block = self._truncate_function(previous_block, required_units, from_end=True) # get the required units from the previous block
+ part_prev_block = self._truncate_function(
+ previous_block, required_units, from_end=True
+ ) # get the required units from the previous block
last_block = part_prev_block + current_block
blocks.append(last_block)
else:
blocks.append(current_block)
- return blocks
\ No newline at end of file
+ return blocks
diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py
index 9e52276d..649a7a2a 100644
--- a/align_data/pinecone/update_pinecone.py
+++ b/align_data/pinecone/update_pinecone.py
@@ -11,10 +11,17 @@
from align_data.db.session import make_session, stream_pinecone_updates
from align_data.pinecone.pinecone_db_handler import PineconeDB
from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter
-from align_data.settings import USE_OPENAI_EMBEDDINGS, OPENAI_EMBEDDINGS_MODEL, \
- OPENAI_EMBEDDINGS_DIMS, OPENAI_EMBEDDINGS_RATE_LIMIT, \
- SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, \
- CHUNK_SIZE, MAX_NUM_AUTHORS_IN_SIGNATURE, EMBEDDING_LENGTH_BIAS
+from align_data.settings import (
+ USE_OPENAI_EMBEDDINGS,
+ OPENAI_EMBEDDINGS_MODEL,
+ OPENAI_EMBEDDINGS_DIMS,
+ OPENAI_EMBEDDINGS_RATE_LIMIT,
+ SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
+ SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS,
+ CHUNK_SIZE,
+ MAX_NUM_AUTHORS_IN_SIGNATURE,
+ EMBEDDING_LENGTH_BIAS,
+)
logger = logging.getLogger(__name__)
@@ -29,54 +36,69 @@ class PineconeEntry(BaseModel):
authors: List[str]
text_chunks: List[str]
embeddings: np.ndarray
-
+
class Config:
arbitrary_types_allowed = True
def __repr__(self):
return f"PineconeEntry(id={self.id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={self.text_chunks[:5]!r})"
- @validator('id', 'source', 'title', 'url', 'date_published', 'authors', 'text_chunks', pre=True, always=True)
+ @validator(
+ "id",
+ "source",
+ "title",
+ "url",
+ "date_published",
+ "authors",
+ "text_chunks",
+ pre=True,
+ always=True,
+ )
def empty_strings_not_allowed(cls, value):
if not str(value).strip():
raise ValueError("Attribute should not be empty.")
return value
-
+
class PineconeUpdater:
def __init__(
- self,
+ self,
min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE,
max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE,
- length_function: Callable[[str], int] = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION,
- truncate_function: Callable[[str, int], str] = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION,
+ length_function: Callable[
+ [str], int
+ ] = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION,
+ truncate_function: Callable[
+ [str, int], str
+ ] = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION,
):
self.min_chunk_size = min_chunk_size
self.max_chunk_size = max_chunk_size
self.length_function = length_function
self.truncate_function = truncate_function
-
+
self.text_splitter = ParagraphSentenceUnitTextSplitter(
min_chunk_size=self.min_chunk_size,
max_chunk_size=self.max_chunk_size,
length_function=self.length_function,
- truncate_function=self.truncate_function
+ truncate_function=self.truncate_function,
)
self.pinecone_db = PineconeDB()
-
+
if USE_OPENAI_EMBEDDINGS:
import openai
- openai.api_key = os.environ['OPENAI_API_KEY']
+
+ openai.api_key = os.environ["OPENAI_API_KEY"]
else:
import torch
from langchain.embeddings import HuggingFaceEmbeddings
-
+
self.hf_embeddings = HuggingFaceEmbeddings(
model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
- model_kwargs={'device': "cuda" if torch.cuda.is_available() else "cpu"},
- encode_kwargs={'show_progress_bar': False}
+ model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
+ encode_kwargs={"show_progress_bar": False},
)
-
+
def update(self, custom_sources: List[str]):
"""
Update the given sources. If no sources are provided, updates all sources.
@@ -90,8 +112,10 @@ def update(self, custom_sources: List[str]):
article.pinecone_update_required = False
session.add(article)
session.commit()
-
- def process_entries(self, article_stream: Generator[Article, None, None]) -> Generator[Tuple[Article, PineconeEntry], None, None]:
+
+ def process_entries(
+ self, article_stream: Generator[Article, None, None]
+ ) -> Generator[Tuple[Article, PineconeEntry], None, None]:
for article in article_stream:
try:
text_chunks = self.get_text_chunks(article)
@@ -101,45 +125,54 @@ def process_entries(self, article_stream: Generator[Article, None, None]) -> Gen
title=article.title,
url=article.url,
date_published=article.date_published,
- authors=[author.strip() for author in article.authors.split(',') if author.strip()],
+ authors=[
+ author.strip()
+ for author in article.authors.split(",")
+ if author.strip()
+ ],
text_chunks=text_chunks,
- embeddings=self.extract_embeddings(text_chunks, [article.source] * len(text_chunks))
+ embeddings=self.extract_embeddings(
+ text_chunks, [article.source] * len(text_chunks)
+ ),
)
except (ValueError, ValidationError) as e:
logger.exception(e)
-
+
def get_text_chunks(self, article: Article) -> List[str]:
signature = f"Title: {article.title}, Author(s): {self.get_authors_str(article.authors)}"
text_chunks = self.text_splitter.split_text(article.text)
text_chunks = [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks]
return text_chunks
-
+
def extract_embeddings(self, chunks_batch, sources_batch):
if USE_OPENAI_EMBEDDINGS:
return self.get_openai_embeddings(chunks_batch, sources_batch)
else:
- return np.array(self.hf_embeddings.embed_documents(chunks_batch, sources_batch))
+ return np.array(
+ self.hf_embeddings.embed_documents(chunks_batch, sources_batch)
+ )
@staticmethod
- def get_openai_embeddings(chunks, sources=''):
+ def get_openai_embeddings(chunks, sources=""):
embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS))
-
+
openai_output = openai.Embedding.create(
- model=OPENAI_EMBEDDINGS_MODEL,
- input=chunks
- )['data']
-
+ model=OPENAI_EMBEDDINGS_MODEL, input=chunks
+ )["data"]
+
for i, (embedding, source) in enumerate(zip(openai_output, sources)):
bias = EMBEDDING_LENGTH_BIAS.get(source, 1.0)
- embeddings[i] = bias * np.array(embedding['embedding'])
-
+ embeddings[i] = bias * np.array(embedding["embedding"])
+
return embeddings
@staticmethod
def get_authors_str(authors_lst: List[str]) -> str:
- if authors_lst == []: return 'n/a'
- if len(authors_lst) == 1: return authors_lst[0]
+ if authors_lst == []:
+ return "n/a"
+ if len(authors_lst) == 1:
+ return authors_lst[0]
else:
authors_lst = authors_lst[:MAX_NUM_AUTHORS_IN_SIGNATURE]
authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}"
- return authors_str
\ No newline at end of file
+ return authors_str
diff --git a/align_data/postprocess/postprocess.py b/align_data/postprocess/postprocess.py
index 3f26885e..1366bc11 100644
--- a/align_data/postprocess/postprocess.py
+++ b/align_data/postprocess/postprocess.py
@@ -1,4 +1,4 @@
-#%%
+# %%
from dataclasses import dataclass
import jsonlines
from tqdm import tqdm
@@ -6,54 +6,69 @@
from path import Path
import pylab as plt
+
# import seaborn as sns
import pandas as pd
logger = logging.getLogger(__name__)
+
@dataclass
class PostProcesser:
"""
This class is used to postprocess the data
"""
- jsonl_path : Path = Path('../../data/')
+
+ jsonl_path: Path = Path("../../data/")
def __init__(self) -> None:
- self.jsonl_list = sorted(self.jsonl_path.files('*.jsonl'))
- self.source_list = [path.name.split('.jsonl')[0] for path in self.jsonl_list]
+ self.jsonl_list = sorted(self.jsonl_path.files("*.jsonl"))
+ self.source_list = [path.name.split(".jsonl")[0] for path in self.jsonl_list]
def compute_statistics(self) -> None:
- self.all_stats = {key : {} for key in self.source_list}
- for source_name , path in tqdm(zip(self.source_list , self.jsonl_list)):
+ self.all_stats = {key: {} for key in self.source_list}
+ for source_name, path in tqdm(zip(self.source_list, self.jsonl_list)):
with jsonlines.open(path) as reader:
for obj in reader:
- self.all_stats[source_name]['num_entries'] = self.all_stats[source_name].get('num_entries' , 0) + 1
- self.all_stats[source_name]['num_tokens'] = self.all_stats[source_name].get('num_tokens' , 0) + len(obj['text'].split())
- self.all_stats[source_name]['num_chars'] = self.all_stats[source_name].get('num_chars' , 0) + len(obj['text'])
- self.all_stats[source_name]['num_words'] = self.all_stats[source_name].get('num_words' , 0) + len(obj['text'].split())
- self.all_stats[source_name]['num_sentences'] = self.all_stats[source_name].get('num_sentences' , 0) + len(obj['text'].split('.'))
- self.all_stats[source_name]['num_paragraphs'] = self.all_stats[source_name].get('num_paragraphs' , 0) + len(obj['text'].splitlines())
-
+ self.all_stats[source_name]["num_entries"] = (
+ self.all_stats[source_name].get("num_entries", 0) + 1
+ )
+ self.all_stats[source_name]["num_tokens"] = self.all_stats[
+ source_name
+ ].get("num_tokens", 0) + len(obj["text"].split())
+ self.all_stats[source_name]["num_chars"] = self.all_stats[
+ source_name
+ ].get("num_chars", 0) + len(obj["text"])
+ self.all_stats[source_name]["num_words"] = self.all_stats[
+ source_name
+ ].get("num_words", 0) + len(obj["text"].split())
+ self.all_stats[source_name]["num_sentences"] = self.all_stats[
+ source_name
+ ].get("num_sentences", 0) + len(obj["text"].split("."))
+ self.all_stats[source_name]["num_paragraphs"] = self.all_stats[
+ source_name
+ ].get("num_paragraphs", 0) + len(obj["text"].splitlines())
+
def plot_statistics(self) -> None:
all_df = pd.DataFrame(self.all_stats).T
- plt.figure(figsize = (5 , 5))
- sns.barplot(x = all_df.index , y = all_df['num_entries'])
-
+ plt.figure(figsize=(5, 5))
+ sns.barplot(x=all_df.index, y=all_df["num_entries"])
- def merge_all_files(self , out_dir : str) -> str:
+ def merge_all_files(self, out_dir: str) -> str:
pass
def deduplicate(self) -> None:
for path in tqdm(self.jsonl_list):
- with jsonlines.open(path , 'r') as reader:
- all_obj = {obj['id'] : obj for obj in reader}
- with jsonlines.open(path , 'w') as writer:
+ with jsonlines.open(path, "r") as reader:
+ all_obj = {obj["id"]: obj for obj in reader}
+ with jsonlines.open(path, "w") as writer:
for obj in all_obj.values():
writer.write(obj)
- def clean_dataset(self , merged_dataset_path : str) -> str:
+ def clean_dataset(self, merged_dataset_path: str) -> str:
pass
+
pp = PostProcesser()
# %%
pp.source_list
diff --git a/align_data/settings.py b/align_data/settings.py
index 43d62e8a..1244861f 100644
--- a/align_data/settings.py
+++ b/align_data/settings.py
@@ -1,30 +1,35 @@
import os
from dotenv import load_dotenv
+
load_dotenv()
### CODA ###
CODA_TOKEN = os.environ.get("CODA_TOKEN")
CODA_DOC_ID = os.environ.get("CODA_DOC_ID", "fau7sl2hmG")
-ON_SITE_TABLE = os.environ.get('CODA_ON_SITE_TABLE', 'table-aOTSHIz_mN')
+ON_SITE_TABLE = os.environ.get("CODA_ON_SITE_TABLE", "table-aOTSHIz_mN")
### GOOGLE DRIVE ###
-PDFS_FOLDER_ID = os.environ.get('PDF_FOLDER_ID', '1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN')
+PDFS_FOLDER_ID = os.environ.get("PDF_FOLDER_ID", "1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN")
### GOOGLE SHEETS ###
-METADATA_SOURCE_SPREADSHEET = os.environ.get('METADATA_SOURCE_SPREADSHEET', '1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI')
-METADATA_SOURCE_SHEET = os.environ.get('METADATA_SOURCE_SHEET', 'special_docs.csv')
-METADATA_OUTPUT_SPREADSHEET = os.environ.get('METADATA_OUTPUT_SPREADSHEET', '1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4')
+METADATA_SOURCE_SPREADSHEET = os.environ.get(
+ "METADATA_SOURCE_SPREADSHEET", "1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI"
+)
+METADATA_SOURCE_SHEET = os.environ.get("METADATA_SOURCE_SHEET", "special_docs.csv")
+METADATA_OUTPUT_SPREADSHEET = os.environ.get(
+ "METADATA_OUTPUT_SPREADSHEET", "1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4"
+)
### YouTube ###
-YOUTUBE_API_KEY = os.environ.get('YOUTUBE_API_KEY')
+YOUTUBE_API_KEY = os.environ.get("YOUTUBE_API_KEY")
### MYSQL ###
-user = os.environ.get('ARD_DB_USER', 'user')
-password = os.environ.get('ARD_DB_PASSWORD', 'we all live in a yellow submarine')
-host = os.environ.get('ARD_DB_HOST', '127.0.0.1')
-port = os.environ.get('ARD_DB_PORT', '3306')
-db_name = os.environ.get('ARD_DB_NAME', 'alignment_research_dataset')
-DB_CONNECTION_URI = f'mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}'
+user = os.environ.get("ARD_DB_USER", "user")
+password = os.environ.get("ARD_DB_PASSWORD", "we all live in a yellow submarine")
+host = os.environ.get("ARD_DB_HOST", "127.0.0.1")
+port = os.environ.get("ARD_DB_PORT", "3306")
+db_name = os.environ.get("ARD_DB_NAME", "alignment_research_dataset")
+DB_CONNECTION_URI = f"mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}"
### EMBEDDINGS ###
USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used.
@@ -36,14 +41,20 @@
OPENAI_EMBEDDINGS_DIMS = 1536
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500
-SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
+SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = (
+ "sentence-transformers/multi-qa-mpnet-base-cos-v1"
+)
SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768
### PINECONE ###
PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME", "stampy-chat-ard")
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
-PINECONE_VALUES_DIMS = OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS
+PINECONE_VALUES_DIMS = (
+ OPENAI_EMBEDDINGS_DIMS
+ if USE_OPENAI_EMBEDDINGS
+ else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS
+)
PINECONE_METRIC = "dotproduct"
PINECONE_METADATA_KEYS = ["entry_id", "source", "title", "authors", "text", "url"]
diff --git a/align_data/sources/alignment_newsletter/__init__.py b/align_data/sources/alignment_newsletter/__init__.py
index 770eb5d5..e2d65fb1 100644
--- a/align_data/sources/alignment_newsletter/__init__.py
+++ b/align_data/sources/alignment_newsletter/__init__.py
@@ -2,7 +2,7 @@
import os
ALIGNMENT_NEWSLETTER_REGISTRY = [
- AlignmentNewsletter(
- name = "alignment_newsletter" , id_fields=['url', 'title', 'source']
- ),
-]
\ No newline at end of file
+ AlignmentNewsletter(
+ name="alignment_newsletter", id_fields=["url", "title", "source"]
+ ),
+]
diff --git a/align_data/sources/alignment_newsletter/alignment_newsletter.py b/align_data/sources/alignment_newsletter/alignment_newsletter.py
index fa613640..2b68e32f 100644
--- a/align_data/sources/alignment_newsletter/alignment_newsletter.py
+++ b/align_data/sources/alignment_newsletter/alignment_newsletter.py
@@ -12,12 +12,11 @@
@dataclass
class AlignmentNewsletter(SummaryDataset):
-
done_key = "url"
- def __post_init__(self, data_path=Path(__file__).parent / '../../../data/'):
+ def __post_init__(self, data_path=Path(__file__).parent / "../../../data/"):
self.data_path = data_path
- self.raw_data_path = self.data_path / 'raw'
+ self.raw_data_path = self.data_path / "raw"
def setup(self) -> None:
super().setup()
@@ -54,26 +53,30 @@ def process_entry(self, row):
def handle_na(v, cast=None):
if not self.maybe(v):
- return ''
+ return ""
if cast:
return cast(v)
return v
- return self.make_data_entry({
- "url": handle_na(row.URL),
- "source": handle_na(self.name),
- "converted_with": "python",
- "source_type": "google-sheets",
- "venue": handle_na(row.Venue, str), # arXiv, Distill, LessWrong, Alignment Forum, ICML 2018, etc
- "newsletter_category": handle_na(row.Category, str),
- "highlight": row[2] == "Highlight",
- "newsletter_number": handle_na(row.Email, str),
- "summarizer": handle_na(row.Summarizer, str),
- "opinion": handle_na(row[11], str),
- "prerequisites": handle_na(row.Prerequisites, str),
- "read_more": handle_na(row[13], str),
- "title": handle_na(row.Title, str),
- "authors": [i.strip() for i in str(row.Authors).split(',')],
- "date_published": self._get_published_date(row.Year),
- "summary": handle_na(row.Summary, str),
- })
+ return self.make_data_entry(
+ {
+ "url": handle_na(row.URL),
+ "source": handle_na(self.name),
+ "converted_with": "python",
+ "source_type": "google-sheets",
+ "venue": handle_na(
+ row.Venue, str
+ ), # arXiv, Distill, LessWrong, Alignment Forum, ICML 2018, etc
+ "newsletter_category": handle_na(row.Category, str),
+ "highlight": row[2] == "Highlight",
+ "newsletter_number": handle_na(row.Email, str),
+ "summarizer": handle_na(row.Summarizer, str),
+ "opinion": handle_na(row[11], str),
+ "prerequisites": handle_na(row.Prerequisites, str),
+ "read_more": handle_na(row[13], str),
+ "title": handle_na(row.Title, str),
+ "authors": [i.strip() for i in str(row.Authors).split(",")],
+ "date_published": self._get_published_date(row.Year),
+ "summary": handle_na(row.Summary, str),
+ }
+ )
diff --git a/align_data/sources/arbital/__init__.py b/align_data/sources/arbital/__init__.py
index ad077e15..fda6c3db 100644
--- a/align_data/sources/arbital/__init__.py
+++ b/align_data/sources/arbital/__init__.py
@@ -1,4 +1,4 @@
from .arbital import Arbital
-ARBITAL_REGISTRY = [Arbital(name='arbital')]
+ARBITAL_REGISTRY = [Arbital(name="arbital")]
diff --git a/align_data/sources/arbital/arbital.py b/align_data/sources/arbital/arbital.py
index 4f147e12..40ce16c3 100644
--- a/align_data/sources/arbital/arbital.py
+++ b/align_data/sources/arbital/arbital.py
@@ -11,13 +11,13 @@
def parse_arbital_link(contents):
- text = contents[1].split(' ')
- url = f'https://arbital.com/p/{text[0]}'
+ text = contents[1].split(" ")
+ url = f"https://arbital.com/p/{text[0]}"
if len(text) > 1:
- title = ' '.join(text[1:])
+ title = " ".join(text[1:])
else:
title = url
- return f'[{title}]({url})'
+ return f"[{title}]({url})"
def flatten(val):
@@ -45,73 +45,78 @@ def markdownify_text(current, view):
in_link = False
for part, next_part in view:
- if part == '[':
+ if part == "[":
# Recursively try to parse this new section - it's probably a link, but can be something else
current.append(markdownify_text([part], view))
- elif part == ']' and next_part == '(':
+ elif part == "]" and next_part == "(":
# mark that it's now in the url part of a markdown link
- current.append(']')
+ current.append("]")
in_link = True
- elif part == ']':
+ elif part == "]":
# this is the arbital summary - just join it for now, but it'll have to be handled later
- if current[1].startswith('summary'):
- return ''.join(current[1:])
+ if current[1].startswith("summary"):
+ return "".join(current[1:])
# if this was a TODO section, then ignore it
- if current[1].startswith('todo'):
- return ''
+ if current[1].startswith("todo"):
+ return ""
# Otherwise it's an arbital link
return parse_arbital_link(current)
- elif in_link and part == ')':
+ elif in_link and part == ")":
# this is the end of a markdown link - just join the contents, as they're already correct
- return ''.join(current + [part])
- elif in_link and current[-1] == '(' and next_part != ')':
+ return "".join(current + [part])
+ elif in_link and current[-1] == "(" and next_part != ")":
# This link is strange... looks like it could be malformed?
# Assuming that it's malformed and missing a closing `)`
# This will remove any additional info in the link, but that seems a reasonable price?
- words = part.split(' ')
- return ''.join(current + [words[0], ') ', ' '.join(words[1:])])
+ words = part.split(" ")
+ return "".join(current + [words[0], ") ", " ".join(words[1:])])
else:
# Just your basic text - add it to the processed parts and go on your merry way
current.append(part)
# Check if the first item is the summary - if so, extract it
- summary = ''
- if current[0].startswith('summary'):
- _, summary = re.split(r'summary[()\w]*:', current[0], 1)
+ summary = ""
+ if current[0].startswith("summary"):
+ _, summary = re.split(r"summary[()\w]*:", current[0], 1)
current = current[1:]
# Otherwise just join all the parts back together
- return summary.strip(), ''.join(flatten(current)).strip()
+ return summary.strip(), "".join(flatten(current)).strip()
def extract_text(text):
- parts = [i for i in re.split('([\[\]()])', text) if i]
+ parts = [i for i in re.split("([\[\]()])", text) if i]
return markdownify_text([], zip(parts, parts[1:] + [None]))
+
@dataclass
class Arbital(AlignmentDataset):
- summary_key: str = 'summary'
+ summary_key: str = "summary"
- ARBITAL_SUBSPACES = ['ai_alignment', 'math', 'rationality']
+ ARBITAL_SUBSPACES = ["ai_alignment", "math", "rationality"]
done_key = "alias"
headers = {
- 'authority': 'arbital.com',
- 'accept': 'application/json, text/plain, */*',
- 'content-type': 'application/json;charset=UTF-8',
- 'sec-ch-ua-mobile': '?0',
- 'origin': 'https://arbital.com',
- 'sec-fetch-site': 'same-origin',
- 'sec-fetch-mode': 'cors',
- 'sec-fetch-dest': 'empty',
- 'accept-language': 'en-US,en;q=0.9',
+ "authority": "arbital.com",
+ "accept": "application/json, text/plain, */*",
+ "content-type": "application/json;charset=UTF-8",
+ "sec-ch-ua-mobile": "?0",
+ "origin": "https://arbital.com",
+ "sec-fetch-site": "same-origin",
+ "sec-fetch-mode": "cors",
+ "sec-fetch-dest": "empty",
+ "accept-language": "en-US,en;q=0.9",
}
titles_map = {}
@property
def items_list(self):
- logger.info('Getting page aliases')
- items = [alias for subspace in self.ARBITAL_SUBSPACES for alias in self.get_arbital_page_aliases(subspace)]
- logger.info('Got %s page aliases', len(items))
+ logger.info("Getting page aliases")
+ items = [
+ alias
+ for subspace in self.ARBITAL_SUBSPACES
+ for alias in self.get_arbital_page_aliases(subspace)
+ ]
+ logger.info("Got %s page aliases", len(items))
return items
def get_item_key(self, item):
@@ -120,44 +125,50 @@ def get_item_key(self, item):
def process_entry(self, alias):
try:
page = self.get_page(alias)
- summary, text = extract_text(page['text'])
-
- return self.make_data_entry({
- 'title': page.get('title') or '',
- 'text': text,
- 'date_published': self._get_published_date(page),
- 'url': f'https://arbital.com/p/{page.get("alias") or alias}',
- 'source': self.name,
- 'source_type': 'text',
- 'authors': self.extract_authors(page),
- 'alias': alias,
- 'tags': list(filter(None, map(self.get_title, page['tagIds']))),
- 'summary': summary,
- })
+ summary, text = extract_text(page["text"])
+
+ return self.make_data_entry(
+ {
+ "title": page.get("title") or "",
+ "text": text,
+ "date_published": self._get_published_date(page),
+ "url": f'https://arbital.com/p/{page.get("alias") or alias}',
+ "source": self.name,
+ "source_type": "text",
+ "authors": self.extract_authors(page),
+ "alias": alias,
+ "tags": list(filter(None, map(self.get_title, page["tagIds"]))),
+ "summary": summary,
+ }
+ )
except Exception as e:
logger.error(f"Error getting page {alias}: {e}")
return None
def get_arbital_page_aliases(self, subspace):
headers = self.headers.copy()
- headers['referer'] = f'https://arbital.com/explore/{subspace}/'
+ headers["referer"] = f"https://arbital.com/explore/{subspace}/"
data = f'{{"pageAlias":"{subspace}"}}'
- response = requests.post('https://arbital.com/json/explore/', headers=headers, data=data).json()
- return list(response['pages'].keys())
+ response = requests.post(
+ "https://arbital.com/json/explore/", headers=headers, data=data
+ ).json()
+ return list(response["pages"].keys())
@staticmethod
def _get_published_date(page):
- date_published = page.get('editCreatedAt') or page.get('pageCreatedAt')
+ date_published = page.get("editCreatedAt") or page.get("pageCreatedAt")
if date_published:
return parse(date_published).astimezone(timezone.utc)
return None
def get_page(self, alias):
headers = self.headers.copy()
- headers['referer'] = 'https://arbital.com/'
+ headers["referer"] = "https://arbital.com/"
data = f'{{"pageAlias":"{alias}"}}'
- response = requests.post('https://arbital.com/json/primaryPage/', headers=headers, data=data)
- return response.json()['pages'][alias]
+ response = requests.post(
+ "https://arbital.com/json/primaryPage/", headers=headers, data=data
+ )
+ return response.json()["pages"][alias]
def get_title(self, itemId):
if title := self.titles_map.get(itemId):
@@ -170,7 +181,7 @@ def get_title(self, itemId):
logger.error(e)
return None
- if title := page.get('title'):
+ if title := page.get("title"):
self.titles_map[itemId] = title
return title
return None
@@ -180,6 +191,6 @@ def extract_authors(self, page):
This will work faster the more its used, as it only fetches info for authors it hasn't yet seen.
"""
- authors = {c['userId'] for c in page.get('changeLogs', [])}
+ authors = {c["userId"] for c in page.get("changeLogs", [])}
return list(filter(None, map(self.get_title, authors)))
diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py
index 6775e496..01c5521f 100644
--- a/align_data/sources/articles/__init__.py
+++ b/align_data/sources/articles/__init__.py
@@ -1,36 +1,41 @@
from align_data.sources.articles.datasets import (
- EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, XMLArticles
+ EbookArticles,
+ DocArticles,
+ HTMLArticles,
+ MarkdownArticles,
+ PDFArticles,
+ XMLArticles,
)
ARTICLES_REGISTRY = [
PDFArticles(
- name='pdfs',
- spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4',
- sheet_id='0'
+ name="pdfs",
+ spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4",
+ sheet_id="0",
),
HTMLArticles(
- name='html_articles',
- spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4',
- sheet_id='759210636'
+ name="html_articles",
+ spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4",
+ sheet_id="759210636",
),
EbookArticles(
- name='ebooks',
- spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4',
- sheet_id='1800487220'
+ name="ebooks",
+ spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4",
+ sheet_id="1800487220",
),
XMLArticles(
- name='xmls',
- spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4',
- sheet_id='823056509'
+ name="xmls",
+ spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4",
+ sheet_id="823056509",
),
MarkdownArticles(
- name='markdown',
- spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4',
- sheet_id='1003473759'
+ name="markdown",
+ spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4",
+ sheet_id="1003473759",
),
DocArticles(
- name='gdocs',
- spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4',
- sheet_id='1293295703'
+ name="gdocs",
+ spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4",
+ sheet_id="1293295703",
),
]
diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py
index c7f82e65..941de037 100644
--- a/align_data/sources/articles/articles.py
+++ b/align_data/sources/articles/articles.py
@@ -4,7 +4,14 @@
from tqdm import tqdm
import gspread
-from align_data.sources.articles.google_cloud import iterate_rows, get_spreadsheet, get_sheet, upload_file, OK, with_retry
+from align_data.sources.articles.google_cloud import (
+ iterate_rows,
+ get_spreadsheet,
+ get_sheet,
+ upload_file,
+ OK,
+ with_retry,
+)
from align_data.sources.articles.parsers import item_metadata, fetch
from align_data.sources.articles.indices import fetch_all
from align_data.sources.articles.html import with_retry
@@ -15,8 +22,8 @@
# Careful changing these - the sheets assume this ordering
-REQUIRED_FIELDS = ['url', 'source_url', 'title', 'source_type', 'date_published']
-OPTIONAL_FIELDS = ['authors', 'summary']
+REQUIRED_FIELDS = ["url", "source_url", "title", "source_type", "date_published"]
+OPTIONAL_FIELDS = ["authors", "summary"]
def save_pdf(filename, link):
@@ -27,47 +34,47 @@ def save_pdf(filename, link):
:returns: the google drive id of the resulting pdf file
"""
res = fetch(link)
- if not filename.lower().endswith('.pdf'):
- filename += '.pdf'
+ if not filename.lower().endswith(".pdf"):
+ filename += ".pdf"
return upload_file(
filename,
bytes_contents=io.BytesIO(res.content),
- mimetype=res.headers.get('Content-Type'),
- parent_id=PDFS_FOLDER_ID
+ mimetype=res.headers.get("Content-Type"),
+ parent_id=PDFS_FOLDER_ID,
)
@with_retry(times=3, exceptions=gspread.exceptions.APIError)
def process_row(row, sheets):
"""Check the given `row` and fetch its metadata + optional extra stuff."""
- logger.info('Checking "%s"', row['title'])
+ logger.info('Checking "%s"', row["title"])
missing = [field for field in REQUIRED_FIELDS if not row.get(field)]
if missing:
- row.set_status('missing keys: ' + ', '.join(missing))
- logger.error('missing keys: ' + ', '.join(missing))
+ row.set_status("missing keys: " + ", ".join(missing))
+ logger.error("missing keys: " + ", ".join(missing))
return
- source_url = row.get('source_url')
+ source_url = row.get("source_url")
contents = item_metadata(source_url)
- if not contents or 'error' in contents:
- error = (contents and contents.get('error')) or 'text could not be fetched'
+ if not contents or "error" in contents:
+ error = (contents and contents.get("error")) or "text could not be fetched"
logger.error(error)
row.set_status(error)
return
- data_source = contents.get('data_source')
+ data_source = contents.get("data_source")
if data_source not in sheets:
- error = 'Unhandled data type'
+ error = "Unhandled data type"
logger.error(error)
row.set_status(error)
return
extra_fields = []
- if data_source == 'pdf':
- extra_fields = [save_pdf(row['title'], source_url)]
+ if data_source == "pdf":
+ extra_fields = [save_pdf(row["title"], source_url)]
sheets[data_source].append_row(
[row.get(field) for field in REQUIRED_FIELDS + OPTIONAL_FIELDS] + extra_fields
@@ -83,19 +90,19 @@ def process_spreadsheets(source_sheet, output_sheets):
:param Worksheet source_sheet: the worksheet to be processed - each row should be a separate entry
:param Dict[str, Worksheet] output_sheets: a dict of per data type worksheets to be updated
"""
- logger.info('fetching seen urls')
+ logger.info("fetching seen urls")
seen = {
url
for sheet in output_sheets.values()
for record in sheet.get_all_records()
- for url in [record.get('url'), record.get('source_url')]
+ for url in [record.get("url"), record.get("source_url")]
if url
}
for row in tqdm(iterate_rows(source_sheet)):
- if not row.get('source_url'):
- row['source_url'] = row['url']
- if row.get('source_url') in seen:
- title = row.get('title')
+ if not row.get("source_url"):
+ row["source_url"] = row["url"]
+ if row.get("source_url") in seen:
+ title = row.get("title")
logger.info(f'skipping "{title}", as it has already been seen')
else:
process_row(row, output_sheets)
@@ -104,31 +111,48 @@ def process_spreadsheets(source_sheet, output_sheets):
def update_new_items(source_spreadsheet, source_sheet, output_spreadsheet):
"""Go through all unprocessed items from the source worksheet, updating the appropriate metadata in the output one."""
source_sheet = get_sheet(source_spreadsheet, source_sheet)
- sheets = {sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets()}
+ sheets = {
+ sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets()
+ }
return process_spreadsheets(source_sheet, sheets)
def check_new_articles(source_spreadsheet, source_sheet):
"""Goes through the special indices looking for unseen articles."""
source_sheet = get_sheet(source_spreadsheet, source_sheet)
- current = {row.get('title'): row for row in iterate_rows(source_sheet)}
- seen_urls = {url for item in current.values() for url in [item.get('url'), item.get('source_url')] if url}
+ current = {row.get("title"): row for row in iterate_rows(source_sheet)}
+ seen_urls = {
+ url
+ for item in current.values()
+ for url in [item.get("url"), item.get("source_url")]
+ if url
+ }
indices_items = fetch_all()
missing = [
- item for title, item in indices_items.items()
- if title not in current and not {item.get('url'), item.get('source_url')} & seen_urls
+ item
+ for title, item in indices_items.items()
+ if title not in current
+ and not {item.get("url"), item.get("source_url")} & seen_urls
]
if not missing:
- logger.info('No new articles found')
+ logger.info("No new articles found")
return 0
- columns = ['status', 'source_url', 'url', 'title', 'date_published', 'authors', 'publication_title', 'source_type']
- res = source_sheet.append_rows([
- [item.get(col) for col in columns]
- for item in missing
- ])
- updated = res['updates']['updatedRows']
- logger.info('Added %s rows', updated)
+ columns = [
+ "status",
+ "source_url",
+ "url",
+ "title",
+ "date_published",
+ "authors",
+ "publication_title",
+ "source_type",
+ ]
+ res = source_sheet.append_rows(
+ [[item.get(col) for col in columns] for item in missing]
+ )
+ updated = res["updates"]["updatedRows"]
+ logger.info("Added %s rows", updated)
return updated
diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py
index a6328f42..2f25a606 100644
--- a/align_data/sources/articles/datasets.py
+++ b/align_data/sources/articles/datasets.py
@@ -19,7 +19,6 @@
@dataclass
class SpreadsheetDataset(AlignmentDataset):
-
spreadsheet_id: str
sheet_id: str
done_key = "title"
@@ -40,9 +39,15 @@ def is_val(val):
@property
def items_list(self):
- logger.info(f'Fetching https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=CS&gid={self.sheet_id}')
- df = pd.read_csv(f'https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}')
- return (item for item in df.itertuples() if not pd.isna(self.get_item_key(item)))
+ logger.info(
+ f"Fetching https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=CS&gid={self.sheet_id}"
+ )
+ df = pd.read_csv(
+ f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}"
+ )
+ return (
+ item for item in df.itertuples() if not pd.isna(self.get_item_key(item))
+ )
def get_item_key(self, item):
return getattr(item, self.done_key)
@@ -55,30 +60,31 @@ def _get_text(item):
def extract_authors(item):
if not SpreadsheetDataset.maybe(item.authors):
return []
- return [author.strip() for author in item.authors.split(',') if author.strip()]
+ return [author.strip() for author in item.authors.split(",") if author.strip()]
def process_entry(self, item):
text = self._get_text(item)
if not text:
- logger.error('Could not get text for %s - skipping for now', item.title)
+ logger.error("Could not get text for %s - skipping for now", item.title)
return None
- return self.make_data_entry({
- 'text': markdownify(text).strip(),
- 'url': self.maybe(item.url),
- 'title': self.maybe(item.title),
- 'source': self.name,
- 'source_type': self.maybe(item.source_type),
- 'source_filetype': self.source_filetype,
- 'date_published': self._get_published_date(item.date_published),
- 'authors': self.extract_authors(item),
- 'summary': self.maybe(item.summary),
- })
+ return self.make_data_entry(
+ {
+ "text": markdownify(text).strip(),
+ "url": self.maybe(item.url),
+ "title": self.maybe(item.title),
+ "source": self.name,
+ "source_type": self.maybe(item.source_type),
+ "source_filetype": self.source_filetype,
+ "date_published": self._get_published_date(item.date_published),
+ "authors": self.extract_authors(item),
+ "summary": self.maybe(item.summary),
+ }
+ )
class PDFArticles(SpreadsheetDataset):
-
- source_filetype = 'pdf'
+ source_filetype = "pdf"
COOLDOWN = 1
batch_size = 1
@@ -87,28 +93,26 @@ def setup(self):
self.files_path.mkdir(exist_ok=True, parents=True)
def _get_text(self, item):
- url = f'https://drive.google.com/uc?id={item.file_id}'
+ url = f"https://drive.google.com/uc?id={item.file_id}"
- filename = self.files_path / f'{item.title}.pdf'
+ filename = self.files_path / f"{item.title}.pdf"
if download(output=str(filename), id=item.file_id):
return read_pdf(filename)
class HTMLArticles(SpreadsheetDataset):
-
- source_filetype = 'html'
+ source_filetype = "html"
@staticmethod
def _get_text(item):
- domain = urlparse(item.source_url).netloc.lstrip('www.')
+ domain = urlparse(item.source_url).netloc.lstrip("www.")
if parser := HTML_PARSERS.get(domain):
return parser(item.source_url)
class EbookArticles(SpreadsheetDataset):
-
- source_filetype = 'epub'
- COOLDOWN = 10 # Add a large cooldown, as google complains a lot
+ source_filetype = "epub"
+ COOLDOWN = 10 # Add a large cooldown, as google complains a lot
batch_size = 1
def setup(self):
@@ -116,44 +120,43 @@ def setup(self):
self.files_path.mkdir(exist_ok=True, parents=True)
def _get_text(self, item):
- file_id = item.source_url.split('/')[-2]
- filename = download(output=str(self.files_path / f'{item.title}.epub'), id=file_id)
- return convert_file(filename, "plain",'epub', extra_args=['--wrap=none'])
+ file_id = item.source_url.split("/")[-2]
+ filename = download(
+ output=str(self.files_path / f"{item.title}.epub"), id=file_id
+ )
+ return convert_file(filename, "plain", "epub", extra_args=["--wrap=none"])
class XMLArticles(SpreadsheetDataset):
-
- source_filetype = 'xml'
+ source_filetype = "xml"
def _get_text(self, item):
vals = extract_gdrive_contents(item.source_url)
- return vals['text']
+ return vals["text"]
class MarkdownArticles(SpreadsheetDataset):
-
- source_filetype = 'md'
+ source_filetype = "md"
def _get_text(self, item):
- file_id = item.source_url.split('/')[-2]
+ file_id = item.source_url.split("/")[-2]
vals = fetch_markdown(file_id)
- return vals['text']
+ return vals["text"]
class DocArticles(SpreadsheetDataset):
-
- source_filetype = 'docx'
+ source_filetype = "docx"
def setup(self):
super().setup()
self.files_path.mkdir(exist_ok=True, parents=True)
def _get_text(self, item):
- pandoc_path = Path('data/raw/pandoc/pandoc/')
+ pandoc_path = Path("data/raw/pandoc/pandoc/")
if pandoc_path.exists():
logger.info("Make sure pandoc is configured correctly.")
os.environ.setdefault("PYPANDOC_PANDOC", str(pandoc_path))
- file_id = item.source_url.split('/')[-2]
+ file_id = item.source_url.split("/")[-2]
file_name = fetch_file(file_id)
- return convert_file(file_name, "md", format='docx', extra_args=['--wrap=none'])
+ return convert_file(file_name, "md", format="docx", extra_args=["--wrap=none"])
diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py
index 36946e89..d0f8646a 100644
--- a/align_data/sources/articles/google_cloud.py
+++ b/align_data/sources/articles/google_cloud.py
@@ -13,17 +13,17 @@
SCOPES = [
- 'https://www.googleapis.com/auth/spreadsheets',
- 'https://www.googleapis.com/auth/drive'
+ "https://www.googleapis.com/auth/spreadsheets",
+ "https://www.googleapis.com/auth/drive",
]
-OK = 'ok'
-OUTPUT_SPREADSHEET_ID = '1bg-6vL-I82CBRkxvWQs1-Ao0nTvHyfn4yns5MdlbCmY'
-sheet_name = 'Sheet1'
+OK = "ok"
+OUTPUT_SPREADSHEET_ID = "1bg-6vL-I82CBRkxvWQs1-Ao0nTvHyfn4yns5MdlbCmY"
+sheet_name = "Sheet1"
-def get_credentials(credentials_file='credentials.json'):
+def get_credentials(credentials_file="credentials.json"):
return Credentials.from_service_account_file(credentials_file, scopes=SCOPES)
@@ -53,20 +53,20 @@ def update_value(self, col, value):
self.sheet.update_cell(self.row_id, self.columns.index(col) + 1, value)
def update_colour(self, col, colour):
- col_letter = chr(ord('A') + self.columns.index(col))
- self.sheet.format(f'{col_letter}{self.row_id}', {"backgroundColor": colour})
+ col_letter = chr(ord("A") + self.columns.index(col))
+ self.sheet.format(f"{col_letter}{self.row_id}", {"backgroundColor": colour})
- def set_status(self, status, status_col='status'):
+ def set_status(self, status, status_col="status"):
if self.get(status_col) == status:
# Don't update anything if the status is the same - this saves on gdocs calls
return
if status == OK:
- colour = {'red': 0, 'green': 1, 'blue': 0}
- elif status == '':
- colour = {'red': 1, 'green': 1, 'blue': 1}
+ colour = {"red": 0, "green": 1, "blue": 0}
+ elif status == "":
+ colour = {"red": 1, "green": 1, "blue": 1}
else:
- colour = {'red': 1, 'green': 0, 'blue': 0}
+ colour = {"red": 1, "green": 0, "blue": 0}
self.update_value(status_col, status)
self.update_colour(status_col, colour)
@@ -91,37 +91,41 @@ def upload_file(filename, bytes_contents, mimetype, parent_id=None):
"""
credentials = get_credentials()
- drive_service = build('drive', 'v3', credentials=credentials)
+ drive_service = build("drive", "v3", credentials=credentials)
- file_metadata = {
- 'name': filename,
- 'parents': parent_id and [parent_id]
- }
- media = drive_service.files().create(
- body=file_metadata,
- media_body=MediaIoBaseUpload(bytes_contents, mimetype=mimetype)
- ).execute()
- return media.get('id')
+ file_metadata = {"name": filename, "parents": parent_id and [parent_id]}
+ media = (
+ drive_service.files()
+ .create(
+ body=file_metadata,
+ media_body=MediaIoBaseUpload(bytes_contents, mimetype=mimetype),
+ )
+ .execute()
+ )
+ return media.get("id")
def with_retry(times=3):
"""A decorator that will retry the wrapped function up to `times` times in case of google sheets errors."""
+
def wrapper(f):
def retrier(*args, **kwargs):
for i in range(times):
try:
return f(*args, **kwargs)
except gspread.exceptions.APIError as e:
- logger.error(f'{e} - retrying up to {times - i} times')
+ logger.error(f"{e} - retrying up to {times - i} times")
# Do a logarithmic backoff
time.sleep((i + 1) ** 2)
- raise ValueError(f'Gave up after {times} tries')
+ raise ValueError(f"Gave up after {times} tries")
+
return retrier
+
return wrapper
def fetch_file(file_id):
- data_path = Path('data/raw/')
+ data_path = Path("data/raw/")
data_path.mkdir(parents=True, exist_ok=True)
file_name = data_path / file_id
return gdown.download(id=file_id, output=str(file_name), quiet=False)
@@ -131,8 +135,8 @@ def fetch_markdown(file_id):
try:
file_name = fetch_file(file_id)
return {
- 'text': Path(file_name).read_text(),
- 'data_source': 'markdown',
+ "text": Path(file_name).read_text(),
+ "data_source": "markdown",
}
except Exception as e:
- return {'error': str(e)}
+ return {"error": str(e)}
diff --git a/align_data/sources/articles/html.py b/align_data/sources/articles/html.py
index 152ea8dc..80b5dc8b 100644
--- a/align_data/sources/articles/html.py
+++ b/align_data/sources/articles/html.py
@@ -10,27 +10,30 @@
DEFAULT_HEADERS = {
- 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0',
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0",
}
def with_retry(times=3, exceptions=requests.exceptions.RequestException):
"""A decorator that will retry the wrapped function up to `times` times in case of google sheets errors."""
+
def wrapper(f):
def retrier(*args, **kwargs):
for i in range(times):
try:
return f(*args, **kwargs)
except exceptions as e:
- logger.error(f'{e} - retrying up to {times - i} times')
+ logger.error(f"{e} - retrying up to {times - i} times")
# Do a logarithmic backoff
time.sleep((i + 1) ** 2)
- raise ValueError(f'Gave up after {times} tries')
+ raise ValueError(f"Gave up after {times} tries")
+
return retrier
+
return wrapper
-def fetch(url, method='get', headers=DEFAULT_HEADERS):
+def fetch(url, method="get", headers=DEFAULT_HEADERS):
"""Fetch the given `url`.
This function is to have a single place to manage headers etc.
@@ -43,7 +46,7 @@ def fetch_element(url: str, selector: str, headers=DEFAULT_HEADERS) -> Union[Tag
try:
resp = fetch(url, headers=headers)
except requests.exceptions.ConnectionError:
- logger.error('Could not connect to %s', url)
+ logger.error("Could not connect to %s", url)
return None
soup = BeautifulSoup(resp.content, "html.parser")
@@ -57,6 +60,7 @@ def element_extractor(selector, remove=[]):
:param List[str] remove: An optional list of selectors to be removed from the resulting HTML. Useful for removing footers etc.
:returns: A function that expects to get an URL, and which will then return the contents of the selected HTML element as markdown.
"""
+
def getter(url):
elem = fetch_element(url, selector)
if not elem:
diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py
index 6eb5761c..0cf2c45a 100644
--- a/align_data/sources/articles/indices.py
+++ b/align_data/sources/articles/indices.py
@@ -11,7 +11,7 @@
def get_text(tag, selector: str) -> str:
if item := tag.select_one(selector):
return item.text
- return ''
+ return ""
def indice_fetcher(url, main_selector, item_selector, formatter):
@@ -19,107 +19,114 @@ def fetcher():
if contents := fetch_element(url, main_selector):
return list(filter(None, map(formatter, contents.select(item_selector))))
return []
+
return fetcher
def reading_what_we_can_items():
- res = fetch('https://readingwhatwecan.com/books.js')
+ res = fetch("https://readingwhatwecan.com/books.js")
items = {
item
- for section in re.findall(r'\[(.*?)\]', res.text, re.DOTALL)
- for item in re.findall(r'Name: "(.*?)",.*?Link: "(.*?)",.*?Author: "(.*?)"', section, re.DOTALL)
+ for section in re.findall(r"\[(.*?)\]", res.text, re.DOTALL)
+ for item in re.findall(
+ r'Name: "(.*?)",.*?Link: "(.*?)",.*?Author: "(.*?)"', section, re.DOTALL
+ )
}
- return [{
- 'title': item[0],
- 'url': item[1],
- 'authors': item[2]
- } for item in items]
+ return [{"title": item[0], "url": item[1], "authors": item[2]} for item in items]
def aisafetysupport():
- contents = fetch_element('https://www.aisafetysupport.org/resources/lots-of-links', 'header + div')
- sections = ['Research Maps and Reviews', 'Research Agendas', 'Books, papers, podcasts, videos']
- sections = [s for s in contents.select('section') if get_text(s, 'h2') in sections]
+ contents = fetch_element(
+ "https://www.aisafetysupport.org/resources/lots-of-links", "header + div"
+ )
+ sections = [
+ "Research Maps and Reviews",
+ "Research Agendas",
+ "Books, papers, podcasts, videos",
+ ]
+ sections = [s for s in contents.select("section") if get_text(s, "h2") in sections]
return [
- {'title': a.text, 'url': a.get('href')}
+ {"title": a.text, "url": a.get("href")}
for section in sections
- for a in section.select('a')
- if a.text and a.get('href').startswith('http')
+ for a in section.select("a")
+ if a.text and a.get("href").startswith("http")
]
def format_mlsafety_course(a):
- if (a.get('href') or '').startswith('http'):
- return {'title': a.text, 'url': a.get('href')}
+ if (a.get("href") or "").startswith("http"):
+ return {"title": a.text, "url": a.get("href")}
def format_anthropic(post):
- if date_published := parse(get_text(post, 'div.post-date')):
+ if date_published := parse(get_text(post, "div.post-date")):
date_published = AlignmentDataset._format_datetime(date_published)
- url = post.get('href')
+ url = post.get("href")
- if source_url := fetch_element(url, 'article .post-heading a.btn-primary'):
- source_url = source_url.get('href')
+ if source_url := fetch_element(url, "article .post-heading a.btn-primary"):
+ source_url = source_url.get("href")
return {
- 'title': get_text(post, 'div.post-heading'),
- 'url': url,
- 'source_url': source_url,
- 'date_published': date_published,
+ "title": get_text(post, "div.post-heading"),
+ "url": url,
+ "source_url": source_url,
+ "date_published": date_published,
}
def format_transformer_circuits(item):
- if not item.get('href').startswith('http'):
+ if not item.get("href").startswith("http"):
url = f'https://transformer-circuits.pub/{item.get("href")}'
return {
- 'title': get_text(item, 'h3'),
- 'url': url,
- 'source_url': url,
+ "title": get_text(item, "h3"),
+ "url": url,
+ "source_url": url,
}
def format_safe_ai(item):
return {
- 'title': get_text(item, 'h4'),
- 'url': item.find('a').get('href'),
- 'source_url': item.find('a').get('href'),
- 'authors': get_text(item, 'h4 ~ p')
+ "title": get_text(item, "h4"),
+ "url": item.find("a").get("href"),
+ "source_url": item.find("a").get("href"),
+ "authors": get_text(item, "h4 ~ p"),
}
def format_far_ai(item):
return {
- 'title': get_text(item, '.article-title'),
- 'url': f'https://www.safe.ai/research{item.select_one(".article-title a").get("href")}',
- 'source_url': item.select_one('div.btn-links a:-soup-contains("PDF")').get('href'),
- 'authors': ', '.join(i.text for i in item.select('.article-metadata a')),
+ "title": get_text(item, ".article-title"),
+ "url": f'https://www.safe.ai/research{item.select_one(".article-title a").get("href")}',
+ "source_url": item.select_one('div.btn-links a:-soup-contains("PDF")').get(
+ "href"
+ ),
+ "authors": ", ".join(i.text for i in item.select(".article-metadata a")),
}
def format_redwoodresearch(item):
- url = item.select_one('.list-item-content__button-container a').get('href')
- authors = get_text(item, 'em')
+ url = item.select_one(".list-item-content__button-container a").get("href")
+ authors = get_text(item, "em")
try:
- parts = authors.split(', ')
+ parts = authors.split(", ")
date_published = parse(parts[-1])
date_published = AlignmentDataset._format_datetime(date_published)
- authors = ', '.join(parts[:-1])
+ authors = ", ".join(parts[:-1])
except ParserError:
date_published = None
return {
- 'title': get_text(item, 'h2'),
- 'url': url,
- 'source_url': url,
- 'authors': authors,
- 'date_published': date_published,
+ "title": get_text(item, "h2"),
+ "url": url,
+ "source_url": url,
+ "authors": authors,
+ "date_published": date_published,
}
def format_chai_research(item):
- author_block = next(item.children).strip().strip('.')
- authors = parts = author_block.split('.')
+ author_block = next(item.children).strip().strip(".")
+ authors = parts = author_block.split(".")
try:
int(parts[-1].strip())
date_published = parts[-1].strip()
@@ -127,47 +134,47 @@ def format_chai_research(item):
except ValueError:
date_published = None
- url = item.select_one('a').get('href')
+ url = item.select_one("a").get("href")
return {
- 'title': get_text(item, 'a'),
- 'url': url,
- 'source_url': url,
- 'authors': ', '.join(authors),
- 'date_published': date_published,
+ "title": get_text(item, "a"),
+ "url": url,
+ "source_url": url,
+ "authors": ", ".join(authors),
+ "date_published": date_published,
}
def format_chai_bibliography(item):
- return {
- 'title': get_text(item, '.bib-entry-title a'),
- 'url': item.select_one('.bib-entry-title a').get('href'),
- 'authors': item.select_one('.bib-entry-title a').next_sibling.strip(',. ')
- }
+ return {
+ "title": get_text(item, ".bib-entry-title a"),
+ "url": item.select_one(".bib-entry-title a").get("href"),
+ "authors": item.select_one(".bib-entry-title a").next_sibling.strip(",. "),
+ }
def format_chai_newsletter(item):
- if item.text.strip().startswith('CHAI Newsletter'):
+ if item.text.strip().startswith("CHAI Newsletter"):
return {
- 'title': item.text,
- 'url': item.get('href'),
- 'source_url': item.get('href'),
+ "title": item.text,
+ "url": item.get("href"),
+ "source_url": item.get("href"),
}
def format_neel_nanda_fav(item):
- url = item.find('a').get('href').strip()
- if not url.startswith('http'):
+ url = item.find("a").get("href").strip()
+ if not url.startswith("http"):
return None
try:
- title = item.find('p').extract().text
+ title = item.find("p").extract().text
except:
- title = get_text(item, 'a')
+ title = get_text(item, "a")
return {
- 'title': title.replace('\n', ' '),
- 'url': url,
- 'summary': MarkdownConverter().convert_soup(item).strip()
+ "title": title.replace("\n", " "),
+ "url": url,
+ "summary": MarkdownConverter().convert_soup(item).strip(),
}
@@ -175,20 +182,70 @@ def fetch_all():
fetchers = [
reading_what_we_can_items,
aisafetysupport,
- indice_fetcher('https://www.neelnanda.io/mechanistic-interpretability/favourite-papers', 'article', 'div > ul > li', format_neel_nanda_fav),
- indice_fetcher('https://course.mlsafety.org/readings/', 'div.main-content', 'a', format_mlsafety_course),
- indice_fetcher('https://www.anthropic.com/research', 'div.b-postList', 'a', format_anthropic),
- indice_fetcher('https://transformer-circuits.pub/', 'div.toc', 'a', format_transformer_circuits),
- indice_fetcher('https://www.safe.ai/research', '#guiding-principles', 'div.card.is-document', format_safe_ai),
- indice_fetcher('https://far.ai/publication/', '#container-publications', 'div.media-body', format_far_ai),
- indice_fetcher('https://www.redwoodresearch.org/research', 'article', '.list-item', format_redwoodresearch),
- indice_fetcher('https://humancompatible.ai/research', 'article', '.publications li', format_chai_research),
- indice_fetcher('https://humancompatible.ai/bibliography', '#content', '.bib-entry', format_chai_bibliography),
- indice_fetcher('https://humancompatible.ai/newsletter/', 'article', 'a', format_chai_newsletter),
+ indice_fetcher(
+ "https://www.neelnanda.io/mechanistic-interpretability/favourite-papers",
+ "article",
+ "div > ul > li",
+ format_neel_nanda_fav,
+ ),
+ indice_fetcher(
+ "https://course.mlsafety.org/readings/",
+ "div.main-content",
+ "a",
+ format_mlsafety_course,
+ ),
+ indice_fetcher(
+ "https://www.anthropic.com/research",
+ "div.b-postList",
+ "a",
+ format_anthropic,
+ ),
+ indice_fetcher(
+ "https://transformer-circuits.pub/",
+ "div.toc",
+ "a",
+ format_transformer_circuits,
+ ),
+ indice_fetcher(
+ "https://www.safe.ai/research",
+ "#guiding-principles",
+ "div.card.is-document",
+ format_safe_ai,
+ ),
+ indice_fetcher(
+ "https://far.ai/publication/",
+ "#container-publications",
+ "div.media-body",
+ format_far_ai,
+ ),
+ indice_fetcher(
+ "https://www.redwoodresearch.org/research",
+ "article",
+ ".list-item",
+ format_redwoodresearch,
+ ),
+ indice_fetcher(
+ "https://humancompatible.ai/research",
+ "article",
+ ".publications li",
+ format_chai_research,
+ ),
+ indice_fetcher(
+ "https://humancompatible.ai/bibliography",
+ "#content",
+ ".bib-entry",
+ format_chai_bibliography,
+ ),
+ indice_fetcher(
+ "https://humancompatible.ai/newsletter/",
+ "article",
+ "a",
+ format_chai_newsletter,
+ ),
]
articles = defaultdict(dict)
for func in tqdm(fetchers):
for item in func():
- articles[item['title']].update(item)
+ articles[item["title"]].update(item)
return articles
diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py
index 3e8c60bc..f8708b5f 100644
--- a/align_data/sources/articles/parsers.py
+++ b/align_data/sources/articles/parsers.py
@@ -5,7 +5,13 @@
import grobid_tei_xml
import regex as re
from align_data.sources.articles.html import element_extractor, fetch, fetch_element
-from align_data.sources.articles.pdf import doi_getter, fetch_pdf, get_pdf_from_page, get_arxiv_pdf, parse_vanity
+from align_data.sources.articles.pdf import (
+ doi_getter,
+ fetch_pdf,
+ get_pdf_from_page,
+ get_arxiv_pdf,
+ parse_vanity,
+)
from align_data.sources.articles.google_cloud import fetch_markdown
from markdownify import MarkdownConverter
from bs4 import BeautifulSoup
@@ -15,12 +21,14 @@
def google_doc(url: str) -> str:
"""Fetch the contents of the given gdoc url as markdown."""
- res = re.search(r'https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/', url)
+ res = re.search(r"https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/", url)
if not res:
return None
doc_id = res.group(1)
- body = fetch_element(f'https://docs.google.com/document/d/{doc_id}/export?format=html', 'body')
+ body = fetch_element(
+ f"https://docs.google.com/document/d/{doc_id}/export?format=html", "body"
+ )
if body:
return MarkdownConverter().convert_soup(body).strip()
@@ -28,12 +36,12 @@ def google_doc(url: str) -> str:
def medium_blog(url):
"""Return the contents of the medium article at the given URL as markdown."""
# Medium does some magic redirects if it detects that the request is from firefox
- article = fetch_element(url, 'article', headers=None)
+ article = fetch_element(url, "article", headers=None)
if not article:
return None
# remove the header
- if title := article.find('h1'):
+ if title := article.find("h1"):
title.parent.extract()
return MarkdownConverter().convert_soup(article).strip()
@@ -41,12 +49,15 @@ def medium_blog(url):
def parse_grobid(contents):
doc_dict = grobid_tei_xml.parse_document_xml(contents).to_dict()
- authors = [xx["full_name"].strip(' !') for xx in doc_dict.get("header", {}).get("authors", [])]
+ authors = [
+ xx["full_name"].strip(" !")
+ for xx in doc_dict.get("header", {}).get("authors", [])
+ ]
- if not doc_dict.get('body'):
+ if not doc_dict.get("body"):
return {
- 'error': 'No contents in XML file',
- 'data_source': 'xml',
+ "error": "No contents in XML file",
+ "data_source": "xml",
}
return {
@@ -59,218 +70,268 @@ def parse_grobid(contents):
def get_content_type(res):
- header = res.headers.get('Content-Type') or ''
- parts = [c_type.strip().lower() for c_type in header.split(';')]
+ header = res.headers.get("Content-Type") or ""
+ parts = [c_type.strip().lower() for c_type in header.split(";")]
return set(filter(None, parts))
def extract_gdrive_contents(link):
- file_id = link.split('/')[-2]
- url = f'https://drive.google.com/uc?id={file_id}'
- res = fetch(url, 'head')
+ file_id = link.split("/")[-2]
+ url = f"https://drive.google.com/uc?id={file_id}"
+ res = fetch(url, "head")
if res.status_code == 403:
- logger.error('Could not fetch the file at %s - 403 returned', link)
- return {'error': 'Could not read file from google drive - forbidden'}
+ logger.error("Could not fetch the file at %s - 403 returned", link)
+ return {"error": "Could not read file from google drive - forbidden"}
if res.status_code >= 400:
- logger.error('Could not fetch the file at %s - are you sure that link is correct?', link)
- return {'error': 'Could not read file from google drive'}
+ logger.error(
+ "Could not fetch the file at %s - are you sure that link is correct?", link
+ )
+ return {"error": "Could not read file from google drive"}
result = {
- 'source_url': link,
- 'downloaded_from': 'google drive',
+ "source_url": link,
+ "downloaded_from": "google drive",
}
content_type = get_content_type(res)
if not content_type:
- result['error'] = 'no content type'
- elif content_type & {'application/octet-stream', 'application/pdf'}:
+ result["error"] = "no content type"
+ elif content_type & {"application/octet-stream", "application/pdf"}:
result.update(fetch_pdf(url))
- elif content_type & {'text/markdown'}:
+ elif content_type & {"text/markdown"}:
result.update(fetch_markdown(file_id))
- elif content_type & {'application/epub+zip', 'application/epub'}:
- result['data_source'] = 'ebook'
- elif content_type & {'text/html'}:
+ elif content_type & {"application/epub+zip", "application/epub"}:
+ result["data_source"] = "ebook"
+ elif content_type & {"text/html"}:
res = fetch(url)
- if 'Google Drive - Virus scan warning' in res.text:
- element_extractor('form')
+ if "Google Drive - Virus scan warning" in res.text:
+ element_extractor("form")
soup = BeautifulSoup(res.content, "html.parser")
- res = fetch(soup.select_one('form').get('action'))
+ res = fetch(soup.select_one("form").get("action"))
content_type = get_content_type(res)
- if content_type & {'text/xml'}:
+ if content_type & {"text/xml"}:
result.update(parse_grobid(res.content))
- elif content_type & {'text/html'}:
+ elif content_type & {"text/html"}:
soup = BeautifulSoup(res.content, "html.parser")
- result.update({
- 'text': MarkdownConverter().convert_soup(soup.select_one('body')).strip(),
- 'data_source': 'html',
- })
+ result.update(
+ {
+ "text": MarkdownConverter()
+ .convert_soup(soup.select_one("body"))
+ .strip(),
+ "data_source": "html",
+ }
+ )
else:
- result['error'] = f'unknown content type: {content_type}'
+ result["error"] = f"unknown content type: {content_type}"
else:
- result['error'] = f'unknown content type: {content_type}'
+ result["error"] = f"unknown content type: {content_type}"
return result
def error(error_msg):
"""Returns a url handler function that just logs the provided `error` string."""
+
def func(url):
if error_msg:
logger.error(error_msg)
return error_msg
+
return func
def multistrategy(*funcs):
"""Merges multiple getter functions, returning the result of the first function call to succeed."""
+
def getter(url):
for func in funcs:
res = func(url)
- if res and 'error' not in res:
+ if res and "error" not in res:
return res
+
return getter
UNIMPLEMENTED_PARSERS = {
# Unhandled items that will be caught later. Though it would be good for them also to be done properly
- 'oxford.universitypressscholarship.com': error(''),
-
+ "oxford.universitypressscholarship.com": error(""),
# Paywalled journal
- 'linkinghub.elsevier.com': error('Elsevier is a known parasite - no point in looking to them for content'),
- 'link.springer.com': error('This article looks paywalled'),
- 'www.dl.begellhouse.com': error('This article is paywalled'),
-
+ "linkinghub.elsevier.com": error(
+ "Elsevier is a known parasite - no point in looking to them for content"
+ ),
+ "link.springer.com": error("This article looks paywalled"),
+ "www.dl.begellhouse.com": error("This article is paywalled"),
# To be implemented
- 'goodreads.com': error('Ebooks are not yet handled'),
- 'judiciary.senate.gov': error(''),
- 'taylorfrancis.com': error('Ebooks are not yet handled'),
- 'YouTube.com': error('Youtube videos are not yet handled'),
- 'researchgate.net': error('Researchgate makes it hard to auto download pdf - please provide a DOI or a different url to the contents'),
- 'repository.cam.ac.uk': error(''),
+ "goodreads.com": error("Ebooks are not yet handled"),
+ "judiciary.senate.gov": error(""),
+ "taylorfrancis.com": error("Ebooks are not yet handled"),
+ "YouTube.com": error("Youtube videos are not yet handled"),
+ "researchgate.net": error(
+ "Researchgate makes it hard to auto download pdf - please provide a DOI or a different url to the contents"
+ ),
+ "repository.cam.ac.uk": error(""),
}
HTML_PARSERS = {
- 'academic.oup.com': element_extractor('#ContentTab'),
- 'ai.googleblog.com': element_extractor('div.post-body.entry-content'),
- 'arxiv-vanity.com': parse_vanity,
- 'ar5iv.labs.arxiv.org': parse_vanity,
- 'bair.berkeley.edu': element_extractor('article'),
- 'mediangroup.org': element_extractor('div.entry-content'),
- 'www.alexirpan.com': element_extractor('article'),
- 'www.incompleteideas.net': element_extractor('body'),
- 'ai-alignment.com': medium_blog,
- 'aisrp.org': element_extractor('article'),
- 'bounded-regret.ghost.io': element_extractor('div.post-content'),
- 'carnegieendowment.org': element_extractor('div.article-body', remove=['.no-print', '.related-pubs']),
- 'casparoesterheld.com': element_extractor('.entry-content', remove=['div.sharedaddy']),
- 'cullenokeefe.com': element_extractor('div.sqs-block-content'),
- 'deepmindsafetyresearch.medium.com': medium_blog,
- 'docs.google.com': google_doc,
- 'docs.microsoft.com': element_extractor('div.content'),
- 'digichina.stanford.edu': element_extractor('div.h_editor-content'),
- 'en.wikipedia.org': element_extractor('main.mw-body'),
- 'eng.uber.com': element_extractor('div.article-body'),
- 'futureoflife.org': multistrategy(
- element_extractor('div.body-content'),
- element_extractor('#main-content'),
+ "academic.oup.com": element_extractor("#ContentTab"),
+ "ai.googleblog.com": element_extractor("div.post-body.entry-content"),
+ "arxiv-vanity.com": parse_vanity,
+ "ar5iv.labs.arxiv.org": parse_vanity,
+ "bair.berkeley.edu": element_extractor("article"),
+ "mediangroup.org": element_extractor("div.entry-content"),
+ "www.alexirpan.com": element_extractor("article"),
+ "www.incompleteideas.net": element_extractor("body"),
+ "ai-alignment.com": medium_blog,
+ "aisrp.org": element_extractor("article"),
+ "bounded-regret.ghost.io": element_extractor("div.post-content"),
+ "carnegieendowment.org": element_extractor(
+ "div.article-body", remove=[".no-print", ".related-pubs"]
+ ),
+ "casparoesterheld.com": element_extractor(
+ ".entry-content", remove=["div.sharedaddy"]
+ ),
+ "cullenokeefe.com": element_extractor("div.sqs-block-content"),
+ "deepmindsafetyresearch.medium.com": medium_blog,
+ "docs.google.com": google_doc,
+ "docs.microsoft.com": element_extractor("div.content"),
+ "digichina.stanford.edu": element_extractor("div.h_editor-content"),
+ "en.wikipedia.org": element_extractor("main.mw-body"),
+ "eng.uber.com": element_extractor("div.article-body"),
+ "futureoflife.org": multistrategy(
+ element_extractor("div.body-content"),
+ element_extractor("#main-content"),
),
- 'gcrinstitute.org': element_extractor('div.blog-content'),
- 'jbkjr.me': element_extractor('section.page__content'),
- 'link.springer.com': element_extractor('article.c-article-body'),
- 'longtermrisk.org': element_extractor('div.entry-content'),
- 'lukemuehlhauser.com': element_extractor('div.entry-content'),
- 'medium.com': medium_blog,
- 'openai.com': element_extractor('#content'),
- 'ought.org': element_extractor('div.BlogPostBodyContainer'),
- 'sideways-view.com': element_extractor('article', remove=['header']),
- 'slatestarcodex.com': element_extractor('div.pjgm-postcontent'),
- 'techpolicy.press': element_extractor('div.post-content', remove=['div.before_content', '.sabox-guest-authors-container', '.jp-relatedposts']),
- 'theconversation.com': element_extractor('div.content-body'),
- 'thegradient.pub': element_extractor('div.c-content'),
- 'towardsdatascience.com': medium_blog,
- 'unstableontology.com': element_extractor('.entry-content', remove=['div.sharedaddy']),
- 'waitbutwhy.com': element_extractor('article', remove=['.entry-header']),
- 'weightagnostic.github.io': element_extractor('dt-article', remove=['#authors_section', 'dt-byline']),
- 'cnas.org': element_extractor('#mainbar-toc'),
- 'econlib.org': element_extractor('div.post-content'),
- 'humanityplus.org': element_extractor('div.content'),
- 'gleech.org': element_extractor('article.post-content', remove=['center', 'div.accordion']),
- 'ibm.com': element_extractor('div:has(> p)'), # IBM's HTML is really ugly...
- 'microsoft.com': element_extractor('div.content-container'),
- 'mdpi.com': element_extractor(
- 'article', remove=[
- '.article-icons', '.title', '.art-authors', '.art-affiliations', '.bib-identity',
- '.pubhistory', '.belongsTo', '.highlight-box1', '.additional-content'
- ]
+ "gcrinstitute.org": element_extractor("div.blog-content"),
+ "jbkjr.me": element_extractor("section.page__content"),
+ "link.springer.com": element_extractor("article.c-article-body"),
+ "longtermrisk.org": element_extractor("div.entry-content"),
+ "lukemuehlhauser.com": element_extractor("div.entry-content"),
+ "medium.com": medium_blog,
+ "openai.com": element_extractor("#content"),
+ "ought.org": element_extractor("div.BlogPostBodyContainer"),
+ "sideways-view.com": element_extractor("article", remove=["header"]),
+ "slatestarcodex.com": element_extractor("div.pjgm-postcontent"),
+ "techpolicy.press": element_extractor(
+ "div.post-content",
+ remove=[
+ "div.before_content",
+ ".sabox-guest-authors-container",
+ ".jp-relatedposts",
+ ],
+ ),
+ "theconversation.com": element_extractor("div.content-body"),
+ "thegradient.pub": element_extractor("div.c-content"),
+ "towardsdatascience.com": medium_blog,
+ "unstableontology.com": element_extractor(
+ ".entry-content", remove=["div.sharedaddy"]
+ ),
+ "waitbutwhy.com": element_extractor("article", remove=[".entry-header"]),
+ "weightagnostic.github.io": element_extractor(
+ "dt-article", remove=["#authors_section", "dt-byline"]
+ ),
+ "cnas.org": element_extractor("#mainbar-toc"),
+ "econlib.org": element_extractor("div.post-content"),
+ "humanityplus.org": element_extractor("div.content"),
+ "gleech.org": element_extractor(
+ "article.post-content", remove=["center", "div.accordion"]
+ ),
+ "ibm.com": element_extractor("div:has(> p)"), # IBM's HTML is really ugly...
+ "microsoft.com": element_extractor("div.content-container"),
+ "mdpi.com": element_extractor(
+ "article",
+ remove=[
+ ".article-icons",
+ ".title",
+ ".art-authors",
+ ".art-affiliations",
+ ".bib-identity",
+ ".pubhistory",
+ ".belongsTo",
+ ".highlight-box1",
+ ".additional-content",
+ ],
+ ),
+ "nature.com": element_extractor(
+ "article", remove=["header", "#rights link-section", "#article-info-section"]
),
- 'nature.com': element_extractor('article', remove=['header', '#rights link-section', '#article-info-section']),
- 'ncbi.nlm.nih.gov': element_extractor('div.article'),
- 'openphilanthropy.org': element_extractor('div.pagenav-content'),
- 'safe.ai': element_extractor('#open-letter'),
- 'sciencedirect.com': element_extractor(
- 'article',
+ "ncbi.nlm.nih.gov": element_extractor("div.article"),
+ "openphilanthropy.org": element_extractor("div.pagenav-content"),
+ "safe.ai": element_extractor("#open-letter"),
+ "sciencedirect.com": element_extractor(
+ "article",
remove=[
- '#section-cited-by', '.Copyright', '.issue-navigation', '.ReferencedArticles',
- '.LicenseInfo', '.ArticleIdentifierLinks', '.Banner', '.screen-reader-main-title', '.Publication'
- ]
+ "#section-cited-by",
+ ".Copyright",
+ ".issue-navigation",
+ ".ReferencedArticles",
+ ".LicenseInfo",
+ ".ArticleIdentifierLinks",
+ ".Banner",
+ ".screen-reader-main-title",
+ ".Publication",
+ ],
),
- 'transformer-circuits.pub': error('not handled yet - same codebase as distill'),
- 'vox.com': element_extractor('did.c-entry-content', remove=['c-article-footer']),
- 'weforum.org': element_extractor('div.wef-0'),
- 'www6.inrae.fr': element_extractor('div.ArticleContent'),
- 'aleph.se': element_extractor('body'),
- 'yoshuabengio.org': element_extractor('div.post-content'),
+ "transformer-circuits.pub": error("not handled yet - same codebase as distill"),
+ "vox.com": element_extractor("did.c-entry-content", remove=["c-article-footer"]),
+ "weforum.org": element_extractor("div.wef-0"),
+ "www6.inrae.fr": element_extractor("div.ArticleContent"),
+ "aleph.se": element_extractor("body"),
+ "yoshuabengio.org": element_extractor("div.post-content"),
}
PDF_PARSERS = {
# Domain sepecific handlers
- 'apcz.umk.pl': get_pdf_from_page('.galleys_links a.pdf', 'a.download'),
- 'arxiv.org': get_arxiv_pdf,
- 'academic.oup.com': get_pdf_from_page('a.article-pdfLink'),
- 'cset.georgetown.edu': get_pdf_from_page('a:-soup-contains("Download Full")'),
- 'drive.google.com': extract_gdrive_contents,
- 'doi.org': doi_getter,
- 'dl.acm.org': fetch_pdf,
- 'dspace.mit.edu': get_pdf_from_page('a.btn-primary.download-button'),
- 'globalprioritiesinstitute.org': get_pdf_from_page('a:-soup-contains("PDF")'),
- 'link.springer.com': multistrategy(
- get_pdf_from_page('div.c-pdf-download a'),
+ "apcz.umk.pl": get_pdf_from_page(".galleys_links a.pdf", "a.download"),
+ "arxiv.org": get_arxiv_pdf,
+ "academic.oup.com": get_pdf_from_page("a.article-pdfLink"),
+ "cset.georgetown.edu": get_pdf_from_page('a:-soup-contains("Download Full")'),
+ "drive.google.com": extract_gdrive_contents,
+ "doi.org": doi_getter,
+ "dl.acm.org": fetch_pdf,
+ "dspace.mit.edu": get_pdf_from_page("a.btn-primary.download-button"),
+ "globalprioritiesinstitute.org": get_pdf_from_page('a:-soup-contains("PDF")'),
+ "link.springer.com": multistrategy(
+ get_pdf_from_page("div.c-pdf-download a"),
doi_getter,
),
- 'openaccess.thecvf.com': get_pdf_from_page('a:-soup-contains("pdf")'),
- 'openreview.net': get_pdf_from_page('a.note_content_pdf'),
- 'ora.ox.ac.uk': fetch_pdf,
- 'papers.nips.cc': get_pdf_from_page('a:-soup-contains("Paper")'),
- 'papers.ssrn.com': get_pdf_from_page('.abstract-buttons a.button-link:-soup-contains("Download")'),
- 'par.nsf.gov': get_pdf_from_page('a:-soup-contains("Accepted Manuscript")'),
- 'proceedings.neurips.cc': get_pdf_from_page('a:-soup-contains("Paper")'),
- 'psyarxiv.com': lambda url: fetch_pdf(url.rstrip('/') + '/download'),
- 'rowanzellers.com': get_pdf_from_page('main a:-soup-contains("Paper")'),
- 'governance.ai': get_pdf_from_page('a.read-paper-button:not([href="#"])'),
- 'ijcai.org': get_pdf_from_page('a.btn-download:-soup-contains("PDF")'),
- 'jair.org': get_pdf_from_page('div.download a.pdf', 'a.download'),
- 'jstor.org': doi_getter,
- 'ri.cmu.edu': get_pdf_from_page('a.pub-link'),
- 'risksciences.ucla.edu': get_pdf_from_page('a:-soup-contains("Download")'),
- 'ssrn.com': get_pdf_from_page('.abstract-buttons a.button-link:-soup-contains("Download")'),
- 'yjolt.org': get_pdf_from_page('span.file a'),
+ "openaccess.thecvf.com": get_pdf_from_page('a:-soup-contains("pdf")'),
+ "openreview.net": get_pdf_from_page("a.note_content_pdf"),
+ "ora.ox.ac.uk": fetch_pdf,
+ "papers.nips.cc": get_pdf_from_page('a:-soup-contains("Paper")'),
+ "papers.ssrn.com": get_pdf_from_page(
+ '.abstract-buttons a.button-link:-soup-contains("Download")'
+ ),
+ "par.nsf.gov": get_pdf_from_page('a:-soup-contains("Accepted Manuscript")'),
+ "proceedings.neurips.cc": get_pdf_from_page('a:-soup-contains("Paper")'),
+ "psyarxiv.com": lambda url: fetch_pdf(url.rstrip("/") + "/download"),
+ "rowanzellers.com": get_pdf_from_page('main a:-soup-contains("Paper")'),
+ "governance.ai": get_pdf_from_page('a.read-paper-button:not([href="#"])'),
+ "ijcai.org": get_pdf_from_page('a.btn-download:-soup-contains("PDF")'),
+ "jair.org": get_pdf_from_page("div.download a.pdf", "a.download"),
+ "jstor.org": doi_getter,
+ "ri.cmu.edu": get_pdf_from_page("a.pub-link"),
+ "risksciences.ucla.edu": get_pdf_from_page('a:-soup-contains("Download")'),
+ "ssrn.com": get_pdf_from_page(
+ '.abstract-buttons a.button-link:-soup-contains("Download")'
+ ),
+ "yjolt.org": get_pdf_from_page("span.file a"),
}
def item_metadata(url) -> Dict[str, str]:
- domain = urlparse(url).netloc.lstrip('www.')
- res = fetch(url, 'head')
- content_type = {item.strip() for item in res.headers.get('Content-Type').split(';')}
+ domain = urlparse(url).netloc.lstrip("www.")
+ res = fetch(url, "head")
+ content_type = {item.strip() for item in res.headers.get("Content-Type").split(";")}
- if content_type & {'text/html', 'text/xml'}:
+ if content_type & {"text/html", "text/xml"}:
# If the url points to a html webpage, then it either contains the text as html, or
# there is a link to a pdf on it
if parser := HTML_PARSERS.get(domain):
if res := parser(url):
# Proper contents were found on the page, so use them
- return {'source_url': url, 'data_source': 'html'}
+ return {"source_url": url, "data_source": "html"}
if parser := PDF_PARSERS.get(domain):
if res := parser(url):
@@ -278,17 +339,19 @@ def item_metadata(url) -> Dict[str, str]:
return res
if parser := UNIMPLEMENTED_PARSERS.get(domain):
- return {'error': parser(url)}
-
- if domain not in (HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()):
- return {'error': 'No domain handler defined'}
- return {'error': 'could not parse url'}
- elif content_type & {'application/octet-stream', 'application/pdf'}:
+ return {"error": parser(url)}
+
+ if domain not in (
+ HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()
+ ):
+ return {"error": "No domain handler defined"}
+ return {"error": "could not parse url"}
+ elif content_type & {"application/octet-stream", "application/pdf"}:
# this looks like it could be a pdf - try to download it as one
return fetch_pdf(url)
- elif content_type & {'application/epub+zip', 'application/epub'}:
+ elif content_type & {"application/epub+zip", "application/epub"}:
# it looks like an ebook. Assume it's fine.
# TODO: validate that the ebook is readable
- return {'source_url': url, 'data_source': 'ebook'}
+ return {"source_url": url, "data_source": "ebook"}
else:
- return {'error': f'Unhandled content type: {content_type}'}
+ return {"error": f"Unhandled content type: {content_type}"}
diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py
index f6a020ce..2120dc56 100644
--- a/align_data/sources/articles/pdf.py
+++ b/align_data/sources/articles/pdf.py
@@ -22,21 +22,21 @@ def sci_hub_pdf(identifier):
large file containing multiple articles, e.g. a whole journal or book, in which case this function
will ignore the result.
"""
- elem = fetch_element(f'https://sci-hub.st/{identifier}', 'embed')
+ elem = fetch_element(f"https://sci-hub.st/{identifier}", "embed")
if not elem:
return None
- src = elem.get('src').strip()
- if src.startswith('//'):
- src = 'https:' + src
- elif src.startswith('/'):
- src = f'https://sci-hub.st/{src}'
+ src = elem.get("src").strip()
+ if src.startswith("//"):
+ src = "https:" + src
+ elif src.startswith("/"):
+ src = f"https://sci-hub.st/{src}"
return src
def read_pdf(filename):
try:
pdf_reader = PdfReader(filename)
- return '\n'.join(page.extract_text() for page in pdf_reader.pages)
+ return "\n".join(page.extract_text() for page in pdf_reader.pages)
except PdfReadError as e:
logger.error(e)
return None
@@ -50,36 +50,45 @@ def fetch_pdf(link):
:returns: the contents of the pdf file as markdown."""
res = fetch(link)
if res.status_code >= 400:
- logger.error('Could not fetch the pdf file at %s - are you sure that link is correct?', link)
+ logger.error(
+ "Could not fetch the pdf file at %s - are you sure that link is correct?",
+ link,
+ )
- content_type = {c_type.strip().lower() for c_type in res.headers.get('Content-Type').split(';')}
- if not content_type & {'application/octet-stream', 'application/pdf'}:
+ content_type = {
+ c_type.strip().lower() for c_type in res.headers.get("Content-Type").split(";")
+ }
+ if not content_type & {"application/octet-stream", "application/pdf"}:
return {
- 'error': f'Wrong content type retrieved: {content_type} - {link}',
- 'contents': res.content,
+ "error": f"Wrong content type retrieved: {content_type} - {link}",
+ "contents": res.content,
}
try:
pdf_reader = PdfReader(io.BytesIO(res.content))
return {
- 'source_url': link,
- 'text': '\n'.join(page.extract_text() for page in pdf_reader.pages),
- 'data_source': 'pdf',
+ "source_url": link,
+ "text": "\n".join(page.extract_text() for page in pdf_reader.pages),
+ "data_source": "pdf",
}
except PdfReadError as e:
- logger.error('Could not read PDF file: %s', e)
- return {'error': str(e)}
+ logger.error("Could not read PDF file: %s", e)
+ return {"error": str(e)}
filenames = [
- i.strip().split('=')[1]
- for i in res.headers.get('Content-Disposition', '').split(';')
- if 'filename' in i
+ i.strip().split("=")[1]
+ for i in res.headers.get("Content-Disposition", "").split(";")
+ if "filename" in i
]
- if filenames and 'pdf' not in filenames[0].lower():
- logger.error('Are you sure %s points to a pdf file? The response says the file should be called %s', link, filenames[0])
- error = f'Probably bad file type: {filenames[0]} - {link}'
+ if filenames and "pdf" not in filenames[0].lower():
+ logger.error(
+ "Are you sure %s points to a pdf file? The response says the file should be called %s",
+ link,
+ filenames[0],
+ )
+ error = f"Probably bad file type: {filenames[0]} - {link}"
- return {'error': error}
+ return {"error": error}
def get_arxiv_link(doi):
@@ -88,14 +97,16 @@ def get_arxiv_link(doi):
if res.status_code != 200:
return None
- vals = [i for i in response.json().get('values') if i.get('type', '').upper() == 'URL']
+ vals = [
+ i for i in response.json().get("values") if i.get("type", "").upper() == "URL"
+ ]
if not vals:
return None
return vals[0]["data"]["value"].replace("/abs/", "/pdf/") + ".pdf"
def get_arxiv_pdf(link):
- return fetch_pdf(link.replace('/abs/', '/pdf/'))
+ return fetch_pdf(link.replace("/abs/", "/pdf/"))
def get_doi(doi):
@@ -104,23 +115,23 @@ def get_doi(doi):
This will look for it in sci-hub and arxiv (if applicable), as those are likely the most
comprehensive sources of pdfs.
"""
- if 'arXiv' in doi:
+ if "arXiv" in doi:
link = get_arxiv_link(doi)
- pdf = (link and fetch_pdf(link))
- if pdf and 'text' in pdf:
- pdf['downloaded_from'] = 'arxiv'
+ pdf = link and fetch_pdf(link)
+ if pdf and "text" in pdf:
+ pdf["downloaded_from"] = "arxiv"
return pdf
if link := sci_hub_pdf(doi):
if pdf := fetch_pdf(link):
- pdf['downloaded_from'] = 'scihub'
+ pdf["downloaded_from"] = "scihub"
return pdf
- return {'error': 'Could not find pdf of article by DOI'}
+ return {"error": "Could not find pdf of article by DOI"}
def doi_getter(url):
"""Extract the DOI from the given `url` and fetch the contents of its article."""
- return get_doi(urlparse(url).path.lstrip('/'))
+ return get_doi(urlparse(url).path.lstrip("/"))
def get_pdf_from_page(*link_selectors):
@@ -133,34 +144,38 @@ def get_pdf_from_page(*link_selectors):
:param List[str] link_selectors: CSS selector used to find the final download link
:returns: the contents of the pdf file as a string
"""
+
def getter(url):
link = url
for selector in link_selectors:
elem = fetch_element(link, selector)
if not elem:
- return {'error': f'Could not find pdf download link for {link} using \'{selector}\''}
+ return {
+ "error": f"Could not find pdf download link for {link} using '{selector}'"
+ }
- link = elem.get('href')
- if not link.startswith('http') or not link.startswith('//'):
+ link = elem.get("href")
+ if not link.startswith("http") or not link.startswith("//"):
link = urljoin(url, link)
# Some pages keep link to google drive previews of pdf files, which need to be
# mangled to get the URL of the actual pdf file
- if 'drive.google.com' in link and '/view' in link:
+ if "drive.google.com" in link and "/view" in link:
return extract_gdrive_contents(link)
if pdf := fetch_pdf(link):
return pdf
- return {'error': f'Could not fetch pdf from {link}'}
+ return {"error": f"Could not fetch pdf from {link}"}
+
return getter
def parse_vanity(url):
- contents = fetch_element(url, 'article')
+ contents = fetch_element(url, "article")
if not contents:
return None
- if title := contents.select_one('h1.ltx_title'):
+ if title := contents.select_one("h1.ltx_title"):
title = title.text
def get_first_child(item):
@@ -170,24 +185,28 @@ def get_first_child(item):
if not isinstance(child, str):
child = child.text
- return child.split(',')
+ return child.split(",")
authors = [
- a.strip() for item in contents.select('div.ltx_authors .ltx_personname') for a in get_first_child(item)
+ a.strip()
+ for item in contents.select("div.ltx_authors .ltx_personname")
+ for a in get_first_child(item)
]
- if date_published := contents.select_one('div.ltx_dates'):
- date_published = date_published.text.strip('()')
+ if date_published := contents.select_one("div.ltx_dates"):
+ date_published = date_published.text.strip("()")
- text = '\n\n'.join([
- MarkdownConverter().convert_soup(elem).strip()
- for elem in contents.select('section.ltx_section')
- ])
+ text = "\n\n".join(
+ [
+ MarkdownConverter().convert_soup(elem).strip()
+ for elem in contents.select("section.ltx_section")
+ ]
+ )
return {
- 'title': title,
- 'authors': authors,
- 'text': text,
- 'date_published': date_published,
- 'data_source': 'html',
+ "title": title,
+ "authors": authors,
+ "text": text,
+ "date_published": date_published,
+ "data_source": "html",
}
diff --git a/align_data/sources/arxiv_papers/__init__.py b/align_data/sources/arxiv_papers/__init__.py
index f9bc2080..29258480 100644
--- a/align_data/sources/arxiv_papers/__init__.py
+++ b/align_data/sources/arxiv_papers/__init__.py
@@ -3,7 +3,7 @@
ARXIV_REGISTRY = [
ArxivPapers(
name="arxiv",
- spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI',
- sheet_id='655836697'
+ spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI",
+ sheet_id="655836697",
)
]
diff --git a/align_data/sources/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py
index ae9b7cb9..1a61ecc6 100644
--- a/align_data/sources/arxiv_papers/arxiv_papers.py
+++ b/align_data/sources/arxiv_papers/arxiv_papers.py
@@ -11,7 +11,7 @@
@dataclass
class ArxivPapers(SpreadsheetDataset):
- summary_key: str = 'summary'
+ summary_key: str = "summary"
COOLDOWN: int = 1
done_key = "url"
batch_size = 1
@@ -28,45 +28,52 @@ def _get_arxiv_metadata(self, paper_id) -> arxiv.Result:
return None
def get_id(self, item):
- if res := re.search(r'https://arxiv.org/abs/(.*?)/?$', item.url):
+ if res := re.search(r"https://arxiv.org/abs/(.*?)/?$", item.url):
return res.group(1)
def get_contents(self, item) -> dict:
paper_id = self.get_id(item)
- for link in [f"https://www.arxiv-vanity.com/papers/{paper_id}", f"https://ar5iv.org/abs/{paper_id}"]:
+ for link in [
+ f"https://www.arxiv-vanity.com/papers/{paper_id}",
+ f"https://ar5iv.org/abs/{paper_id}",
+ ]:
if contents := parse_vanity(link):
return contents
- return fetch_pdf(f'https://arxiv.org/pdf/{paper_id}.pdf')
+ return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf")
def process_entry(self, item) -> None:
logger.info(f"Processing {item.title}")
paper = self.get_contents(item)
- if not paper or not paper.get('text'):
+ if not paper or not paper.get("text"):
return None
metadata = self._get_arxiv_metadata(self.get_id(item))
if self.is_val(item.authors) and item.authors.strip():
- authors = item.authors.split(',')
+ authors = item.authors.split(",")
elif metadata and metadata.authors:
authors = metadata.authors
else:
- authors = paper.get('authors') or []
+ authors = paper.get("authors") or []
authors = [str(a).strip() for a in authors]
- return self.make_data_entry({
- "url": self.get_item_key(item),
- "source": self.name,
- "source_type": paper['data_source'],
- "title": self.is_val(item.title) or paper.get('title'),
- "authors": authors,
- "date_published": self._get_published_date(self.is_val(item.date_published) or paper.get('date_published')),
- "data_last_modified": str(metadata.updated),
- "summary": metadata.summary.replace("\n", " "),
- "author_comment": metadata.comment,
- "journal_ref": metadata.journal_ref,
- "doi": metadata.doi,
- "primary_category": metadata.primary_category,
- "categories": metadata.categories,
- "text": paper['text'],
- })
+ return self.make_data_entry(
+ {
+ "url": self.get_item_key(item),
+ "source": self.name,
+ "source_type": paper["data_source"],
+ "title": self.is_val(item.title) or paper.get("title"),
+ "authors": authors,
+ "date_published": self._get_published_date(
+ self.is_val(item.date_published) or paper.get("date_published")
+ ),
+ "data_last_modified": str(metadata.updated),
+ "summary": metadata.summary.replace("\n", " "),
+ "author_comment": metadata.comment,
+ "journal_ref": metadata.journal_ref,
+ "doi": metadata.doi,
+ "primary_category": metadata.primary_category,
+ "categories": metadata.categories,
+ "text": paper["text"],
+ }
+ )
diff --git a/align_data/sources/blogs/__init__.py b/align_data/sources/blogs/__init__.py
index 7021c994..ed55dc81 100644
--- a/align_data/sources/blogs/__init__.py
+++ b/align_data/sources/blogs/__init__.py
@@ -2,7 +2,12 @@
from align_data.sources.blogs.medium_blog import MediumBlog
from align_data.sources.blogs.gwern_blog import GwernBlog
from align_data.sources.blogs.blogs import (
- ColdTakes, GenerativeInk, CaradoMoe, EleutherAI, OpenAIResearch, DeepMindTechnicalBlog
+ ColdTakes,
+ GenerativeInk,
+ CaradoMoe,
+ EleutherAI,
+ OpenAIResearch,
+ DeepMindTechnicalBlog,
)
from align_data.sources.blogs.substack_blog import SubstackBlog
@@ -14,34 +19,43 @@
WordpressBlog(name="jsteinhardt_blog", url="https://jsteinhardt.wordpress.com"),
WordpressBlog(name="vkrakovna_blog", url="https://vkrakovna.wordpress.com"),
WordpressBlog(name="yudkowsky_blog", url="https://yudkowsky.net"),
- MediumBlog(name="deepmind_blog", url="https://deepmindsafetyresearch.medium.com/", authors=["DeepMind Safety Research"]),
- GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]),
+ MediumBlog(
+ name="deepmind_blog",
+ url="https://deepmindsafetyresearch.medium.com/",
+ authors=["DeepMind Safety Research"],
+ ),
+ GwernBlog(
+ name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
+ ),
ColdTakes(
name="cold_takes",
url="https://www.cold-takes.com/",
- authors=['Holden Karnofsky'],
+ authors=["Holden Karnofsky"],
),
GenerativeInk(
name="generative.ink",
url="https://generative.ink/posts/",
- authors=['janus'],
+ authors=["janus"],
),
CaradoMoe(
name="carado.moe",
- url='https://carado.moe',
- authors=['Tamsin Leake'],
+ url="https://carado.moe",
+ authors=["Tamsin Leake"],
),
SubstackBlog(
name="importai",
url="https://importai.substack.com",
- id_fields=['url', 'title', 'source']
+ id_fields=["url", "title", "source"],
),
SubstackBlog(
name="ml_safety_newsletter",
url="https://newsletter.mlsafety.org",
- id_fields=['url', 'title', 'source']
+ id_fields=["url", "title", "source"],
+ ),
+ EleutherAI(name="eleuther.ai", url="https://blog.eleuther.ai/"),
+ OpenAIResearch(name="openai.research", url="https://openai.com/research"),
+ DeepMindTechnicalBlog(
+ name="deepmind_technical_blog",
+ url="https://www.deepmind.com/blog-categories/technical-blogs",
),
- EleutherAI(name='eleuther.ai', url='https://blog.eleuther.ai/'),
- OpenAIResearch(name='openai.research', url='https://openai.com/research'),
- DeepMindTechnicalBlog(name='deepmind_technical_blog', url='https://www.deepmind.com/blog-categories/technical-blogs'),
]
diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py
index 65ad95e6..ae4439a9 100644
--- a/align_data/sources/blogs/blogs.py
+++ b/align_data/sources/blogs/blogs.py
@@ -9,82 +9,89 @@
logger = logging.getLogger(__name__)
+
class ColdTakes(HTMLDataset):
- item_selector = 'div.post-feed article'
+ item_selector = "div.post-feed article"
- ignored_selectors = ['center', 'div[style*="display:flex"]', 'footer']
+ ignored_selectors = ["center", 'div[style*="display:flex"]', "footer"]
def _get_published_date(self, contents):
- header = contents.select_one('article header').extract()
- date = header.find('time').get('datetime')
+ header = contents.select_one("article header").extract()
+ date = header.find("time").get("datetime")
return super()._get_published_date(date)
class GenerativeInk(HTMLDataset):
- item_selector = 'div.post.on-list'
+ item_selector = "div.post.on-list"
def _get_published_date(self, contents):
possible_date_elements = [
- elem for info in contents.select('div.post-info')
- for elem in info.children
+ elem for info in contents.select("div.post-info") for elem in info.children
]
return self._find_date(possible_date_elements)
class CaradoMoe(RSSDataset):
def _get_text(self, item):
- contents = item['soup']
- meta = contents.find('p', {'class': 'postmeta'})
- return self._extract_markdown(meta.find_next_sibling('div'))
+ contents = item["soup"]
+ meta = contents.find("p", {"class": "postmeta"})
+ return self._extract_markdown(meta.find_next_sibling("div"))
class EleutherAI(HTMLDataset):
-
- item_selector = 'div.archive-entry'
- text_selector = 'div.post-content'
+ item_selector = "div.archive-entry"
+ text_selector = "div.post-content"
def _get_published_date(self, contents):
try:
- date = contents.select_one('header .post-meta').text.split('·')[0].strip()
+ date = contents.select_one("header .post-meta").text.split("·")[0].strip()
return super()._get_published_date(date)
except (ValueError, ParserError):
- return ''
+ return ""
def extract_authors(self, article):
- return article.select_one('header .post-meta').text.split('·')[1].strip().split(', ')
+ return (
+ article.select_one("header .post-meta")
+ .text.split("·")[1]
+ .strip()
+ .split(", ")
+ )
class OpenAIResearch(HTMLDataset):
-
- item_selector = 'li.group-item'
- title_selector = '.container h1'
+ item_selector = "li.group-item"
+ title_selector = ".container h1"
def _get_published_date(self, contents):
- if date := contents.select_one('.container .f-meta-2'):
+ if date := contents.select_one(".container .f-meta-2"):
return super()._get_published_date(date.text)
- return ''
+ return ""
def _get_text(self, contents):
- if paper_link := contents.select_one('.container .cols-container a.ui-link:-soup-contains("Read paper")'):
- return item_metadata(paper_link.get('href')).get('text')
+ if paper_link := contents.select_one(
+ '.container .cols-container a.ui-link:-soup-contains("Read paper")'
+ ):
+ return item_metadata(paper_link.get("href")).get("text")
def extract_authors(self, article):
- authors = (
- article.select_one('div:-soup-contains("Authors") + div .f-body-1') or
- article.select_one('div:-soup-contains("Acknowledgments") + div .f-body-1')
- )
+ authors = article.select_one(
+ 'div:-soup-contains("Authors") + div .f-body-1'
+ ) or article.select_one('div:-soup-contains("Acknowledgments") + div .f-body-1')
if not authors:
return []
- return [i.split('(')[0].strip() for i in authors.select_one('p').children if not i.name]
+ return [
+ i.split("(")[0].strip()
+ for i in authors.select_one("p").children
+ if not i.name
+ ]
class DeepMindTechnicalBlog(HTMLDataset):
-
- item_selector = 'div.w-dyn-item .c_card_list__item__blog'
- title_selector = '.c_banner__blog__card h2'
- text_selector = '.c_rich-text__cms'
- ignored_selectors = ['.article-gtag-buttons']
+ item_selector = "div.w-dyn-item .c_card_list__item__blog"
+ title_selector = ".c_banner__blog__card h2"
+ text_selector = ".c_rich-text__cms"
+ ignored_selectors = [".article-gtag-buttons"]
@property
def items_list(self):
@@ -93,7 +100,9 @@ def items_list(self):
with tqdm(desc=f"Loading {self.name} pages") as pbar:
while True:
logger.info(f"Fetching entries from {self.url}")
- response = requests.get(self.url, allow_redirects=True, params={'73df3071_page': page})
+ response = requests.get(
+ self.url, allow_redirects=True, params={"73df3071_page": page}
+ )
soup = BeautifulSoup(response.content, "html.parser")
items = soup.select(self.item_selector)
if not items:
@@ -103,18 +112,22 @@ def items_list(self):
page += 1
# update the tqdm progress bar
- pbar.set_postfix_str(f"page {page}", refresh=True) # Set postfix to "page X"
+ pbar.set_postfix_str(
+ f"page {page}", refresh=True
+ ) # Set postfix to "page X"
pbar.update() # Here we increment the progress bar by 1
- logger.info('Got %s pages', len(articles))
+ logger.info("Got %s pages", len(articles))
return articles
def _get_published_date(self, contents):
- if date := contents.select_one('.c_banner__blog__card__meta'):
+ if date := contents.select_one(".c_banner__blog__card__meta"):
return super()._get_published_date(date.text)
- return ''
+ return ""
def extract_authors(self, article):
- if div := article.select_one('.c_cms_content__meta__wrapper div:-soup-contains("Authors") + div'):
- return [author.strip() for author in div.text.split(',')]
+ if div := article.select_one(
+ '.c_cms_content__meta__wrapper div:-soup-contains("Authors") + div'
+ ):
+ return [author.strip() for author in div.text.split(",")]
return []
diff --git a/align_data/sources/blogs/gwern_blog.py b/align_data/sources/blogs/gwern_blog.py
index 325bd7d3..9328d874 100644
--- a/align_data/sources/blogs/gwern_blog.py
+++ b/align_data/sources/blogs/gwern_blog.py
@@ -22,48 +22,50 @@ def get_item_key(self, item):
@property
def items_list(self):
return [
- 'https://www.gwern.net/Scaling-hypothesis.page',
- 'https://www.gwern.net/Tanks.page',
- 'https://www.gwern.net/Clippy.page',
- 'https://www.gwern.net/complexity.page',
- 'https://www.gwern.net/Tool-AI.page',
- 'https://www.gwern.net/Backstop.page',
- 'https://www.gwern.net/Hyperbolic-Time-Chamber.page'
+ "https://www.gwern.net/Scaling-hypothesis.page",
+ "https://www.gwern.net/Tanks.page",
+ "https://www.gwern.net/Clippy.page",
+ "https://www.gwern.net/complexity.page",
+ "https://www.gwern.net/Tool-AI.page",
+ "https://www.gwern.net/Backstop.page",
+ "https://www.gwern.net/Hyperbolic-Time-Chamber.page",
]
def process_entry(self, post_href):
article = self._get_article(post_href)
if article.status_code != 200:
- logger.error(f'Could not fetch {post_href}')
+ logger.error(f"Could not fetch {post_href}")
return None
# Some pages are returned as markdown, some as HTML, so handle both
- if 'text/html' in article.headers.get('Content-Type', ''):
+ if "text/html" in article.headers.get("Content-Type", ""):
return super().process_entry(post_href)
return self._process_markdown(post_href, article)
def _process_markdown(self, post_href, article):
- parts = article.text.split('...')
+ parts = article.text.split("...")
metadata = self._get_metadata(parts[0])
- text = self._extract_markdown('...'.join(parts[1:]))
-
- return self.make_data_entry({
- "source": self.name,
- "source_type": self.source_type,
- "url": post_href,
- "title": metadata.get('title'),
- "authors": self.authors,
- "date_published": self._get_published_date(metadata),
- "text": text,
- })
+ text = self._extract_markdown("...".join(parts[1:]))
+
+ return self.make_data_entry(
+ {
+ "source": self.name,
+ "source_type": self.source_type,
+ "url": post_href,
+ "title": metadata.get("title"),
+ "authors": self.authors,
+ "date_published": self._get_published_date(metadata),
+ "text": text,
+ }
+ )
@staticmethod
def _get_metadata(header):
def extract(item):
- parts = item.split(': ')
+ parts = item.split(": ")
if len(parts) > 1:
- return (parts[0].strip(), ': '.join(parts[1:]))
+ return (parts[0].strip(), ": ".join(parts[1:]))
return None
return dict(filter(None, map(extract, header.splitlines())))
@@ -74,17 +76,17 @@ def _get_article(self, url):
@staticmethod
def _get_title(contents):
- return contents.find('header').find('h1').text
+ return contents.find("header").find("h1").text
def _get_published_date(self, contents):
if isinstance(contents, dict):
- date_published = contents.get('modified') or contents.get('created')
+ date_published = contents.get("modified") or contents.get("created")
else:
date_published = (
- contents.select_one('.page-date-range .page-modified') or
- contents.select_one('.page-date-range .page-created')
+ contents.select_one(".page-date-range .page-modified")
+ or contents.select_one(".page-date-range .page-created")
).text.strip()
return super()._get_published_date(date_published)
def _get_text(self, contents):
- return self._extract_markdown(contents.select_one('div#markdownBody'))
+ return self._extract_markdown(contents.select_one("div#markdownBody"))
diff --git a/align_data/sources/blogs/medium_blog.py b/align_data/sources/blogs/medium_blog.py
index 9d57e2ab..5d80dfee 100644
--- a/align_data/sources/blogs/medium_blog.py
+++ b/align_data/sources/blogs/medium_blog.py
@@ -5,6 +5,7 @@
logger = logging.getLogger(__name__)
+
@dataclass
class MediumBlog(HTMLDataset):
"""
@@ -27,8 +28,8 @@ class MediumBlog(HTMLDataset):
"""
source_type = "medium_blog"
- ignored_selectors = ['div:first-child span']
+ ignored_selectors = ["div:first-child span"]
def _get_published_date(self, contents):
- possible_date_elements = contents.select('article div:first-child span')
+ possible_date_elements = contents.select("article div:first-child span")
return self._find_date(possible_date_elements)
diff --git a/align_data/sources/blogs/substack_blog.py b/align_data/sources/blogs/substack_blog.py
index ec6aa481..526bbb4f 100644
--- a/align_data/sources/blogs/substack_blog.py
+++ b/align_data/sources/blogs/substack_blog.py
@@ -3,8 +3,8 @@
class SubstackBlog(RSSDataset):
source_type = "substack"
- date_format = '%a, %d %b %Y %H:%M:%S %Z'
+ date_format = "%a, %d %b %Y %H:%M:%S %Z"
@property
def feed_url(self):
- return self.url + '/feed'
+ return self.url + "/feed"
diff --git a/align_data/sources/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py
index 197cc078..c0132301 100644
--- a/align_data/sources/blogs/wp_blog.py
+++ b/align_data/sources/blogs/wp_blog.py
@@ -11,7 +11,7 @@
@dataclass
class WordpressBlog(RSSDataset):
- summary_key = 'summary'
+ summary_key = "summary"
@property
def feed_url(self):
@@ -31,19 +31,21 @@ def items_list(self):
logging.info(f"Fetching {paged_url}")
feed = feedparser.parse(paged_url)
- title = feed.get('feed', {}).get('title')
+ title = feed.get("feed", {}).get("title")
if not title or title == prev_title:
break
prev_title = feed["feed"]["title"]
page_number += 1
- for item in feed['entries']:
- self.items[item['link']] = item
+ for item in feed["entries"]:
+ self.items[item["link"]] = item
# update the tqdm progress bar
- pbar.set_postfix_str(f"page {page_number}", refresh=True) # Set postfix to "page X"
+ pbar.set_postfix_str(
+ f"page {page_number}", refresh=True
+ ) # Set postfix to "page X"
pbar.update() # Here we increment the progress bar by 1
- logger.info(f'Got {len(self.items)} pages')
+ logger.info(f"Got {len(self.items)} pages")
return list(self.items.keys())
diff --git a/align_data/sources/distill/__init__.py b/align_data/sources/distill/__init__.py
index 80e40f24..66684922 100644
--- a/align_data/sources/distill/__init__.py
+++ b/align_data/sources/distill/__init__.py
@@ -3,7 +3,7 @@
DISTILL_REGISTRY = [
Distill(
- name = "distill",
- url='https://distill.pub',
+ name="distill",
+ url="https://distill.pub",
),
]
diff --git a/align_data/sources/distill/distill.py b/align_data/sources/distill/distill.py
index 4a9ea388..f54fb554 100644
--- a/align_data/sources/distill/distill.py
+++ b/align_data/sources/distill/distill.py
@@ -5,29 +5,30 @@
@dataclass
class Distill(RSSDataset):
- source_type = 'html'
- done_key = 'url'
- summary_key = 'summary'
+ source_type = "html"
+ done_key = "url"
+ summary_key = "summary"
def extract_authors(self, item):
- return [a.text for a in item['soup'].select('.authors-affiliations p.author a')]
+ return [a.text for a in item["soup"].select(".authors-affiliations p.author a")]
def _get_text(self, item):
- article = item['soup'].find('d-article') or item['soup'].find('dt-article')
+ article = item["soup"].find("d-article") or item["soup"].find("dt-article")
return self._extract_markdown(article)
def _extra_values(self, item):
- soup = item['soup']
+ soup = item["soup"]
- doi_elem = soup.find('h3', string='DOI')
- doi_elem = doi_elem and doi_elem.find_next_sibling('p')
+ doi_elem = soup.find("h3", string="DOI")
+ doi_elem = doi_elem and doi_elem.find_next_sibling("p")
return {
- 'doi': doi_elem and doi_elem.text,
- 'summary': item['summary'],
- 'journal_ref': 'distill-pub',
- 'bibliography': [
- {'title': el.find('span').text, 'link': el.find('a').get('href')}
- for el in soup.select('.references li') if el.find('a')
- ]
+ "doi": doi_elem and doi_elem.text,
+ "summary": item["summary"],
+ "journal_ref": "distill-pub",
+ "bibliography": [
+ {"title": el.find("span").text, "link": el.find("a").get("href")}
+ for el in soup.select(".references li")
+ if el.find("a")
+ ],
}
diff --git a/align_data/sources/ebooks/__init__.py b/align_data/sources/ebooks/__init__.py
index 7334c938..0055f5e0 100644
--- a/align_data/sources/ebooks/__init__.py
+++ b/align_data/sources/ebooks/__init__.py
@@ -2,7 +2,6 @@
EBOOK_REGISTRY = [
AgentModels(
- name='agentmodels',
- repo='https://github.com/agentmodels/agentmodels.org.git'
+ name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git"
),
]
diff --git a/align_data/sources/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py
index ee9593bb..cfd68a79 100644
--- a/align_data/sources/ebooks/agentmodels.py
+++ b/align_data/sources/ebooks/agentmodels.py
@@ -6,6 +6,7 @@
logger = logging.getLogger(__name__)
+
@dataclass
class AgentModels(AlignmentDataset):
"""
@@ -13,30 +14,39 @@ class AgentModels(AlignmentDataset):
John Salvatier, and Daniel Filan as .md from GitHub
"""
- repo: str = 'https://github.com/agentmodels/agentmodels.org.git'
+ repo: str = "https://github.com/agentmodels/agentmodels.org.git"
done_key = "filename"
def setup(self):
super().setup()
- self.base_dir = self.raw_data_path / 'agentmodels.org'
- if not self.base_dir.exists() or not list(self.base_dir.glob('*')):
+ self.base_dir = self.raw_data_path / "agentmodels.org"
+ if not self.base_dir.exists() or not list(self.base_dir.glob("*")):
logger.info("Cloning repo")
Repo.clone_from(self.repo, self.base_dir)
self.repository = Repo(self.base_dir)
- self.files_path = self.base_dir / 'chapters'
+ self.files_path = self.base_dir / "chapters"
def _get_published_date(self, filename):
- last_commit = next(self.repository.iter_commits(paths=f'chapters/{filename.name}'))
+ last_commit = next(
+ self.repository.iter_commits(paths=f"chapters/{filename.name}")
+ )
return last_commit.committed_datetime.astimezone(timezone.utc)
def process_entry(self, filename):
- return self.make_data_entry({
- 'source': self.name,
- 'source_type': 'markdown',
- 'authors': ['Owain Evans', 'Andreas Stuhlmüller', 'John Salvatier', 'Daniel Filan'],
- 'date_published': self._get_published_date(filename),
- 'title': 'Modeling Agents with Probabilistic Programs',
- 'url': f'https://agentmodels.org/chapters/{filename.stem}.html',
- 'filename': filename.name,
- 'text': filename.read_text(encoding='utf-8'),
- })
+ return self.make_data_entry(
+ {
+ "source": self.name,
+ "source_type": "markdown",
+ "authors": [
+ "Owain Evans",
+ "Andreas Stuhlmüller",
+ "John Salvatier",
+ "Daniel Filan",
+ ],
+ "date_published": self._get_published_date(filename),
+ "title": "Modeling Agents with Probabilistic Programs",
+ "url": f"https://agentmodels.org/chapters/{filename.stem}.html",
+ "filename": filename.name,
+ "text": filename.read_text(encoding="utf-8"),
+ }
+ )
diff --git a/align_data/sources/greaterwrong/__init__.py b/align_data/sources/greaterwrong/__init__.py
index 8f4a079c..300fd00b 100644
--- a/align_data/sources/greaterwrong/__init__.py
+++ b/align_data/sources/greaterwrong/__init__.py
@@ -3,23 +3,23 @@
GREATERWRONG_REGISTRY = [
GreaterWrong(
name="lesswrong",
- base_url='https://www.lesswrong.com',
+ base_url="https://www.lesswrong.com",
start_year=2005,
min_karma=1,
af=False,
),
GreaterWrong(
name="alignmentforum",
- base_url='https://www.alignmentforum.org',
+ base_url="https://www.alignmentforum.org",
start_year=2009,
min_karma=1,
af=True,
),
GreaterWrong(
name="eaforum",
- base_url='https://forum.effectivealtruism.org',
+ base_url="https://forum.effectivealtruism.org",
start_year=2011,
min_karma=1,
af=False,
- )
+ ),
]
diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py
index f00e9d06..f746e552 100644
--- a/align_data/sources/greaterwrong/greaterwrong.py
+++ b/align_data/sources/greaterwrong/greaterwrong.py
@@ -16,33 +16,37 @@
def fetch_LW_tags(url):
res = requests.get(
- url + '/tag/ai',
- headers={'User-Agent': 'Mozilla /5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0'},
+ url + "/tag/ai",
+ headers={
+ "User-Agent": "Mozilla /5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0"
+ },
)
soup = BeautifulSoup(res.content, "html.parser")
- tags = soup.select('div.TagPage-description .table a')
- return {a.text.strip() for a in tags if '/tag/' in a.get('href')}
+ tags = soup.select("div.TagPage-description .table a")
+ return {a.text.strip() for a in tags if "/tag/" in a.get("href")}
def fetch_ea_forum_topics(url):
- res = requests.get(url + '/topics/ai-safety')
+ res = requests.get(url + "/topics/ai-safety")
soup = BeautifulSoup(res.content, "html.parser")
- links = soup.select('div.SidebarSubtagsBox-root a')
- return {a.text.strip() for a in links if '/topics/' in a.get('href', '')}
+ links = soup.select("div.SidebarSubtagsBox-root a")
+ return {a.text.strip() for a in links if "/topics/" in a.get("href", "")}
def get_allowed_tags(url, name):
- if name == 'alignmentforum':
+ if name == "alignmentforum":
return set()
try:
- if name == 'lesswrong':
+ if name == "lesswrong":
return fetch_LW_tags(url)
- if name == 'eaforum':
+ if name == "eaforum":
return fetch_ea_forum_topics(url)
except Exception:
- raise ValueError('Could not fetch tags! Please retry')
+ raise ValueError("Could not fetch tags! Please retry")
- raise ValueError(f'Could not fetch tags for unknown datasource: "{name}". Must be one of alignmentforum|lesswrong|eaforum')
+ raise ValueError(
+ f'Could not fetch tags for unknown datasource: "{name}". Must be one of alignmentforum|lesswrong|eaforum'
+ )
@dataclass
@@ -61,8 +65,8 @@ class GreaterWrong(AlignmentDataset):
"""Whether alignment forum posts should be returned"""
limit = 50
- COOLDOWN_TIME : float = 0.5
- summary_key: str = 'summary'
+ COOLDOWN_TIME: float = 0.5
+ summary_key: str = "summary"
done_key = "url"
lazy_eval = True
@@ -73,26 +77,30 @@ def setup(self):
self.ai_tags = get_allowed_tags(self.base_url, self.name)
def tags_ok(self, post):
- return not self.ai_tags or {t['name'] for t in post['tags'] if t.get('name')} & self.ai_tags
+ return (
+ not self.ai_tags
+ or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags
+ )
def get_item_key(self, item):
- return item['pageUrl']
+ return item["pageUrl"]
def _get_published_date(self, item):
- return super()._get_published_date(item.get('postedAt'))
+ return super()._get_published_date(item.get("postedAt"))
def make_query(self, after: str):
- return """{
+ return (
+ """{
posts(input: {
terms: {
excludeEvents: true
view: "old"
- """ \
- f" af: {self.af}\n" \
- f" limit: {self.limit}\n" \
- f" karmaThreshold: {self.min_karma}\n" \
- f' after: "{after}"\n' \
- """ filter: "tagged"
+ """
+ f" af: {self.af}\n"
+ f" limit: {self.limit}\n"
+ f" karmaThreshold: {self.min_karma}\n"
+ f' after: "{after}"\n'
+ """ filter: "tagged"
}
}) {
totalCount
@@ -123,60 +131,65 @@ def make_query(self, after: str):
}
}
}"""
+ )
def fetch_posts(self, query: str):
res = requests.post(
- f'{self.base_url}/graphql',
+ f"{self.base_url}/graphql",
# The GraphQL endpoint returns a 403 if the user agent isn't set... Makes sense, but is annoying
- headers={'User-Agent': 'Mozilla /5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0'},
- json={'query': query}
+ headers={
+ "User-Agent": "Mozilla /5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0"
+ },
+ json={"query": query},
)
- return res.json()['data']['posts']
+ return res.json()["data"]["posts"]
@property
def last_date_published(self):
try:
prev_item = next(self.read_entries(sort_by=Article.date_published.desc()))
if prev_item and prev_item.date_published:
- return prev_item.date_published.isoformat() + 'Z'
+ return prev_item.date_published.isoformat() + "Z"
except StopIteration:
pass
- return datetime(self.start_year, 1, 1).isoformat() + 'Z'
+ return datetime(self.start_year, 1, 1).isoformat() + "Z"
@property
def items_list(self):
next_date = self.last_date_published
- logger.info('Starting from %s', next_date)
+ logger.info("Starting from %s", next_date)
while next_date:
posts = self.fetch_posts(self.make_query(next_date))
- if not posts['results']:
+ if not posts["results"]:
return
- for post in posts['results']:
- if post['htmlBody'] and self.tags_ok(post):
+ for post in posts["results"]:
+ if post["htmlBody"] and self.tags_ok(post):
yield post
- next_date = posts['results'][-1]['postedAt']
+ next_date = posts["results"][-1]["postedAt"]
time.sleep(self.COOLDOWN)
def process_entry(self, item):
- authors = item['coauthors']
- if item['user']:
- authors = [item['user']] + authors
- authors = [a['displayName'] for a in authors]
- return self.make_data_entry({
- 'title': item['title'],
- 'text': markdownify(item['htmlBody']).strip(),
- 'url': item['pageUrl'],
- 'date_published': self._get_published_date(item),
- 'modified_at': item['modifiedAt'],
- "source": self.name,
- "source_type": "GreaterWrong",
- 'votes': item['voteCount'],
- 'karma': item['baseScore'],
- 'tags': [t['name'] for t in item['tags']],
- 'words': item['wordCount'],
- 'comment_count': item['commentCount'],
- # Some posts don't have authors, for some reaason
- 'authors': authors,
- })
+ authors = item["coauthors"]
+ if item["user"]:
+ authors = [item["user"]] + authors
+ authors = [a["displayName"] for a in authors]
+ return self.make_data_entry(
+ {
+ "title": item["title"],
+ "text": markdownify(item["htmlBody"]).strip(),
+ "url": item["pageUrl"],
+ "date_published": self._get_published_date(item),
+ "modified_at": item["modifiedAt"],
+ "source": self.name,
+ "source_type": "GreaterWrong",
+ "votes": item["voteCount"],
+ "karma": item["baseScore"],
+ "tags": [t["name"] for t in item["tags"]],
+ "words": item["wordCount"],
+ "comment_count": item["commentCount"],
+ # Some posts don't have authors, for some reaason
+ "authors": authors,
+ }
+ )
diff --git a/align_data/sources/stampy/__init__.py b/align_data/sources/stampy/__init__.py
index a2a31645..cb0ce2d8 100644
--- a/align_data/sources/stampy/__init__.py
+++ b/align_data/sources/stampy/__init__.py
@@ -1,5 +1,5 @@
from .stampy import Stampy
STAMPY_REGISTRY = [
- Stampy(name='aisafety.info', id_fields=['url']),
+ Stampy(name="aisafety.info", id_fields=["url"]),
]
diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py
index 025e4658..88a7149e 100644
--- a/align_data/sources/stampy/stampy.py
+++ b/align_data/sources/stampy/stampy.py
@@ -16,12 +16,13 @@
@dataclass
class Stampy(AlignmentDataset):
-
done_key = "title"
def setup(self):
if not CODA_TOKEN:
- print(f'No CODA_TOKEN found! Please provide a valid Read token for the {CODA_DOC_ID} table')
+ print(
+ f"No CODA_TOKEN found! Please provide a valid Read token for the {CODA_DOC_ID} table"
+ )
sys.exit(1)
super().setup()
@@ -30,34 +31,40 @@ def setup(self):
def items_list(self):
coda = Coda(CODA_TOKEN)
doc = Document(CODA_DOC_ID, coda=coda)
- logger.info('Fetching table: %s', CODA_DOC_ID)
+ logger.info("Fetching table: %s", CODA_DOC_ID)
table = doc.get_table(ON_SITE_TABLE)
- return table.to_dict() # a list of dicts
+ return table.to_dict() # a list of dicts
def get_item_key(self, entry):
- return html.unescape(entry['Question'])
+ return html.unescape(entry["Question"])
def _get_published_date(self, entry):
- date_published = entry['Doc Last Edited']
+ date_published = entry["Doc Last Edited"]
return super()._get_published_date(date_published)
def process_entry(self, entry):
def clean_text(text):
text = html.unescape(text)
- return re.sub(r'\(/\?state=(\w+)\)', r'(http://aisafety.info?state=\1)', text)
+ return re.sub(
+ r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text
+ )
- question = clean_text(entry['Question']) # raise an error if the entry has no question
- answer = clean_text(entry['Rich Text'])
- url = 'https://aisafety.info?state=' + entry['UI ID']
+ question = clean_text(
+ entry["Question"]
+ ) # raise an error if the entry has no question
+ answer = clean_text(entry["Rich Text"])
+ url = "https://aisafety.info?state=" + entry["UI ID"]
logger.info(f"Processing {question}")
- return self.make_data_entry({
- "source": self.name,
- "source_type": "markdown",
- "url": url,
- "title": question,
- "authors": ['Stampy aisafety.info'],
- "date_published": self._get_published_date(entry),
- "text": answer,
- })
+ return self.make_data_entry(
+ {
+ "source": self.name,
+ "source_type": "markdown",
+ "url": url,
+ "title": question,
+ "authors": ["Stampy aisafety.info"],
+ "date_published": self._get_published_date(entry),
+ "text": answer,
+ }
+ )
diff --git a/align_data/sources/youtube/__init__.py b/align_data/sources/youtube/__init__.py
index fd393cc5..06c8defe 100644
--- a/align_data/sources/youtube/__init__.py
+++ b/align_data/sources/youtube/__init__.py
@@ -1,39 +1,42 @@
-from align_data.sources.youtube.youtube import YouTubeChannelDataset, YouTubePlaylistDataset
+from align_data.sources.youtube.youtube import (
+ YouTubeChannelDataset,
+ YouTubePlaylistDataset,
+)
YOUTUBE_REGISTRY = [
YouTubeChannelDataset(
- name='rob_miles_ai_safety',
- channel_id='UCLB7AzTwc6VFZrBsO2ucBMg',
- authors=['Rob Miles'],
+ name="rob_miles_ai_safety",
+ channel_id="UCLB7AzTwc6VFZrBsO2ucBMg",
+ authors=["Rob Miles"],
),
YouTubeChannelDataset(
- name='ai_safety_talks',
- channel_id='UCXowyqjXvFS-tMKF1GwhpkA',
- authors=['Evan Hubinger'],
+ name="ai_safety_talks",
+ channel_id="UCXowyqjXvFS-tMKF1GwhpkA",
+ authors=["Evan Hubinger"],
),
YouTubeChannelDataset(
- name='ai_safety_reading_group',
- channel_id='UC-C23F-9rK2gtRiJZMWsTzQ',
+ name="ai_safety_reading_group",
+ channel_id="UC-C23F-9rK2gtRiJZMWsTzQ",
authors=[],
),
YouTubeChannelDataset(
- name='ai_tech_tu_delft',
- channel_id='UCPK-Ell2WYxyfP5UYzRzjAA',
+ name="ai_tech_tu_delft",
+ channel_id="UCPK-Ell2WYxyfP5UYzRzjAA",
authors=[],
),
YouTubeChannelDataset(
- name='ai_explained',
- channel_id='UCNJ1Ymd5yFuUPtn21xtRbbw',
+ name="ai_explained",
+ channel_id="UCNJ1Ymd5yFuUPtn21xtRbbw",
authors=[],
),
YouTubePlaylistDataset(
- name='ai_alignment_playlist',
+ name="ai_alignment_playlist",
playlist_ids=[
- 'PLqYmG7hTraZCRwoyGxvQkqVrZgDQi4m-5',
- 'PLqYmG7hTraZBiUr6_Qf8YTS2Oqy3OGZEj',
- 'PLAPVC5uNprwY0q4_nyeeHqIT07wZqwjGO',
- 'PLCRVRLd2RhZTpdUdEzJjo3qhmX3y3skWA',
- 'PLTYHZYmxohXpn5uf8JZ2OouB1PsDJAk-x',
- ]
+ "PLqYmG7hTraZCRwoyGxvQkqVrZgDQi4m-5",
+ "PLqYmG7hTraZBiUr6_Qf8YTS2Oqy3OGZEj",
+ "PLAPVC5uNprwY0q4_nyeeHqIT07wZqwjGO",
+ "PLCRVRLd2RhZTpdUdEzJjo3qhmX3y3skWA",
+ "PLTYHZYmxohXpn5uf8JZ2OouB1PsDJAk-x",
+ ],
),
]
diff --git a/align_data/sources/youtube/youtube.py b/align_data/sources/youtube/youtube.py
index e5912dc2..8670b691 100644
--- a/align_data/sources/youtube/youtube.py
+++ b/align_data/sources/youtube/youtube.py
@@ -5,7 +5,11 @@
from googleapiclient.discovery import build
from youtube_transcript_api import YouTubeTranscriptApi
-from youtube_transcript_api._errors import NoTranscriptFound, VideoUnavailable, TranscriptsDisabled
+from youtube_transcript_api._errors import (
+ NoTranscriptFound,
+ VideoUnavailable,
+ TranscriptsDisabled,
+)
from align_data.settings import YOUTUBE_API_KEY
from align_data.common.alignment_dataset import AlignmentDataset
@@ -15,8 +19,7 @@
class YouTubeDataset(AlignmentDataset):
-
- done_key = 'url'
+ done_key = "url"
batch_size = 1
# COOLDOWN = 2
authors = None
@@ -25,34 +28,34 @@ class YouTubeDataset(AlignmentDataset):
def setup(self):
super().setup()
if not YOUTUBE_API_KEY:
- raise ValueError('No YOUTUBE_API_KEY provided!')
- self.youtube = build('youtube', 'v3', developerKey=YOUTUBE_API_KEY)
+ raise ValueError("No YOUTUBE_API_KEY provided!")
+ self.youtube = build("youtube", "v3", developerKey=YOUTUBE_API_KEY)
def next_page(self, collection_id, next_page_token):
- return {'items': []}
+ return {"items": []}
@staticmethod
def _get_id(item):
- if item.get('kind') == 'youtube#searchResult':
- resource = item['id']
- elif item.get('kind') == 'youtube#playlistItem':
- resource = item['snippet']['resourceId']
+ if item.get("kind") == "youtube#searchResult":
+ resource = item["id"]
+ elif item.get("kind") == "youtube#playlistItem":
+ resource = item["snippet"]["resourceId"]
else:
return None
- if resource['kind'] == 'youtube#video':
- return resource['videoId']
+ if resource["kind"] == "youtube#video":
+ return resource["videoId"]
def fetch_videos(self, collection_id):
next_page_token = None
while True:
videos_response = self.next_page(collection_id, next_page_token)
- for item in videos_response.get('items'):
+ for item in videos_response.get("items"):
if self._get_id(item):
yield item
- next_page_token = videos_response.get('nextPageToken')
+ next_page_token = videos_response.get("nextPageToken")
if not next_page_token:
return
@@ -66,23 +69,29 @@ def items_list(self):
def get_item_key(self, item):
video_id = self._get_id(item)
- return f'https://www.youtube.com/watch?v={video_id}'
+ return f"https://www.youtube.com/watch?v={video_id}"
def _get_contents(self, video):
video_id = self._get_id(video)
try:
- transcript = YouTubeTranscriptApi.list_transcripts(video_id).find_transcript(['en', 'en-GB']).fetch()
- return '\n'.join([i['text'] for i in transcript])
+ transcript = (
+ YouTubeTranscriptApi.list_transcripts(video_id)
+ .find_transcript(["en", "en-GB"])
+ .fetch()
+ )
+ return "\n".join([i["text"] for i in transcript])
except (NoTranscriptFound, VideoUnavailable):
return None
except TranscriptsDisabled:
- logger.error(f'Transcripts disabled for https://www.youtube.com/watch?v={video_id} - skipping')
+ logger.error(
+ f"Transcripts disabled for https://www.youtube.com/watch?v={video_id} - skipping"
+ )
return None
def extract_authors(self, video):
if self.authors:
return self.authors
- return [video['snippet']['channelTitle'].strip()]
+ return [video["snippet"]["channelTitle"].strip()]
def process_entry(self, video):
video_url = self.get_item_key(video)
@@ -91,20 +100,21 @@ def process_entry(self, video):
if not contents:
return None
- return self.make_data_entry({
- "text": contents,
- "url": video_url,
- "title": video['snippet']['title'],
- "source": self.name,
- "source_type": "youtube",
- "date_published": self._get_published_date(video),
- "authors": self.extract_authors(video),
- })
+ return self.make_data_entry(
+ {
+ "text": contents,
+ "url": video_url,
+ "title": video["snippet"]["title"],
+ "source": self.name,
+ "source_type": "youtube",
+ "date_published": self._get_published_date(video),
+ "authors": self.extract_authors(video),
+ }
+ )
@dataclass
class YouTubeChannelDataset(YouTubeDataset):
-
channel_id: str
authors: List[str]
@@ -113,20 +123,23 @@ def collection_ids(self):
return [self.channel_id]
def next_page(self, collection_id, next_page_token):
- return self.youtube.search().list(
- part='snippet',
- channelId=collection_id,
- maxResults=50,
- pageToken=next_page_token
- ).execute()
+ return (
+ self.youtube.search()
+ .list(
+ part="snippet",
+ channelId=collection_id,
+ maxResults=50,
+ pageToken=next_page_token,
+ )
+ .execute()
+ )
def _get_published_date(self, video):
- return super()._get_published_date(video['snippet']['publishTime'])
+ return super()._get_published_date(video["snippet"]["publishTime"])
@dataclass
class YouTubePlaylistDataset(YouTubeDataset):
-
playlist_ids: str
@property
@@ -134,12 +147,16 @@ def collection_ids(self):
return self.playlist_ids
def next_page(self, collection_id, next_page_token):
- return self.youtube.playlistItems().list(
- part='snippet',
- playlistId=collection_id,
- maxResults=50,
- pageToken=next_page_token,
- ).execute()
+ return (
+ self.youtube.playlistItems()
+ .list(
+ part="snippet",
+ playlistId=collection_id,
+ maxResults=50,
+ pageToken=next_page_token,
+ )
+ .execute()
+ )
def _get_published_date(self, video):
- return super()._get_published_date(video['snippet']['publishedAt'])
+ return super()._get_published_date(video["snippet"]["publishedAt"])
diff --git a/main.py b/main.py
index 4ea6bf3d..ae11b641 100644
--- a/main.py
+++ b/main.py
@@ -10,7 +10,9 @@
from align_data.sources.articles.articles import update_new_items, check_new_articles
from align_data.pinecone.update_pinecone import PineconeUpdater
from align_data.settings import (
- METADATA_OUTPUT_SPREADSHEET, METADATA_SOURCE_SHEET, METADATA_SOURCE_SPREADSHEET
+ METADATA_OUTPUT_SPREADSHEET,
+ METADATA_SOURCE_SHEET,
+ METADATA_SOURCE_SPREADSHEET,
)
@@ -19,7 +21,6 @@
@dataclass
class AlignmentDataset:
-
out_path: str = "data"
"""The path to the directory where the data will be downloaded, defaults to data"""
@@ -34,13 +35,13 @@ def fetch(self, *names) -> None:
:param str name: The name of the dataset to fetch
:return: The path to the file that was written to.
"""
- if names == ('all',):
+ if names == ("all",):
names = ALL_DATASETS
missing = {name for name in names if name not in ALL_DATASETS}
assert not missing, f"{missing} are not valid dataset names"
for name in names:
dataset = get_dataset(name)
-
+
dataset.add_entries(dataset.fetch_entries())
def fetch_all(self, *skip) -> None:
@@ -62,7 +63,7 @@ def generate_jsonl_files(self, *names):
:param List[str] names: The names of the datasets to generate
"""
- if names == ('all',):
+ if names == ("all",):
names = ALL_DATASETS
missing = {name for name in names if name not in ALL_DATASETS}
assert not missing, f"{missing} are not valid dataset names"
@@ -75,12 +76,16 @@ def count_tokens(self, merged_dataset_path: str) -> None:
This function counts the number of tokens, words, and characters in the dataset
:return: None
"""
- assert os.path.exists(merged_dataset_path), "The path to the merged dataset does not exist"
+ assert os.path.exists(
+ merged_dataset_path
+ ), "The path to the merged dataset does not exist"
count_token(merged_dataset_path)
def update_metadata(
- self, source_spreadsheet=METADATA_SOURCE_SPREADSHEET,
- source_sheet=METADATA_SOURCE_SHEET, output_spreadsheet=METADATA_OUTPUT_SPREADSHEET
+ self,
+ source_spreadsheet=METADATA_SOURCE_SPREADSHEET,
+ source_sheet=METADATA_SOURCE_SHEET,
+ output_spreadsheet=METADATA_OUTPUT_SPREADSHEET,
):
"""Go through all unprocessed items from the source worksheet, updating the appropriate metadata in the output one.
@@ -90,7 +95,11 @@ def update_metadata(
"""
return update_new_items(source_spreadsheet, source_sheet, output_spreadsheet)
- def fetch_new_articles(self, source_spreadsheet=METADATA_SOURCE_SPREADSHEET, source_sheet=METADATA_SOURCE_SHEET):
+ def fetch_new_articles(
+ self,
+ source_spreadsheet=METADATA_SOURCE_SPREADSHEET,
+ source_sheet=METADATA_SOURCE_SHEET,
+ ):
"""Look for unseen articles in the special indices, adding any that are found to the provided spreadsheet.
:param str source_spreadsheet: The id of the google docs spreadsheet containing the items to be processed
@@ -102,12 +111,12 @@ def pinecone_update(self, *names) -> None:
"""
This function updates the Pinecone vector DB.
"""
- if names == ('all',):
+ if names == ("all",):
names = ALL_DATASETS
missing = {name for name in names if name not in ALL_DATASETS}
assert not missing, f"{missing} are not valid dataset names"
PineconeUpdater().update(names)
-
+
def pinecone_update_all(self, *skip) -> None:
"""
This function updates the Pinecone vector DB.
@@ -117,4 +126,4 @@ def pinecone_update_all(self, *skip) -> None:
if __name__ == "__main__":
- fire.Fire(AlignmentDataset)
\ No newline at end of file
+ fire.Fire(AlignmentDataset)
diff --git a/migrations/env.py b/migrations/env.py
index 838bfb97..1d07d5d2 100644
--- a/migrations/env.py
+++ b/migrations/env.py
@@ -15,13 +15,15 @@
fileConfig(config.config_file_name)
from align_data.settings import DB_CONNECTION_URI
-config.set_main_option('sqlalchemy.url', DB_CONNECTION_URI)
+
+config.set_main_option("sqlalchemy.url", DB_CONNECTION_URI)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from align_data.db.models import Base
+
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
@@ -68,9 +70,7 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
- context.configure(
- connection=connection, target_metadata=target_metadata
- )
+ context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
diff --git a/migrations/versions/0a0041c28458_confidence_column.py b/migrations/versions/0a0041c28458_confidence_column.py
index 1c39b30f..53c6eeb9 100644
--- a/migrations/versions/0a0041c28458_confidence_column.py
+++ b/migrations/versions/0a0041c28458_confidence_column.py
@@ -10,15 +10,15 @@
# revision identifiers, used by Alembic.
-revision = '0a0041c28458'
-down_revision = '983b5bdef5f6'
+revision = "0a0041c28458"
+down_revision = "983b5bdef5f6"
branch_labels = None
depends_on = None
def upgrade() -> None:
- op.add_column('articles', sa.Column('confidence', sa.Float(), nullable=True))
+ op.add_column("articles", sa.Column("confidence", sa.Float(), nullable=True))
def downgrade() -> None:
- op.drop_column('articles', 'confidence')
+ op.drop_column("articles", "confidence")
diff --git a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
index 74792469..7a8485fe 100644
--- a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
+++ b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
@@ -10,19 +10,21 @@
# revision identifiers, used by Alembic.
-revision = '59ac3cb671e3'
-down_revision = '0a0041c28458'
+revision = "59ac3cb671e3"
+down_revision = "0a0041c28458"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('articles', sa.Column('pinecone_update_required', sa.Boolean(), nullable=False))
+ op.add_column(
+ "articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)
+ )
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_column('articles', 'pinecone_update_required')
+ op.drop_column("articles", "pinecone_update_required")
# ### end Alembic commands ###
diff --git a/migrations/versions/983b5bdef5f6_initial_structure.py b/migrations/versions/983b5bdef5f6_initial_structure.py
index ff1ef321..947a37c1 100644
--- a/migrations/versions/983b5bdef5f6_initial_structure.py
+++ b/migrations/versions/983b5bdef5f6_initial_structure.py
@@ -10,7 +10,7 @@
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = '983b5bdef5f6'
+revision = "983b5bdef5f6"
down_revision = None
branch_labels = None
depends_on = None
@@ -18,33 +18,36 @@
def upgrade() -> None:
op.create_table(
- 'articles',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('hash_id', sa.String(length=32), nullable=False),
- sa.Column('title', sa.String(length=1028), nullable=True),
- sa.Column('url', sa.String(length=1028), nullable=True),
- sa.Column('source', sa.String(length=128), nullable=True),
- sa.Column('source_type', sa.String(length=128), nullable=True),
- sa.Column('authors', sa.String(length=1024), nullable=False),
- sa.Column('text', mysql.LONGTEXT(), nullable=True),
- sa.Column('date_published', sa.DateTime(), nullable=True),
- sa.Column('metadata', sa.JSON(), nullable=True),
- sa.Column('date_created', sa.DateTime(), nullable=False),
- sa.Column('date_updated', sa.DateTime(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('hash_id')
+ "articles",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("hash_id", sa.String(length=32), nullable=False),
+ sa.Column("title", sa.String(length=1028), nullable=True),
+ sa.Column("url", sa.String(length=1028), nullable=True),
+ sa.Column("source", sa.String(length=128), nullable=True),
+ sa.Column("source_type", sa.String(length=128), nullable=True),
+ sa.Column("authors", sa.String(length=1024), nullable=False),
+ sa.Column("text", mysql.LONGTEXT(), nullable=True),
+ sa.Column("date_published", sa.DateTime(), nullable=True),
+ sa.Column("metadata", sa.JSON(), nullable=True),
+ sa.Column("date_created", sa.DateTime(), nullable=False),
+ sa.Column("date_updated", sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("hash_id"),
)
op.create_table(
- 'summaries',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('text', sa.Text(), nullable=False),
- sa.Column('source', sa.String(length=256), nullable=True),
- sa.Column('article_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['article_id'], ['articles.id'], ),
- sa.PrimaryKeyConstraint('id')
+ "summaries",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("text", sa.Text(), nullable=False),
+ sa.Column("source", sa.String(length=256), nullable=True),
+ sa.Column("article_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["article_id"],
+ ["articles.id"],
+ ),
+ sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
- op.drop_table('summaries')
- op.drop_table('articles')
+ op.drop_table("summaries")
+ op.drop_table("articles")
diff --git a/setup.py b/setup.py
index a1106c58..871c1d32 100644
--- a/setup.py
+++ b/setup.py
@@ -4,13 +4,13 @@
long_description = fh.read()
setuptools.setup(
- name='align_data',
- version='0.0.1',
+ name="align_data",
+ version="0.0.1",
description="A framework for constructing a dataset for alignment research",
long_description=long_description,
long_description_content_type="text/markdown",
packages=setuptools.find_packages(),
- python_requires='>=3.6',
+ python_requires=">=3.6",
install_requires=[
"bs4==0.0.1",
"python-dateutil==2.8.2",
@@ -21,5 +21,5 @@
"GitPython",
"gdown",
"pypandoc",
- ]
+ ],
)
diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py
index 48340000..36c47071 100644
--- a/tests/align_data/articles/test_datasets.py
+++ b/tests/align_data/articles/test_datasets.py
@@ -3,245 +3,288 @@
import pandas as pd
import pytest
from align_data.sources.articles.datasets import (
- EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, SpreadsheetDataset, XMLArticles
+ EbookArticles,
+ DocArticles,
+ HTMLArticles,
+ MarkdownArticles,
+ PDFArticles,
+ SpreadsheetDataset,
+ XMLArticles,
)
@pytest.fixture
def articles():
- source_type = 'something'
+ source_type = "something"
articles = [
{
- 'source_url': f'http://example.com/source_url/{i}',
- 'url': f'http://example.com/item/{i}',
- 'title': f'article no {i}',
- 'source_type': source_type,
- 'date_published': f'2023/01/0{i + 1} 12:32:11',
- 'authors': f'John Snow, mr Blobby',
- 'summary': f'the summary of article {i}',
- 'file_id': str(i),
- } for i in range(5)
+ "source_url": f"http://example.com/source_url/{i}",
+ "url": f"http://example.com/item/{i}",
+ "title": f"article no {i}",
+ "source_type": source_type,
+ "date_published": f"2023/01/0{i + 1} 12:32:11",
+ "authors": f"John Snow, mr Blobby",
+ "summary": f"the summary of article {i}",
+ "file_id": str(i),
+ }
+ for i in range(5)
]
return pd.DataFrame(articles)
def test_spreadsheet_dataset_items_list(articles):
- dataset = SpreadsheetDataset(name='bla', spreadsheet_id='123', sheet_id='456')
+ dataset = SpreadsheetDataset(name="bla", spreadsheet_id="123", sheet_id="456")
df = pd.concat(
- [articles, pd.DataFrame([{'title': None}, {'summary': 'bla'}])],
- ignore_index=True
+ [articles, pd.DataFrame([{"title": None}, {"summary": "bla"}])],
+ ignore_index=True,
)
- with patch('pandas.read_csv', return_value=df):
+ with patch("pandas.read_csv", return_value=df):
assert list(dataset.items_list) == list(pd.DataFrame(articles).itertuples())
def test_spreadsheet_dataset_get_item_key():
- dataset = SpreadsheetDataset(name='bla', spreadsheet_id='123', sheet_id='456')
- assert dataset.get_item_key(Mock(bla='ble', title='the key')) == 'the key'
-
-
-@pytest.mark.parametrize('authors, expected', (
- ('', []),
- (' \n \n \t', []),
- ('John Snow', ['John Snow']),
- ('John Snow, mr. Blobby', ['John Snow', 'mr. Blobby']),
-))
+ dataset = SpreadsheetDataset(name="bla", spreadsheet_id="123", sheet_id="456")
+ assert dataset.get_item_key(Mock(bla="ble", title="the key")) == "the key"
+
+
+@pytest.mark.parametrize(
+ "authors, expected",
+ (
+ ("", []),
+ (" \n \n \t", []),
+ ("John Snow", ["John Snow"]),
+ ("John Snow, mr. Blobby", ["John Snow", "mr. Blobby"]),
+ ),
+)
def test_spreadsheet_dataset_extract_authors(authors, expected):
- dataset = SpreadsheetDataset(name='bla', spreadsheet_id='123', sheet_id='456')
+ dataset = SpreadsheetDataset(name="bla", spreadsheet_id="123", sheet_id="456")
assert dataset.extract_authors(Mock(authors=authors)) == expected
def test_pdf_articles_get_text():
- dataset = PDFArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- item = Mock(file_id='23423', title='bla bla bla')
+ dataset = PDFArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ item = Mock(file_id="23423", title="bla bla bla")
def check_downloads(output, id):
- assert output == str(dataset.files_path / 'bla bla bla.pdf')
- assert id == '23423'
+ assert output == str(dataset.files_path / "bla bla bla.pdf")
+ assert id == "23423"
return output
def read_pdf(filename):
- assert filename == dataset.files_path / 'bla bla bla.pdf'
- return 'pdf contents'
+ assert filename == dataset.files_path / "bla bla bla.pdf"
+ return "pdf contents"
- with patch('align_data.sources.articles.datasets.download', check_downloads):
- with patch('align_data.sources.articles.datasets.read_pdf', read_pdf):
- assert dataset._get_text(item) == 'pdf contents'
+ with patch("align_data.sources.articles.datasets.download", check_downloads):
+ with patch("align_data.sources.articles.datasets.read_pdf", read_pdf):
+ assert dataset._get_text(item) == "pdf contents"
def test_pdf_articles_process_item(articles):
- dataset = PDFArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('pandas.read_csv', return_value=articles):
+ dataset = PDFArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch("pandas.read_csv", return_value=articles):
item = list(dataset.items_list)[0]
- with patch('align_data.sources.articles.datasets.download'):
- with patch('align_data.sources.articles.datasets.read_pdf', return_value='pdf contents bla'):
+ with patch("align_data.sources.articles.datasets.download"):
+ with patch(
+ "align_data.sources.articles.datasets.read_pdf",
+ return_value='pdf contents bla',
+ ):
assert dataset.process_entry(item).to_dict() == {
- 'authors': ['John Snow', 'mr Blobby'],
- 'date_published': '2023-01-01T12:32:11Z',
- 'id': None,
- 'source': 'bla',
- 'source_filetype': 'pdf',
- 'source_type': 'something',
- 'summaries': ['the summary of article 0'],
- 'text': 'pdf contents [bla](asd.com)',
- 'title': 'article no 0',
- 'url': 'http://example.com/item/0',
+ "authors": ["John Snow", "mr Blobby"],
+ "date_published": "2023-01-01T12:32:11Z",
+ "id": None,
+ "source": "bla",
+ "source_filetype": "pdf",
+ "source_type": "something",
+ "summaries": ["the summary of article 0"],
+ "text": "pdf contents [bla](asd.com)",
+ "title": "article no 0",
+ "url": "http://example.com/item/0",
}
def test_html_articles_get_text():
def parser(url):
- assert url == 'http://example.org/bla.bla'
- return 'html contents'
+ assert url == "http://example.org/bla.bla"
+ return "html contents"
- with patch('align_data.sources.articles.datasets.HTML_PARSERS', {'example.org': parser}):
- assert HTMLArticles._get_text(Mock(source_url='http://example.org/bla.bla')) == 'html contents'
+ with patch(
+ "align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser}
+ ):
+ assert (
+ HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla"))
+ == "html contents"
+ )
def test_html_articles_get_text_no_parser():
- with patch('align_data.sources.articles.datasets.HTML_PARSERS', {}):
- assert HTMLArticles._get_text(Mock(source_url='http://example.org/bla.bla')) is None
+ with patch("align_data.sources.articles.datasets.HTML_PARSERS", {}):
+ assert (
+ HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla"))
+ is None
+ )
def test_html_articles_process_entry(articles):
- dataset = HTMLArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('pandas.read_csv', return_value=articles):
+ dataset = HTMLArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch("pandas.read_csv", return_value=articles):
item = list(dataset.items_list)[0]
- parsers = {'example.com': lambda _: ' html contents with proper elements ble ble '}
- with patch('align_data.sources.articles.datasets.HTML_PARSERS', parsers):
+ parsers = {
+ "example.com": lambda _: ' html contents with proper elements ble ble '
+ }
+ with patch("align_data.sources.articles.datasets.HTML_PARSERS", parsers):
assert dataset.process_entry(item).to_dict() == {
- 'authors': ['John Snow', 'mr Blobby'],
- 'date_published': '2023-01-01T12:32:11Z',
- 'id': None,
- 'source': 'bla',
- 'source_filetype': 'html',
- 'source_type': 'something',
- 'summaries': ['the summary of article 0'],
- 'text': 'html contents with [proper elements](bla.com) ble ble',
- 'title': 'article no 0',
- 'url': 'http://example.com/item/0',
+ "authors": ["John Snow", "mr Blobby"],
+ "date_published": "2023-01-01T12:32:11Z",
+ "id": None,
+ "source": "bla",
+ "source_filetype": "html",
+ "source_type": "something",
+ "summaries": ["the summary of article 0"],
+ "text": "html contents with [proper elements](bla.com) ble ble",
+ "title": "article no 0",
+ "url": "http://example.com/item/0",
}
def test_ebook_articles_get_text():
- dataset = EbookArticles(name='bla', spreadsheet_id='123', sheet_id='456')
+ dataset = EbookArticles(name="bla", spreadsheet_id="123", sheet_id="456")
item = Mock(
- source_url='https://drive.google.com/file/d/123456/view?usp=drive_link',
- title='bla bla bla'
+ source_url="https://drive.google.com/file/d/123456/view?usp=drive_link",
+ title="bla bla bla",
)
def check_downloads(output, id):
- assert output == str(dataset.files_path / 'bla bla bla.epub')
- assert id == '123456'
+ assert output == str(dataset.files_path / "bla bla bla.epub")
+ assert id == "123456"
return output
def read_ebook(filename, *args, **kwargs):
- return 'ebook contents'
+ return "ebook contents"
- with patch('align_data.sources.articles.datasets.download', check_downloads):
- with patch('align_data.sources.articles.datasets.convert_file', read_ebook):
- assert dataset._get_text(item) == 'ebook contents'
+ with patch("align_data.sources.articles.datasets.download", check_downloads):
+ with patch("align_data.sources.articles.datasets.convert_file", read_ebook):
+ assert dataset._get_text(item) == "ebook contents"
def test_ebook_articles_process_entry(articles):
- dataset = EbookArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('pandas.read_csv', return_value=articles):
+ dataset = EbookArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch("pandas.read_csv", return_value=articles):
item = list(dataset.items_list)[0]
contents = ' html contents with proper elements ble ble '
- with patch('align_data.sources.articles.datasets.download'):
- with patch('align_data.sources.articles.datasets.convert_file', return_value=contents):
+ with patch("align_data.sources.articles.datasets.download"):
+ with patch(
+ "align_data.sources.articles.datasets.convert_file", return_value=contents
+ ):
assert dataset.process_entry(item).to_dict() == {
- 'authors': ['John Snow', 'mr Blobby'],
- 'date_published': '2023-01-01T12:32:11Z',
- 'id': None,
- 'source': 'bla',
- 'source_filetype': 'epub',
- 'source_type': 'something',
- 'summaries': ['the summary of article 0'],
- 'text': 'html contents with [proper elements](bla.com) ble ble',
- 'title': 'article no 0',
- 'url': 'http://example.com/item/0',
+ "authors": ["John Snow", "mr Blobby"],
+ "date_published": "2023-01-01T12:32:11Z",
+ "id": None,
+ "source": "bla",
+ "source_filetype": "epub",
+ "source_type": "something",
+ "summaries": ["the summary of article 0"],
+ "text": "html contents with [proper elements](bla.com) ble ble",
+ "title": "article no 0",
+ "url": "http://example.com/item/0",
}
def test_xml_articles_get_text():
- dataset = XMLArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('align_data.sources.articles.datasets.extract_gdrive_contents', return_value={'text': 'bla bla'}):
- assert dataset._get_text(Mock(source_url='bla.com')) == 'bla bla'
+ dataset = XMLArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch(
+ "align_data.sources.articles.datasets.extract_gdrive_contents",
+ return_value={"text": "bla bla"},
+ ):
+ assert dataset._get_text(Mock(source_url="bla.com")) == "bla bla"
def test_xml_articles_process_entry(articles):
- dataset = XMLArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('pandas.read_csv', return_value=articles):
+ dataset = XMLArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch("pandas.read_csv", return_value=articles):
item = list(dataset.items_list)[0]
- with patch('align_data.sources.articles.datasets.extract_gdrive_contents', return_value={'text': 'bla bla'}):
+ with patch(
+ "align_data.sources.articles.datasets.extract_gdrive_contents",
+ return_value={"text": "bla bla"},
+ ):
assert dataset.process_entry(item).to_dict() == {
- 'authors': ['John Snow', 'mr Blobby'],
- 'date_published': '2023-01-01T12:32:11Z',
- 'id': None,
- 'source': 'bla',
- 'source_filetype': 'xml',
- 'source_type': 'something',
- 'summaries': ['the summary of article 0'],
- 'text': 'bla bla',
- 'title': 'article no 0',
- 'url': 'http://example.com/item/0',
+ "authors": ["John Snow", "mr Blobby"],
+ "date_published": "2023-01-01T12:32:11Z",
+ "id": None,
+ "source": "bla",
+ "source_filetype": "xml",
+ "source_type": "something",
+ "summaries": ["the summary of article 0"],
+ "text": "bla bla",
+ "title": "article no 0",
+ "url": "http://example.com/item/0",
}
def test_markdown_articles_get_text():
- dataset = MarkdownArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('align_data.sources.articles.datasets.fetch_markdown', return_value={'text': 'bla bla'}):
- assert dataset._get_text(Mock(source_url='bla.com/bla/123/bla')) == 'bla bla'
+ dataset = MarkdownArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch(
+ "align_data.sources.articles.datasets.fetch_markdown",
+ return_value={"text": "bla bla"},
+ ):
+ assert dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla"
def test_markdown_articles_process_entry(articles):
- dataset = MarkdownArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('pandas.read_csv', return_value=articles):
+ dataset = MarkdownArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch("pandas.read_csv", return_value=articles):
item = list(dataset.items_list)[0]
- with patch('align_data.sources.articles.datasets.fetch_markdown', return_value={'text': 'bla bla'}):
+ with patch(
+ "align_data.sources.articles.datasets.fetch_markdown",
+ return_value={"text": "bla bla"},
+ ):
assert dataset.process_entry(item).to_dict() == {
- 'authors': ['John Snow', 'mr Blobby'],
- 'date_published': '2023-01-01T12:32:11Z',
- 'id': None,
- 'source': 'bla',
- 'source_filetype': 'md',
- 'source_type': 'something',
- 'summaries': ['the summary of article 0'],
- 'text': 'bla bla',
- 'title': 'article no 0',
- 'url': 'http://example.com/item/0',
+ "authors": ["John Snow", "mr Blobby"],
+ "date_published": "2023-01-01T12:32:11Z",
+ "id": None,
+ "source": "bla",
+ "source_filetype": "md",
+ "source_type": "something",
+ "summaries": ["the summary of article 0"],
+ "text": "bla bla",
+ "title": "article no 0",
+ "url": "http://example.com/item/0",
}
def test_doc_articles_get_text():
- dataset = DocArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('align_data.sources.articles.datasets.fetch_file'):
- with patch('align_data.sources.articles.datasets.convert_file', return_value='bla bla'):
- assert dataset._get_text(Mock(source_url='bla.com/bla/123/bla')) == 'bla bla'
+ dataset = DocArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch("align_data.sources.articles.datasets.fetch_file"):
+ with patch(
+ "align_data.sources.articles.datasets.convert_file", return_value="bla bla"
+ ):
+ assert (
+ dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla"
+ )
def test_doc_articles_process_entry(articles):
- dataset = DocArticles(name='bla', spreadsheet_id='123', sheet_id='456')
- with patch('pandas.read_csv', return_value=articles):
+ dataset = DocArticles(name="bla", spreadsheet_id="123", sheet_id="456")
+ with patch("pandas.read_csv", return_value=articles):
item = list(dataset.items_list)[0]
- with patch('align_data.sources.articles.datasets.fetch_file'):
- with patch('align_data.sources.articles.datasets.convert_file', return_value='bla bla'):
+ with patch("align_data.sources.articles.datasets.fetch_file"):
+ with patch(
+ "align_data.sources.articles.datasets.convert_file", return_value="bla bla"
+ ):
assert dataset.process_entry(item).to_dict() == {
- 'authors': ['John Snow', 'mr Blobby'],
- 'date_published': '2023-01-01T12:32:11Z',
- 'id': None,
- 'source': 'bla',
- 'source_filetype': 'docx',
- 'source_type': 'something',
- 'summaries': ['the summary of article 0'],
- 'text': 'bla bla',
- 'title': 'article no 0',
- 'url': 'http://example.com/item/0',
+ "authors": ["John Snow", "mr Blobby"],
+ "date_published": "2023-01-01T12:32:11Z",
+ "id": None,
+ "source": "bla",
+ "source_filetype": "docx",
+ "source_type": "something",
+ "summaries": ["the summary of article 0"],
+ "text": "bla bla",
+ "title": "article no 0",
+ "url": "http://example.com/item/0",
}
diff --git a/tests/align_data/articles/test_parsers.py b/tests/align_data/articles/test_parsers.py
index 9abdb43b..9f43c231 100644
--- a/tests/align_data/articles/test_parsers.py
+++ b/tests/align_data/articles/test_parsers.py
@@ -5,7 +5,11 @@
from bs4 import BeautifulSoup
from align_data.sources.articles.parsers import (
- google_doc, medium_blog, parse_grobid, get_content_type, extract_gdrive_contents
+ google_doc,
+ medium_blog,
+ parse_grobid,
+ get_content_type,
+ extract_gdrive_contents,
)
@@ -47,10 +51,15 @@
"""
+
def test_google_doc():
def fetcher(url, *args, **kwargs):
- assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html'
- return Mock(content="""
+ assert (
+ url
+ == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html"
+ )
+ return Mock(
+ content="""