diff --git a/align_data/analysis/analyse_jsonl_data.py b/align_data/analysis/analyse_jsonl_data.py
index 0aed124d..9ef49649 100644
--- a/align_data/analysis/analyse_jsonl_data.py
+++ b/align_data/analysis/analyse_jsonl_data.py
@@ -1,8 +1,9 @@
from datetime import datetime
from pathlib import Path
+from collections import defaultdict
+
import jsonlines
-from collections import defaultdict
def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"):
diff --git a/align_data/analysis/count_tokens.py b/align_data/analysis/count_tokens.py
index cd099c68..bc0232d3 100644
--- a/align_data/analysis/count_tokens.py
+++ b/align_data/analysis/count_tokens.py
@@ -1,7 +1,8 @@
+from typing import Tuple
+import logging
+
from transformers import AutoTokenizer
import jsonlines
-import logging
-from typing import Tuple
logger = logging.getLogger(__name__)
diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py
index 344ee89d..5e35d9fd 100644
--- a/align_data/common/alignment_dataset.py
+++ b/align_data/common/alignment_dataset.py
@@ -3,23 +3,23 @@
from itertools import islice
import logging
import time
-from dataclasses import dataclass, KW_ONLY
+from dataclasses import dataclass, field, KW_ONLY
from pathlib import Path
-from typing import Iterable, List, Optional, Set
-from sqlalchemy import select
-from sqlalchemy.exc import IntegrityError
-from sqlalchemy.orm import joinedload
+from typing import List, Optional, Set, Iterable, Tuple, Generator
-import jsonlines
import pytz
+from sqlalchemy import select, Select, JSON
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.orm import joinedload, Session
+import jsonlines
from dateutil.parser import parse, ParserError
from tqdm import tqdm
+
from align_data.db.models import Article, Summary
from align_data.db.session import make_session
from align_data.settings import ARTICLE_MAIN_KEYS
from align_data.sources.utils import merge_dicts
-
logger = logging.getLogger(__name__)
@@ -28,40 +28,42 @@ class AlignmentDataset:
"""The base dataset class."""
name: str
- """The name of the dataset"""
+ """The name of the dataset."""
_: KW_ONLY
- files_path = Path("")
- """The path where data can be found. Usually a folder"""
+ data_path: Path = Path(__file__).parent / "../../data/"
+ """The path where data can be found. Usually a folder."""
+
+ # Derived paths
+ raw_data_path: Path = field(init=False)
+ files_path: Path = field(init=False)
+
+ # Internal housekeeping variables
+ _outputted_items: Set[str] = field(default_factory=set, init=False)
+ """A set of the ids of all previously processed items."""
done_key = "id"
"""The key of the entry to use as the id when checking if already processed."""
COOLDOWN = 0
- """An optional cool down between processing entries"""
+ """An optional cool down between processing entries."""
lazy_eval = False
"""Whether to lazy fetch items. This is nice in that it will start processing, but messes up the progress bar."""
+
batch_size = 20
"""The number of items to collect before flushing to the database."""
- # Internal housekeeping variables
- _entry_idx = 0
- """Used internally for writing debugging info - each file write will increment it"""
- _outputted_items = set()
- """A set of the ids of all previously processed items"""
+ def __post_init__(self):
+ self.data_path = self.data_path.resolve()
- def __str__(self) -> str:
- return self.name
-
- def __post_init__(self, data_path=Path(__file__).parent / "../../data/"):
- self.data_path = data_path
self.raw_data_path = self.data_path / "raw"
-
- # set the default place to look for data
self.files_path = self.raw_data_path / self.name
+ def __str__(self) -> str:
+ return self.name
+
def _add_authors(self, article: Article, authors: List[str]) -> Article:
# TODO: Don't keep adding the same authors - come up with some way to reuse them
article.authors = ",".join(authors)
@@ -87,42 +89,42 @@ def make_data_entry(self, data, **kwargs) -> Article:
article.summaries += [Summary(text=summary, source=self.name) for summary in summaries]
return article
- def to_jsonl(self, out_path=None, filename=None) -> Path:
- if not out_path:
- out_path = Path(__file__).parent / "../../data/"
-
- if not filename:
- filename = f"{self.name}.jsonl"
- filename = Path(out_path) / filename
+ def to_jsonl(self, out_path: Path | None = None, filename: str | None = None) -> Path:
+ out_path = out_path or self.data_path
+ filename = filename or f"{self.name}.jsonl"
+ filepath = out_path / filename
- with jsonlines.open(filename, "w") as jsonl_writer:
+ with jsonlines.open(filepath, "w") as jsonl_writer:
for article in self.read_entries():
jsonl_writer.write(article.to_dict())
- return filename.resolve()
+ return filepath.resolve()
@property
- def _query_items(self):
+ def _query_items(self) -> Select[Tuple[Article]]:
return select(Article).where(Article.source == self.name)
- def read_entries(self, sort_by=None):
+ def read_entries(self, sort_by=None) -> Iterable[Article]:
"""Iterate through all the saved entries."""
with make_session() as session:
query = self._query_items.options(joinedload(Article.summaries))
if sort_by is not None:
query = query.order_by(sort_by)
- for item in session.scalars(query).unique():
- yield item
+
+ result = session.scalars(query)
+ for article in result.unique(): # removes duplicates
+ yield article
- def _add_batch(self, session, batch):
+ def _add_batch(self, session: Session, batch: tuple):
session.add_all(batch)
def add_entries(self, entries):
- def commit():
+ def commit() -> bool:
try:
session.commit()
return True
except IntegrityError:
session.rollback()
+ return False
with make_session() as session:
items = iter(entries)
@@ -183,7 +185,11 @@ def _normalize_urls(self, urls: Iterable[str]) -> Set[str]:
def _load_outputted_items(self) -> Set[str]:
- """Load the output file (if it exists) in order to know which items have already been output."""
+ """
+ Loads the outputted items from the database and returns them as a set.
+
+ if the done_key is not an attribute of Article, it will try to load it from the meta field.
+ """
with make_session() as session:
items = set()
if hasattr(Article, self.done_key):
@@ -203,23 +209,24 @@ def not_processed(self, item) -> bool:
# If it get's to that level, consider batching it somehow
return self._normalize_url(self.get_item_key(item)) not in self._outputted_items
- def unprocessed_items(self, items=None) -> Iterable:
+ def unprocessed_items(self, items=None) -> list | filter:
"""Return a list of all items to be processed.
This will automatically remove any items that have already been processed,
based on the contents of the output file.
"""
self.setup()
+ items = items or self.items_list
- filtered = filter(self.not_processed, items or self.items_list)
+ items_to_process = filter(self.not_processed, items)
# greedily fetch all items if not lazy eval. This makes the progress bar look nice
if not self.lazy_eval:
- filtered = list(filtered)
+ return list(items_to_process)
- return filtered
+ return items_to_process
- def fetch_entries(self):
+ def fetch_entries(self) -> Generator[Article, None, None]:
"""Get all entries to be written to the file."""
for item in tqdm(self.unprocessed_items(), desc=f"Processing {self.name}"):
entry = self.process_entry(item)
@@ -242,7 +249,7 @@ def process_entry(self, entry) -> Article | None:
raise NotImplementedError
@staticmethod
- def _format_datetime(date) -> str:
+ def _format_datetime(date: datetime) -> str:
return date.strftime("%Y-%m-%dT%H:%M:%SZ")
@staticmethod
@@ -280,7 +287,7 @@ def _load_outputted_items(self) -> Set[str]:
)
)
- def _add_batch(self, session, batch):
+ def _add_batch(self, session: Session, batch: tuple):
def merge(item):
if prev := self.articles.get(item.url):
return session.merge(item.update(prev))
diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py
index e5e4d277..a1b748e3 100644
--- a/align_data/common/html_dataset.py
+++ b/align_data/common/html_dataset.py
@@ -1,16 +1,18 @@
-import pytz
-import regex as re
import logging
from datetime import datetime
-from dataclasses import dataclass, field, KW_ONLY
+from dataclasses import dataclass, field
from urllib.parse import urljoin
-from typing import List
+from typing import List, Dict, Any
+import re
+import pytz
import requests
import feedparser
from bs4 import BeautifulSoup
+from bs4.element import ResultSet, Tag
from markdownify import markdownify
+from align_data.db.models import Article
from align_data.common.alignment_dataset import AlignmentDataset
logger = logging.getLogger(__name__)
@@ -26,9 +28,6 @@ class HTMLDataset(AlignmentDataset):
done_key = "url"
authors: List[str] = field(default_factory=list)
- _: KW_ONLY
- source_key: str = None
- summary_key: str = None
item_selector = "article"
title_selector = "article h1"
@@ -39,12 +38,14 @@ class HTMLDataset(AlignmentDataset):
def extract_authors(self, article):
return self.authors
- def get_item_key(self, item) -> str:
- article_url = item.find_all("a")[0]["href"].split("?")[0]
- return urljoin(self.url, article_url)
+
+ def get_item_key(self, item: Tag) -> str:
+ first_href = item.find("a")["href"]
+ href_base, *_ = first_href.split("?")
+ return urljoin(self.url, href_base)
@property
- def items_list(self):
+ def items_list(self) -> ResultSet[Tag]:
logger.info(f"Fetching entries from {self.url}")
response = requests.get(self.url, allow_redirects=True)
soup = BeautifulSoup(response.content, "html.parser")
@@ -52,10 +53,10 @@ def items_list(self):
logger.info(f"Found {len(articles)} articles")
return articles
- def _extra_values(self, contents):
+ def _extra_values(self, contents: BeautifulSoup):
return {}
- def get_contents(self, article_url: str):
+ def get_contents(self, article_url: str) -> Dict[str, Any]:
contents = self.fetch_contents(article_url)
title = self._get_title(contents)
@@ -72,7 +73,7 @@ def get_contents(self, article_url: str):
**self._extra_values(contents),
}
- def process_entry(self, article):
+ def process_entry(self, article: Tag) -> Article:
article_url = self.get_item_key(article)
contents = self.get_contents(article_url)
if not contents.get("text"):
@@ -80,8 +81,8 @@ def process_entry(self, article):
return self.make_data_entry(contents)
- def fetch_contents(self, url):
- logger.info("Fetching {}".format(url))
+ def fetch_contents(self, url: str):
+ logger.info(f"Fetching {url}")
resp = requests.get(url, allow_redirects=True)
return BeautifulSoup(resp.content, "html.parser")
@@ -136,7 +137,7 @@ def _get_text(self, item):
text = item.get("content") and item["content"][0].get("value")
return self._extract_markdown(text)
- def fetch_contents(self, url):
+ def fetch_contents(self, url: str):
item = self.items[url]
if "content" in item:
return item
diff --git a/align_data/db/models.py b/align_data/db/models.py
index e79da232..d67ebfd5 100644
--- a/align_data/db/models.py
+++ b/align_data/db/models.py
@@ -221,6 +221,7 @@ def to_dict(self) -> Dict[str, Any]:
}
+
event.listen(Article, "before_insert", Article.before_write)
event.listen(Article, "before_update", Article.before_write)
event.listen(Article, "before_insert", Article.check_for_changes)
diff --git a/align_data/db/session.py b/align_data/db/session.py
index 331de9b7..2e80c4b4 100644
--- a/align_data/db/session.py
+++ b/align_data/db/session.py
@@ -7,7 +7,6 @@
from align_data.settings import DB_CONNECTION_URI, MIN_CONFIDENCE
from align_data.db.models import Article, PineconeStatus
-
logger = logging.getLogger(__name__)
# We create a single engine for the entire application
diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py
index b0b09b9f..3cf32112 100644
--- a/align_data/embeddings/pinecone/pinecone_db_handler.py
+++ b/align_data/embeddings/pinecone/pinecone_db_handler.py
@@ -19,7 +19,6 @@
PINECONE_NAMESPACE,
)
-
logger = logging.getLogger(__name__)
diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py
index 4e6ac03a..92fa726f 100644
--- a/align_data/embeddings/pinecone/update_pinecone.py
+++ b/align_data/embeddings/pinecone/update_pinecone.py
@@ -22,10 +22,8 @@
)
from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter
-
logger = logging.getLogger(__name__)
-
# Define type aliases for the Callables
LengthFunctionType = Callable[[str], int]
TruncateFunctionType = Callable[[str, int], str]
diff --git a/align_data/postprocess/postprocess.py b/align_data/postprocess/postprocess.py
index 05e9dbde..9db1a480 100644
--- a/align_data/postprocess/postprocess.py
+++ b/align_data/postprocess/postprocess.py
@@ -1,53 +1,81 @@
# %%
from collections import defaultdict, Counter
-from dataclasses import dataclass
-import jsonlines
-from tqdm import tqdm
+from dataclasses import dataclass, field
+from typing import List, DefaultDict
import logging
-from path import Path
+from pathlib import Path
+import jsonlines
+from tqdm import tqdm
import pylab as plt
-import seaborn as sns
+from nltk.tokenize import sent_tokenize, word_tokenize
+import seaborn as sns #TODO: install seaborn or fix this file
import pandas as pd
logger = logging.getLogger(__name__)
+#TODO: fix this file
@dataclass
class PostProcesser:
"""
This class is used to postprocess the data
"""
+ jsonl_path: Path = field(default_factory=lambda: (Path(__file__).parent / '../../data/').resolve())
+
+ def __post_init__(self) -> None:
+ print(f"Looking for data in {self.jsonl_path}")
- jsonl_path: Path = Path("../../data/")
+ # Check if the directory exists
+ if not self.jsonl_path.is_dir():
+ raise FileNotFoundError(f"Data directory {self.jsonl_path} does not exist")
- def __init__(self) -> None:
- self.jsonl_list = sorted(self.jsonl_path.files("*.jsonl"))
- self.source_list = [path.name.split(".jsonl")[0] for path in self.jsonl_list]
- self.all_stats = defaultdict(Counter)
+ self.jsonl_list: List[Path] = sorted(self.jsonl_path.glob("*.jsonl"))
+ self.source_list: List[str] = [path.stem for path in self.jsonl_list]
+ self.all_stats: DefaultDict[str, Counter] = defaultdict(Counter)
def compute_statistics(self) -> None:
for source_name, path in tqdm(zip(self.source_list, self.jsonl_list)):
with jsonlines.open(path) as reader:
for obj in reader:
- text = obj["text"]
+ text: str = obj['text']
source_stats = self.all_stats[source_name]
source_stats["num_entries"] += 1
- source_stats["num_tokens"] += len(text.split()) # TODO: Use tokenizer
+ source_stats["num_tokens"] += len(word_tokenize(text))
source_stats["num_chars"] += len(text)
source_stats["num_words"] += len(text.split())
- source_stats["num_sentences"] += len(
- text.split(".")
- ) # TODO: Use NLTK/Spacy or similar
- source_stats["num_paragraphs"] += len(text.splitlines())
+ source_stats["num_sentences"] += len(sent_tokenize(text))
+ source_stats["num_newlines"] += len(text.split("\n"))
+ source_stats["num_paragraphs"] += len(text.split("\n\n"))
def plot_statistics(self) -> None:
all_df = pd.DataFrame(self.all_stats).T
- plt.figure(figsize=(5, 5))
- sns.barplot(x=all_df.index, y=all_df["num_entries"])
+
+ fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(15, 15))
+ metrics_to_plot = [
+ "num_entries",
+ "num_tokens",
+ "num_words",
+ "num_sentences",
+ "num_paragraphs",
+ "num_chars",
+ ]
+
+ for i, metric in enumerate(metrics_to_plot):
+ ax = axes[i // 2, i % 2]
+ sns.barplot(x=all_df.index, y=all_df[metric], ax=ax)
+ ax.set_title(metric)
+ ax.set_ylabel('')
+ ax.tick_params(axis='x', rotation=45)
+ # Uncomment the next line if you want to apply a log scale for better visualization.
+ # ax.set_yscale("log")
+
+ plt.tight_layout()
+ plt.show()
+
def merge_all_files(self, out_dir: str) -> str:
- pass
+ raise NotImplementedError
def deduplicate(self) -> None:
for path in tqdm(self.jsonl_list):
@@ -58,7 +86,7 @@ def deduplicate(self) -> None:
writer.write(obj)
def clean_dataset(self, merged_dataset_path: str) -> str:
- pass
+ raise NotImplementedError
pp = PostProcesser()
@@ -66,6 +94,8 @@ def clean_dataset(self, merged_dataset_path: str) -> str:
pp.source_list
# %%
pp.compute_statistics()
+print(pp.all_stats)
+pp.plot_statistics()
# %%
pp.deduplicate()
# %%
diff --git a/align_data/sources/alignment_newsletter/alignment_newsletter.py b/align_data/sources/alignment_newsletter/alignment_newsletter.py
index 87541d4a..ebd98c1d 100644
--- a/align_data/sources/alignment_newsletter/alignment_newsletter.py
+++ b/align_data/sources/alignment_newsletter/alignment_newsletter.py
@@ -1,10 +1,10 @@
# %%
import logging
from datetime import datetime, timezone
-from pathlib import Path
+from dataclasses import dataclass
+
import pandas as pd
-from dataclasses import dataclass
from align_data.common.alignment_dataset import SummaryDataset
logger = logging.getLogger(__name__)
@@ -14,10 +14,6 @@
class AlignmentNewsletter(SummaryDataset):
done_key = "url"
- def __post_init__(self, data_path=Path(__file__).parent / "../../../data/"):
- self.data_path = data_path
- self.raw_data_path = self.data_path / "raw"
-
def setup(self) -> None:
super().setup()
@@ -42,7 +38,7 @@ def _get_published_date(self, year):
def items_list(self):
return self.df.itertuples()
- def process_entry(self, row):
+ def process_entry(self, row: pd.Series):
"""
For each row in the dataframe, create a new entry with the following fields: url, source,
converted_with, source_type, venue, newsletter_category, highlight, newsletter_number,
diff --git a/align_data/sources/arbital/arbital.py b/align_data/sources/arbital/arbital.py
index b08393c4..aa7b9633 100644
--- a/align_data/sources/arbital/arbital.py
+++ b/align_data/sources/arbital/arbital.py
@@ -5,11 +5,9 @@
from typing import List, Tuple, Iterator, Dict, Union, Any, TypedDict
import requests
-from datetime import datetime, timezone
from dateutil.parser import parse
from align_data.common.alignment_dataset import AlignmentDataset
-from dataclasses import dataclass
logger = logging.getLogger(__name__)
diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py
index 7db94a7b..4d497bb1 100644
--- a/align_data/sources/articles/articles.py
+++ b/align_data/sources/articles/articles.py
@@ -1,10 +1,13 @@
import io
import logging
+from typing import Dict, Set
from tqdm import tqdm
import gspread
+from gspread.worksheet import Worksheet
from align_data.sources.articles.google_cloud import (
+ SheetRow,
iterate_rows,
get_spreadsheet,
get_sheet,
@@ -18,10 +21,8 @@
from align_data.sources.articles.updater import ReplacerDataset
from align_data.settings import PDFS_FOLDER_ID
-
logger = logging.getLogger(__name__)
-
# Careful changing these - the sheets assume this ordering
REQUIRED_FIELDS = ["url", "source_url", "title", "source_type", "date_published"]
OPTIONAL_FIELDS = ["authors", "summary"]
@@ -45,11 +46,10 @@ def save_pdf(filename, link):
parent_id=PDFS_FOLDER_ID,
)
-
@with_retry(times=3, exceptions=gspread.exceptions.APIError)
-def process_row(row, sheets):
+def process_row(row: SheetRow, sheets: Dict[str, Worksheet]):
"""Check the given `row` and fetch its metadata + optional extra stuff."""
- logger.info('Checking "%s"', row["title"])
+ logger.info('Checking "%s" at "%s', row["title"], row["url"])
missing = [field for field in REQUIRED_FIELDS if not row.get(field)]
if missing:
@@ -60,10 +60,13 @@ def process_row(row, sheets):
source_url = row.get("source_url")
contents = item_metadata(source_url)
- if not contents or "error" in contents:
- error = (contents and contents.get("error")) or "text could not be fetched"
- logger.error(error)
- row.set_status(error)
+ if not contents:
+ logger.error("text could not be fetched")
+ row.set_status("text could not be fetched")
+ return
+ elif "error" in contents:
+ logger.error(contents["error"])
+ row.set_status(contents["error"])
return
data_source = contents.get("source_type")
@@ -83,7 +86,7 @@ def process_row(row, sheets):
row.set_status(OK)
-def process_spreadsheets(source_sheet, output_sheets):
+def process_spreadsheets(source_sheet: Worksheet, output_sheets: Dict[str, Worksheet]) -> None:
"""Go through all entries in `source_sheet` and update the appropriate metadata in `output_sheets`.
`output_sheets` should be a dict with a key for each possible data type, e.g. html, pdf etc.
@@ -92,43 +95,49 @@ def process_spreadsheets(source_sheet, output_sheets):
:param Dict[str, Worksheet] output_sheets: a dict of per data type worksheets to be updated
"""
logger.info("fetching seen urls")
- seen = {
+ seen: Set[str] = {
url
- for sheet in output_sheets.values()
- for record in sheet.get_all_records()
+ for output_sheet in output_sheets.values()
+ for record in output_sheet.get_all_records()
for url in [record.get("url"), record.get("source_url")]
if url
- }
+ }
+ # TODO: This requires our output_sheet to already have the headers for
+ # the different sheets. otherwise we raise an error, but we could have it be added
+ # automatically instead
+
for row in tqdm(iterate_rows(source_sheet)):
- title = row.get("title")
if not row.get("source_url"):
row["source_url"] = row["url"]
+
if row.get("source_url") in seen:
- logger.info(f'skipping "{title}", as it has already been seen')
- elif row.get("status"):
- logger.info(
- f'skipping "{title}", as it has a status set - remove it for this row to be processed'
- )
+ logger.info(f'skipping "{row.get("title")}", as it has already been seen')
+ elif row.get('status'):
+ logger.info(f'skipping "{row.get("title")}", as it has a status set - remove it for this row to be processed')
else:
process_row(row, output_sheets)
-def update_new_items(source_spreadsheet, source_sheet, output_spreadsheet):
+def update_new_items(source_spreadsheet_id: str, source_sheet_name: str, output_spreadsheet_id: str) -> None:
"""Go through all unprocessed items from the source worksheet, updating the appropriate metadata in the output one."""
- source_sheet = get_sheet(source_spreadsheet, source_sheet)
- sheets = {sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets()}
- return process_spreadsheets(source_sheet, sheets)
+ source_sheet = get_sheet(source_spreadsheet_id, source_sheet_name)
+ output_sheets = {
+ sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet_id).worksheets()
+ }
+ process_spreadsheets(source_sheet, output_sheets)
-def check_new_articles(source_spreadsheet, source_sheet):
+def check_new_articles(source_spreadsheet_id: str, source_sheet_name: str):
"""Goes through the special indices looking for unseen articles."""
- source_sheet = get_sheet(source_spreadsheet, source_sheet)
- current = {row.get("title"): row for row in iterate_rows(source_sheet)}
+ source_sheet = get_sheet(source_spreadsheet_id, source_sheet_name)
+ current: Dict[str, SheetRow] = {row.get("title"): row for row in iterate_rows(source_sheet)}
+ logger.info('Found %s articles in the sheet', len(current))
+
seen_urls = {
url
- for item in current.values()
- for key in ("url", "source_url")
- if (url := item.get(key)) is not None
+ for row in current.values()
+ for url_key in ("url", "source_url")
+ if (url := row.get(url_key)) is not None
}
indices_items = fetch_all()
diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py
index cbf7f9d9..ec372126 100644
--- a/align_data/sources/articles/datasets.py
+++ b/align_data/sources/articles/datasets.py
@@ -2,13 +2,14 @@
import os
from dataclasses import dataclass
from pathlib import Path
-from typing import Dict, Iterable
+from typing import Dict, Tuple, Iterable
+from urllib.parse import urlparse
import pandas as pd
from gdown.download import download
from markdownify import markdownify
from pypandoc import convert_file
-from sqlalchemy import select
+from sqlalchemy import select, Select
from align_data.common.alignment_dataset import AlignmentDataset
from align_data.db.models import Article
@@ -33,7 +34,7 @@ class SpreadsheetDataset(AlignmentDataset):
spreadsheet_id: str
sheet_id: str
done_key = "url"
- source_filetype = None
+ source_filetype = None # type: str
batch_size = 1
@staticmethod
@@ -51,7 +52,11 @@ def items_list(self) -> Iterable[tuple]:
url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}"
logger.info(f"Fetching {url}")
df = pd.read_csv(url)
- return (item for item in df.itertuples() if self.get_item_key(item))
+ return (
+ item
+ for item in df.itertuples()
+ if self.get_item_key(item) is not None
+ )
@staticmethod
def _get_text(item):
@@ -90,7 +95,7 @@ def process_entry(self, item: tuple):
class SpecialDocs(SpreadsheetDataset):
@property
- def _query_items(self):
+ def _query_items(self) -> Select[Tuple[Article]]:
special_docs_types = ["pdf", "html", "xml", "markdown", "docx"]
return select(Article).where(Article.source.in_(special_docs_types))
@@ -146,7 +151,6 @@ def process_entry(self, item):
return self.make_data_entry(contents)
-
class PDFArticles(SpreadsheetDataset):
source_filetype = "pdf"
COOLDOWN = 1
diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py
index ca310235..9ff56a2c 100644
--- a/align_data/sources/articles/google_cloud.py
+++ b/align_data/sources/articles/google_cloud.py
@@ -2,89 +2,99 @@
import time
from collections import UserDict
from pathlib import Path
-from typing import Dict, Optional
-import regex as re
+from typing import Dict, Any, Iterator, Union, List, Set
+import re
+import requests
import gdown
import grobid_tei_xml
import gspread
+from gspread.worksheet import Worksheet
+from gspread.spreadsheet import Spreadsheet
from bs4 import BeautifulSoup
from google.oauth2.service_account import Credentials
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseUpload
from markdownify import MarkdownConverter
+
from align_data.sources.articles.html import fetch, fetch_element
from align_data.sources.articles.pdf import fetch_pdf
logger = logging.getLogger(__name__)
-
SCOPES = [
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
]
-
OK = "ok"
-OUTPUT_SPREADSHEET_ID = "1bg-6vL-I82CBRkxvWQs1-Ao0nTvHyfn4yns5MdlbCmY"
-sheet_name = "Sheet1"
+OUTPUT_SPREADSHEET_ID = "1bg-6vL-I82CBRkxvWQs1-Ao0nTvHyfn4yns5MdlbCmY" # TODO: remove this
+sheet_name = "Sheet1" # TODO: remove this
-def get_credentials(credentials_file="credentials.json"):
+def get_credentials(credentials_file: Union[Path, str] = "credentials.json") -> Credentials:
return Credentials.from_service_account_file(credentials_file, scopes=SCOPES)
-def get_spreadsheet(spreadsheet_id, credentials=None):
+def get_spreadsheet(spreadsheet_id: str, credentials: Credentials = None) -> Spreadsheet:
client = gspread.authorize(credentials or get_credentials())
return client.open_by_key(spreadsheet_id)
-def get_sheet(spreadsheet_id, sheet_name, credentials=None):
+def get_sheet(spreadsheet_id: str, sheet_name: str, credentials: Credentials = None) -> Worksheet:
spreadsheet = get_spreadsheet(spreadsheet_id, credentials)
return spreadsheet.worksheet(title=sheet_name)
-class Row(UserDict):
- sheet = None
+class SheetRow(UserDict):
+ """A row in a Google Sheet."""
+ sheet = None # type: Worksheet
+ columns = None # type: List[str | None]
@classmethod
- def set_sheet(cls, sheet):
+ def set_sheet(cls, sheet: Worksheet):
cls.sheet = sheet
- cls.columns = sheet.row_values(1)
+ cols = sheet.row_values(1)
+ # if there is no first column, we raise an error
+ if not isinstance(cols, list) or not cols:
+ raise ValueError(f"Sheet {sheet.title} has no header row")
+
+ cls.columns = cols
- def __init__(self, row_id, data):
+ def __init__(self, row_id: int, data: Dict[str, Any]):
self.row_id = row_id
super().__init__(data)
- def update_value(self, col, value):
+ def update_value(self, col: str, value: str):
self.sheet.update_cell(self.row_id, self.columns.index(col) + 1, value)
- def update_colour(self, col, colour):
+ def update_colour(self, col: str, colour: Dict[str, float]):
col_letter = chr(ord("A") + self.columns.index(col))
self.sheet.format(f"{col_letter}{self.row_id}", {"backgroundColor": colour})
- def set_status(self, status, status_col="status"):
+ def set_status(self, status: str, status_col: str = "status"):
if self.get(status_col) == status:
# Don't update anything if the status is the same - this saves on gdocs calls
return
if status == OK:
- colour = {"red": 0, "green": 1, "blue": 0}
- elif status == "":
- colour = {"red": 1, "green": 1, "blue": 1}
+ colour = {"red": 0.0, "green": 1.0, "blue": 0.0}
+ elif status == "": # TODO: this should never be reached
+ colour = {"red": 1.0, "green": 1.0, "blue": 1.0}
else:
- colour = {"red": 1, "green": 0, "blue": 0}
+ colour = {"red": 1.0, "green": 0.0, "blue": 0.0}
self.update_value(status_col, status)
self.update_colour(status_col, colour)
-def iterate_rows(sheet):
+def iterate_rows(sheet: Worksheet) -> Iterator[SheetRow]:
"""Iterate over all the rows of the given `sheet`."""
- Row.set_sheet(sheet)
+ SheetRow.set_sheet(sheet)
- for i, row in enumerate(sheet.get_all_records(), 2):
- yield Row(i, row)
+ # we start the enumeration at 2 to avoid the header row
+ for row_id, row_data in enumerate(sheet.get_all_records(), 2):
+ yield SheetRow(row_id, row_data)
def upload_file(filename, bytes_contents, mimetype, parent_id=None):
@@ -131,14 +141,14 @@ def retrier(*args, **kwargs):
return wrapper
-def fetch_file(file_id):
+def fetch_file(file_id: str):
data_path = Path("data/raw/")
data_path.mkdir(parents=True, exist_ok=True)
file_name = data_path / file_id
return gdown.download(id=file_id, output=str(file_name), quiet=False)
-def fetch_markdown(file_id):
+def fetch_markdown(file_id: str) -> Dict[str, str]:
try:
file_name = fetch_file(file_id)
return {
@@ -149,9 +159,11 @@ def fetch_markdown(file_id):
return {"error": str(e)}
-def parse_grobid(contents):
+def parse_grobid(contents: str | bytes) -> Dict[str, Any]:
+ if isinstance(contents, bytes):
+ contents = contents.decode('utf-8')
doc_dict = grobid_tei_xml.parse_document_xml(contents).to_dict()
- authors = [xx["full_name"].strip(" !") for xx in doc_dict.get("header", {}).get("authors", [])]
+ authors: List[str] = [author["full_name"].strip(" !") for author in doc_dict.get("header", {}).get("authors", [])]
if not doc_dict.get("body"):
return {
@@ -168,13 +180,13 @@ def parse_grobid(contents):
}
-def get_content_type(res):
+def get_content_type(res: requests.Response) -> Set[str]:
header = res.headers.get("Content-Type") or ""
parts = [c_type.strip().lower() for c_type in header.split(";")]
return set(filter(None, parts))
-def extract_gdrive_contents(link):
+def extract_gdrive_contents(link: str) -> Dict[str, Any]:
file_id = link.split("/")[-2]
url = f"https://drive.google.com/uc?id={file_id}"
res = fetch(url, "head")
@@ -185,9 +197,9 @@ def extract_gdrive_contents(link):
logger.error("Could not fetch the file at %s - are you sure that link is correct?", link)
return {"error": "Could not read file from google drive"}
- result = {
- "source_url": link,
- "downloaded_from": "google drive",
+ result: Dict[str, Any] = {
+ 'source_url': link,
+ 'downloaded_from': 'google drive',
}
content_type = get_content_type(res)
@@ -203,7 +215,16 @@ def extract_gdrive_contents(link):
res = fetch(url)
if "Google Drive - Virus scan warning" in res.text:
soup = BeautifulSoup(res.content, "html.parser")
- res = fetch(soup.select_one("form").get("action"))
+
+ form_tag = soup.select_one('form')
+ if form_tag is None:
+ return {**result, 'error': 'Virus scan warning - no form tag'}
+
+ form_action_url = form_tag.get('action')
+ if not isinstance(form_action_url, str):
+ return {**result, 'error': 'Virus scan warning - no form action url'}
+
+ res = fetch(form_action_url)
content_type = get_content_type(res)
if content_type & {"text/xml"}:
@@ -224,7 +245,7 @@ def extract_gdrive_contents(link):
return result
-def google_doc(url: str) -> Dict:
+def google_doc(url: str) -> Dict[str, Any]:
"""Fetch the contents of the given gdoc url as markdown."""
res = re.search(r"https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/", url)
if not res:
diff --git a/align_data/sources/articles/html.py b/align_data/sources/articles/html.py
index d3c2490c..fb280c37 100644
--- a/align_data/sources/articles/html.py
+++ b/align_data/sources/articles/html.py
@@ -1,6 +1,6 @@
import time
import logging
-from typing import Union
+from typing import Optional, Dict, Literal, Optional, Any, List
import requests
from bs4 import BeautifulSoup, Tag
@@ -8,7 +8,6 @@
logger = logging.getLogger(__name__)
-
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0",
}
@@ -33,7 +32,11 @@ def retrier(*args, **kwargs):
return wrapper
-def fetch(url, method="get", headers=DEFAULT_HEADERS):
+def fetch(
+ url: str,
+ method: Literal["get", "post", "put", "delete", "patch", "options", "head"] = "get",
+ headers: Dict[str, str] = DEFAULT_HEADERS
+) -> requests.Response:
"""Fetch the given `url`.
This function is to have a single place to manage headers etc.
@@ -41,7 +44,7 @@ def fetch(url, method="get", headers=DEFAULT_HEADERS):
return getattr(requests, method)(url, allow_redirects=True, headers=headers)
-def fetch_element(url: str, selector: str, headers=DEFAULT_HEADERS) -> Union[Tag, None]:
+def fetch_element(url: str, selector: str, headers: Dict[str, str] = DEFAULT_HEADERS) -> Tag | None:
"""Fetch the first HTML element that matches the given CSS `selector` on the page found at `url`."""
try:
resp = fetch(url, headers=headers)
@@ -53,15 +56,16 @@ def fetch_element(url: str, selector: str, headers=DEFAULT_HEADERS) -> Union[Tag
return soup.select_one(selector)
-def element_extractor(selector, remove=[]):
+def element_extractor(selector: str, remove: Optional[List[str]] = None):
"""Returns a function that will extract the first element that matches the given CSS selector.
:params str selector: a CSS selector to run on the HTML of the page provided as the parameter of the function
:param List[str] remove: An optional list of selectors to be removed from the resulting HTML. Useful for removing footers etc.
:returns: A function that expects to get an URL, and which will then return the contents of the selected HTML element as markdown.
"""
+ remove = remove or []
- def getter(url):
+ def getter(url: str) -> Dict[str, Any]:
elem = fetch_element(url, selector)
if not elem:
return {}
diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py
index 6deb02d8..1a61e0eb 100644
--- a/align_data/sources/articles/indices.py
+++ b/align_data/sources/articles/indices.py
@@ -1,6 +1,7 @@
import logging
import re
from collections import defaultdict
+from typing import Callable
from dateutil.parser import ParserError, parse
from markdownify import MarkdownConverter
@@ -19,16 +20,22 @@ def get_text(tag, selector: str) -> str:
return ""
-def indice_fetcher(url, main_selector, item_selector, formatter):
+def indice_fetcher(url: str, main_selector: str, item_selector: str, formatter: Callable):
def fetcher():
if contents := fetch_element(url, main_selector):
return list(filter(None, map(formatter, contents.select(item_selector))))
return []
-
+ fetcher.__name__ = formatter.__name__.replace("format_", "") + '_fetcher'
+ # formatter called "format_anthropic" -> fetcher called "anthropic_fetcher"
+ #TODO: Make this more explicit
return fetcher
def reading_what_we_can_items():
+ # We fetch the books.js page of readingwhatwecan.
+ # It has 4 sections: first_entry, ml, ais, and scifi,
+ # which contain a dozen items (books, stories, papers) each.
+
res = fetch("https://readingwhatwecan.com/books.js")
items = {
item
@@ -240,8 +247,11 @@ def fetch_all():
articles = defaultdict(dict)
for func in tqdm(fetchers):
+ logger.info(f"Processing function: {func.__name__}")
for item in func():
- articles[item["title"]].update(item)
+ logger.info(f"Processing item: {item}")
+ articles[item['title']].update(item)
+ logger.info(f"Found {len(articles)} articles")
return articles
diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py
index 4bd493be..d5fe2578 100644
--- a/align_data/sources/articles/parsers.py
+++ b/align_data/sources/articles/parsers.py
@@ -1,6 +1,6 @@
import logging
from urllib.parse import urlparse, urljoin
-from typing import Dict
+from typing import Dict, Optional, Callable, Any
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
@@ -10,9 +10,10 @@
from align_data.sources.arxiv_papers import fetch_arxiv
from align_data.common.html_dataset import HTMLDataset
-
logger = logging.getLogger(__name__)
+ParserFunc = Callable[[str], Dict[str, Any]]
+
def get_pdf_from_page(*link_selectors: str):
"""Get a function that receives an `url` to a page containing a pdf link and returns the pdf's contents as text.
@@ -21,32 +22,40 @@ def get_pdf_from_page(*link_selectors: str):
* if there are more selectors left, fetch the contents at the extracted link and continue
* otherwise return the pdf contents at the last URL
- :param List[str] link_selectors: CSS selector used to find the final download link
+ :param str *link_selectors: CSS selectors used to find the final download link
:returns: the contents of the pdf file as a string
"""
+ def getter(url: str) -> Dict[str, Any]:
+ current_url: str = url
- def getter(url: str):
- link: str = url
for selector in link_selectors:
- elem = fetch_element(link, selector)
+ elem = fetch_element(current_url, selector)
if not elem:
- return {"error": f"Could not find pdf download link for {link} using '{selector}'"}
+ return {"error": f"Could not find pdf download link for {current_url} using '{selector}'"}
+
+ # Extracting href, considering it can be a string or a list of strings
+ href = elem.get("href")
+ if isinstance(href, list):
+ href = href[0] if href else None
+
+ if not href:
+ return {"error": f"Could not extract href for {current_url} using '{selector}'"}
- link = elem.get("href")
- if not link.startswith("http") or not link.startswith("//"):
- link = urljoin(url, link)
+ # Making sure the link is absolute
+ if not href.startswith(("http", "//")):
+ href = urljoin(url, href)
+ current_url = href
# Some pages keep link to google drive previews of pdf files, which need to be
# mangled to get the URL of the actual pdf file
- if "drive.google.com" in link and "/view" in link:
- return extract_gdrive_contents(link)
+ if "drive.google.com" in current_url and "/view" in current_url:
+ return extract_gdrive_contents(current_url)
- if parse_domain(link) == "arxiv.org":
- return fetch_arxiv(link)
- if pdf := fetch_pdf(link):
+ if parse_domain(current_url) == "arxiv.org":
+ return fetch_arxiv(current_url)
+ if pdf := fetch_pdf(current_url):
return pdf
- return {"error": f"Could not fetch pdf from {link}"}
-
+ return {"error": f"Could not fetch pdf from {current_url}"}
return getter
@@ -64,6 +73,7 @@ class MediumParser(HTMLDataset):
It is possible that there is additional variation in the layout that hasn't been represented
in the blogs tested so far. In that case, additional fixes to this code may be needed.
+ #TODO: investigate this
"""
source_type = "MediumParser(name='html', url='')"
@@ -73,14 +83,13 @@ def _get_published_date(self, contents):
possible_date_elements = contents.select("article div:first-child span")
return self._find_date(possible_date_elements)
- def __call__(self, url):
+ def __call__(self, url: str) -> Dict[str, Any]:
return self.get_contents(url)
-def error(error_msg):
+def error(error_msg: str):
"""Returns a url handler function that just logs the provided `error` string."""
-
- def func(url):
+ def func(url: str) -> Dict[str, Any]:
if error_msg:
logger.error(error_msg)
return {"error": error_msg, "source_url": url}
@@ -88,10 +97,13 @@ def func(url):
return func
-def multistrategy(*funcs):
- """Merges multiple getter functions, returning the result of the first function call to succeed."""
+def multistrategy(*funcs: ParserFunc):
+ """
+ Merges multiple getter functions, returning the result
+ of the first function call to succeed.
+ """
- def getter(url):
+ def getter(url: str) -> Dict[str, Any]:
for func in funcs:
res = func(url)
if res and "error" not in res:
@@ -100,7 +112,7 @@ def getter(url):
return getter
-UNIMPLEMENTED_PARSERS = {
+UNIMPLEMENTED_PARSERS: Dict[str, ParserFunc] = {
# Unhandled items that will be caught later. Though it would be good for them also to be done properly
"oxford.universitypressscholarship.com": error(""),
# Paywalled journal
@@ -109,6 +121,8 @@ def getter(url):
),
"link.springer.com": error("This article looks paywalled"),
"www.dl.begellhouse.com": error("This article is paywalled"),
+ "dl.begellhouse.com": error("Begell house is not yet handled"),
+
# To be implemented
"goodreads.com": error("Ebooks are not yet handled"),
"judiciary.senate.gov": error(""),
@@ -120,10 +134,23 @@ def getter(url):
"Researchgate makes it hard to auto download pdf - please provide a DOI or a different url to the contents"
),
"repository.cam.ac.uk": error(""),
+
+ "deliverypdf.ssrn.com": error("SSRN is not yet handled"),
+ "doi.wiley.com": error("Wiley is not yet handled"),
+ "onlinelibrary.wiley.com": error("Wiley is not yet handled"),
+ "globalprioritiesproject.org": error("Global priorities project is not yet handled"),
+ "ieeexplore.ieee.org": error("IEEE is not yet handled"),
+ "pdcnet.org": error("pdcnet.org is not yet handled"),
+ "sciencemag.org": error("sciencemag.org is not yet handled"),
+ "iopscience.iop.org": error("iopscience.iop.org is not yet handled"),
+ "journals.aom.org": error("journals.aom.org is not yet handled"),
+ "cambridge.org": error("cambridge.org is not yet handled"),
+ "transformer-circuits.pub": error("not handled yet - same codebase as distill"),
+
}
-HTML_PARSERS = {
+HTML_PARSERS: Dict[str, ParserFunc] = {
"academic.oup.com": element_extractor("#ContentTab"),
"ai.googleblog.com": element_extractor("div.post-body.entry-content"),
"arxiv-vanity.com": parse_vanity,
@@ -218,7 +245,6 @@ def getter(url):
".Publication",
],
),
- "transformer-circuits.pub": error("not handled yet - same codebase as distill"),
"vox.com": element_extractor("did.c-entry-content", remove=["c-article-footer"]),
"weforum.org": element_extractor("div.wef-0"),
"www6.inrae.fr": element_extractor("div.ArticleContent"),
@@ -226,7 +252,7 @@ def getter(url):
"yoshuabengio.org": element_extractor("div.post-content"),
}
-PDF_PARSERS = {
+PDF_PARSERS: Dict[str, ParserFunc] = {
# Domain sepecific handlers
"apcz.umk.pl": get_pdf_from_page(".galleys_links a.pdf", "a.download"),
"arxiv.org": fetch_arxiv,
@@ -264,15 +290,21 @@ def getter(url):
def parse_domain(url: str) -> str:
- return url and urlparse(url).netloc.lstrip("www.")
+ net_loc = urlparse(url).netloc
+ return net_loc[4:] if net_loc.startswith("www.") else net_loc
-def item_metadata(url: str) -> Dict[str, any]:
+def item_metadata(url: str) -> Dict[str, Any]:
+ if not url:
+ return {"error": "No url was given to item_metadata"}
domain = parse_domain(url)
try:
res = fetch(url, "head")
except (MissingSchema, InvalidSchema, ConnectionError) as e:
return {"error": str(e)}
+
+ if not res.headers.get('Content-Type'):
+ return {'error': 'No content type found'}
content_type = {item.strip() for item in res.headers.get("Content-Type", "").split(";")}
@@ -286,15 +318,17 @@ def item_metadata(url: str) -> Dict[str, any]:
return res
if parser := PDF_PARSERS.get(domain):
- if res := parser(url):
+ if content := parser(url):
# A pdf was found - use it, though it might not be useable
- return res
+ return content
if parser := UNIMPLEMENTED_PARSERS.get(domain):
return parser(url)
- if domain not in (HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()):
- return {"error": "No domain handler defined"}
+ if domain not in (
+ HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()
+ ):
+ return {"error": f"No domain handler defined for {domain}"}
return {"error": "could not parse url"}
elif content_type & {"application/octet-stream", "application/pdf"}:
if domain == "arxiv.org":
diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py
index d5af5bf4..d54e6dda 100644
--- a/align_data/sources/articles/pdf.py
+++ b/align_data/sources/articles/pdf.py
@@ -1,6 +1,8 @@
import io
import logging
+from typing import Dict, Any, List
from urllib.parse import urlparse
+from pathlib import Path
from typing import Dict, Any
from dateutil.parser import ParserError, parse
@@ -8,14 +10,15 @@
from PyPDF2 import PdfReader
from PyPDF2.errors import PdfReadError
from markdownify import MarkdownConverter
+from bs4.element import Tag
from align_data.sources.articles.html import fetch, fetch_element, with_retry
logger = logging.getLogger(__name__)
-def sci_hub_pdf(identifier):
- """Search Sci-hub for a link to a pdf of the article with the given identifier.
+def sci_hub_pdf(identifier: str) -> str | None:
+ """Search Sci-hub for a link to a pdf of the article with the given identifier (doi).
This will only get pdf that are directly served by Sci-hub. Sometimes it will redirect to a
large file containing multiple articles, e.g. a whole journal or book, in which case this function
@@ -24,7 +27,16 @@ def sci_hub_pdf(identifier):
elem = fetch_element(f"https://sci-hub.st/{identifier}", "embed")
if not elem:
return None
- src = elem.get("src").strip()
+
+ src = elem.get("src")
+
+ if isinstance(src, list):
+ src = src[0] if src else None
+
+ if src is None:
+ return None
+
+ src = src.strip()
if src.startswith("//"):
src = "https:" + src
elif src.startswith("/"):
@@ -32,7 +44,7 @@ def sci_hub_pdf(identifier):
return src
-def read_pdf(filename):
+def read_pdf(filename: Path) -> str | None:
try:
pdf_reader = PdfReader(filename)
return "\n".join(page.extract_text() for page in pdf_reader.pages)
@@ -42,7 +54,7 @@ def read_pdf(filename):
@with_retry(times=3)
-def fetch_pdf(link):
+def fetch_pdf(link: str) -> Dict[str, str]:
"""Return the contents of the pdf file at `link` as a markdown string.
:param str link: the URL to check for a pdf file
@@ -53,6 +65,7 @@ def fetch_pdf(link):
"Could not fetch the pdf file at %s - are you sure that link is correct?",
link,
)
+ return {"error": "Could not read pdf file"}
content_type = {c_type.strip().lower() for c_type in res.headers.get("Content-Type").split(";")}
if not content_type & {"application/octet-stream", "application/pdf"}:
@@ -68,8 +81,8 @@ def fetch_pdf(link):
"source_type": "pdf",
}
except (TypeError, PdfReadError) as e:
- logger.error("Could not read PDF file: %s", e)
- return {"error": str(e)}
+ logger.error('Could not read PDF file: %s', e)
+ error = str(e)
filenames = [
i.strip().split("=")[1]
@@ -87,20 +100,22 @@ def fetch_pdf(link):
return {"error": error}
-def get_arxiv_link(doi):
+def get_arxiv_link(doi: str) -> str | None:
"""Find the URL to the pdf of the given arXiv DOI."""
res = requests.get(f"https://doi.org/api/handles/{doi}")
if res.status_code != 200:
return None
- vals = [val for val in response.json().get("values") if val.get("type", "").upper() == "URL"]
+ vals = [
+ val
+ for val in res.json().get("values")
+ if val.get("type", "").upper() == "URL"
+ ]
- if not vals:
- return None
- return vals[0]["data"]["value"].replace("/abs/", "/pdf/") + ".pdf"
+ return vals and vals[0]["data"]["value"].replace("/abs/", "/pdf/") + ".pdf"
-def get_doi(doi):
+def get_doi(doi: str) -> Dict[str, Any]:
"""Get the article with the given `doi`.
This will look for it in sci-hub and arxiv (if applicable), as those are likely the most
@@ -110,49 +125,44 @@ def get_doi(doi):
link = get_arxiv_link(doi)
pdf = link and fetch_pdf(link)
if pdf and "text" in pdf:
- pdf["downloaded_from"] = "arxiv"
- return pdf
+ return {**pdf, "downloaded_from": "arxiv"}
if link := sci_hub_pdf(doi):
if pdf := fetch_pdf(link):
- pdf["downloaded_from"] = "scihub"
- return pdf
+ return {**pdf, "downloaded_from": "scihub"}
return {"error": "Could not find pdf of article by DOI"}
-def doi_getter(url):
+def doi_getter(url: str) -> Dict[str, Any]:
"""Extract the DOI from the given `url` and fetch the contents of its article."""
return get_doi(urlparse(url).path.lstrip("/"))
-def parse_vanity(url) -> Dict[str, Any]:
+def parse_vanity(url: str) -> Dict[str, Any]:
contents = fetch_element(url, "article")
if not contents:
return {"error": "Could not fetch from arxiv vanity"}
- if title := contents.select_one("h1.ltx_title"):
- title = title.text
+ selected_title = contents.select_one("h1.ltx_title")
+ title = selected_title.text if selected_title else None
- def get_first_child(item):
- child = next(item.children)
+ def get_first_child(item: Tag) -> List[str]:
+ child = next(iter(item.children), None)
if not child:
return []
-
- if not isinstance(child, str):
- child = child.text
- return child.split(",")
+ return child.text.split(",")
authors = [
- a.strip()
+ author.strip()
for item in contents.select("div.ltx_authors .ltx_personname")
- for a in get_first_child(item)
+ for author in get_first_child(item)
]
- if date_published := contents.select_one("div.ltx_dates"):
- try:
- date_published = parse(date_published.text.strip("()"))
- except ParserError:
- "If the date couldn't be parsed, hope that later phases will be more successful"
+ selected_date = contents.select_one("div.ltx_dates")
+ try:
+ date_published = parse(selected_date.text.strip("()")) if selected_date else None
+ except ParserError:
+ date_published = None
text = "\n\n".join(
MarkdownConverter().convert_soup(elem).strip()
diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py
index c2a1d29e..d9dccaef 100644
--- a/align_data/sources/articles/updater.py
+++ b/align_data/sources/articles/updater.py
@@ -1,21 +1,27 @@
import logging
from collections import namedtuple
from dataclasses import dataclass
+from typing import List, Optional, Union, Tuple, NamedTuple
+from pathlib import Path
import pandas as pd
-from sqlalchemy import select, or_
+from sqlalchemy import select, or_, Select
from align_data.common.alignment_dataset import AlignmentDataset
from align_data.db.models import Article
from align_data.sources.articles.parsers import item_metadata
+from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
-Item = namedtuple("Item", ["updates", "article"])
+
+class Item(NamedTuple):
+ updates: NamedTuple
+ article: Article
@dataclass
class ReplacerDataset(AlignmentDataset):
- csv_path: str
+ csv_path: str | Path
delimiter: str
done_key = "url"
@@ -30,25 +36,31 @@ def maybe(item, key):
return val
@property
- def items_list(self):
+ def items_list(self) -> List[Item]:
df = pd.read_csv(self.csv_path, delimiter=self.delimiter)
self.csv_items = [
item
for item in df.itertuples()
if self.maybe(item, "id") or self.maybe(item, "hash_id")
]
- by_id = {i.id: i for i in self.csv_items if self.maybe(i, "id")}
- by_hash_id = {i.hash_id: i for i in self.csv_items if self.maybe(i, "hash_id")}
-
- return [Item(by_id.get(a._id) or by_hash_id.get(a.id), a) for a in self.read_entries()]
+ by_id = {id: item for item in self.csv_items if (id := self.maybe(item, 'id'))}
+ by_hash_id = {hash_id: item for item in self.csv_items if (hash_id := self.maybe(item, 'hash_id'))}
+
+ return [
+ Item(
+ updates=by_id.get(article._id) or by_hash_id.get(article.id),
+ article=article
+ )
+ for article in self.read_entries()
+ ]
@property
- def _query_items(self):
+ def _query_items(self) -> Select[Tuple[Article]]:
ids = [i.id for i in self.csv_items if self.maybe(i, "id")]
hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, "hash_id")]
return select(Article).where(or_(Article.id.in_(hash_ids), Article._id.in_(ids)))
- def update_text(self, updates, article):
+ def update_text(self, updates: NamedTuple, article: Article):
# If the url is the same as it was before, and there isn't a source url provided, assume that the
# previous text is still valid
if article.url == self.maybe(updates, "url") and not self.maybe(updates, "source_url"):
@@ -65,10 +77,10 @@ def update_text(self, updates, article):
metadata = item_metadata(url)
# Only change the text if it could be fetched - better to have outdated values than none
if metadata.get("text"):
- article.text = metadata.get("text")
+ article.text = metadata["text"]
article.status = metadata.get("error")
- def process_entry(self, item):
+ def process_entry(self, item: Item) -> Article:
updates, article = item
for key in ["url", "title", "source", "authors", "comment", "confidence"]:
@@ -84,5 +96,5 @@ def process_entry(self, item):
return article
- def _add_batch(self, session, batch):
+ def _add_batch(self, session: Session, batch: tuple):
session.add_all(map(session.merge, batch))
diff --git a/align_data/sources/arxiv_papers.py b/align_data/sources/arxiv_papers.py
index 30002a78..45b3148b 100644
--- a/align_data/sources/arxiv_papers.py
+++ b/align_data/sources/arxiv_papers.py
@@ -3,6 +3,7 @@
from typing import Dict, Optional, Any
import arxiv
+
from align_data.sources.articles.pdf import fetch_pdf, parse_vanity
from align_data.sources.articles.html import fetch_element
from align_data.sources.utils import merge_dicts
@@ -10,7 +11,7 @@
logger = logging.getLogger(__name__)
-def get_arxiv_metadata(paper_id) -> arxiv.Result:
+def get_arxiv_metadata(paper_id: str) -> arxiv.Result | None:
"""
Get metadata from arxiv
"""
@@ -25,6 +26,7 @@ def get_arxiv_metadata(paper_id) -> arxiv.Result:
def get_id(url: str) -> str | None:
if res := re.search(r"https?://arxiv.org/(?:abs|pdf)/(.*?)(?:v\d+)?(?:/|\.pdf)?$", url):
return res.group(1)
+ return None
def canonical_url(url: str) -> str:
@@ -50,13 +52,13 @@ def get_version(id: str) -> str | None:
return res.group(1)
-def is_withdrawn(url: str):
- if elem := fetch_element(canonical_url(url), ".extra-services .full-text ul"):
- return elem.text.strip().lower() == "withdrawn"
- return None
+def is_withdrawn(url: str) -> bool:
+ if elem := fetch_element(canonical_url(url), '.extra-services .full-text ul'):
+ return elem.text.strip().lower() == 'withdrawn'
+ return False
-def add_metadata(data, paper_id):
+def add_metadata(data: Dict[str, Any], paper_id: str) -> Dict[str, Any]:
metadata = get_arxiv_metadata(paper_id)
if not metadata:
return {}
@@ -78,7 +80,7 @@ def add_metadata(data, paper_id):
)
-def fetch_arxiv(url) -> Dict:
+def fetch_arxiv(url: str) -> Dict[str, Any]:
paper_id = get_id(url)
if not paper_id:
return {"error": "Could not extract arxiv id"}
diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py
index 1245aec6..e3a25f0d 100644
--- a/align_data/sources/blogs/blogs.py
+++ b/align_data/sources/blogs/blogs.py
@@ -2,12 +2,13 @@
from urllib.parse import urljoin
import requests
-from align_data.sources.articles.parsers import item_metadata
-from align_data.common.html_dataset import HTMLDataset, RSSDataset
from bs4 import BeautifulSoup
from dateutil.parser import ParserError
from tqdm import tqdm
+from align_data.sources.articles.parsers import item_metadata
+from align_data.common.html_dataset import HTMLDataset, RSSDataset
+
logger = logging.getLogger(__name__)
@@ -77,7 +78,11 @@ def extract_authors(self, article):
authors = []
if authors_div:
authors = [
- i.split("(")[0].strip() for i in authors_div.select_one("p").children if not i.name
+ i.split("(")[0].strip()
+ for i in authors_div.select_one("p").children
+ if not i.name and i.strip()
+ # i.name is non-empty if it's a tag, ie
has name br
+ # but "OpenAI Research" has no name
]
return authors or ["OpenAI Research"]
diff --git a/align_data/sources/blogs/gwern_blog.py b/align_data/sources/blogs/gwern_blog.py
index 1d573a8e..0375bb8c 100644
--- a/align_data/sources/blogs/gwern_blog.py
+++ b/align_data/sources/blogs/gwern_blog.py
@@ -1,6 +1,7 @@
-import requests
-import logging
from dataclasses import dataclass
+import logging
+
+import requests
from align_data.common.html_dataset import HTMLDataset
@@ -14,7 +15,6 @@ class GwernBlog(HTMLDataset):
"""
COOLDOWN: int = 1
- done_key = "url"
def get_item_key(self, item: str) -> str:
return item
diff --git a/align_data/sources/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py
index cd409d98..b0c9e9f1 100644
--- a/align_data/sources/blogs/wp_blog.py
+++ b/align_data/sources/blogs/wp_blog.py
@@ -1,17 +1,16 @@
from dataclasses import dataclass
import logging
+
import feedparser
from tqdm import tqdm
from align_data.common.html_dataset import RSSDataset
-
logger = logging.getLogger(__name__)
@dataclass
class WordpressBlog(RSSDataset):
- summary_key = "summary"
@property
def feed_url(self):
@@ -28,7 +27,7 @@ def items_list(self):
with tqdm(desc=f"Loading {self.name} pages") as pbar:
while True:
paged_url = f"{self.feed_url}?paged={page_number}"
- logging.info(f"Fetching {paged_url}")
+ logger.info(f"Fetching {paged_url}")
feed = feedparser.parse(paged_url)
title = feed.get("feed", {}).get("title")
diff --git a/align_data/sources/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py
index 65b52502..3756524a 100644
--- a/align_data/sources/ebooks/agentmodels.py
+++ b/align_data/sources/ebooks/agentmodels.py
@@ -1,9 +1,11 @@
-from align_data.common.alignment_dataset import AlignmentDataset
from dataclasses import dataclass
-from git import Repo
import logging
from datetime import timezone
+from git import Repo
+
+from align_data.common.alignment_dataset import AlignmentDataset
+
logger = logging.getLogger(__name__)
diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py
index 8925fc22..579a4680 100644
--- a/align_data/sources/greaterwrong/greaterwrong.py
+++ b/align_data/sources/greaterwrong/greaterwrong.py
@@ -5,7 +5,6 @@
from typing import Set, Tuple
import requests
-import jsonlines
from bs4 import BeautifulSoup
from markdownify import markdownify
from sqlalchemy import select
@@ -69,7 +68,6 @@ class GreaterWrong(AlignmentDataset):
limit = 50
COOLDOWN_TIME: float = 0.5
- summary_key: str = "summary"
done_key = "url"
lazy_eval = True
source_type = 'GreaterWrong'
@@ -112,49 +110,48 @@ def _get_published_date(self, item):
return super()._get_published_date(item.get("postedAt"))
def make_query(self, after: str):
- return (
- """{
- posts(input: {
- terms: {
- excludeEvents: true
- view: "old"
- """
- f" af: {self.af}\n"
- f" limit: {self.limit}\n"
- f" karmaThreshold: {self.min_karma}\n"
- f' after: "{after}"\n'
- """ filter: "tagged"
- }
- }) {
- totalCount
- results {
- _id
- title
- slug
- pageUrl
- postedAt
- modifiedAt
- score
- extendedScore
- baseScore
- voteCount
- commentCount
- wordCount
- tags {
- name
- }
- user {
- displayName
- }
- coauthors {
- displayName
- }
- af
- htmlBody
- }
- }
- }"""
- )
+ return f'''
+ {{
+ posts(input: {{
+ terms: {{
+ excludeEvents: true
+ view: "old"
+ af: {self.af}
+ limit: {self.limit}
+ karmaThreshold: {self.min_karma}
+ after: "{after}"
+ filter: "tagged"
+ }}
+ }}) {{
+ totalCount
+ results {{
+ _id
+ title
+ slug
+ pageUrl
+ postedAt
+ modifiedAt
+ score
+ extendedScore
+ baseScore
+ voteCount
+ commentCount
+ wordCount
+ tags {{
+ name
+ }}
+ user {{
+ displayName
+ }}
+ coauthors {{
+ displayName
+ }}
+ af
+ htmlBody
+ }}
+ }}
+ }}
+ '''
def fetch_posts(self, query: str):
res = requests.post(
@@ -168,14 +165,18 @@ def fetch_posts(self, query: str):
return res.json()["data"]["posts"]
@property
- def last_date_published(self):
- try:
- prev_item = next(self.read_entries(sort_by=Article.date_published.desc()))
- if prev_item and prev_item.date_published:
- return prev_item.date_published.isoformat() + "Z"
- except StopIteration:
- pass
- return datetime(self.start_year, 1, 1).isoformat() + "Z"
+ def last_date_published(self) -> str:
+ entries = self.read_entries(sort_by=Article.date_published.desc())
+
+ # Get the first entry if exists, else return a default datetime
+ prev_item = next(entries, None)
+
+ # If there is no previous item or it doesn't have a published date, return default datetime
+ if not prev_item or not prev_item.date_published:
+ return datetime(self.start_year, 1, 1).isoformat() + 'Z'
+
+ # If the previous item has a published date, return it in isoformat
+ return prev_item.date_published.isoformat() + 'Z'
@property
def items_list(self):
diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py
index 95319820..cc40e9ac 100644
--- a/align_data/sources/stampy/stampy.py
+++ b/align_data/sources/stampy/stampy.py
@@ -2,15 +2,15 @@
import re
import logging
from dataclasses import dataclass
+
from codaio import Coda, Document
+import html
from align_data.common.alignment_dataset import AlignmentDataset
from align_data.settings import CODA_TOKEN, CODA_DOC_ID, ON_SITE_TABLE
logger = logging.getLogger(__name__)
-import html
-
@dataclass
class Stampy(AlignmentDataset):
diff --git a/align_data/sources/youtube/youtube.py b/align_data/sources/youtube/youtube.py
index 740597d0..608cd96d 100644
--- a/align_data/sources/youtube/youtube.py
+++ b/align_data/sources/youtube/youtube.py
@@ -1,8 +1,8 @@
import logging
-from dataclasses import dataclass
-from typing import List
+from dataclasses import dataclass, field
+from typing import List, Optional, Iterable
-from googleapiclient.discovery import build
+from googleapiclient.discovery import build, Resource
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api._errors import (
NoTranscriptFound,
@@ -13,7 +13,6 @@
from align_data.settings import YOUTUBE_API_KEY
from align_data.common.alignment_dataset import AlignmentDataset
-
logger = logging.getLogger(__name__)
@@ -21,16 +20,17 @@ class YouTubeDataset(AlignmentDataset):
done_key = "url"
batch_size = 1
# COOLDOWN = 2
- authors = None
- collection_ids = []
+ authors: Optional[List[str]] = None
+ collection_ids: List[str] = field(default_factory=list)
+
def setup(self):
super().setup()
if not YOUTUBE_API_KEY:
raise ValueError("No YOUTUBE_API_KEY provided!")
- self.youtube = build("youtube", "v3", developerKey=YOUTUBE_API_KEY)
+ self.youtube: Resource = build("youtube", "v3", developerKey=YOUTUBE_API_KEY)
- def next_page(self, collection_id, next_page_token):
+ def next_page(self, collection_id: str, next_page_token: list) -> dict:
return {"items": []}
@staticmethod
@@ -45,7 +45,7 @@ def _get_id(item) -> str | None:
if resource["kind"] == "youtube#video":
return resource["videoId"]
- def fetch_videos(self, collection_id):
+ def fetch_videos(self, collection_id: str) -> Iterable[dict]:
next_page_token = None
while True:
videos_response = self.next_page(collection_id, next_page_token)
@@ -74,7 +74,8 @@ def _get_contents(self, video):
video_id = self._get_id(video)
try:
transcript = (
- YouTubeTranscriptApi.list_transcripts(video_id)
+ YouTubeTranscriptApi
+ .list_transcripts(video_id)
.find_transcript(["en", "en-GB"])
.fetch()
)
@@ -139,13 +140,14 @@ def _get_published_date(self, video):
@dataclass
class YouTubePlaylistDataset(YouTubeDataset):
- playlist_ids: str
+
+ playlist_ids: List[str]
@property
def collection_ids(self):
return self.playlist_ids
- def next_page(self, collection_id, next_page_token):
+ def next_page(self, collection_id: str, next_page_token: list):
return (
self.youtube.playlistItems()
.list(
diff --git a/main.py b/main.py
index 82c30f07..8ad9f88e 100644
--- a/main.py
+++ b/main.py
@@ -1,7 +1,7 @@
-import logging
import os
from dataclasses import dataclass
from typing import List
+import logging
import fire
@@ -20,7 +20,6 @@
METADATA_SOURCE_SPREADSHEET,
)
-
logger = logging.getLogger(__name__)
diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py
index 7f911208..acdcdfcd 100644
--- a/tests/align_data/articles/test_datasets.py
+++ b/tests/align_data/articles/test_datasets.py
@@ -360,7 +360,7 @@ def test_arxiv_process_entry_retracted(mock_arxiv):