diff --git a/align_data/analysis/analyse_jsonl_data.py b/align_data/analysis/analyse_jsonl_data.py index 0aed124d..9ef49649 100644 --- a/align_data/analysis/analyse_jsonl_data.py +++ b/align_data/analysis/analyse_jsonl_data.py @@ -1,8 +1,9 @@ from datetime import datetime from pathlib import Path +from collections import defaultdict + import jsonlines -from collections import defaultdict def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"): diff --git a/align_data/analysis/count_tokens.py b/align_data/analysis/count_tokens.py index cd099c68..bc0232d3 100644 --- a/align_data/analysis/count_tokens.py +++ b/align_data/analysis/count_tokens.py @@ -1,7 +1,8 @@ +from typing import Tuple +import logging + from transformers import AutoTokenizer import jsonlines -import logging -from typing import Tuple logger = logging.getLogger(__name__) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 344ee89d..5e35d9fd 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -3,23 +3,23 @@ from itertools import islice import logging import time -from dataclasses import dataclass, KW_ONLY +from dataclasses import dataclass, field, KW_ONLY from pathlib import Path -from typing import Iterable, List, Optional, Set -from sqlalchemy import select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload +from typing import List, Optional, Set, Iterable, Tuple, Generator -import jsonlines import pytz +from sqlalchemy import select, Select, JSON +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import joinedload, Session +import jsonlines from dateutil.parser import parse, ParserError from tqdm import tqdm + from align_data.db.models import Article, Summary from align_data.db.session import make_session from align_data.settings import ARTICLE_MAIN_KEYS from align_data.sources.utils import merge_dicts - logger = logging.getLogger(__name__) @@ -28,40 +28,42 @@ class AlignmentDataset: """The base dataset class.""" name: str - """The name of the dataset""" + """The name of the dataset.""" _: KW_ONLY - files_path = Path("") - """The path where data can be found. Usually a folder""" + data_path: Path = Path(__file__).parent / "../../data/" + """The path where data can be found. Usually a folder.""" + + # Derived paths + raw_data_path: Path = field(init=False) + files_path: Path = field(init=False) + + # Internal housekeeping variables + _outputted_items: Set[str] = field(default_factory=set, init=False) + """A set of the ids of all previously processed items.""" done_key = "id" """The key of the entry to use as the id when checking if already processed.""" COOLDOWN = 0 - """An optional cool down between processing entries""" + """An optional cool down between processing entries.""" lazy_eval = False """Whether to lazy fetch items. This is nice in that it will start processing, but messes up the progress bar.""" + batch_size = 20 """The number of items to collect before flushing to the database.""" - # Internal housekeeping variables - _entry_idx = 0 - """Used internally for writing debugging info - each file write will increment it""" - _outputted_items = set() - """A set of the ids of all previously processed items""" + def __post_init__(self): + self.data_path = self.data_path.resolve() - def __str__(self) -> str: - return self.name - - def __post_init__(self, data_path=Path(__file__).parent / "../../data/"): - self.data_path = data_path self.raw_data_path = self.data_path / "raw" - - # set the default place to look for data self.files_path = self.raw_data_path / self.name + def __str__(self) -> str: + return self.name + def _add_authors(self, article: Article, authors: List[str]) -> Article: # TODO: Don't keep adding the same authors - come up with some way to reuse them article.authors = ",".join(authors) @@ -87,42 +89,42 @@ def make_data_entry(self, data, **kwargs) -> Article: article.summaries += [Summary(text=summary, source=self.name) for summary in summaries] return article - def to_jsonl(self, out_path=None, filename=None) -> Path: - if not out_path: - out_path = Path(__file__).parent / "../../data/" - - if not filename: - filename = f"{self.name}.jsonl" - filename = Path(out_path) / filename + def to_jsonl(self, out_path: Path | None = None, filename: str | None = None) -> Path: + out_path = out_path or self.data_path + filename = filename or f"{self.name}.jsonl" + filepath = out_path / filename - with jsonlines.open(filename, "w") as jsonl_writer: + with jsonlines.open(filepath, "w") as jsonl_writer: for article in self.read_entries(): jsonl_writer.write(article.to_dict()) - return filename.resolve() + return filepath.resolve() @property - def _query_items(self): + def _query_items(self) -> Select[Tuple[Article]]: return select(Article).where(Article.source == self.name) - def read_entries(self, sort_by=None): + def read_entries(self, sort_by=None) -> Iterable[Article]: """Iterate through all the saved entries.""" with make_session() as session: query = self._query_items.options(joinedload(Article.summaries)) if sort_by is not None: query = query.order_by(sort_by) - for item in session.scalars(query).unique(): - yield item + + result = session.scalars(query) + for article in result.unique(): # removes duplicates + yield article - def _add_batch(self, session, batch): + def _add_batch(self, session: Session, batch: tuple): session.add_all(batch) def add_entries(self, entries): - def commit(): + def commit() -> bool: try: session.commit() return True except IntegrityError: session.rollback() + return False with make_session() as session: items = iter(entries) @@ -183,7 +185,11 @@ def _normalize_urls(self, urls: Iterable[str]) -> Set[str]: def _load_outputted_items(self) -> Set[str]: - """Load the output file (if it exists) in order to know which items have already been output.""" + """ + Loads the outputted items from the database and returns them as a set. + + if the done_key is not an attribute of Article, it will try to load it from the meta field. + """ with make_session() as session: items = set() if hasattr(Article, self.done_key): @@ -203,23 +209,24 @@ def not_processed(self, item) -> bool: # If it get's to that level, consider batching it somehow return self._normalize_url(self.get_item_key(item)) not in self._outputted_items - def unprocessed_items(self, items=None) -> Iterable: + def unprocessed_items(self, items=None) -> list | filter: """Return a list of all items to be processed. This will automatically remove any items that have already been processed, based on the contents of the output file. """ self.setup() + items = items or self.items_list - filtered = filter(self.not_processed, items or self.items_list) + items_to_process = filter(self.not_processed, items) # greedily fetch all items if not lazy eval. This makes the progress bar look nice if not self.lazy_eval: - filtered = list(filtered) + return list(items_to_process) - return filtered + return items_to_process - def fetch_entries(self): + def fetch_entries(self) -> Generator[Article, None, None]: """Get all entries to be written to the file.""" for item in tqdm(self.unprocessed_items(), desc=f"Processing {self.name}"): entry = self.process_entry(item) @@ -242,7 +249,7 @@ def process_entry(self, entry) -> Article | None: raise NotImplementedError @staticmethod - def _format_datetime(date) -> str: + def _format_datetime(date: datetime) -> str: return date.strftime("%Y-%m-%dT%H:%M:%SZ") @staticmethod @@ -280,7 +287,7 @@ def _load_outputted_items(self) -> Set[str]: ) ) - def _add_batch(self, session, batch): + def _add_batch(self, session: Session, batch: tuple): def merge(item): if prev := self.articles.get(item.url): return session.merge(item.update(prev)) diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py index e5e4d277..a1b748e3 100644 --- a/align_data/common/html_dataset.py +++ b/align_data/common/html_dataset.py @@ -1,16 +1,18 @@ -import pytz -import regex as re import logging from datetime import datetime -from dataclasses import dataclass, field, KW_ONLY +from dataclasses import dataclass, field from urllib.parse import urljoin -from typing import List +from typing import List, Dict, Any +import re +import pytz import requests import feedparser from bs4 import BeautifulSoup +from bs4.element import ResultSet, Tag from markdownify import markdownify +from align_data.db.models import Article from align_data.common.alignment_dataset import AlignmentDataset logger = logging.getLogger(__name__) @@ -26,9 +28,6 @@ class HTMLDataset(AlignmentDataset): done_key = "url" authors: List[str] = field(default_factory=list) - _: KW_ONLY - source_key: str = None - summary_key: str = None item_selector = "article" title_selector = "article h1" @@ -39,12 +38,14 @@ class HTMLDataset(AlignmentDataset): def extract_authors(self, article): return self.authors - def get_item_key(self, item) -> str: - article_url = item.find_all("a")[0]["href"].split("?")[0] - return urljoin(self.url, article_url) + + def get_item_key(self, item: Tag) -> str: + first_href = item.find("a")["href"] + href_base, *_ = first_href.split("?") + return urljoin(self.url, href_base) @property - def items_list(self): + def items_list(self) -> ResultSet[Tag]: logger.info(f"Fetching entries from {self.url}") response = requests.get(self.url, allow_redirects=True) soup = BeautifulSoup(response.content, "html.parser") @@ -52,10 +53,10 @@ def items_list(self): logger.info(f"Found {len(articles)} articles") return articles - def _extra_values(self, contents): + def _extra_values(self, contents: BeautifulSoup): return {} - def get_contents(self, article_url: str): + def get_contents(self, article_url: str) -> Dict[str, Any]: contents = self.fetch_contents(article_url) title = self._get_title(contents) @@ -72,7 +73,7 @@ def get_contents(self, article_url: str): **self._extra_values(contents), } - def process_entry(self, article): + def process_entry(self, article: Tag) -> Article: article_url = self.get_item_key(article) contents = self.get_contents(article_url) if not contents.get("text"): @@ -80,8 +81,8 @@ def process_entry(self, article): return self.make_data_entry(contents) - def fetch_contents(self, url): - logger.info("Fetching {}".format(url)) + def fetch_contents(self, url: str): + logger.info(f"Fetching {url}") resp = requests.get(url, allow_redirects=True) return BeautifulSoup(resp.content, "html.parser") @@ -136,7 +137,7 @@ def _get_text(self, item): text = item.get("content") and item["content"][0].get("value") return self._extract_markdown(text) - def fetch_contents(self, url): + def fetch_contents(self, url: str): item = self.items[url] if "content" in item: return item diff --git a/align_data/db/models.py b/align_data/db/models.py index e79da232..d67ebfd5 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -221,6 +221,7 @@ def to_dict(self) -> Dict[str, Any]: } + event.listen(Article, "before_insert", Article.before_write) event.listen(Article, "before_update", Article.before_write) event.listen(Article, "before_insert", Article.check_for_changes) diff --git a/align_data/db/session.py b/align_data/db/session.py index 331de9b7..2e80c4b4 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -7,7 +7,6 @@ from align_data.settings import DB_CONNECTION_URI, MIN_CONFIDENCE from align_data.db.models import Article, PineconeStatus - logger = logging.getLogger(__name__) # We create a single engine for the entire application diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py index b0b09b9f..3cf32112 100644 --- a/align_data/embeddings/pinecone/pinecone_db_handler.py +++ b/align_data/embeddings/pinecone/pinecone_db_handler.py @@ -19,7 +19,6 @@ PINECONE_NAMESPACE, ) - logger = logging.getLogger(__name__) diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index 4e6ac03a..92fa726f 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -22,10 +22,8 @@ ) from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter - logger = logging.getLogger(__name__) - # Define type aliases for the Callables LengthFunctionType = Callable[[str], int] TruncateFunctionType = Callable[[str, int], str] diff --git a/align_data/postprocess/postprocess.py b/align_data/postprocess/postprocess.py index 05e9dbde..9db1a480 100644 --- a/align_data/postprocess/postprocess.py +++ b/align_data/postprocess/postprocess.py @@ -1,53 +1,81 @@ # %% from collections import defaultdict, Counter -from dataclasses import dataclass -import jsonlines -from tqdm import tqdm +from dataclasses import dataclass, field +from typing import List, DefaultDict import logging -from path import Path +from pathlib import Path +import jsonlines +from tqdm import tqdm import pylab as plt -import seaborn as sns +from nltk.tokenize import sent_tokenize, word_tokenize +import seaborn as sns #TODO: install seaborn or fix this file import pandas as pd logger = logging.getLogger(__name__) +#TODO: fix this file @dataclass class PostProcesser: """ This class is used to postprocess the data """ + jsonl_path: Path = field(default_factory=lambda: (Path(__file__).parent / '../../data/').resolve()) + + def __post_init__(self) -> None: + print(f"Looking for data in {self.jsonl_path}") - jsonl_path: Path = Path("../../data/") + # Check if the directory exists + if not self.jsonl_path.is_dir(): + raise FileNotFoundError(f"Data directory {self.jsonl_path} does not exist") - def __init__(self) -> None: - self.jsonl_list = sorted(self.jsonl_path.files("*.jsonl")) - self.source_list = [path.name.split(".jsonl")[0] for path in self.jsonl_list] - self.all_stats = defaultdict(Counter) + self.jsonl_list: List[Path] = sorted(self.jsonl_path.glob("*.jsonl")) + self.source_list: List[str] = [path.stem for path in self.jsonl_list] + self.all_stats: DefaultDict[str, Counter] = defaultdict(Counter) def compute_statistics(self) -> None: for source_name, path in tqdm(zip(self.source_list, self.jsonl_list)): with jsonlines.open(path) as reader: for obj in reader: - text = obj["text"] + text: str = obj['text'] source_stats = self.all_stats[source_name] source_stats["num_entries"] += 1 - source_stats["num_tokens"] += len(text.split()) # TODO: Use tokenizer + source_stats["num_tokens"] += len(word_tokenize(text)) source_stats["num_chars"] += len(text) source_stats["num_words"] += len(text.split()) - source_stats["num_sentences"] += len( - text.split(".") - ) # TODO: Use NLTK/Spacy or similar - source_stats["num_paragraphs"] += len(text.splitlines()) + source_stats["num_sentences"] += len(sent_tokenize(text)) + source_stats["num_newlines"] += len(text.split("\n")) + source_stats["num_paragraphs"] += len(text.split("\n\n")) def plot_statistics(self) -> None: all_df = pd.DataFrame(self.all_stats).T - plt.figure(figsize=(5, 5)) - sns.barplot(x=all_df.index, y=all_df["num_entries"]) + + fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(15, 15)) + metrics_to_plot = [ + "num_entries", + "num_tokens", + "num_words", + "num_sentences", + "num_paragraphs", + "num_chars", + ] + + for i, metric in enumerate(metrics_to_plot): + ax = axes[i // 2, i % 2] + sns.barplot(x=all_df.index, y=all_df[metric], ax=ax) + ax.set_title(metric) + ax.set_ylabel('') + ax.tick_params(axis='x', rotation=45) + # Uncomment the next line if you want to apply a log scale for better visualization. + # ax.set_yscale("log") + + plt.tight_layout() + plt.show() + def merge_all_files(self, out_dir: str) -> str: - pass + raise NotImplementedError def deduplicate(self) -> None: for path in tqdm(self.jsonl_list): @@ -58,7 +86,7 @@ def deduplicate(self) -> None: writer.write(obj) def clean_dataset(self, merged_dataset_path: str) -> str: - pass + raise NotImplementedError pp = PostProcesser() @@ -66,6 +94,8 @@ def clean_dataset(self, merged_dataset_path: str) -> str: pp.source_list # %% pp.compute_statistics() +print(pp.all_stats) +pp.plot_statistics() # %% pp.deduplicate() # %% diff --git a/align_data/sources/alignment_newsletter/alignment_newsletter.py b/align_data/sources/alignment_newsletter/alignment_newsletter.py index 87541d4a..ebd98c1d 100644 --- a/align_data/sources/alignment_newsletter/alignment_newsletter.py +++ b/align_data/sources/alignment_newsletter/alignment_newsletter.py @@ -1,10 +1,10 @@ # %% import logging from datetime import datetime, timezone -from pathlib import Path +from dataclasses import dataclass + import pandas as pd -from dataclasses import dataclass from align_data.common.alignment_dataset import SummaryDataset logger = logging.getLogger(__name__) @@ -14,10 +14,6 @@ class AlignmentNewsletter(SummaryDataset): done_key = "url" - def __post_init__(self, data_path=Path(__file__).parent / "../../../data/"): - self.data_path = data_path - self.raw_data_path = self.data_path / "raw" - def setup(self) -> None: super().setup() @@ -42,7 +38,7 @@ def _get_published_date(self, year): def items_list(self): return self.df.itertuples() - def process_entry(self, row): + def process_entry(self, row: pd.Series): """ For each row in the dataframe, create a new entry with the following fields: url, source, converted_with, source_type, venue, newsletter_category, highlight, newsletter_number, diff --git a/align_data/sources/arbital/arbital.py b/align_data/sources/arbital/arbital.py index b08393c4..aa7b9633 100644 --- a/align_data/sources/arbital/arbital.py +++ b/align_data/sources/arbital/arbital.py @@ -5,11 +5,9 @@ from typing import List, Tuple, Iterator, Dict, Union, Any, TypedDict import requests -from datetime import datetime, timezone from dateutil.parser import parse from align_data.common.alignment_dataset import AlignmentDataset -from dataclasses import dataclass logger = logging.getLogger(__name__) diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py index 7db94a7b..4d497bb1 100644 --- a/align_data/sources/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -1,10 +1,13 @@ import io import logging +from typing import Dict, Set from tqdm import tqdm import gspread +from gspread.worksheet import Worksheet from align_data.sources.articles.google_cloud import ( + SheetRow, iterate_rows, get_spreadsheet, get_sheet, @@ -18,10 +21,8 @@ from align_data.sources.articles.updater import ReplacerDataset from align_data.settings import PDFS_FOLDER_ID - logger = logging.getLogger(__name__) - # Careful changing these - the sheets assume this ordering REQUIRED_FIELDS = ["url", "source_url", "title", "source_type", "date_published"] OPTIONAL_FIELDS = ["authors", "summary"] @@ -45,11 +46,10 @@ def save_pdf(filename, link): parent_id=PDFS_FOLDER_ID, ) - @with_retry(times=3, exceptions=gspread.exceptions.APIError) -def process_row(row, sheets): +def process_row(row: SheetRow, sheets: Dict[str, Worksheet]): """Check the given `row` and fetch its metadata + optional extra stuff.""" - logger.info('Checking "%s"', row["title"]) + logger.info('Checking "%s" at "%s', row["title"], row["url"]) missing = [field for field in REQUIRED_FIELDS if not row.get(field)] if missing: @@ -60,10 +60,13 @@ def process_row(row, sheets): source_url = row.get("source_url") contents = item_metadata(source_url) - if not contents or "error" in contents: - error = (contents and contents.get("error")) or "text could not be fetched" - logger.error(error) - row.set_status(error) + if not contents: + logger.error("text could not be fetched") + row.set_status("text could not be fetched") + return + elif "error" in contents: + logger.error(contents["error"]) + row.set_status(contents["error"]) return data_source = contents.get("source_type") @@ -83,7 +86,7 @@ def process_row(row, sheets): row.set_status(OK) -def process_spreadsheets(source_sheet, output_sheets): +def process_spreadsheets(source_sheet: Worksheet, output_sheets: Dict[str, Worksheet]) -> None: """Go through all entries in `source_sheet` and update the appropriate metadata in `output_sheets`. `output_sheets` should be a dict with a key for each possible data type, e.g. html, pdf etc. @@ -92,43 +95,49 @@ def process_spreadsheets(source_sheet, output_sheets): :param Dict[str, Worksheet] output_sheets: a dict of per data type worksheets to be updated """ logger.info("fetching seen urls") - seen = { + seen: Set[str] = { url - for sheet in output_sheets.values() - for record in sheet.get_all_records() + for output_sheet in output_sheets.values() + for record in output_sheet.get_all_records() for url in [record.get("url"), record.get("source_url")] if url - } + } + # TODO: This requires our output_sheet to already have the headers for + # the different sheets. otherwise we raise an error, but we could have it be added + # automatically instead + for row in tqdm(iterate_rows(source_sheet)): - title = row.get("title") if not row.get("source_url"): row["source_url"] = row["url"] + if row.get("source_url") in seen: - logger.info(f'skipping "{title}", as it has already been seen') - elif row.get("status"): - logger.info( - f'skipping "{title}", as it has a status set - remove it for this row to be processed' - ) + logger.info(f'skipping "{row.get("title")}", as it has already been seen') + elif row.get('status'): + logger.info(f'skipping "{row.get("title")}", as it has a status set - remove it for this row to be processed') else: process_row(row, output_sheets) -def update_new_items(source_spreadsheet, source_sheet, output_spreadsheet): +def update_new_items(source_spreadsheet_id: str, source_sheet_name: str, output_spreadsheet_id: str) -> None: """Go through all unprocessed items from the source worksheet, updating the appropriate metadata in the output one.""" - source_sheet = get_sheet(source_spreadsheet, source_sheet) - sheets = {sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets()} - return process_spreadsheets(source_sheet, sheets) + source_sheet = get_sheet(source_spreadsheet_id, source_sheet_name) + output_sheets = { + sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet_id).worksheets() + } + process_spreadsheets(source_sheet, output_sheets) -def check_new_articles(source_spreadsheet, source_sheet): +def check_new_articles(source_spreadsheet_id: str, source_sheet_name: str): """Goes through the special indices looking for unseen articles.""" - source_sheet = get_sheet(source_spreadsheet, source_sheet) - current = {row.get("title"): row for row in iterate_rows(source_sheet)} + source_sheet = get_sheet(source_spreadsheet_id, source_sheet_name) + current: Dict[str, SheetRow] = {row.get("title"): row for row in iterate_rows(source_sheet)} + logger.info('Found %s articles in the sheet', len(current)) + seen_urls = { url - for item in current.values() - for key in ("url", "source_url") - if (url := item.get(key)) is not None + for row in current.values() + for url_key in ("url", "source_url") + if (url := row.get(url_key)) is not None } indices_items = fetch_all() diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index cbf7f9d9..ec372126 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -2,13 +2,14 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Dict, Iterable +from typing import Dict, Tuple, Iterable +from urllib.parse import urlparse import pandas as pd from gdown.download import download from markdownify import markdownify from pypandoc import convert_file -from sqlalchemy import select +from sqlalchemy import select, Select from align_data.common.alignment_dataset import AlignmentDataset from align_data.db.models import Article @@ -33,7 +34,7 @@ class SpreadsheetDataset(AlignmentDataset): spreadsheet_id: str sheet_id: str done_key = "url" - source_filetype = None + source_filetype = None # type: str batch_size = 1 @staticmethod @@ -51,7 +52,11 @@ def items_list(self) -> Iterable[tuple]: url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}" logger.info(f"Fetching {url}") df = pd.read_csv(url) - return (item for item in df.itertuples() if self.get_item_key(item)) + return ( + item + for item in df.itertuples() + if self.get_item_key(item) is not None + ) @staticmethod def _get_text(item): @@ -90,7 +95,7 @@ def process_entry(self, item: tuple): class SpecialDocs(SpreadsheetDataset): @property - def _query_items(self): + def _query_items(self) -> Select[Tuple[Article]]: special_docs_types = ["pdf", "html", "xml", "markdown", "docx"] return select(Article).where(Article.source.in_(special_docs_types)) @@ -146,7 +151,6 @@ def process_entry(self, item): return self.make_data_entry(contents) - class PDFArticles(SpreadsheetDataset): source_filetype = "pdf" COOLDOWN = 1 diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py index ca310235..9ff56a2c 100644 --- a/align_data/sources/articles/google_cloud.py +++ b/align_data/sources/articles/google_cloud.py @@ -2,89 +2,99 @@ import time from collections import UserDict from pathlib import Path -from typing import Dict, Optional -import regex as re +from typing import Dict, Any, Iterator, Union, List, Set +import re +import requests import gdown import grobid_tei_xml import gspread +from gspread.worksheet import Worksheet +from gspread.spreadsheet import Spreadsheet from bs4 import BeautifulSoup from google.oauth2.service_account import Credentials from googleapiclient.discovery import build from googleapiclient.http import MediaIoBaseUpload from markdownify import MarkdownConverter + from align_data.sources.articles.html import fetch, fetch_element from align_data.sources.articles.pdf import fetch_pdf logger = logging.getLogger(__name__) - SCOPES = [ "https://www.googleapis.com/auth/spreadsheets", "https://www.googleapis.com/auth/drive", ] - OK = "ok" -OUTPUT_SPREADSHEET_ID = "1bg-6vL-I82CBRkxvWQs1-Ao0nTvHyfn4yns5MdlbCmY" -sheet_name = "Sheet1" +OUTPUT_SPREADSHEET_ID = "1bg-6vL-I82CBRkxvWQs1-Ao0nTvHyfn4yns5MdlbCmY" # TODO: remove this +sheet_name = "Sheet1" # TODO: remove this -def get_credentials(credentials_file="credentials.json"): +def get_credentials(credentials_file: Union[Path, str] = "credentials.json") -> Credentials: return Credentials.from_service_account_file(credentials_file, scopes=SCOPES) -def get_spreadsheet(spreadsheet_id, credentials=None): +def get_spreadsheet(spreadsheet_id: str, credentials: Credentials = None) -> Spreadsheet: client = gspread.authorize(credentials or get_credentials()) return client.open_by_key(spreadsheet_id) -def get_sheet(spreadsheet_id, sheet_name, credentials=None): +def get_sheet(spreadsheet_id: str, sheet_name: str, credentials: Credentials = None) -> Worksheet: spreadsheet = get_spreadsheet(spreadsheet_id, credentials) return spreadsheet.worksheet(title=sheet_name) -class Row(UserDict): - sheet = None +class SheetRow(UserDict): + """A row in a Google Sheet.""" + sheet = None # type: Worksheet + columns = None # type: List[str | None] @classmethod - def set_sheet(cls, sheet): + def set_sheet(cls, sheet: Worksheet): cls.sheet = sheet - cls.columns = sheet.row_values(1) + cols = sheet.row_values(1) + # if there is no first column, we raise an error + if not isinstance(cols, list) or not cols: + raise ValueError(f"Sheet {sheet.title} has no header row") + + cls.columns = cols - def __init__(self, row_id, data): + def __init__(self, row_id: int, data: Dict[str, Any]): self.row_id = row_id super().__init__(data) - def update_value(self, col, value): + def update_value(self, col: str, value: str): self.sheet.update_cell(self.row_id, self.columns.index(col) + 1, value) - def update_colour(self, col, colour): + def update_colour(self, col: str, colour: Dict[str, float]): col_letter = chr(ord("A") + self.columns.index(col)) self.sheet.format(f"{col_letter}{self.row_id}", {"backgroundColor": colour}) - def set_status(self, status, status_col="status"): + def set_status(self, status: str, status_col: str = "status"): if self.get(status_col) == status: # Don't update anything if the status is the same - this saves on gdocs calls return if status == OK: - colour = {"red": 0, "green": 1, "blue": 0} - elif status == "": - colour = {"red": 1, "green": 1, "blue": 1} + colour = {"red": 0.0, "green": 1.0, "blue": 0.0} + elif status == "": # TODO: this should never be reached + colour = {"red": 1.0, "green": 1.0, "blue": 1.0} else: - colour = {"red": 1, "green": 0, "blue": 0} + colour = {"red": 1.0, "green": 0.0, "blue": 0.0} self.update_value(status_col, status) self.update_colour(status_col, colour) -def iterate_rows(sheet): +def iterate_rows(sheet: Worksheet) -> Iterator[SheetRow]: """Iterate over all the rows of the given `sheet`.""" - Row.set_sheet(sheet) + SheetRow.set_sheet(sheet) - for i, row in enumerate(sheet.get_all_records(), 2): - yield Row(i, row) + # we start the enumeration at 2 to avoid the header row + for row_id, row_data in enumerate(sheet.get_all_records(), 2): + yield SheetRow(row_id, row_data) def upload_file(filename, bytes_contents, mimetype, parent_id=None): @@ -131,14 +141,14 @@ def retrier(*args, **kwargs): return wrapper -def fetch_file(file_id): +def fetch_file(file_id: str): data_path = Path("data/raw/") data_path.mkdir(parents=True, exist_ok=True) file_name = data_path / file_id return gdown.download(id=file_id, output=str(file_name), quiet=False) -def fetch_markdown(file_id): +def fetch_markdown(file_id: str) -> Dict[str, str]: try: file_name = fetch_file(file_id) return { @@ -149,9 +159,11 @@ def fetch_markdown(file_id): return {"error": str(e)} -def parse_grobid(contents): +def parse_grobid(contents: str | bytes) -> Dict[str, Any]: + if isinstance(contents, bytes): + contents = contents.decode('utf-8') doc_dict = grobid_tei_xml.parse_document_xml(contents).to_dict() - authors = [xx["full_name"].strip(" !") for xx in doc_dict.get("header", {}).get("authors", [])] + authors: List[str] = [author["full_name"].strip(" !") for author in doc_dict.get("header", {}).get("authors", [])] if not doc_dict.get("body"): return { @@ -168,13 +180,13 @@ def parse_grobid(contents): } -def get_content_type(res): +def get_content_type(res: requests.Response) -> Set[str]: header = res.headers.get("Content-Type") or "" parts = [c_type.strip().lower() for c_type in header.split(";")] return set(filter(None, parts)) -def extract_gdrive_contents(link): +def extract_gdrive_contents(link: str) -> Dict[str, Any]: file_id = link.split("/")[-2] url = f"https://drive.google.com/uc?id={file_id}" res = fetch(url, "head") @@ -185,9 +197,9 @@ def extract_gdrive_contents(link): logger.error("Could not fetch the file at %s - are you sure that link is correct?", link) return {"error": "Could not read file from google drive"} - result = { - "source_url": link, - "downloaded_from": "google drive", + result: Dict[str, Any] = { + 'source_url': link, + 'downloaded_from': 'google drive', } content_type = get_content_type(res) @@ -203,7 +215,16 @@ def extract_gdrive_contents(link): res = fetch(url) if "Google Drive - Virus scan warning" in res.text: soup = BeautifulSoup(res.content, "html.parser") - res = fetch(soup.select_one("form").get("action")) + + form_tag = soup.select_one('form') + if form_tag is None: + return {**result, 'error': 'Virus scan warning - no form tag'} + + form_action_url = form_tag.get('action') + if not isinstance(form_action_url, str): + return {**result, 'error': 'Virus scan warning - no form action url'} + + res = fetch(form_action_url) content_type = get_content_type(res) if content_type & {"text/xml"}: @@ -224,7 +245,7 @@ def extract_gdrive_contents(link): return result -def google_doc(url: str) -> Dict: +def google_doc(url: str) -> Dict[str, Any]: """Fetch the contents of the given gdoc url as markdown.""" res = re.search(r"https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/", url) if not res: diff --git a/align_data/sources/articles/html.py b/align_data/sources/articles/html.py index d3c2490c..fb280c37 100644 --- a/align_data/sources/articles/html.py +++ b/align_data/sources/articles/html.py @@ -1,6 +1,6 @@ import time import logging -from typing import Union +from typing import Optional, Dict, Literal, Optional, Any, List import requests from bs4 import BeautifulSoup, Tag @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) - DEFAULT_HEADERS = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0", } @@ -33,7 +32,11 @@ def retrier(*args, **kwargs): return wrapper -def fetch(url, method="get", headers=DEFAULT_HEADERS): +def fetch( + url: str, + method: Literal["get", "post", "put", "delete", "patch", "options", "head"] = "get", + headers: Dict[str, str] = DEFAULT_HEADERS +) -> requests.Response: """Fetch the given `url`. This function is to have a single place to manage headers etc. @@ -41,7 +44,7 @@ def fetch(url, method="get", headers=DEFAULT_HEADERS): return getattr(requests, method)(url, allow_redirects=True, headers=headers) -def fetch_element(url: str, selector: str, headers=DEFAULT_HEADERS) -> Union[Tag, None]: +def fetch_element(url: str, selector: str, headers: Dict[str, str] = DEFAULT_HEADERS) -> Tag | None: """Fetch the first HTML element that matches the given CSS `selector` on the page found at `url`.""" try: resp = fetch(url, headers=headers) @@ -53,15 +56,16 @@ def fetch_element(url: str, selector: str, headers=DEFAULT_HEADERS) -> Union[Tag return soup.select_one(selector) -def element_extractor(selector, remove=[]): +def element_extractor(selector: str, remove: Optional[List[str]] = None): """Returns a function that will extract the first element that matches the given CSS selector. :params str selector: a CSS selector to run on the HTML of the page provided as the parameter of the function :param List[str] remove: An optional list of selectors to be removed from the resulting HTML. Useful for removing footers etc. :returns: A function that expects to get an URL, and which will then return the contents of the selected HTML element as markdown. """ + remove = remove or [] - def getter(url): + def getter(url: str) -> Dict[str, Any]: elem = fetch_element(url, selector) if not elem: return {} diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py index 6deb02d8..1a61e0eb 100644 --- a/align_data/sources/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -1,6 +1,7 @@ import logging import re from collections import defaultdict +from typing import Callable from dateutil.parser import ParserError, parse from markdownify import MarkdownConverter @@ -19,16 +20,22 @@ def get_text(tag, selector: str) -> str: return "" -def indice_fetcher(url, main_selector, item_selector, formatter): +def indice_fetcher(url: str, main_selector: str, item_selector: str, formatter: Callable): def fetcher(): if contents := fetch_element(url, main_selector): return list(filter(None, map(formatter, contents.select(item_selector)))) return [] - + fetcher.__name__ = formatter.__name__.replace("format_", "") + '_fetcher' + # formatter called "format_anthropic" -> fetcher called "anthropic_fetcher" + #TODO: Make this more explicit return fetcher def reading_what_we_can_items(): + # We fetch the books.js page of readingwhatwecan. + # It has 4 sections: first_entry, ml, ais, and scifi, + # which contain a dozen items (books, stories, papers) each. + res = fetch("https://readingwhatwecan.com/books.js") items = { item @@ -240,8 +247,11 @@ def fetch_all(): articles = defaultdict(dict) for func in tqdm(fetchers): + logger.info(f"Processing function: {func.__name__}") for item in func(): - articles[item["title"]].update(item) + logger.info(f"Processing item: {item}") + articles[item['title']].update(item) + logger.info(f"Found {len(articles)} articles") return articles diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py index 4bd493be..d5fe2578 100644 --- a/align_data/sources/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -1,6 +1,6 @@ import logging from urllib.parse import urlparse, urljoin -from typing import Dict +from typing import Dict, Optional, Callable, Any from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema @@ -10,9 +10,10 @@ from align_data.sources.arxiv_papers import fetch_arxiv from align_data.common.html_dataset import HTMLDataset - logger = logging.getLogger(__name__) +ParserFunc = Callable[[str], Dict[str, Any]] + def get_pdf_from_page(*link_selectors: str): """Get a function that receives an `url` to a page containing a pdf link and returns the pdf's contents as text. @@ -21,32 +22,40 @@ def get_pdf_from_page(*link_selectors: str): * if there are more selectors left, fetch the contents at the extracted link and continue * otherwise return the pdf contents at the last URL - :param List[str] link_selectors: CSS selector used to find the final download link + :param str *link_selectors: CSS selectors used to find the final download link :returns: the contents of the pdf file as a string """ + def getter(url: str) -> Dict[str, Any]: + current_url: str = url - def getter(url: str): - link: str = url for selector in link_selectors: - elem = fetch_element(link, selector) + elem = fetch_element(current_url, selector) if not elem: - return {"error": f"Could not find pdf download link for {link} using '{selector}'"} + return {"error": f"Could not find pdf download link for {current_url} using '{selector}'"} + + # Extracting href, considering it can be a string or a list of strings + href = elem.get("href") + if isinstance(href, list): + href = href[0] if href else None + + if not href: + return {"error": f"Could not extract href for {current_url} using '{selector}'"} - link = elem.get("href") - if not link.startswith("http") or not link.startswith("//"): - link = urljoin(url, link) + # Making sure the link is absolute + if not href.startswith(("http", "//")): + href = urljoin(url, href) + current_url = href # Some pages keep link to google drive previews of pdf files, which need to be # mangled to get the URL of the actual pdf file - if "drive.google.com" in link and "/view" in link: - return extract_gdrive_contents(link) + if "drive.google.com" in current_url and "/view" in current_url: + return extract_gdrive_contents(current_url) - if parse_domain(link) == "arxiv.org": - return fetch_arxiv(link) - if pdf := fetch_pdf(link): + if parse_domain(current_url) == "arxiv.org": + return fetch_arxiv(current_url) + if pdf := fetch_pdf(current_url): return pdf - return {"error": f"Could not fetch pdf from {link}"} - + return {"error": f"Could not fetch pdf from {current_url}"} return getter @@ -64,6 +73,7 @@ class MediumParser(HTMLDataset): It is possible that there is additional variation in the layout that hasn't been represented in the blogs tested so far. In that case, additional fixes to this code may be needed. + #TODO: investigate this """ source_type = "MediumParser(name='html', url='')" @@ -73,14 +83,13 @@ def _get_published_date(self, contents): possible_date_elements = contents.select("article div:first-child span") return self._find_date(possible_date_elements) - def __call__(self, url): + def __call__(self, url: str) -> Dict[str, Any]: return self.get_contents(url) -def error(error_msg): +def error(error_msg: str): """Returns a url handler function that just logs the provided `error` string.""" - - def func(url): + def func(url: str) -> Dict[str, Any]: if error_msg: logger.error(error_msg) return {"error": error_msg, "source_url": url} @@ -88,10 +97,13 @@ def func(url): return func -def multistrategy(*funcs): - """Merges multiple getter functions, returning the result of the first function call to succeed.""" +def multistrategy(*funcs: ParserFunc): + """ + Merges multiple getter functions, returning the result + of the first function call to succeed. + """ - def getter(url): + def getter(url: str) -> Dict[str, Any]: for func in funcs: res = func(url) if res and "error" not in res: @@ -100,7 +112,7 @@ def getter(url): return getter -UNIMPLEMENTED_PARSERS = { +UNIMPLEMENTED_PARSERS: Dict[str, ParserFunc] = { # Unhandled items that will be caught later. Though it would be good for them also to be done properly "oxford.universitypressscholarship.com": error(""), # Paywalled journal @@ -109,6 +121,8 @@ def getter(url): ), "link.springer.com": error("This article looks paywalled"), "www.dl.begellhouse.com": error("This article is paywalled"), + "dl.begellhouse.com": error("Begell house is not yet handled"), + # To be implemented "goodreads.com": error("Ebooks are not yet handled"), "judiciary.senate.gov": error(""), @@ -120,10 +134,23 @@ def getter(url): "Researchgate makes it hard to auto download pdf - please provide a DOI or a different url to the contents" ), "repository.cam.ac.uk": error(""), + + "deliverypdf.ssrn.com": error("SSRN is not yet handled"), + "doi.wiley.com": error("Wiley is not yet handled"), + "onlinelibrary.wiley.com": error("Wiley is not yet handled"), + "globalprioritiesproject.org": error("Global priorities project is not yet handled"), + "ieeexplore.ieee.org": error("IEEE is not yet handled"), + "pdcnet.org": error("pdcnet.org is not yet handled"), + "sciencemag.org": error("sciencemag.org is not yet handled"), + "iopscience.iop.org": error("iopscience.iop.org is not yet handled"), + "journals.aom.org": error("journals.aom.org is not yet handled"), + "cambridge.org": error("cambridge.org is not yet handled"), + "transformer-circuits.pub": error("not handled yet - same codebase as distill"), + } -HTML_PARSERS = { +HTML_PARSERS: Dict[str, ParserFunc] = { "academic.oup.com": element_extractor("#ContentTab"), "ai.googleblog.com": element_extractor("div.post-body.entry-content"), "arxiv-vanity.com": parse_vanity, @@ -218,7 +245,6 @@ def getter(url): ".Publication", ], ), - "transformer-circuits.pub": error("not handled yet - same codebase as distill"), "vox.com": element_extractor("did.c-entry-content", remove=["c-article-footer"]), "weforum.org": element_extractor("div.wef-0"), "www6.inrae.fr": element_extractor("div.ArticleContent"), @@ -226,7 +252,7 @@ def getter(url): "yoshuabengio.org": element_extractor("div.post-content"), } -PDF_PARSERS = { +PDF_PARSERS: Dict[str, ParserFunc] = { # Domain sepecific handlers "apcz.umk.pl": get_pdf_from_page(".galleys_links a.pdf", "a.download"), "arxiv.org": fetch_arxiv, @@ -264,15 +290,21 @@ def getter(url): def parse_domain(url: str) -> str: - return url and urlparse(url).netloc.lstrip("www.") + net_loc = urlparse(url).netloc + return net_loc[4:] if net_loc.startswith("www.") else net_loc -def item_metadata(url: str) -> Dict[str, any]: +def item_metadata(url: str) -> Dict[str, Any]: + if not url: + return {"error": "No url was given to item_metadata"} domain = parse_domain(url) try: res = fetch(url, "head") except (MissingSchema, InvalidSchema, ConnectionError) as e: return {"error": str(e)} + + if not res.headers.get('Content-Type'): + return {'error': 'No content type found'} content_type = {item.strip() for item in res.headers.get("Content-Type", "").split(";")} @@ -286,15 +318,17 @@ def item_metadata(url: str) -> Dict[str, any]: return res if parser := PDF_PARSERS.get(domain): - if res := parser(url): + if content := parser(url): # A pdf was found - use it, though it might not be useable - return res + return content if parser := UNIMPLEMENTED_PARSERS.get(domain): return parser(url) - if domain not in (HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()): - return {"error": "No domain handler defined"} + if domain not in ( + HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys() + ): + return {"error": f"No domain handler defined for {domain}"} return {"error": "could not parse url"} elif content_type & {"application/octet-stream", "application/pdf"}: if domain == "arxiv.org": diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py index d5af5bf4..d54e6dda 100644 --- a/align_data/sources/articles/pdf.py +++ b/align_data/sources/articles/pdf.py @@ -1,6 +1,8 @@ import io import logging +from typing import Dict, Any, List from urllib.parse import urlparse +from pathlib import Path from typing import Dict, Any from dateutil.parser import ParserError, parse @@ -8,14 +10,15 @@ from PyPDF2 import PdfReader from PyPDF2.errors import PdfReadError from markdownify import MarkdownConverter +from bs4.element import Tag from align_data.sources.articles.html import fetch, fetch_element, with_retry logger = logging.getLogger(__name__) -def sci_hub_pdf(identifier): - """Search Sci-hub for a link to a pdf of the article with the given identifier. +def sci_hub_pdf(identifier: str) -> str | None: + """Search Sci-hub for a link to a pdf of the article with the given identifier (doi). This will only get pdf that are directly served by Sci-hub. Sometimes it will redirect to a large file containing multiple articles, e.g. a whole journal or book, in which case this function @@ -24,7 +27,16 @@ def sci_hub_pdf(identifier): elem = fetch_element(f"https://sci-hub.st/{identifier}", "embed") if not elem: return None - src = elem.get("src").strip() + + src = elem.get("src") + + if isinstance(src, list): + src = src[0] if src else None + + if src is None: + return None + + src = src.strip() if src.startswith("//"): src = "https:" + src elif src.startswith("/"): @@ -32,7 +44,7 @@ def sci_hub_pdf(identifier): return src -def read_pdf(filename): +def read_pdf(filename: Path) -> str | None: try: pdf_reader = PdfReader(filename) return "\n".join(page.extract_text() for page in pdf_reader.pages) @@ -42,7 +54,7 @@ def read_pdf(filename): @with_retry(times=3) -def fetch_pdf(link): +def fetch_pdf(link: str) -> Dict[str, str]: """Return the contents of the pdf file at `link` as a markdown string. :param str link: the URL to check for a pdf file @@ -53,6 +65,7 @@ def fetch_pdf(link): "Could not fetch the pdf file at %s - are you sure that link is correct?", link, ) + return {"error": "Could not read pdf file"} content_type = {c_type.strip().lower() for c_type in res.headers.get("Content-Type").split(";")} if not content_type & {"application/octet-stream", "application/pdf"}: @@ -68,8 +81,8 @@ def fetch_pdf(link): "source_type": "pdf", } except (TypeError, PdfReadError) as e: - logger.error("Could not read PDF file: %s", e) - return {"error": str(e)} + logger.error('Could not read PDF file: %s', e) + error = str(e) filenames = [ i.strip().split("=")[1] @@ -87,20 +100,22 @@ def fetch_pdf(link): return {"error": error} -def get_arxiv_link(doi): +def get_arxiv_link(doi: str) -> str | None: """Find the URL to the pdf of the given arXiv DOI.""" res = requests.get(f"https://doi.org/api/handles/{doi}") if res.status_code != 200: return None - vals = [val for val in response.json().get("values") if val.get("type", "").upper() == "URL"] + vals = [ + val + for val in res.json().get("values") + if val.get("type", "").upper() == "URL" + ] - if not vals: - return None - return vals[0]["data"]["value"].replace("/abs/", "/pdf/") + ".pdf" + return vals and vals[0]["data"]["value"].replace("/abs/", "/pdf/") + ".pdf" -def get_doi(doi): +def get_doi(doi: str) -> Dict[str, Any]: """Get the article with the given `doi`. This will look for it in sci-hub and arxiv (if applicable), as those are likely the most @@ -110,49 +125,44 @@ def get_doi(doi): link = get_arxiv_link(doi) pdf = link and fetch_pdf(link) if pdf and "text" in pdf: - pdf["downloaded_from"] = "arxiv" - return pdf + return {**pdf, "downloaded_from": "arxiv"} if link := sci_hub_pdf(doi): if pdf := fetch_pdf(link): - pdf["downloaded_from"] = "scihub" - return pdf + return {**pdf, "downloaded_from": "scihub"} return {"error": "Could not find pdf of article by DOI"} -def doi_getter(url): +def doi_getter(url: str) -> Dict[str, Any]: """Extract the DOI from the given `url` and fetch the contents of its article.""" return get_doi(urlparse(url).path.lstrip("/")) -def parse_vanity(url) -> Dict[str, Any]: +def parse_vanity(url: str) -> Dict[str, Any]: contents = fetch_element(url, "article") if not contents: return {"error": "Could not fetch from arxiv vanity"} - if title := contents.select_one("h1.ltx_title"): - title = title.text + selected_title = contents.select_one("h1.ltx_title") + title = selected_title.text if selected_title else None - def get_first_child(item): - child = next(item.children) + def get_first_child(item: Tag) -> List[str]: + child = next(iter(item.children), None) if not child: return [] - - if not isinstance(child, str): - child = child.text - return child.split(",") + return child.text.split(",") authors = [ - a.strip() + author.strip() for item in contents.select("div.ltx_authors .ltx_personname") - for a in get_first_child(item) + for author in get_first_child(item) ] - if date_published := contents.select_one("div.ltx_dates"): - try: - date_published = parse(date_published.text.strip("()")) - except ParserError: - "If the date couldn't be parsed, hope that later phases will be more successful" + selected_date = contents.select_one("div.ltx_dates") + try: + date_published = parse(selected_date.text.strip("()")) if selected_date else None + except ParserError: + date_published = None text = "\n\n".join( MarkdownConverter().convert_soup(elem).strip() diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py index c2a1d29e..d9dccaef 100644 --- a/align_data/sources/articles/updater.py +++ b/align_data/sources/articles/updater.py @@ -1,21 +1,27 @@ import logging from collections import namedtuple from dataclasses import dataclass +from typing import List, Optional, Union, Tuple, NamedTuple +from pathlib import Path import pandas as pd -from sqlalchemy import select, or_ +from sqlalchemy import select, or_, Select from align_data.common.alignment_dataset import AlignmentDataset from align_data.db.models import Article from align_data.sources.articles.parsers import item_metadata +from sqlalchemy.orm import Session logger = logging.getLogger(__name__) -Item = namedtuple("Item", ["updates", "article"]) + +class Item(NamedTuple): + updates: NamedTuple + article: Article @dataclass class ReplacerDataset(AlignmentDataset): - csv_path: str + csv_path: str | Path delimiter: str done_key = "url" @@ -30,25 +36,31 @@ def maybe(item, key): return val @property - def items_list(self): + def items_list(self) -> List[Item]: df = pd.read_csv(self.csv_path, delimiter=self.delimiter) self.csv_items = [ item for item in df.itertuples() if self.maybe(item, "id") or self.maybe(item, "hash_id") ] - by_id = {i.id: i for i in self.csv_items if self.maybe(i, "id")} - by_hash_id = {i.hash_id: i for i in self.csv_items if self.maybe(i, "hash_id")} - - return [Item(by_id.get(a._id) or by_hash_id.get(a.id), a) for a in self.read_entries()] + by_id = {id: item for item in self.csv_items if (id := self.maybe(item, 'id'))} + by_hash_id = {hash_id: item for item in self.csv_items if (hash_id := self.maybe(item, 'hash_id'))} + + return [ + Item( + updates=by_id.get(article._id) or by_hash_id.get(article.id), + article=article + ) + for article in self.read_entries() + ] @property - def _query_items(self): + def _query_items(self) -> Select[Tuple[Article]]: ids = [i.id for i in self.csv_items if self.maybe(i, "id")] hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, "hash_id")] return select(Article).where(or_(Article.id.in_(hash_ids), Article._id.in_(ids))) - def update_text(self, updates, article): + def update_text(self, updates: NamedTuple, article: Article): # If the url is the same as it was before, and there isn't a source url provided, assume that the # previous text is still valid if article.url == self.maybe(updates, "url") and not self.maybe(updates, "source_url"): @@ -65,10 +77,10 @@ def update_text(self, updates, article): metadata = item_metadata(url) # Only change the text if it could be fetched - better to have outdated values than none if metadata.get("text"): - article.text = metadata.get("text") + article.text = metadata["text"] article.status = metadata.get("error") - def process_entry(self, item): + def process_entry(self, item: Item) -> Article: updates, article = item for key in ["url", "title", "source", "authors", "comment", "confidence"]: @@ -84,5 +96,5 @@ def process_entry(self, item): return article - def _add_batch(self, session, batch): + def _add_batch(self, session: Session, batch: tuple): session.add_all(map(session.merge, batch)) diff --git a/align_data/sources/arxiv_papers.py b/align_data/sources/arxiv_papers.py index 30002a78..45b3148b 100644 --- a/align_data/sources/arxiv_papers.py +++ b/align_data/sources/arxiv_papers.py @@ -3,6 +3,7 @@ from typing import Dict, Optional, Any import arxiv + from align_data.sources.articles.pdf import fetch_pdf, parse_vanity from align_data.sources.articles.html import fetch_element from align_data.sources.utils import merge_dicts @@ -10,7 +11,7 @@ logger = logging.getLogger(__name__) -def get_arxiv_metadata(paper_id) -> arxiv.Result: +def get_arxiv_metadata(paper_id: str) -> arxiv.Result | None: """ Get metadata from arxiv """ @@ -25,6 +26,7 @@ def get_arxiv_metadata(paper_id) -> arxiv.Result: def get_id(url: str) -> str | None: if res := re.search(r"https?://arxiv.org/(?:abs|pdf)/(.*?)(?:v\d+)?(?:/|\.pdf)?$", url): return res.group(1) + return None def canonical_url(url: str) -> str: @@ -50,13 +52,13 @@ def get_version(id: str) -> str | None: return res.group(1) -def is_withdrawn(url: str): - if elem := fetch_element(canonical_url(url), ".extra-services .full-text ul"): - return elem.text.strip().lower() == "withdrawn" - return None +def is_withdrawn(url: str) -> bool: + if elem := fetch_element(canonical_url(url), '.extra-services .full-text ul'): + return elem.text.strip().lower() == 'withdrawn' + return False -def add_metadata(data, paper_id): +def add_metadata(data: Dict[str, Any], paper_id: str) -> Dict[str, Any]: metadata = get_arxiv_metadata(paper_id) if not metadata: return {} @@ -78,7 +80,7 @@ def add_metadata(data, paper_id): ) -def fetch_arxiv(url) -> Dict: +def fetch_arxiv(url: str) -> Dict[str, Any]: paper_id = get_id(url) if not paper_id: return {"error": "Could not extract arxiv id"} diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index 1245aec6..e3a25f0d 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -2,12 +2,13 @@ from urllib.parse import urljoin import requests -from align_data.sources.articles.parsers import item_metadata -from align_data.common.html_dataset import HTMLDataset, RSSDataset from bs4 import BeautifulSoup from dateutil.parser import ParserError from tqdm import tqdm +from align_data.sources.articles.parsers import item_metadata +from align_data.common.html_dataset import HTMLDataset, RSSDataset + logger = logging.getLogger(__name__) @@ -77,7 +78,11 @@ def extract_authors(self, article): authors = [] if authors_div: authors = [ - i.split("(")[0].strip() for i in authors_div.select_one("p").children if not i.name + i.split("(")[0].strip() + for i in authors_div.select_one("p").children + if not i.name and i.strip() + # i.name is non-empty if it's a tag, ie
has name br + # but "OpenAI Research" has no name ] return authors or ["OpenAI Research"] diff --git a/align_data/sources/blogs/gwern_blog.py b/align_data/sources/blogs/gwern_blog.py index 1d573a8e..0375bb8c 100644 --- a/align_data/sources/blogs/gwern_blog.py +++ b/align_data/sources/blogs/gwern_blog.py @@ -1,6 +1,7 @@ -import requests -import logging from dataclasses import dataclass +import logging + +import requests from align_data.common.html_dataset import HTMLDataset @@ -14,7 +15,6 @@ class GwernBlog(HTMLDataset): """ COOLDOWN: int = 1 - done_key = "url" def get_item_key(self, item: str) -> str: return item diff --git a/align_data/sources/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py index cd409d98..b0c9e9f1 100644 --- a/align_data/sources/blogs/wp_blog.py +++ b/align_data/sources/blogs/wp_blog.py @@ -1,17 +1,16 @@ from dataclasses import dataclass import logging + import feedparser from tqdm import tqdm from align_data.common.html_dataset import RSSDataset - logger = logging.getLogger(__name__) @dataclass class WordpressBlog(RSSDataset): - summary_key = "summary" @property def feed_url(self): @@ -28,7 +27,7 @@ def items_list(self): with tqdm(desc=f"Loading {self.name} pages") as pbar: while True: paged_url = f"{self.feed_url}?paged={page_number}" - logging.info(f"Fetching {paged_url}") + logger.info(f"Fetching {paged_url}") feed = feedparser.parse(paged_url) title = feed.get("feed", {}).get("title") diff --git a/align_data/sources/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py index 65b52502..3756524a 100644 --- a/align_data/sources/ebooks/agentmodels.py +++ b/align_data/sources/ebooks/agentmodels.py @@ -1,9 +1,11 @@ -from align_data.common.alignment_dataset import AlignmentDataset from dataclasses import dataclass -from git import Repo import logging from datetime import timezone +from git import Repo + +from align_data.common.alignment_dataset import AlignmentDataset + logger = logging.getLogger(__name__) diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py index 8925fc22..579a4680 100644 --- a/align_data/sources/greaterwrong/greaterwrong.py +++ b/align_data/sources/greaterwrong/greaterwrong.py @@ -5,7 +5,6 @@ from typing import Set, Tuple import requests -import jsonlines from bs4 import BeautifulSoup from markdownify import markdownify from sqlalchemy import select @@ -69,7 +68,6 @@ class GreaterWrong(AlignmentDataset): limit = 50 COOLDOWN_TIME: float = 0.5 - summary_key: str = "summary" done_key = "url" lazy_eval = True source_type = 'GreaterWrong' @@ -112,49 +110,48 @@ def _get_published_date(self, item): return super()._get_published_date(item.get("postedAt")) def make_query(self, after: str): - return ( - """{ - posts(input: { - terms: { - excludeEvents: true - view: "old" - """ - f" af: {self.af}\n" - f" limit: {self.limit}\n" - f" karmaThreshold: {self.min_karma}\n" - f' after: "{after}"\n' - """ filter: "tagged" - } - }) { - totalCount - results { - _id - title - slug - pageUrl - postedAt - modifiedAt - score - extendedScore - baseScore - voteCount - commentCount - wordCount - tags { - name - } - user { - displayName - } - coauthors { - displayName - } - af - htmlBody - } - } - }""" - ) + return f''' + {{ + posts(input: {{ + terms: {{ + excludeEvents: true + view: "old" + af: {self.af} + limit: {self.limit} + karmaThreshold: {self.min_karma} + after: "{after}" + filter: "tagged" + }} + }}) {{ + totalCount + results {{ + _id + title + slug + pageUrl + postedAt + modifiedAt + score + extendedScore + baseScore + voteCount + commentCount + wordCount + tags {{ + name + }} + user {{ + displayName + }} + coauthors {{ + displayName + }} + af + htmlBody + }} + }} + }} + ''' def fetch_posts(self, query: str): res = requests.post( @@ -168,14 +165,18 @@ def fetch_posts(self, query: str): return res.json()["data"]["posts"] @property - def last_date_published(self): - try: - prev_item = next(self.read_entries(sort_by=Article.date_published.desc())) - if prev_item and prev_item.date_published: - return prev_item.date_published.isoformat() + "Z" - except StopIteration: - pass - return datetime(self.start_year, 1, 1).isoformat() + "Z" + def last_date_published(self) -> str: + entries = self.read_entries(sort_by=Article.date_published.desc()) + + # Get the first entry if exists, else return a default datetime + prev_item = next(entries, None) + + # If there is no previous item or it doesn't have a published date, return default datetime + if not prev_item or not prev_item.date_published: + return datetime(self.start_year, 1, 1).isoformat() + 'Z' + + # If the previous item has a published date, return it in isoformat + return prev_item.date_published.isoformat() + 'Z' @property def items_list(self): diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py index 95319820..cc40e9ac 100644 --- a/align_data/sources/stampy/stampy.py +++ b/align_data/sources/stampy/stampy.py @@ -2,15 +2,15 @@ import re import logging from dataclasses import dataclass + from codaio import Coda, Document +import html from align_data.common.alignment_dataset import AlignmentDataset from align_data.settings import CODA_TOKEN, CODA_DOC_ID, ON_SITE_TABLE logger = logging.getLogger(__name__) -import html - @dataclass class Stampy(AlignmentDataset): diff --git a/align_data/sources/youtube/youtube.py b/align_data/sources/youtube/youtube.py index 740597d0..608cd96d 100644 --- a/align_data/sources/youtube/youtube.py +++ b/align_data/sources/youtube/youtube.py @@ -1,8 +1,8 @@ import logging -from dataclasses import dataclass -from typing import List +from dataclasses import dataclass, field +from typing import List, Optional, Iterable -from googleapiclient.discovery import build +from googleapiclient.discovery import build, Resource from youtube_transcript_api import YouTubeTranscriptApi from youtube_transcript_api._errors import ( NoTranscriptFound, @@ -13,7 +13,6 @@ from align_data.settings import YOUTUBE_API_KEY from align_data.common.alignment_dataset import AlignmentDataset - logger = logging.getLogger(__name__) @@ -21,16 +20,17 @@ class YouTubeDataset(AlignmentDataset): done_key = "url" batch_size = 1 # COOLDOWN = 2 - authors = None - collection_ids = [] + authors: Optional[List[str]] = None + collection_ids: List[str] = field(default_factory=list) + def setup(self): super().setup() if not YOUTUBE_API_KEY: raise ValueError("No YOUTUBE_API_KEY provided!") - self.youtube = build("youtube", "v3", developerKey=YOUTUBE_API_KEY) + self.youtube: Resource = build("youtube", "v3", developerKey=YOUTUBE_API_KEY) - def next_page(self, collection_id, next_page_token): + def next_page(self, collection_id: str, next_page_token: list) -> dict: return {"items": []} @staticmethod @@ -45,7 +45,7 @@ def _get_id(item) -> str | None: if resource["kind"] == "youtube#video": return resource["videoId"] - def fetch_videos(self, collection_id): + def fetch_videos(self, collection_id: str) -> Iterable[dict]: next_page_token = None while True: videos_response = self.next_page(collection_id, next_page_token) @@ -74,7 +74,8 @@ def _get_contents(self, video): video_id = self._get_id(video) try: transcript = ( - YouTubeTranscriptApi.list_transcripts(video_id) + YouTubeTranscriptApi + .list_transcripts(video_id) .find_transcript(["en", "en-GB"]) .fetch() ) @@ -139,13 +140,14 @@ def _get_published_date(self, video): @dataclass class YouTubePlaylistDataset(YouTubeDataset): - playlist_ids: str + + playlist_ids: List[str] @property def collection_ids(self): return self.playlist_ids - def next_page(self, collection_id, next_page_token): + def next_page(self, collection_id: str, next_page_token: list): return ( self.youtube.playlistItems() .list( diff --git a/main.py b/main.py index 82c30f07..8ad9f88e 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ -import logging import os from dataclasses import dataclass from typing import List +import logging import fire @@ -20,7 +20,6 @@ METADATA_SOURCE_SPREADSHEET, ) - logger = logging.getLogger(__name__) diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index 7f911208..acdcdfcd 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -360,7 +360,7 @@ def test_arxiv_process_entry_retracted(mock_arxiv):
- + """ with patch("requests.get", return_value=Mock(content=response)): diff --git a/upload_to_huggingface.py b/upload_to_huggingface.py index 9e7481f8..43956b29 100644 --- a/upload_to_huggingface.py +++ b/upload_to_huggingface.py @@ -145,11 +145,7 @@ def update_readme(api, files, repo_name): for name in files: upload_data_file(api, name + ".jsonl", "alignment-research-dataset") - update_readme( - api, - [name for _, name in files if name in DATASOURCES], - "alignment-research-dataset", - ) - update_readme(api, [name for _, name in files], "ard-private") + update_readme(api, files, "alignment-research-dataset") + update_readme(api, files, "ard-private") print("done")