diff --git a/.github/workflows/update-metadata.yml b/.github/workflows/update-metadata.yml new file mode 100644 index 00000000..70e2ddd5 --- /dev/null +++ b/.github/workflows/update-metadata.yml @@ -0,0 +1,34 @@ +name: Update metadata +on: + workflow_dispatch: + inputs: + csv_url: + description: 'URL of CSV' + required: true + delimiter: + description: 'The column delimiter' + default: ',' + +jobs: + update: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Download CSV + id: download + run: curl -L "${{ inputs.csv_url }}" -o data.csv + + - name: Run Script + run: python main.py update data.csv ${{ inputs.delimiter }} diff --git a/.github/workflows/upload-to-huggingface.yml b/.github/workflows/upload-to-huggingface.yml index eaac2ceb..7788ad32 100644 --- a/.github/workflows/upload-to-huggingface.yml +++ b/.github/workflows/upload-to-huggingface.yml @@ -28,22 +28,19 @@ jobs: - distill - eaforum - eleuther.ai - - gdocs - generative.ink - gwern_blog - html_articles - importai - jsteinhardt_blog - lesswrong - - markdown - miri - ml_safety_newsletter - openai.research - - pdfs - rob_miles_ai_safety + - special_docs - vkrakovna_blog - yudkowsky_blog - - xmls uses: ./.github/workflows/push-dataset.yml with: diff --git a/align_data/__init__.py b/align_data/__init__.py index 9f6c9893..54041500 100644 --- a/align_data/__init__.py +++ b/align_data/__init__.py @@ -2,7 +2,6 @@ import align_data.sources.articles as articles import align_data.sources.blogs as blogs import align_data.sources.ebooks as ebooks -import align_data.sources.arxiv_papers as arxiv_papers import align_data.sources.greaterwrong as greaterwrong import align_data.sources.stampy as stampy import align_data.sources.alignment_newsletter as alignment_newsletter @@ -14,7 +13,6 @@ + articles.ARTICLES_REGISTRY + blogs.BLOG_REGISTRY + ebooks.EBOOK_REGISTRY - + arxiv_papers.ARXIV_REGISTRY + greaterwrong.GREATERWRONG_REGISTRY + stampy.STAMPY_REGISTRY + distill.DISTILL_REGISTRY diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 22d7df02..6e9cd731 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -24,6 +24,9 @@ "title": None, "url": None, "authors": lambda: [], + "source_type": None, + "status": None, + "comments": None, } logger = logging.getLogger(__name__) @@ -85,7 +88,7 @@ def make_data_entry(self, data, **kwargs) -> Article: article = Article( id_fields=self.id_fields, - meta={k: v for k, v in data.items() if k not in INIT_DICT}, + meta={k: v for k, v in data.items() if k not in INIT_DICT and v is not None}, **{k: v for k, v in data.items() if k in INIT_DICT}, ) self._add_authors(article, authors) @@ -106,13 +109,17 @@ def to_jsonl(self, out_path=None, filename=None) -> Path: jsonl_writer.write(article.to_dict()) return filename.resolve() + @property + def _query_items(self): + return select(Article).where(Article.source == self.name) + def read_entries(self, sort_by=None): """Iterate through all the saved entries.""" with make_session() as session: - query = select(Article).where(Article.source == self.name) + query = self._query_items.options(joinedload(Article.summaries)) if sort_by is not None: query = query.order_by(sort_by) - for item in session.scalars(query): + for item in session.scalars(query).unique(): yield item def _add_batch(self, session, batch): @@ -204,7 +211,7 @@ def fetch_entries(self): if self.COOLDOWN: time.sleep(self.COOLDOWN) - def process_entry(self, entry): + def process_entry(self, entry) -> Optional[Article]: """Process a single entry.""" raise NotImplementedError @@ -212,7 +219,8 @@ def process_entry(self, entry): def _format_datetime(date) -> str: return date.strftime("%Y-%m-%dT%H:%M:%SZ") - def _get_published_date(self, date) -> Optional[datetime]: + @staticmethod + def _get_published_date(date) -> Optional[datetime]: try: # Totally ignore any timezone info, forcing everything to UTC return parse(str(date)).replace(tzinfo=pytz.UTC) @@ -228,13 +236,8 @@ def unprocessed_items(self, items=None) -> Iterable: urls = map(self.get_item_key, items) with make_session() as session: - self.articles = { - a.url: a - for a in session.query(Article) - .options(joinedload(Article.summaries)) - .filter(Article.url.in_(urls)) - if a.url - } + articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls)) + self.articles = {a.url: a for a in articles if a.url} return items diff --git a/align_data/db/models.py b/align_data/db/models.py index 029c3d1c..157c4340 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -11,11 +11,10 @@ String, Boolean, Text, - Float, func, event, ) -from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.dialects.mysql import LONGTEXT from align_data.settings import PINECONE_METADATA_KEYS @@ -58,6 +57,8 @@ class Article(Base): date_updated: Mapped[Optional[datetime]] = mapped_column( DateTime, onupdate=func.current_timestamp() ) + status: Mapped[Optional[str]] = mapped_column(String(256)) + comments: Mapped[Optional[str]] = mapped_column(LONGTEXT) # Editor comments. Can be anything pinecone_update_required: Mapped[bool] = mapped_column(Boolean, default=False) @@ -90,9 +91,10 @@ def generate_id_string(self) -> str: "utf-8" ) - def verify_fields(self): - missing = [field for field in self.__id_fields if not getattr(self, field)] - assert not missing, f"Entry is missing the following fields: {missing}" + @property + def missing_fields(self): + fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'} + return sorted([field for field in fields if not getattr(self, field, None)]) def verify_id(self): assert self.id is not None, "Entry is missing id" @@ -101,13 +103,17 @@ def verify_id(self): id_from_fields = hashlib.md5(id_string).hexdigest() assert ( self.id == id_from_fields - ), f"Entry id {self.id} does not match id from id_fields, {id_from_fields}" + ), f"Entry id {self.id} does not match id from id_fields: {id_from_fields}" + + def verify_id_fields(self): + missing = [field for field in self.__id_fields if not getattr(self, field)] + assert not missing, f"Entry is missing the following fields: {missing}" def update(self, other): for field in self.__table__.columns.keys(): if field not in ["id", "hash_id", "metadata"] and getattr(other, field): setattr(self, field, getattr(other, field)) - self.meta.update({k: v for k, v in other.meta.items() if k and v}) + self.meta = dict((self.meta or {}), **{k: v for k, v in other.meta.items() if k and v}) if other._id: self._id = other._id @@ -118,14 +124,28 @@ def _set_id(self): id_string = self.generate_id_string() self.id = hashlib.md5(id_string).hexdigest() + def add_meta(self, key, val): + if self.meta is None: + self.meta = {} + self.meta[key] = val + @classmethod def before_write(cls, mapper, connection, target): - target.verify_fields() + target.verify_id_fields() + + if not target.status and target.missing_fields: + target.status = 'Missing fields' + target.comments = f'missing fields: {", ".join(target.missing_fields)}' if target.id: target.verify_id() else: target._set_id() + + # This assumes that status pretty much just notes down that an entry is invalid. If it has + # all fields set and is being written to the database, then it must have been modified, ergo + # should be also updated in pinecone + if not target.status: target.pinecone_update_required = True def to_dict(self): @@ -147,7 +167,7 @@ def to_dict(self): "date_published": date, "authors": authors, "summaries": [s.text for s in (self.summaries or [])], - **(self.meta or {}), + **(meta or {}), } diff --git a/align_data/pinecone/text_splitter.py b/align_data/pinecone/text_splitter.py index 03c74b57..b8af09a3 100644 --- a/align_data/pinecone/text_splitter.py +++ b/align_data/pinecone/text_splitter.py @@ -4,6 +4,15 @@ from langchain.text_splitter import TextSplitter from nltk.tokenize import sent_tokenize +# TODO: Fix this. +# sent_tokenize has strange behavior sometimes: 'The units could be anything (characters, words, sentences, etc.), depending on how you want to chunk your text.' +# splits into ['The units could be anything (characters, words, sentences, etc.', '), depending on how you want to chunk your text.'] + +StrToIntFunction = Callable[[str], int] +StrIntBoolToStrFunction = Callable[[str, int, bool], str] + +def default_truncate_function(string: str, length: int, from_end: bool = False) -> str: + return string[-length:] if from_end else string[:length] def default_truncate_function(string: str, length: int, from_end: bool = False) -> str: return string[-length:] if from_end else string[:length] @@ -14,22 +23,22 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter): @param min_chunk_size: The minimum number of units in a chunk. @param max_chunk_size: The maximum number of units in a chunk. - @param length_function: A function that returns the length of a string in units. + @param length_function: A function that returns the length of a string in units. Defaults to len(). @param truncate_function: A function that truncates a string to a given unit length. """ - - DEFAULT_MIN_CHUNK_SIZE = 900 - DEFAULT_MAX_CHUNK_SIZE = 1100 - DEFAULT_LENGTH_FUNCTION = lambda string: len(string) - DEFAULT_TRUNCATE_FUNCTION = default_truncate_function + DEFAULT_MIN_CHUNK_SIZE: int = 900 + DEFAULT_MAX_CHUNK_SIZE: int = 1100 + DEFAULT_LENGTH_FUNCTION: StrToIntFunction = len + DEFAULT_TRUNCATE_FUNCTION: StrIntBoolToStrFunction = default_truncate_function + def __init__( self, min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, - length_function: Callable[[str], int] = DEFAULT_LENGTH_FUNCTION, - truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION, - **kwargs: Any, + length_function: StrToIntFunction = DEFAULT_LENGTH_FUNCTION, + truncate_function: StrIntBoolToStrFunction = DEFAULT_TRUNCATE_FUNCTION, + **kwargs: Any ): super().__init__(**kwargs) self.min_chunk_size = min_chunk_size @@ -39,8 +48,9 @@ def __init__( self._truncate_function = truncate_function def split_text(self, text: str) -> List[str]: - blocks = [] - current_block = "" + """Split text into chunks of length between min_chunk_size and max_chunk_size.""" + blocks: List[str] = [] + current_block: str = "" paragraphs = text.split("\n\n") for paragraph in paragraphs: @@ -56,10 +66,9 @@ def split_text(self, text: str) -> List[str]: continue blocks = self._handle_remaining_text(current_block, blocks) - return [block.strip() for block in blocks] - def _handle_large_paragraph(self, current_block, blocks, paragraph): + def _handle_large_paragraph(self, current_block: str, blocks: List[str], paragraph: str) -> str: # Undo adding the whole paragraph offset = len(paragraph) + 2 # +2 accounts for "\n\n" current_block = current_block[:-offset] @@ -75,44 +84,35 @@ def _handle_large_paragraph(self, current_block, blocks, paragraph): blocks.append(current_block) current_block = "" else: - current_block = self._truncate_large_block( - current_block, blocks, sentence - ) - + current_block = self._truncate_large_block(current_block, blocks) return current_block - def _truncate_large_block(self, current_block, blocks, sentence): + def _truncate_large_block(self, current_block: str, blocks: List[str]) -> str: while self._length_function(current_block) > self.max_chunk_size: - # Truncate current_block to max size, set remaining sentence as next sentence + # Truncate current_block to max size, set remaining text as current_block truncated_block = self._truncate_function( - current_block, self.max_chunk_size + current_block, self.max_chunk_size ) blocks.append(truncated_block) - remaining_sentence = current_block[len(truncated_block) :].lstrip() - current_block = sentence = remaining_sentence - + current_block = current_block[len(truncated_block):].lstrip() + return current_block - def _handle_remaining_text(self, current_block, blocks): + def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str]: if blocks == []: # no blocks were added - return [current_block] - elif current_block: # any leftover text - len_current_block = self._length_function(current_block) - if len_current_block < self.min_chunk_size: - # it needs to take the last min_chunk_size-len_current_block units from the previous block - previous_block = blocks[-1] - required_units = ( - self.min_chunk_size - len_current_block - ) # calculate the required units - + return [last_block] + elif last_block: # any leftover text + len_last_block = self._length_function(last_block) + if self.min_chunk_size - len_last_block > 0: + # Add text from previous block to last block if last_block is too short part_prev_block = self._truncate_function( - previous_block, required_units, from_end=True - ) # get the required units from the previous block - last_block = part_prev_block + current_block + string=blocks[-1], + length=self.min_chunk_size - len_last_block, + from_end=True + ) + last_block = part_prev_block + last_block - blocks.append(last_block) - else: - blocks.append(current_block) + blocks.append(last_block) return blocks diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py index 7e9fdbde..da7f3a6b 100644 --- a/align_data/sources/articles/__init__.py +++ b/align_data/sources/articles/__init__.py @@ -1,5 +1,6 @@ from align_data.sources.articles.datasets import ( - EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles + ArxivPapers, EbookArticles, DocArticles, HTMLArticles, + MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles ) from align_data.sources.articles.indices import IndicesDataset @@ -39,5 +40,10 @@ spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI', sheet_id='980957638', ), + ArxivPapers( + name="arxiv", + spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", + sheet_id="655836697", + ), IndicesDataset('indices'), ] diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py index 7485ce9e..39cc31b7 100644 --- a/align_data/sources/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -15,6 +15,7 @@ from align_data.sources.articles.parsers import item_metadata, fetch from align_data.sources.articles.indices import fetch_all from align_data.sources.articles.html import with_retry +from align_data.sources.articles.updater import ReplacerDataset from align_data.settings import PDFS_FOLDER_ID @@ -65,7 +66,7 @@ def process_row(row, sheets): row.set_status(error) return - data_source = contents.get("data_source") + data_source = contents.get("source_type") if data_source not in sheets: error = "Unhandled data type" logger.error(error) @@ -158,3 +159,8 @@ def check_new_articles(source_spreadsheet, source_sheet): updated = res["updates"]["updatedRows"] logger.info("Added %s rows", updated) return updated + + +def update_articles(csv_file, delimiter): + dataset = ReplacerDataset(name='updater', csv_path=csv_file, delimiter=delimiter) + dataset.add_entries(dataset.fetch_entries()) diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index a7c3bb47..2cf52ee8 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -1,18 +1,24 @@ -import os import logging +import os from dataclasses import dataclass from pathlib import Path +from typing import Dict from urllib.parse import urlparse -from pypandoc import convert_file import pandas as pd from gdown.download import download from markdownify import markdownify +from pypandoc import convert_file +from sqlalchemy import select -from align_data.sources.articles.pdf import read_pdf -from align_data.sources.articles.parsers import HTML_PARSERS, extract_gdrive_contents, item_metadata -from align_data.sources.articles.google_cloud import fetch_markdown, fetch_file from align_data.common.alignment_dataset import AlignmentDataset +from align_data.db.models import Article +from align_data.sources.articles.google_cloud import fetch_file, fetch_markdown +from align_data.sources.articles.parsers import ( + HTML_PARSERS, extract_gdrive_contents, item_metadata, parse_domain +) +from align_data.sources.articles.pdf import read_pdf +from align_data.sources.arxiv_papers.arxiv_papers import fetch as fetch_arxiv logger = logging.getLogger(__name__) @@ -26,20 +32,21 @@ class SpreadsheetDataset(AlignmentDataset): batch_size = 1 @staticmethod - def maybe(val): + def maybe(item, key: str): + val = getattr(item, key, None) if pd.isna(val): return None return val + def get_item_key(self, item): + return self.maybe(item, self.done_key) + @property def items_list(self): url = f'https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}' logger.info(f'Fetching {url}') df = pd.read_csv(url) - return (item for item in df.itertuples() if self.maybe(self.get_item_key(item))) - - def get_item_key(self, item): - return getattr(item, self.done_key) + return (item for item in df.itertuples() if self.get_item_key(item)) @staticmethod def _get_text(item): @@ -47,7 +54,7 @@ def _get_text(item): @staticmethod def extract_authors(item): - if not SpreadsheetDataset.maybe(item.authors): + if not SpreadsheetDataset.maybe(item, "authors"): return [] return [author.strip() for author in item.authors.split(",") if author.strip()] @@ -57,42 +64,62 @@ def process_entry(self, item): logger.error("Could not get text for %s - skipping for now", item.title) return None + url = self.maybe(item, "url") + source_url = self.maybe(item, "source_url") + return self.make_data_entry( { "text": markdownify(text).strip(), - "url": self.maybe(item.url), - "title": self.maybe(item.title), + "url": url, + "source_url": source_url if source_url != url else None, + "title": self.maybe(item, "title"), "source": self.name, - "source_type": self.maybe(item.source_type), + "source_type": self.maybe(item, "source_type"), "source_filetype": self.source_filetype, "date_published": self._get_published_date(item.date_published), "authors": self.extract_authors(item), - "summary": self.maybe(item.summary), + "summary": self.maybe(item, "summary"), } ) class SpecialDocs(SpreadsheetDataset): - def process_entry(self, item): + @property + def _query_items(self): + special_docs_types = ["pdf", "html", "xml", "markdown", "docx"] + return select(Article).where(Article.source.in_(special_docs_types)) + + def get_contents(self, item) -> Dict: metadata = {} - if url := self.maybe(item.source_url) or self.maybe(item.url): + if url := self.maybe(item, "source_url") or self.maybe(item, "url"): metadata = item_metadata(url) - text = metadata.get('text') - if not text: - logger.error('Could not get text for %s - skipping for now', item.title) - return None - - return self.make_data_entry({ - 'source': metadata.get('data_source') or self.name, - 'url': self.maybe(item.url), - 'title': self.maybe(item.title) or metadata.get('title'), - 'source_type': self.maybe(item.source_type), + return { + 'url': self.maybe(item, "url"), + 'title': self.maybe(item, "title") or metadata.get('title'), + 'source': metadata.get('source_type') or self.name, + 'source_url': self.maybe(item, "source_url"), + 'source_type': metadata.get('source_type') or self.maybe(item, "source_type"), 'date_published': self._get_published_date(item.date_published) or metadata.get('date_published'), 'authors': self.extract_authors(item) or metadata.get('authors', []), - 'text': text, - }) + 'text': metadata.get('text'), + 'status': 'Invalid' if metadata.get('error') else None, + 'comments': metadata.get('error'), + } + + def process_entry(self, item): + if parse_domain(item.url) == "arxiv.org": + contents = ArxivPapers.get_contents(item) + contents['source'] = 'arxiv' + else: + contents = self.get_contents(item) + + # Skip items that can't be saved because missing fields + if not all(contents.get(key) for key in self.id_fields): + return None + + return self.make_data_entry(contents) class PDFArticles(SpreadsheetDataset): @@ -148,7 +175,7 @@ def _get_text(self, item): class MarkdownArticles(SpreadsheetDataset): - source_filetype = "md" + source_filetype = "markdown" def _get_text(self, item): file_id = item.source_url.split("/")[-2] @@ -172,3 +199,26 @@ def _get_text(self, item): file_id = item.source_url.split("/")[-2] file_name = fetch_file(file_id) return convert_file(file_name, "md", format="docx", extra_args=["--wrap=none"]) + + +class ArxivPapers(SpreadsheetDataset): + COOLDOWN: int = 1 + + @classmethod + def get_contents(cls, item) -> Dict: + contents = fetch_arxiv(item.url or item.source_url) + + if cls.maybe(item, "authors") and item.authors.strip(): + contents['authors'] = [i.strip() for i in item.authors.split(',')] + if cls.maybe(item, "title"): + contents['title'] = cls.maybe(item, "title") + + contents['date_published'] = cls._get_published_date( + cls.maybe(item, "date_published") or contents.get('date_published') + ) + return contents + + def process_entry(self, item): + logger.info(f"Processing {item.title}") + + return self.make_data_entry(self.get_contents(item), source=self.name) diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py index 6cd9e337..b1e957f8 100644 --- a/align_data/sources/articles/google_cloud.py +++ b/align_data/sources/articles/google_cloud.py @@ -143,7 +143,7 @@ def fetch_markdown(file_id): file_name = fetch_file(file_id) return { "text": Path(file_name).read_text(), - "data_source": "markdown", + "source_type": "markdown", } except Exception as e: return {'error': str(e)} @@ -156,7 +156,7 @@ def parse_grobid(contents): if not doc_dict.get('body'): return { 'error': 'No contents in XML file', - 'data_source': 'xml', + 'source_type': 'xml', } return { @@ -164,7 +164,7 @@ def parse_grobid(contents): "abstract": doc_dict.get("abstract"), "text": doc_dict["body"], "authors": list(filter(None, authors)), - "data_source": "xml", + "source_type": "xml", } @@ -198,7 +198,7 @@ def extract_gdrive_contents(link): elif content_type & {'text/markdown'}: result.update(fetch_markdown(file_id)) elif content_type & {'application/epub+zip', 'application/epub'}: - result['data_source'] = 'ebook' + result['source_type'] = 'ebook' elif content_type & {'text/html'}: res = fetch(url) if 'Google Drive - Virus scan warning' in res.text: @@ -213,7 +213,7 @@ def extract_gdrive_contents(link): soup = BeautifulSoup(res.content, "html.parser") result.update({ 'text': MarkdownConverter().convert_soup(soup.select_one('body')).strip(), - 'data_source': 'html', + 'source_type': 'html', }) else: result['error'] = f'unknown content type: {content_type}' diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py index 85d23fe8..42c25c9f 100644 --- a/align_data/sources/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -250,8 +250,12 @@ def getter(url): } +def parse_domain(url: str) -> str: + return url and urlparse(url).netloc.lstrip('www.') + + def item_metadata(url) -> Dict[str, str]: - domain = urlparse(url).netloc.lstrip('www.') + domain = parse_domain(url) try: res = fetch(url, 'head') except (MissingSchema, InvalidSchema, ConnectionError) as e: @@ -265,7 +269,7 @@ def item_metadata(url) -> Dict[str, str]: if parser := HTML_PARSERS.get(domain): if res := parser(url): # Proper contents were found on the page, so use them - return {'source_url': url, 'data_source': 'html', 'text': res} + return {'source_url': url, 'source_type': 'html', 'text': res} if parser := PDF_PARSERS.get(domain): if res := parser(url): @@ -286,6 +290,6 @@ def item_metadata(url) -> Dict[str, str]: elif content_type & {"application/epub+zip", "application/epub"}: # it looks like an ebook. Assume it's fine. # TODO: validate that the ebook is readable - return {"source_url": url, "data_source": "ebook"} + return {"source_url": url, "source_type": "ebook"} else: return {"error": f"Unhandled content type: {content_type}"} diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py index 9db52b9b..aca627f1 100644 --- a/align_data/sources/articles/pdf.py +++ b/align_data/sources/articles/pdf.py @@ -66,7 +66,7 @@ def fetch_pdf(link): return { "source_url": link, "text": "\n".join(page.extract_text() for page in pdf_reader.pages), - "data_source": "pdf", + "source_type": "pdf", } except (TypeError, PdfReadError) as e: logger.error('Could not read PDF file: %s', e) @@ -170,5 +170,5 @@ def get_first_child(item): "authors": authors, "text": text, "date_published": date_published, - "data_source": "html", + "source_type": "html", } diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py new file mode 100644 index 00000000..861d0af6 --- /dev/null +++ b/align_data/sources/articles/updater.py @@ -0,0 +1,90 @@ +import logging +from collections import namedtuple +from dataclasses import dataclass + +import pandas as pd +from sqlalchemy import select, or_ +from align_data.common.alignment_dataset import AlignmentDataset +from align_data.db.models import Article +from align_data.sources.articles.parsers import item_metadata + +logger = logging.getLogger(__name__) + +Item = namedtuple('Item', ['updates', 'article']) + + +@dataclass +class ReplacerDataset(AlignmentDataset): + csv_path: str + delimiter: str + done_key = "url" + + def get_item_key(self, item): + return None + + @staticmethod + def maybe(item, key): + val = getattr(item, key, None) + if pd.isna(val) or (isinstance(val, str) and not val.strip()): + return None + return val + + @property + def items_list(self): + df = pd.read_csv(self.csv_path, delimiter=self.delimiter) + self.csv_items = [ + item for item in df.itertuples() + if self.maybe(item, 'id') or self.maybe(item, 'hash_id') + ] + by_id = {i.id: i for i in self.csv_items if self.maybe(i, 'id')} + by_hash_id = {i.hash_id: i for i in self.csv_items if self.maybe(i, 'hash_id')} + + return [ + Item(by_id.get(a._id) or by_hash_id.get(a.id), a) + for a in self.read_entries() + ] + + @property + def _query_items(self): + ids = [i.id for i in self.csv_items if self.maybe(i, 'id')] + hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, 'hash_id')] + return select(Article).where(or_(Article.id.in_(hash_ids), Article._id.in_(ids))) + + def update_text(self, updates, article): + # If the url is the same as it was before, and there isn't a source url provided, assume that the + # previous text is still valid + if article.url == self.maybe(updates, 'url') and not self.maybe(updates, 'source_url'): + return + + # If no url found, then don't bother fetching the text - assume it was successfully fetched previously + url = self.maybe(updates, 'source_url') or self.maybe(updates, 'url') + if not url: + return + + if article.url != url: + article.add_meta('source_url', url) + + metadata = item_metadata(url) + # Only change the text if it could be fetched - better to have outdated values than none + if metadata.get('text'): + article.text = metadata.get('text') + article.status = metadata.get('error') + + def process_entry(self, item): + updates, article = item + + for key in ['url', 'title', 'source', 'authors', 'comment', 'confidence']: + value = self.maybe(updates, key) + if value and getattr(article, key, None) != value: + setattr(article, key, value) + + if date := getattr(updates, 'date_published', None): + article.date_published = self._get_published_date(date) + + self.update_text(updates, article) + article._set_id() + + return article + + def _add_batch(self, session, batch): + session.add_all(map(session.merge, batch)) diff --git a/align_data/sources/arxiv_papers/__init__.py b/align_data/sources/arxiv_papers/__init__.py index 29258480..e69de29b 100644 --- a/align_data/sources/arxiv_papers/__init__.py +++ b/align_data/sources/arxiv_papers/__init__.py @@ -1,9 +0,0 @@ -from .arxiv_papers import ArxivPapers - -ARXIV_REGISTRY = [ - ArxivPapers( - name="arxiv", - spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", - sheet_id="655836697", - ) -] diff --git a/align_data/sources/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py index 42f6dcaa..dc551946 100644 --- a/align_data/sources/arxiv_papers/arxiv_papers.py +++ b/align_data/sources/arxiv_papers/arxiv_papers.py @@ -1,75 +1,92 @@ import logging import re -from dataclasses import dataclass +from typing import Dict, Optional import arxiv -from align_data.sources.articles.datasets import SpreadsheetDataset from align_data.sources.articles.pdf import fetch_pdf, parse_vanity +from align_data.sources.articles.html import fetch_element logger = logging.getLogger(__name__) -@dataclass -class ArxivPapers(SpreadsheetDataset): - summary_key: str = "summary" - COOLDOWN: int = 1 - done_key = "url" - batch_size = 1 - - def _get_arxiv_metadata(self, paper_id) -> arxiv.Result: - """ - Get metadata from arxiv - """ - try: - search = arxiv.Search(id_list=[paper_id], max_results=1) - return next(search.results()) - except Exception as e: - logger.error(e) - return None - - def get_id(self, item): - if res := re.search(r"https://arxiv.org/abs/(.*?)/?$", item.url): - return res.group(1) - - def get_contents(self, item) -> dict: - paper_id = self.get_id(item) - for link in [ - f"https://www.arxiv-vanity.com/papers/{paper_id}", - f"https://ar5iv.org/abs/{paper_id}", - ]: - if contents := parse_vanity(link): - return contents - return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf") - - def process_entry(self, item) -> None: - logger.info(f"Processing {item.title}") - - paper = self.get_contents(item) - if not paper or not paper.get("text"): - return None - - metadata = self._get_arxiv_metadata(self.get_id(item)) - if self.maybe(item.authors) and item.authors.strip(): - authors = item.authors.split(',') - elif metadata and metadata.authors: - authors = metadata.authors - else: - authors = paper.get("authors") or [] - authors = [str(a).strip() for a in authors] - - return self.make_data_entry({ - "url": self.get_item_key(item), - "source": self.name, - "source_type": paper['data_source'], - "title": self.maybe(item.title) or paper.get('title'), - "authors": authors, - "date_published": self._get_published_date(self.maybe(item.date_published) or paper.get('date_published')), - "data_last_modified": str(metadata.updated), - "summary": metadata.summary.replace("\n", " "), - "author_comment": metadata.comment, - "journal_ref": metadata.journal_ref, - "doi": metadata.doi, - "primary_category": metadata.primary_category, - "categories": metadata.categories, - "text": paper['text'], - }) +def get_arxiv_metadata(paper_id) -> arxiv.Result: + """ + Get metadata from arxiv + """ + try: + search = arxiv.Search(id_list=[paper_id], max_results=1) + return next(search.results()) + except Exception as e: + logger.error(e) + return None + + +def get_id(url: str) -> Optional[str]: + if res := re.search(r"https?://arxiv.org/(?:abs|pdf)/(.*?)(?:v\d+)?(?:/|\.pdf)?$", url): + return res.group(1) + + +def canonical_url(url: str) -> str: + if paper_id := get_id(url): + return f'https://arxiv.org/abs/{paper_id}' + return url + + +def get_contents(paper_id: str) -> dict: + for link in [ + f"https://www.arxiv-vanity.com/papers/{paper_id}", + f"https://ar5iv.org/abs/{paper_id}", + ]: + if contents := parse_vanity(link): + return contents + return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf") + + +def get_version(id: str) -> Optional[str]: + if res := re.search(r'.*v(\d+)$', id): + return res.group(1) + + +def is_withdrawn(url: str): + if elem := fetch_element(canonical_url(url), '.extra-services .full-text ul'): + return elem.text.strip().lower() == 'withdrawn' + return None + + +def add_metadata(data, paper_id): + metadata = get_arxiv_metadata(paper_id) + if not metadata: + return {} + return dict({ + "authors": metadata.authors, + "title": metadata.title, + "date_published": metadata.published, + "data_last_modified": metadata.updated.isoformat(), + "summary": metadata.summary.replace("\n", " "), + "comment": metadata.comment, + "journal_ref": metadata.journal_ref, + "doi": metadata.doi, + "primary_category": metadata.primary_category, + "categories": metadata.categories, + "version": get_version(metadata.get_short_id()), + }, **data) + + +def fetch(url) -> Dict: + paper_id = get_id(url) + if not paper_id: + return {'error': 'Could not extract arxiv id'} + + if is_withdrawn(url): + paper = {'status': 'Withdrawn'} + else: + paper = get_contents(paper_id) + + data = add_metadata({ + "url": canonical_url(url), + "source_type": paper.get('data_source'), + }, paper_id) + authors = data.get('authors') or paper.get("authors") or [] + data['authors'] = [str(a).strip() for a in authors] + + return dict(data, **paper) diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index 294ee205..940ba5b4 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -76,16 +76,16 @@ def _get_text(self, contents): def extract_authors(self, article): author_selector = 'div:-soup-contains("Authors") + div .f-body-1' ack_selector = 'div:-soup-contains("Acknowledgments") + div .f-body-1' - - authors_div = article.select_one(author_selector) or article.select_one(ack_selector) - if not authors_div: - return [] - return [ - i.split("(")[0].strip() - for i in authors_div.select_one("p").children - if not i.name - ] + authors_div = article.select_one(author_selector) or article.select_one(ack_selector) + authors = [] + if authors_div: + authors = [ + i.split("(")[0].strip() + for i in authors_div.select_one("p").children + if not i.name + ] + return authors or ["OpenAI Research"] class DeepMindTechnicalBlog(HTMLDataset): diff --git a/align_data/sources/distill/distill.py b/align_data/sources/distill/distill.py index f54fb554..5c889c9a 100644 --- a/align_data/sources/distill/distill.py +++ b/align_data/sources/distill/distill.py @@ -10,7 +10,7 @@ class Distill(RSSDataset): summary_key = "summary" def extract_authors(self, item): - return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] + return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or ["Distill"] def _get_text(self, item): article = item["soup"].find("d-article") or item["soup"].find("dt-article") diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py index f746e552..deb6d4d1 100644 --- a/align_data/sources/greaterwrong/greaterwrong.py +++ b/align_data/sources/greaterwrong/greaterwrong.py @@ -174,7 +174,7 @@ def process_entry(self, item): authors = item["coauthors"] if item["user"]: authors = [item["user"]] + authors - authors = [a["displayName"] for a in authors] + authors = [a["displayName"] for a in authors] or ['anonymous'] return self.make_data_entry( { "title": item["title"], diff --git a/main.py b/main.py index acf10b0a..de8412c5 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from align_data import ALL_DATASETS, get_dataset from align_data.analysis.count_tokens import count_token -from align_data.sources.articles.articles import update_new_items, check_new_articles +from align_data.sources.articles.articles import update_new_items, check_new_articles, update_articles from align_data.pinecone.update_pinecone import PineconeUpdater from align_data.settings import ( METADATA_OUTPUT_SPREADSHEET, @@ -32,7 +32,7 @@ def fetch(self, *names) -> None: """ > This function takes a dataset name and writes the entries of that dataset to a file - :param str name: The name of the dataset to fetch + :param str name: The name of the dataset to fetch, or 'all' for all of them :return: The path to the file that was written to. """ if names == ("all",): @@ -81,6 +81,14 @@ def count_tokens(self, merged_dataset_path: str) -> None: ), "The path to the merged dataset does not exist" count_token(merged_dataset_path) + def update(self, csv_path, delimiter=','): + """Update all articles in the provided csv files, overwriting the provided values and fetching new text if a different url provided. + + :param str csv_path: The path to the csv file to be processed + :param str delimiter: Specifies what's used as a column delimiter + """ + update_articles(csv_path, delimiter) + def update_metadata( self, source_spreadsheet=METADATA_SOURCE_SPREADSHEET, @@ -110,6 +118,8 @@ def fetch_new_articles( def pinecone_update(self, *names) -> None: """ This function updates the Pinecone vector DB. + + :param List[str] names: The name of the dataset to update, or 'all' for all of them """ if names == ("all",): names = ALL_DATASETS diff --git a/migrations/versions/f5a2bcfa6b2c_add_status_column.py b/migrations/versions/f5a2bcfa6b2c_add_status_column.py new file mode 100644 index 00000000..76c89ee0 --- /dev/null +++ b/migrations/versions/f5a2bcfa6b2c_add_status_column.py @@ -0,0 +1,26 @@ +"""Add status column + +Revision ID: f5a2bcfa6b2c +Revises: 59ac3cb671e3 +Create Date: 2023-08-12 15:59:44.741360 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = 'f5a2bcfa6b2c' +down_revision = '59ac3cb671e3' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('articles', sa.Column('status', sa.String(length=256), nullable=True)) + op.add_column('articles', sa.Column('comments', mysql.LONGTEXT(), nullable=True)) + + +def downgrade() -> None: + op.drop_column('articles', 'comments') + op.drop_column('articles', 'status') diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index dd59d0c7..737864b3 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -1,14 +1,17 @@ +from datetime import datetime from unittest.mock import Mock, patch import pandas as pd import pytest from align_data.sources.articles.datasets import ( + ArxivPapers, EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, SpreadsheetDataset, + SpecialDocs, XMLArticles, ) @@ -32,6 +35,26 @@ def articles(): return pd.DataFrame(articles) +@pytest.fixture +def mock_arxiv(): + metadata = Mock( + summary="abstract bla bla", + comment="no comment", + categories="wut", + updated=datetime.fromisoformat("2023-01-01T00:00:00"), + authors=[], + doi="123", + journal_ref="sdf", + primary_category="cat", + ) + metadata.get_short_id.return_value = '2001.11038' + arxiv = Mock() + arxiv.Search.return_value.results.return_value = iter([metadata]) + + with patch("align_data.sources.arxiv_papers.arxiv_papers.arxiv", arxiv): + yield + + def test_spreadsheet_dataset_items_list(articles): dataset = SpreadsheetDataset(name="bla", spreadsheet_id="123", sheet_id="456") df = pd.concat( @@ -100,6 +123,7 @@ def test_pdf_articles_process_item(articles): "text": "pdf contents [bla](asd.com)", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -145,6 +169,7 @@ def test_html_articles_process_entry(articles): "text": "html contents with [proper elements](bla.com) ble ble", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -189,6 +214,7 @@ def test_ebook_articles_process_entry(articles): "text": "html contents with [proper elements](bla.com) ble ble", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -221,6 +247,7 @@ def test_xml_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -247,12 +274,13 @@ def test_markdown_articles_process_entry(articles): "date_published": "2023-01-01T12:32:11Z", "id": None, "source": "bla", - "source_filetype": "md", + "source_filetype": "markdown", "source_type": "something", "summaries": ["the summary of article 0"], "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -287,4 +315,154 @@ def test_doc_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } + + +@patch('requests.get', return_value=Mock(content='')) +def test_arxiv_process_entry(_, mock_arxiv): + dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") + item = Mock( + title="this is the title", + url="https://arxiv.org/abs/2001.11038", + authors="", + date_published="2020-01-29", + ) + contents = { + "text": "this is the text", + "date_published": "December 12, 2021", + "authors": ["mr blobby"], + "source_type": "html", + } + with patch( + "align_data.sources.arxiv_papers.arxiv_papers.parse_vanity", return_value=contents + ): + assert dataset.process_entry(item).to_dict() == { + "comment": "no comment", + "authors": ["mr blobby"], + "categories": "wut", + "data_last_modified": "2023-01-01T00:00:00", + "date_published": "2020-01-29T00:00:00Z", + "doi": "123", + "id": None, + "journal_ref": "sdf", + "primary_category": "cat", + "source": "asd", + "source_type": "html", + "summaries": ["abstract bla bla"], + "text": "this is the text", + "title": "this is the title", + "url": "https://arxiv.org/abs/2001.11038", + } + + +def test_arxiv_process_entry_retracted(mock_arxiv): + dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") + item = Mock( + title="this is the title", + url="https://arxiv.org/abs/2001.11038", + authors="", + date_published="2020-01-29", + ) + response = """ +
+ """ + + with patch('requests.get', return_value=Mock(content=response)): + article = dataset.process_entry(item) + assert article.status == 'Withdrawn' + assert article.to_dict() == { + "comment": "no comment", + "authors": [], + "categories": "wut", + "data_last_modified": "2023-01-01T00:00:00", + "date_published": "2020-01-29T00:00:00Z", + "doi": "123", + "id": None, + "journal_ref": "sdf", + "primary_category": "cat", + "source": "asd", + "source_type": None, + "summaries": ["abstract bla bla"], + "title": "this is the title", + "url": "https://arxiv.org/abs/2001.11038", + "text": None, + } + + +def test_special_docs_process_entry(): + dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") + item = Mock( + title="this is the title", + url="https://bla.bla.bla", + authors="mr. blobby", + date_published="2023-10-02T01:23:45", + source_type=None, + source_url="https://ble.ble.com" + ) + contents = { + "text": "this is the text", + "date_published": "December 12, 2021", + "authors": ["mr blobby"], + "source_type": "html", + } + + with patch("align_data.sources.articles.datasets.item_metadata", return_value=contents): + assert dataset.process_entry(item).to_dict() == { + 'authors': ['mr. blobby'], + 'date_published': '2023-10-02T01:23:45Z', + 'id': None, + 'source': 'html', + 'source_url': "https://ble.ble.com", + 'source_type': 'html', + 'summaries': [], + 'text': 'this is the text', + 'title': 'this is the title', + 'url': 'https://bla.bla.bla', + } + + +@patch('requests.get', return_value=Mock(content='')) +def test_special_docs_process_entry_arxiv(_, mock_arxiv): + dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") + item = Mock( + title="this is the title", + url="https://arxiv.org/abs/2001.11038", + authors="", + date_published="2020-01-29", + ) + contents = { + "text": "this is the text", + "date_published": "December 12, 2021", + "authors": ["mr blobby"], + "source_type": "pdf", + } + + with patch( + "align_data.sources.arxiv_papers.arxiv_papers.parse_vanity", return_value=contents + ): + assert dataset.process_entry(item).to_dict() == { + "comment": "no comment", + "authors": ["mr blobby"], + "categories": "wut", + "data_last_modified": "2023-01-01T00:00:00", + "date_published": "2020-01-29T00:00:00Z", + "doi": "123", + "id": None, + "journal_ref": "sdf", + "primary_category": "cat", + "source": "arxiv", + "source_type": "pdf", + "summaries": ["abstract bla bla"], + "text": "this is the text", + "title": "this is the title", + "url": "https://arxiv.org/abs/2001.11038", + } diff --git a/tests/align_data/articles/test_google_cloud.py b/tests/align_data/articles/test_google_cloud.py index 3232978a..7b268e43 100644 --- a/tests/align_data/articles/test_google_cloud.py +++ b/tests/align_data/articles/test_google_cloud.py @@ -78,7 +78,7 @@ def test_parse_grobid(): 'authors': ['Cullen Oâ\x80\x99Keefe'], 'text': 'This is the contents', 'title': 'The title!!', - 'data_source': 'xml', + 'source_type': 'xml', } @@ -100,7 +100,7 @@ def test_parse_grobid_no_body(): """ - assert parse_grobid(xml) == {'error': 'No contents in XML file', 'data_source': 'xml'} + assert parse_grobid(xml) == {'error': 'No contents in XML file', 'source_type': 'xml'} @pytest.mark.parametrize('header, expected', ( @@ -160,7 +160,7 @@ def test_extract_gdrive_contents_ebook(header): assert extract_gdrive_contents(url) == { 'downloaded_from': 'google drive', 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'data_source': 'ebook', + 'source_type': 'ebook', } @@ -185,7 +185,7 @@ def test_extract_gdrive_contents_html(): 'downloaded_from': 'google drive', 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', 'text': 'bla bla', - 'data_source': 'html', + 'source_type': 'html', } @@ -207,7 +207,7 @@ def test_extract_gdrive_contents_xml(): 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', 'text': 'This is the contents', 'title': 'The title!!', - 'data_source': 'xml', + 'source_type': 'xml', } @@ -238,7 +238,7 @@ def fetcher(link, *args, **kwargs): 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', 'text': 'This is the contents', 'title': 'The title!!', - 'data_source': 'xml', + 'source_type': 'xml', } diff --git a/tests/align_data/articles/test_updater.py b/tests/align_data/articles/test_updater.py new file mode 100644 index 00000000..72cf7243 --- /dev/null +++ b/tests/align_data/articles/test_updater.py @@ -0,0 +1,204 @@ +from unittest.mock import Mock, patch +from csv import DictWriter +from numpy import source + +import pandas as pd +import pytest +from align_data.db.models import Article +from align_data.sources.articles.updater import ReplacerDataset, Item + +SAMPLE_UPDATES = [ + {}, + {'title': 'no id - should be ignored'}, + + {'id': '122', 'hash_id': 'deadbeef000'}, + { + 'id': '123', 'hash_id': 'deadbeef001', + 'title': 'bla bla', + 'url': 'http://bla.com', + 'source_url': 'http://bla.bla.com', + 'authors': 'mr. blobby, johnny', + }, { + 'id': '124', + 'title': 'no hash id', + 'url': 'http://bla.com', + 'source_url': 'http://bla.bla.com', + 'authors': 'mr. blobby', + }, { + 'hash_id': 'deadbeef002', + 'title': 'no id', + 'url': 'http://bla.com', + 'source_url': 'http://bla.bla.com', + 'authors': 'mr. blobby', + }, { + 'id': '125', + 'title': 'no hash id, url or title', + 'authors': 'mr. blobby', + } +] + +@pytest.fixture +def csv_file(tmp_path): + filename = tmp_path / 'data.csv' + with open(filename, 'w', newline='') as csvfile: + fieldnames = ['id', 'hash_id', 'title', 'url', 'source_url', 'authors'] + writer = DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + for row in SAMPLE_UPDATES: + writer.writerow(row) + return filename + + +def test_items_list(csv_file): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + def mock_entries(): + return [ + Mock( + _id=dataset.maybe(v, 'id'), + id=dataset.maybe(v, 'hash_id'), + title=dataset.maybe(v, 'title'), + url=dataset.maybe(v, 'url'), + authors=dataset.maybe(v, 'authors') + ) + for v in dataset.csv_items + ] + + with patch.object(dataset, 'read_entries', mock_entries): + items = dataset.items_list + assert len(items) == 5, "items_list should only contain items with valid ids - something is wrong" + for item in items: + assert dataset.maybe(item.updates, 'id') == item.article._id + assert dataset.maybe(item.updates, 'hash_id') == item.article.id + assert dataset.maybe(item.updates, 'title') == item.article.title + assert dataset.maybe(item.updates, 'url') == item.article.url + assert dataset.maybe(item.updates, 'authors') == item.article.authors + + +@pytest.mark.parametrize('updates', ( + Mock(url='http://some.other.url'), + Mock(source_url='http://some.other.url'), + Mock(url='http://some.other.url', source_url='http://another.url'), +)) +def test_update_text(csv_file, updates): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Mock(text='this should be changed', status='as should this', url='http:/bla.bla.com') + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + dataset.update_text(updates, article) + assert article.text == 'bla bla bla' + assert article.status == None + + +@pytest.mark.parametrize('updates', ( + Mock(url='http://some.other.url'), + Mock(source_url='http://some.other.url'), + Mock(url='http://some.other.url', source_url='http://another.url'), +)) +def test_update_text_error(csv_file, updates): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Mock(text='this should not be changed', status='but this should be', url='http:/bla.bla.com') + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'error': 'oh noes!'}): + dataset.update_text(updates, article) + assert article.text == 'this should not be changed' + assert article.status == 'oh noes!' + + +@pytest.mark.parametrize('updates', ( + Mock(url='http://bla.bla.com', source_url=None, comment='Same url as article, no source_url'), + Mock(url='http://bla.bla.com', source_url='', comment='Same url as article, empty source_url'), + Mock(url=None, source_url=None, comment='no urls provided'), +)) +def test_update_text_no_update(csv_file, updates): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Mock(text='this should not be changed', status='as should not this', url='http://bla.bla.com') + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + dataset.update_text(updates, article) + assert article.text == 'this should not be changed' + assert article.status == 'as should not this' + + +def test_process_entry(csv_file): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Article( + _id=123, id='deadbeef0123', + title='this should be changed', + url='this should be changed', + text='this should be changed', + authors='this should be changed', + date_published='this should be changed', + id_fields=['url', 'title'], + ) + + updates = Mock( + id='123', + hash_id='deadbeef001', + title='bla bla', + url='http://bla.com', + source_url='http://bla.bla.com', + source='tests', + authors='mr. blobby, johnny', + date_published='2000-12-23T10:32:43Z', + ) + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + assert dataset.process_entry(Item(updates, article)).to_dict() == { + 'authors': ['mr. blobby', 'johnny'], + 'date_published': '2000-12-23T10:32:43Z', + 'id': 'd8d8cad8d28739a0862654a0e6e8ce6e', + 'source': 'tests', + 'source_type': None, + 'summaries': [], + 'text': 'bla bla bla', + 'title': 'bla bla', + 'url': 'http://bla.com', + 'source_url': 'http://bla.bla.com', + } + + +def test_process_entry_empty(csv_file): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Article( + _id=123, id='deadbeef0123', + title='this should not be changed', + url='this should not be changed', + source='this should not be changed', + authors='this should not be changed', + + text='this should be changed', + date_published='this should be changed', + id_fields=['url', 'title'], + ) + + updates = Mock( + id='123', + hash_id='deadbeef001', + title=None, + url='', + source_url='http://bla.bla.com', + source=' ', + authors=' \n \n \t \t ', + date_published='2000-12-23T10:32:43Z', + ) + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + assert dataset.process_entry(Item(updates, article)).to_dict() == { + 'authors': ['this should not be changed'], + 'date_published': '2000-12-23T10:32:43Z', + 'id': '606e9224254f508d297bcb17bcc6d104', + 'source': 'this should not be changed', + 'source_type': None, + 'summaries': [], + 'text': 'bla bla bla', + 'title': 'this should not be changed', + 'url': 'this should not be changed', + 'source_url': 'http://bla.bla.com', + } diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py index fcefeacd..879e6295 100644 --- a/tests/align_data/common/test_alignment_dataset.py +++ b/tests/align_data/common/test_alignment_dataset.py @@ -75,78 +75,42 @@ def test_data_entry_id_from_urls_and_title(): ) -def test_data_entry_no_url_and_title(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry({"key1": 12, "key2": 312}) - with pytest.raises( - AssertionError, - match="Entry is missing the following fields: \\['url', 'title'\\]", - ): - Article.before_write(None, None, entry) - - -def test_data_entry_no_url(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry( - {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines"} - ) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['url'\\]" - ): - Article.before_write(None, None, entry) - - -def test_data_entry_none_url(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry({"key1": 12, "key2": 312, "url": None}) - with pytest.raises( - AssertionError, - match="Entry is missing the following fields: \\['url', 'title'\\]", - ): - Article.before_write(None, None, entry) - - -def test_data_entry_none_title(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry( - {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": None} - ) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['title'\\]" - ): - Article.before_write(None, None, entry) - - -def test_data_entry_empty_url_and_title(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry({"key1": 12, "key2": 312, "url": "", "title": ""}) - with pytest.raises( - AssertionError, - match="Entry is missing the following fields: \\['url', 'title'\\]", - ): - Article.before_write(None, None, entry) - - -def test_data_entry_empty_url_only(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry( - {"key1": 12, "key2": 312, "url": "", "title": "once upon a time"} - ) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['url'\\]" - ): - Article.before_write(None, None, entry) - - -def test_data_entry_empty_title_only(): +@pytest.mark.parametrize('item, error', ( + ( + {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines", "url": "asd"}, + 'missing fields: date_published, source, text' + ), + ( + {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "text": "asdasd", "title": "asdasd"}, + 'missing fields: date_published, source' + ), + ( + { + "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", + "source": "dwe", "date_published": "dwe" + }, + 'missing fields: text' + ), + ( + { + "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", + "text": "asdasd", "date_published": "dwe" + }, + 'missing fields: source' + ), + ( + { + "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", "text": "asdasd", "source": "dwe" + }, + 'missing fields: date_published' + ), +)) +def test_data_entry_missing(item, error): dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry( - {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": ""} - ) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['title'\\]" - ): - Article.before_write(None, None, entry) + entry = dataset.make_data_entry(item) + Article.before_write(None, None, entry) + assert entry.status == 'Missing fields' + assert entry.comments == error def test_data_entry_verify_id_passes(): @@ -172,26 +136,38 @@ def test_data_entry_verify_id_fails(): "id": "f2b4e02fc1dd8ae43845e4f930f2d84f", } ) - with pytest.raises(AssertionError, match="Entry id does not match id_fields"): + expected = 'Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97' + with pytest.raises(AssertionError, match=expected): entry.verify_id() -def test_data_entry_id_fields_url_no_url(): - dataset = AlignmentDataset(name="blaa", id_fields=["url"]) - entry = dataset.make_data_entry({"source": "arbital", "text": "once upon a time"}) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['url'\\]" - ): - Article.before_write(None, None, entry) - - -def test_data_entry_id_fields_url_empty_url(): - dataset = AlignmentDataset(name="blaa", id_fields=["url"]) - entry = dataset.make_data_entry({"url": ""}) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['url'\\]" - ): - Article.before_write(None, None, entry) +@pytest.mark.parametrize( + "data, error", + ( + ({"id": "123"}, "Entry is missing the following fields: \\['url', 'title'\\]"), + ( + {"id": "123", "url": None}, + "Entry is missing the following fields: \\['url', 'title'\\]", + ), + ( + {"id": "123", "url": "www.google.com/"}, + "Entry is missing the following fields: \\['title'\\]", + ), + ( + {"id": "123", "url": "google", "text": None}, + "Entry is missing the following fields: \\['title'\\]", + ), + ( + {"id": "123", "url": "", "title": ""}, + "Entry is missing the following fields: \\['url', 'title'\\]", + ), + ), +) +def test_data_entry_verify_fields_fails(data, error): + dataset = AlignmentDataset(name="blaa", id_fields=["url", "title"]) + entry = dataset.make_data_entry(data) + with pytest.raises(AssertionError, match=error): + entry.verify_id_fields() def test_data_entry_id_fields_url(): @@ -223,73 +199,6 @@ def test_data_entry_different_id_from_different_url(): assert entry1.generate_id_string() != entry2.generate_id_string() -@pytest.mark.parametrize( - "data, error", - ( - ({"text": "bla bla bla"}, "Entry is missing id"), - ({"text": "bla bla bla", "id": None}, "Entry is missing id"), - ( - { - "id": "123", - "url": "www.google.com/winter_wonderland", - "title": "winter wonderland", - }, - "Entry id 123 does not match id from id_fields, [0-9a-fA-F]{32}", - ), - ( - { - "id": "457c21e0ecabebcb85c12022d481d9f4", - "url": "www.google.com", - "title": "winter wonderland", - }, - "Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}", - ), - ( - { - "id": "457c21e0ecabebcb85c12022d481d9f4", - "url": "www.google.com", - "title": "Once upon a time", - }, - "Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}", - ), - ), -) -def test_data_entry_verify_id_fails(data, error): - dataset = AlignmentDataset(name="blaa", id_fields=["url", "title"]) - entry = dataset.make_data_entry(data) - with pytest.raises(AssertionError, match=error): - entry.verify_id() - - -@pytest.mark.parametrize( - "data, error", - ( - ({"id": "123"}, "Entry is missing the following fields: \\['url', 'title'\\]"), - ( - {"id": "123", "url": None}, - "Entry is missing the following fields: \\['url', 'title'\\]", - ), - ( - {"id": "123", "url": "www.google.com/"}, - "Entry is missing the following fields: \\['title'\\]", - ), - ( - {"id": "123", "url": "google", "text": None}, - "Entry is missing the following fields: \\['title'\\]", - ), - ( - {"id": "123", "url": "", "title": ""}, - "Entry is missing the following fields: \\['url', 'title'\\]", - ), - ), -) -def test_data_entry_verify_fields_fails(data, error): - dataset = AlignmentDataset(name="blaa", id_fields=["url", "title"]) - entry = dataset.make_data_entry(data) - with pytest.raises(AssertionError, match=error): - entry.verify_fields() - - @pytest.fixture def numbers_dataset(): """Make a dataset that raises its items to the power of 2.""" diff --git a/tests/align_data/test_arxiv.py b/tests/align_data/test_arxiv.py index 5817fd2c..898fef70 100644 --- a/tests/align_data/test_arxiv.py +++ b/tests/align_data/test_arxiv.py @@ -1,7 +1,5 @@ -from datetime import datetime -from unittest.mock import patch, Mock import pytest -from align_data.sources.arxiv_papers.arxiv_papers import ArxivPapers +from align_data.sources.arxiv_papers.arxiv_papers import get_id, canonical_url, get_version @pytest.mark.parametrize( @@ -13,55 +11,28 @@ ), ) def test_get_id(url, expected): - dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") - assert dataset.get_id(Mock(url="https://arxiv.org/abs/2001.11038")) == "2001.11038" + assert get_id("https://arxiv.org/abs/2001.11038") == "2001.11038" -def test_process_entry(): - dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") - item = Mock( - title="this is the title", - url="https://arxiv.org/abs/2001.11038", - authors="", - date_published="2020-01-29", - ) - contents = { - "text": "this is the text", - "date_published": "December 12, 2021", - "authors": ["mr blobby"], - "data_source": "html", - } - metadata = Mock( - summary="abstract bla bla", - comment="no comment", - categories="wut", - updated="2023-01-01", - authors=[], - doi="123", - journal_ref="sdf", - primary_category="cat", - ) - arxiv = Mock() - arxiv.Search.return_value.results.return_value = iter([metadata]) +@pytest.mark.parametrize('url, expected', ( + ("http://bla.bla", "http://bla.bla"), + ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/math/2001.11038", "https://arxiv.org/abs/math/2001.11038"), +)) +def test_canonical_url(url, expected): + assert canonical_url(url) == expected - with patch( - "align_data.arxiv_papers.arxiv_papers.parse_vanity", return_value=contents - ): - with patch("align_data.arxiv_papers.arxiv_papers.arxiv", arxiv): - assert dataset.process_entry(item).to_dict() == { - "author_comment": "no comment", - "authors": ["mr blobby"], - "categories": "wut", - "data_last_modified": "2023-01-01", - "date_published": "2020-01-29T00:00:00Z", - "doi": "123", - "id": None, - "journal_ref": "sdf", - "primary_category": "cat", - "source": "asd", - "source_type": "html", - "summaries": ["abstract bla bla"], - "text": "this is the text", - "title": "this is the title", - "url": "https://arxiv.org/abs/2001.11038", - } + +@pytest.mark.parametrize('id, version', ( + ('123.123', None), + ('math/312', None), + ('3123123v1', '1'), + ('3123123v123', '123'), +)) +def test_get_version(id, version): + assert get_version(id) == version diff --git a/tests/align_data/test_blogs.py b/tests/align_data/test_blogs.py index d09edd16..65e98c10 100644 --- a/tests/align_data/test_blogs.py +++ b/tests/align_data/test_blogs.py @@ -682,7 +682,7 @@ def test_openai_research_get_text(): """, - [], + ["OpenAI Research"], ), ), ) diff --git a/tests/align_data/test_distill.py b/tests/align_data/test_distill.py index 02a45d98..6ced02df 100644 --- a/tests/align_data/test_distill.py +++ b/tests/align_data/test_distill.py @@ -30,6 +30,13 @@ def test_extract_authors(): ] +def test_extract_authors_none(): + dataset = Distill(name="distill", url="bla.bla") + + soup = BeautifulSoup("", "html.parser") + assert dataset.extract_authors({"soup": soup}) == ["Distill"] + + @pytest.mark.parametrize( "text", ( diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/test_greater_wrong.py index 2f89bac0..29140794 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/test_greater_wrong.py @@ -200,3 +200,37 @@ def test_process_entry(dataset): "votes": 12, "words": 123, } + + +def test_process_entry_no_authors(dataset): + entry = { + "coauthors": [], + "user": {}, + "title": "The title", + "pageUrl": "http://example.com/bla", + "modifiedAt": "2001-02-10", + "postedAt": "2012/02/01 12:23:34", + "htmlBody": '\n\n bla bla a link ', + "voteCount": 12, + "baseScore": 32, + "tags": [{"name": "tag1"}, {"name": "tag2"}], + "wordCount": 123, + "commentCount": 423, + } + assert dataset.process_entry(entry).to_dict() == { + "authors": ["anonymous"], + "comment_count": 423, + "date_published": "2012-02-01T12:23:34Z", + "id": None, + "karma": 32, + "modified_at": "2001-02-10", + "source": "bla", + "source_type": "GreaterWrong", + "summaries": [], + "tags": ["tag1", "tag2"], + "text": "bla bla [a link](bla.com)", + "title": "The title", + "url": "http://example.com/bla", + "votes": 12, + "words": 123, + }