From bc656b296f69618e9b19196cc5a10df474192c1e Mon Sep 17 00:00:00 2001 From: Thomas Lemoine Date: Sun, 13 Aug 2023 05:07:19 -0400 Subject: [PATCH 1/7] text splitter minor refactor --- align_data/pinecone/text_splitter.py | 86 +++++++++++++++++----------- 1 file changed, 52 insertions(+), 34 deletions(-) diff --git a/align_data/pinecone/text_splitter.py b/align_data/pinecone/text_splitter.py index 76bb29b8..615ce2fe 100644 --- a/align_data/pinecone/text_splitter.py +++ b/align_data/pinecone/text_splitter.py @@ -4,39 +4,50 @@ from langchain.text_splitter import TextSplitter from nltk.tokenize import sent_tokenize +# TODO: Fix this. +# sent_tokenize has strange behavior sometimes: 'The units could be anything (characters, words, sentences, etc.), depending on how you want to chunk your text.' +# splits into ['The units could be anything (characters, words, sentences, etc.', '), depending on how you want to chunk your text.'] + +StrToIntFunction = Callable[[str], int] +StrIntBoolToStrFunction = Callable[[str, int, bool], str] + +def default_truncate_function(string: str, length: int, from_end: bool = False) -> str: + return string[-length:] if from_end else string[:length] class ParagraphSentenceUnitTextSplitter(TextSplitter): """A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc). @param min_chunk_size: The minimum number of units in a chunk. @param max_chunk_size: The maximum number of units in a chunk. - @param length_function: A function that returns the length of a string in units. + @param length_function: A function that returns the length of a string in units. Defaults to len(). @param truncate_function: A function that truncates a string to a given unit length. """ - DEFAULT_MIN_CHUNK_SIZE = 900 - DEFAULT_MAX_CHUNK_SIZE = 1100 - DEFAULT_LENGTH_FUNCTION = lambda string: len(string) - DEFAULT_TRUNCATE_FUNCTION = lambda string, length, from_end=False: string[-length:] if from_end else string[:length] + DEFAULT_MIN_CHUNK_SIZE: int = 900 + DEFAULT_MAX_CHUNK_SIZE: int = 1100 + DEFAULT_LENGTH_FUNCTION: StrToIntFunction = len + DEFAULT_TRUNCATE_FUNCTION: StrIntBoolToStrFunction = default_truncate_function def __init__( self, min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, - length_function: Callable[[str], int] = DEFAULT_LENGTH_FUNCTION, - truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION, + length_function: StrToIntFunction = DEFAULT_LENGTH_FUNCTION, + truncate_function: StrIntBoolToStrFunction = DEFAULT_TRUNCATE_FUNCTION, **kwargs: Any ): super().__init__(**kwargs) self.min_chunk_size = min_chunk_size - self.max_chunk_size = max_chunk_size + self.max_chunk_size = max_chunk_size + assert self.min_chunk_size <= self.max_chunk_size, "min_chunk_size must be less than or equal to max_chunk_size" self._length_function = length_function self._truncate_function = truncate_function def split_text(self, text: str) -> List[str]: - blocks = [] - current_block = "" + """Split text into chunks of length between min_chunk_size and max_chunk_size.""" + blocks: List[str] = [] + current_block: str = "" paragraphs = text.split("\n\n") for paragraph in paragraphs: @@ -52,10 +63,9 @@ def split_text(self, text: str) -> List[str]: continue blocks = self._handle_remaining_text(current_block, blocks) - return [block.strip() for block in blocks] - def _handle_large_paragraph(self, current_block, blocks, paragraph): + def _handle_large_paragraph(self, current_block: str, blocks: List[str], paragraph: str) -> str: # Undo adding the whole paragraph current_block = current_block[:-(len(paragraph)+2)] # +2 accounts for "\n\n" @@ -70,36 +80,44 @@ def _handle_large_paragraph(self, current_block, blocks, paragraph): blocks.append(current_block) current_block = "" else: - current_block = self._truncate_large_block(current_block, blocks, sentence) - + current_block = self._truncate_large_block(current_block, blocks) return current_block - def _truncate_large_block(self, current_block, blocks, sentence): + def _truncate_large_block(self, current_block: str, blocks: List[str]) -> str: while self._length_function(current_block) > self.max_chunk_size: - # Truncate current_block to max size, set remaining sentence as next sentence + # Truncate current_block to max size, set remaining text as current_block truncated_block = self._truncate_function(current_block, self.max_chunk_size) blocks.append(truncated_block) - remaining_sentence = current_block[len(truncated_block):].lstrip() - current_block = sentence = remaining_sentence + current_block = current_block[len(truncated_block):].lstrip() return current_block - def _handle_remaining_text(self, current_block, blocks): + def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str]: if blocks == []: # no blocks were added - return [current_block] - elif current_block: # any leftover text - len_current_block = self._length_function(current_block) - if len_current_block < self.min_chunk_size: - # it needs to take the last min_chunk_size-len_current_block units from the previous block - previous_block = blocks[-1] - required_units = self.min_chunk_size - len_current_block # calculate the required units - - part_prev_block = self._truncate_function(previous_block, required_units, from_end=True) # get the required units from the previous block - last_block = part_prev_block + current_block - - blocks.append(last_block) - else: - blocks.append(current_block) + return [last_block] + elif last_block: # any leftover text + len_last_block = self._length_function(last_block) + len_to_add_to_last_block_from_prev_block = self.min_chunk_size - len_last_block + if len_to_add_to_last_block_from_prev_block > 0: + # Add text from previous block to last block if the last_block is too short + part_prev_block = self._truncate_function( + string=blocks[-1], + length=len_to_add_to_last_block_from_prev_block, + from_end=True + ) + last_block = part_prev_block + last_block + + blocks.append(last_block) + + return blocks + - return blocks \ No newline at end of file +if __name__ == '__main__': + #Test + splitter = ParagraphSentenceUnitTextSplitter() + text = """This is a TextSplitter class implementation which is used for dividing a piece of text into chunks based on certain criteria. It inherits from another TextSplitter class and overrides some of its methods to provide custom text splitting functionality. Here's a high-level overview: The TextSplitter receives a string of text and splits it into blocks, with each block being a sequence of paragraphs, sentences, and then units (chars/words/tokens/etc). The split_text method is the main method where the splitting occurs. It starts by splitting the text into paragraphs and then iteratively goes through each paragraph, checking if the length of the current block of text is larger than max_chunk_size or smaller than min_chunk_size, and acting accordingly. If the current block becomes too large, the _handle_large_paragraph method is called, which reverts the addition of the last paragraph and splits it into sentences instead, adding them one by one to the current block. If adding a sentence makes the block too large, the _truncate_large_block method is called. It repeatedly truncates the block to the max_chunk_size and moves the remaining text to the next block until the block is small enough. After all paragraphs have been processed, the _handle_remaining_text method is called to handle any text that didn't make it into a block. The a = b = c pattern is used in Python to assign the same value to multiple variables at once. In this case, current_block = sentence = remaining_sentence is setting both current_block and sentence to the value of remaining_sentence. This means that in the next iteration, both current_block and sentence will start as the remaining part of the sentence that didn't fit into the previous block. The _truncate_function is used to truncate a string to a certain length. By default, it either takes the first or last length characters from the string, depending on the from_end argument. However, you can provide a different function to use for truncation when you create the TextSplitter. Note that this class requires a _length_function to be defined in a parent or in this class itself. This function should take a string and return its length in units. The units could be anything (characters, words, sentences, etc.), depending on how you want to chunk your text.""" + text += ' ' + text += ' '.join(str(i) for i in range(900)) + chunks = splitter.split_text(text) + print('\n\n\n-----------------------------------------------------\n\n\n'.join(chunks)) \ No newline at end of file From 603b4cae288dbc14a4c97d838bdf714ddea0265d Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 14 Aug 2023 10:05:23 +0200 Subject: [PATCH 2/7] Default authors (#130) --- align_data/common/alignment_dataset.py | 1 + align_data/sources/blogs/blogs.py | 18 +++++----- align_data/sources/distill/distill.py | 2 +- .../sources/greaterwrong/greaterwrong.py | 2 +- tests/align_data/test_blogs.py | 2 +- tests/align_data/test_distill.py | 7 ++++ tests/align_data/test_greater_wrong.py | 34 +++++++++++++++++++ 7 files changed, 54 insertions(+), 12 deletions(-) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 22d7df02..fba20ce2 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -24,6 +24,7 @@ "title": None, "url": None, "authors": lambda: [], + "source_type": None, } logger = logging.getLogger(__name__) diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index 294ee205..940ba5b4 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -76,16 +76,16 @@ def _get_text(self, contents): def extract_authors(self, article): author_selector = 'div:-soup-contains("Authors") + div .f-body-1' ack_selector = 'div:-soup-contains("Acknowledgments") + div .f-body-1' - - authors_div = article.select_one(author_selector) or article.select_one(ack_selector) - if not authors_div: - return [] - return [ - i.split("(")[0].strip() - for i in authors_div.select_one("p").children - if not i.name - ] + authors_div = article.select_one(author_selector) or article.select_one(ack_selector) + authors = [] + if authors_div: + authors = [ + i.split("(")[0].strip() + for i in authors_div.select_one("p").children + if not i.name + ] + return authors or ["OpenAI Research"] class DeepMindTechnicalBlog(HTMLDataset): diff --git a/align_data/sources/distill/distill.py b/align_data/sources/distill/distill.py index f54fb554..5c889c9a 100644 --- a/align_data/sources/distill/distill.py +++ b/align_data/sources/distill/distill.py @@ -10,7 +10,7 @@ class Distill(RSSDataset): summary_key = "summary" def extract_authors(self, item): - return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] + return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or ["Distill"] def _get_text(self, item): article = item["soup"].find("d-article") or item["soup"].find("dt-article") diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py index f746e552..deb6d4d1 100644 --- a/align_data/sources/greaterwrong/greaterwrong.py +++ b/align_data/sources/greaterwrong/greaterwrong.py @@ -174,7 +174,7 @@ def process_entry(self, item): authors = item["coauthors"] if item["user"]: authors = [item["user"]] + authors - authors = [a["displayName"] for a in authors] + authors = [a["displayName"] for a in authors] or ['anonymous'] return self.make_data_entry( { "title": item["title"], diff --git a/tests/align_data/test_blogs.py b/tests/align_data/test_blogs.py index d09edd16..65e98c10 100644 --- a/tests/align_data/test_blogs.py +++ b/tests/align_data/test_blogs.py @@ -682,7 +682,7 @@ def test_openai_research_get_text(): """, - [], + ["OpenAI Research"], ), ), ) diff --git a/tests/align_data/test_distill.py b/tests/align_data/test_distill.py index 02a45d98..6ced02df 100644 --- a/tests/align_data/test_distill.py +++ b/tests/align_data/test_distill.py @@ -30,6 +30,13 @@ def test_extract_authors(): ] +def test_extract_authors_none(): + dataset = Distill(name="distill", url="bla.bla") + + soup = BeautifulSoup("", "html.parser") + assert dataset.extract_authors({"soup": soup}) == ["Distill"] + + @pytest.mark.parametrize( "text", ( diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/test_greater_wrong.py index 2f89bac0..29140794 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/test_greater_wrong.py @@ -200,3 +200,37 @@ def test_process_entry(dataset): "votes": 12, "words": 123, } + + +def test_process_entry_no_authors(dataset): + entry = { + "coauthors": [], + "user": {}, + "title": "The title", + "pageUrl": "http://example.com/bla", + "modifiedAt": "2001-02-10", + "postedAt": "2012/02/01 12:23:34", + "htmlBody": '\n\n bla bla a link ', + "voteCount": 12, + "baseScore": 32, + "tags": [{"name": "tag1"}, {"name": "tag2"}], + "wordCount": 123, + "commentCount": 423, + } + assert dataset.process_entry(entry).to_dict() == { + "authors": ["anonymous"], + "comment_count": 423, + "date_published": "2012-02-01T12:23:34Z", + "id": None, + "karma": 32, + "modified_at": "2001-02-10", + "source": "bla", + "source_type": "GreaterWrong", + "summaries": [], + "tags": ["tag1", "tag2"], + "text": "bla bla [a link](bla.com)", + "title": "The title", + "url": "http://example.com/bla", + "votes": 12, + "words": 123, + } From 188c82703525282e70cd1a13d508d5e72a8053e3 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 14 Aug 2023 11:31:25 +0200 Subject: [PATCH 3/7] Special docs (#124) * Export special docs to HF * Add status column * Add comments column --- .github/workflows/upload-to-huggingface.yml | 5 +- align_data/common/alignment_dataset.py | 8 +- align_data/db/models.py | 23 +++- align_data/sources/articles/datasets.py | 27 +++-- .../f5a2bcfa6b2c_add_status_column.py | 26 +++++ tests/align_data/articles/test_datasets.py | 2 +- .../common/test_alignment_dataset.py | 108 +++--------------- 7 files changed, 85 insertions(+), 114 deletions(-) create mode 100644 migrations/versions/f5a2bcfa6b2c_add_status_column.py diff --git a/.github/workflows/upload-to-huggingface.yml b/.github/workflows/upload-to-huggingface.yml index eaac2ceb..7788ad32 100644 --- a/.github/workflows/upload-to-huggingface.yml +++ b/.github/workflows/upload-to-huggingface.yml @@ -28,22 +28,19 @@ jobs: - distill - eaforum - eleuther.ai - - gdocs - generative.ink - gwern_blog - html_articles - importai - jsteinhardt_blog - lesswrong - - markdown - miri - ml_safety_newsletter - openai.research - - pdfs - rob_miles_ai_safety + - special_docs - vkrakovna_blog - yudkowsky_blog - - xmls uses: ./.github/workflows/push-dataset.yml with: diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index fba20ce2..6bbb2109 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -86,7 +86,7 @@ def make_data_entry(self, data, **kwargs) -> Article: article = Article( id_fields=self.id_fields, - meta={k: v for k, v in data.items() if k not in INIT_DICT}, + meta={k: v for k, v in data.items() if k not in INIT_DICT and v is not None}, **{k: v for k, v in data.items() if k in INIT_DICT}, ) self._add_authors(article, authors) @@ -107,10 +107,14 @@ def to_jsonl(self, out_path=None, filename=None) -> Path: jsonl_writer.write(article.to_dict()) return filename.resolve() + @property + def _query_items(self): + return select(Article).where(Article.source == self.name) + def read_entries(self, sort_by=None): """Iterate through all the saved entries.""" with make_session() as session: - query = select(Article).where(Article.source == self.name) + query = self._query_items if sort_by is not None: query = query.order_by(sort_by) for item in session.scalars(query): diff --git a/align_data/db/models.py b/align_data/db/models.py index 029c3d1c..b3da7042 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -11,11 +11,10 @@ String, Boolean, Text, - Float, func, event, ) -from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.dialects.mysql import LONGTEXT from align_data.settings import PINECONE_METADATA_KEYS @@ -58,6 +57,8 @@ class Article(Base): date_updated: Mapped[Optional[datetime]] = mapped_column( DateTime, onupdate=func.current_timestamp() ) + status: Mapped[Optional[str]] = mapped_column(String(256)) + comments: Mapped[Optional[str]] = mapped_column(LONGTEXT) # Editor comments. Can be anything pinecone_update_required: Mapped[bool] = mapped_column(Boolean, default=False) @@ -90,8 +91,12 @@ def generate_id_string(self) -> str: "utf-8" ) + @property + def missing_fields(self): + return [field for field in self.__id_fields if not getattr(self, field)] + def verify_fields(self): - missing = [field for field in self.__id_fields if not getattr(self, field)] + missing = self.missing_fields assert not missing, f"Entry is missing the following fields: {missing}" def verify_id(self): @@ -107,7 +112,7 @@ def update(self, other): for field in self.__table__.columns.keys(): if field not in ["id", "hash_id", "metadata"] and getattr(other, field): setattr(self, field, getattr(other, field)) - self.meta.update({k: v for k, v in other.meta.items() if k and v}) + self.meta = dict((self.meta or {}), **{k: v for k, v in other.meta.items() if k and v}) if other._id: self._id = other._id @@ -120,12 +125,18 @@ def _set_id(self): @classmethod def before_write(cls, mapper, connection, target): - target.verify_fields() + if not target.status and target.missing_fields: + target.status = f'missing fields: {", ".join(target.missing_fields)}' if target.id: target.verify_id() else: target._set_id() + + # This assumes that status pretty much just notes down that an entry is invalid. If it has + # all fields set and is being written to the database, then it must have been modified, ergo + # should be also updated in pinecone + if not target.status: target.pinecone_update_required = True def to_dict(self): @@ -147,7 +158,7 @@ def to_dict(self): "date_published": date, "authors": authors, "summaries": [s.text for s in (self.summaries or [])], - **(self.meta or {}), + **(meta or {}), } diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index a7c3bb47..ceec27f5 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -1,18 +1,20 @@ -import os import logging +import os from dataclasses import dataclass from pathlib import Path from urllib.parse import urlparse -from pypandoc import convert_file import pandas as pd from gdown.download import download from markdownify import markdownify +from pypandoc import convert_file +from sqlalchemy import select -from align_data.sources.articles.pdf import read_pdf -from align_data.sources.articles.parsers import HTML_PARSERS, extract_gdrive_contents, item_metadata -from align_data.sources.articles.google_cloud import fetch_markdown, fetch_file from align_data.common.alignment_dataset import AlignmentDataset +from align_data.db.models import Article +from align_data.sources.articles.google_cloud import fetch_file, fetch_markdown +from align_data.sources.articles.parsers import HTML_PARSERS, extract_gdrive_contents, item_metadata +from align_data.sources.articles.pdf import read_pdf logger = logging.getLogger(__name__) @@ -74,16 +76,16 @@ def process_entry(self, item): class SpecialDocs(SpreadsheetDataset): + @property + def _query_items(self): + special_docs_types = ["pdf", "html", "xml", "markdown", "docx"] + return select(Article).where(Article.source.in_(special_docs_types)) + def process_entry(self, item): metadata = {} if url := self.maybe(item.source_url) or self.maybe(item.url): metadata = item_metadata(url) - text = metadata.get('text') - if not text: - logger.error('Could not get text for %s - skipping for now', item.title) - return None - return self.make_data_entry({ 'source': metadata.get('data_source') or self.name, 'url': self.maybe(item.url), @@ -91,7 +93,8 @@ def process_entry(self, item): 'source_type': self.maybe(item.source_type), 'date_published': self._get_published_date(item.date_published) or metadata.get('date_published'), 'authors': self.extract_authors(item) or metadata.get('authors', []), - 'text': text, + 'text': metadata.get('text'), + 'status': metadata.get('error'), }) @@ -148,7 +151,7 @@ def _get_text(self, item): class MarkdownArticles(SpreadsheetDataset): - source_filetype = "md" + source_filetype = "markdown" def _get_text(self, item): file_id = item.source_url.split("/")[-2] diff --git a/migrations/versions/f5a2bcfa6b2c_add_status_column.py b/migrations/versions/f5a2bcfa6b2c_add_status_column.py new file mode 100644 index 00000000..76c89ee0 --- /dev/null +++ b/migrations/versions/f5a2bcfa6b2c_add_status_column.py @@ -0,0 +1,26 @@ +"""Add status column + +Revision ID: f5a2bcfa6b2c +Revises: 59ac3cb671e3 +Create Date: 2023-08-12 15:59:44.741360 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = 'f5a2bcfa6b2c' +down_revision = '59ac3cb671e3' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('articles', sa.Column('status', sa.String(length=256), nullable=True)) + op.add_column('articles', sa.Column('comments', mysql.LONGTEXT(), nullable=True)) + + +def downgrade() -> None: + op.drop_column('articles', 'comments') + op.drop_column('articles', 'status') diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index dd59d0c7..eed98094 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -247,7 +247,7 @@ def test_markdown_articles_process_entry(articles): "date_published": "2023-01-01T12:32:11Z", "id": None, "source": "bla", - "source_filetype": "md", + "source_filetype": "markdown", "source_type": "something", "summaries": ["the summary of article 0"], "text": "bla bla", diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py index fcefeacd..8d19f66b 100644 --- a/tests/align_data/common/test_alignment_dataset.py +++ b/tests/align_data/common/test_alignment_dataset.py @@ -75,78 +75,26 @@ def test_data_entry_id_from_urls_and_title(): ) -def test_data_entry_no_url_and_title(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry({"key1": 12, "key2": 312}) - with pytest.raises( - AssertionError, - match="Entry is missing the following fields: \\['url', 'title'\\]", - ): - Article.before_write(None, None, entry) - - -def test_data_entry_no_url(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry( - {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines"} - ) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['url'\\]" - ): - Article.before_write(None, None, entry) - - -def test_data_entry_none_url(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry({"key1": 12, "key2": 312, "url": None}) - with pytest.raises( - AssertionError, - match="Entry is missing the following fields: \\['url', 'title'\\]", - ): - Article.before_write(None, None, entry) - - -def test_data_entry_none_title(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry( - {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": None} - ) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['title'\\]" - ): - Article.before_write(None, None, entry) - - -def test_data_entry_empty_url_and_title(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry({"key1": 12, "key2": 312, "url": "", "title": ""}) - with pytest.raises( - AssertionError, - match="Entry is missing the following fields: \\['url', 'title'\\]", - ): - Article.before_write(None, None, entry) - - -def test_data_entry_empty_url_only(): - dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry( - {"key1": 12, "key2": 312, "url": "", "title": "once upon a time"} - ) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['url'\\]" - ): - Article.before_write(None, None, entry) - - -def test_data_entry_empty_title_only(): +@pytest.mark.parametrize('item, error', ( + ({"key1": 12, "key2": 312}, 'missing fields: url, title'), + ( + {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines"}, + 'missing fields: url' + ), + ({"key1": 12, "key2": 312, "url": None}, 'missing fields: url, title'), + ( + {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": None}, + 'missing fields: title' + ), + ({"key1": 12, "key2": 312, "url": "", "title": ""}, 'missing fields: url, title'), + ({"key1": 12, "key2": 312, "url": "", "title": "once upon a time"}, 'missing fields: url'), + ({"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": ""}, 'missing fields: title'), +)) +def test_data_entry_missing(item, error): dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry( - {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": ""} - ) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['title'\\]" - ): - Article.before_write(None, None, entry) + entry = dataset.make_data_entry(item) + Article.before_write(None, None, entry) + assert entry.status == error def test_data_entry_verify_id_passes(): @@ -176,24 +124,6 @@ def test_data_entry_verify_id_fails(): entry.verify_id() -def test_data_entry_id_fields_url_no_url(): - dataset = AlignmentDataset(name="blaa", id_fields=["url"]) - entry = dataset.make_data_entry({"source": "arbital", "text": "once upon a time"}) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['url'\\]" - ): - Article.before_write(None, None, entry) - - -def test_data_entry_id_fields_url_empty_url(): - dataset = AlignmentDataset(name="blaa", id_fields=["url"]) - entry = dataset.make_data_entry({"url": ""}) - with pytest.raises( - AssertionError, match="Entry is missing the following fields: \\['url'\\]" - ): - Article.before_write(None, None, entry) - - def test_data_entry_id_fields_url(): dataset = AlignmentDataset(name="blaa", id_fields=["url"]) entry = dataset.make_data_entry({"url": "https://www.google.ca/once_upon_a_time"}) From e47d8b0b78f5f2466b86897267ed55b9f9dbe339 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 14 Aug 2023 21:00:59 +0200 Subject: [PATCH 4/7] Handle Arxiv links in special docs (#131) * Handle Arxiv links in special docs * Handle retracted arxiv articles * rename Retracted -> Withdrawn --- align_data/__init__.py | 2 - align_data/common/alignment_dataset.py | 5 +- align_data/sources/articles/__init__.py | 8 +- align_data/sources/articles/articles.py | 2 +- align_data/sources/articles/datasets.py | 46 ++++- align_data/sources/articles/google_cloud.py | 10 +- align_data/sources/articles/parsers.py | 10 +- align_data/sources/articles/pdf.py | 4 +- align_data/sources/arxiv_papers/__init__.py | 9 - .../sources/arxiv_papers/arxiv_papers.py | 143 ++++++++------- tests/align_data/articles/test_datasets.py | 169 ++++++++++++++++++ .../align_data/articles/test_google_cloud.py | 12 +- tests/align_data/test_arxiv.py | 75 +++----- 13 files changed, 341 insertions(+), 154 deletions(-) diff --git a/align_data/__init__.py b/align_data/__init__.py index 9f6c9893..54041500 100644 --- a/align_data/__init__.py +++ b/align_data/__init__.py @@ -2,7 +2,6 @@ import align_data.sources.articles as articles import align_data.sources.blogs as blogs import align_data.sources.ebooks as ebooks -import align_data.sources.arxiv_papers as arxiv_papers import align_data.sources.greaterwrong as greaterwrong import align_data.sources.stampy as stampy import align_data.sources.alignment_newsletter as alignment_newsletter @@ -14,7 +13,6 @@ + articles.ARTICLES_REGISTRY + blogs.BLOG_REGISTRY + ebooks.EBOOK_REGISTRY - + arxiv_papers.ARXIV_REGISTRY + greaterwrong.GREATERWRONG_REGISTRY + stampy.STAMPY_REGISTRY + distill.DISTILL_REGISTRY diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 6bbb2109..b129407a 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -209,7 +209,7 @@ def fetch_entries(self): if self.COOLDOWN: time.sleep(self.COOLDOWN) - def process_entry(self, entry): + def process_entry(self, entry) -> Optional[Article]: """Process a single entry.""" raise NotImplementedError @@ -217,7 +217,8 @@ def process_entry(self, entry): def _format_datetime(date) -> str: return date.strftime("%Y-%m-%dT%H:%M:%SZ") - def _get_published_date(self, date) -> Optional[datetime]: + @staticmethod + def _get_published_date(date) -> Optional[datetime]: try: # Totally ignore any timezone info, forcing everything to UTC return parse(str(date)).replace(tzinfo=pytz.UTC) diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py index 7e9fdbde..da7f3a6b 100644 --- a/align_data/sources/articles/__init__.py +++ b/align_data/sources/articles/__init__.py @@ -1,5 +1,6 @@ from align_data.sources.articles.datasets import ( - EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles + ArxivPapers, EbookArticles, DocArticles, HTMLArticles, + MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles ) from align_data.sources.articles.indices import IndicesDataset @@ -39,5 +40,10 @@ spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI', sheet_id='980957638', ), + ArxivPapers( + name="arxiv", + spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", + sheet_id="655836697", + ), IndicesDataset('indices'), ] diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py index 7485ce9e..9f16da77 100644 --- a/align_data/sources/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -65,7 +65,7 @@ def process_row(row, sheets): row.set_status(error) return - data_source = contents.get("data_source") + data_source = contents.get("source_type") if data_source not in sheets: error = "Unhandled data type" logger.error(error) diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index ceec27f5..6b223b61 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import Dict from urllib.parse import urlparse import pandas as pd @@ -13,8 +14,11 @@ from align_data.common.alignment_dataset import AlignmentDataset from align_data.db.models import Article from align_data.sources.articles.google_cloud import fetch_file, fetch_markdown -from align_data.sources.articles.parsers import HTML_PARSERS, extract_gdrive_contents, item_metadata +from align_data.sources.articles.parsers import ( + HTML_PARSERS, extract_gdrive_contents, item_metadata, parse_domain +) from align_data.sources.articles.pdf import read_pdf +from align_data.sources.arxiv_papers.arxiv_papers import fetch as fetch_arxiv logger = logging.getLogger(__name__) @@ -81,21 +85,30 @@ def _query_items(self): special_docs_types = ["pdf", "html", "xml", "markdown", "docx"] return select(Article).where(Article.source.in_(special_docs_types)) - def process_entry(self, item): + def get_contents(self, item) -> Dict: metadata = {} if url := self.maybe(item.source_url) or self.maybe(item.url): metadata = item_metadata(url) - return self.make_data_entry({ - 'source': metadata.get('data_source') or self.name, + return { 'url': self.maybe(item.url), 'title': self.maybe(item.title) or metadata.get('title'), + 'source': metadata.get('source_type') or self.name, 'source_type': self.maybe(item.source_type), 'date_published': self._get_published_date(item.date_published) or metadata.get('date_published'), 'authors': self.extract_authors(item) or metadata.get('authors', []), 'text': metadata.get('text'), 'status': metadata.get('error'), - }) + } + + def process_entry(self, item): + if parse_domain(item.url) == "arxiv.org": + contents = ArxivPapers.get_contents(item) + contents['source'] = 'arxiv' + else: + contents = self.get_contents(item) + + return self.make_data_entry(contents) class PDFArticles(SpreadsheetDataset): @@ -175,3 +188,26 @@ def _get_text(self, item): file_id = item.source_url.split("/")[-2] file_name = fetch_file(file_id) return convert_file(file_name, "md", format="docx", extra_args=["--wrap=none"]) + + +class ArxivPapers(SpreadsheetDataset): + COOLDOWN: int = 1 + + @classmethod + def get_contents(cls, item) -> Dict: + contents = fetch_arxiv(item.url or item.source_url) + + if cls.maybe(item.authors) and item.authors.strip(): + contents['authors'] = [i.strip() for i in item.authors.split(',')] + if cls.maybe(item.title): + contents['title'] = cls.maybe(item.title) + + contents['date_published'] = cls._get_published_date( + cls.maybe(item.date_published) or contents.get('date_published') + ) + return contents + + def process_entry(self, item): + logger.info(f"Processing {item.title}") + + return self.make_data_entry(self.get_contents(item), source=self.name) diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py index 6cd9e337..b1e957f8 100644 --- a/align_data/sources/articles/google_cloud.py +++ b/align_data/sources/articles/google_cloud.py @@ -143,7 +143,7 @@ def fetch_markdown(file_id): file_name = fetch_file(file_id) return { "text": Path(file_name).read_text(), - "data_source": "markdown", + "source_type": "markdown", } except Exception as e: return {'error': str(e)} @@ -156,7 +156,7 @@ def parse_grobid(contents): if not doc_dict.get('body'): return { 'error': 'No contents in XML file', - 'data_source': 'xml', + 'source_type': 'xml', } return { @@ -164,7 +164,7 @@ def parse_grobid(contents): "abstract": doc_dict.get("abstract"), "text": doc_dict["body"], "authors": list(filter(None, authors)), - "data_source": "xml", + "source_type": "xml", } @@ -198,7 +198,7 @@ def extract_gdrive_contents(link): elif content_type & {'text/markdown'}: result.update(fetch_markdown(file_id)) elif content_type & {'application/epub+zip', 'application/epub'}: - result['data_source'] = 'ebook' + result['source_type'] = 'ebook' elif content_type & {'text/html'}: res = fetch(url) if 'Google Drive - Virus scan warning' in res.text: @@ -213,7 +213,7 @@ def extract_gdrive_contents(link): soup = BeautifulSoup(res.content, "html.parser") result.update({ 'text': MarkdownConverter().convert_soup(soup.select_one('body')).strip(), - 'data_source': 'html', + 'source_type': 'html', }) else: result['error'] = f'unknown content type: {content_type}' diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py index 85d23fe8..42c25c9f 100644 --- a/align_data/sources/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -250,8 +250,12 @@ def getter(url): } +def parse_domain(url: str) -> str: + return url and urlparse(url).netloc.lstrip('www.') + + def item_metadata(url) -> Dict[str, str]: - domain = urlparse(url).netloc.lstrip('www.') + domain = parse_domain(url) try: res = fetch(url, 'head') except (MissingSchema, InvalidSchema, ConnectionError) as e: @@ -265,7 +269,7 @@ def item_metadata(url) -> Dict[str, str]: if parser := HTML_PARSERS.get(domain): if res := parser(url): # Proper contents were found on the page, so use them - return {'source_url': url, 'data_source': 'html', 'text': res} + return {'source_url': url, 'source_type': 'html', 'text': res} if parser := PDF_PARSERS.get(domain): if res := parser(url): @@ -286,6 +290,6 @@ def item_metadata(url) -> Dict[str, str]: elif content_type & {"application/epub+zip", "application/epub"}: # it looks like an ebook. Assume it's fine. # TODO: validate that the ebook is readable - return {"source_url": url, "data_source": "ebook"} + return {"source_url": url, "source_type": "ebook"} else: return {"error": f"Unhandled content type: {content_type}"} diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py index 9db52b9b..aca627f1 100644 --- a/align_data/sources/articles/pdf.py +++ b/align_data/sources/articles/pdf.py @@ -66,7 +66,7 @@ def fetch_pdf(link): return { "source_url": link, "text": "\n".join(page.extract_text() for page in pdf_reader.pages), - "data_source": "pdf", + "source_type": "pdf", } except (TypeError, PdfReadError) as e: logger.error('Could not read PDF file: %s', e) @@ -170,5 +170,5 @@ def get_first_child(item): "authors": authors, "text": text, "date_published": date_published, - "data_source": "html", + "source_type": "html", } diff --git a/align_data/sources/arxiv_papers/__init__.py b/align_data/sources/arxiv_papers/__init__.py index 29258480..e69de29b 100644 --- a/align_data/sources/arxiv_papers/__init__.py +++ b/align_data/sources/arxiv_papers/__init__.py @@ -1,9 +0,0 @@ -from .arxiv_papers import ArxivPapers - -ARXIV_REGISTRY = [ - ArxivPapers( - name="arxiv", - spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", - sheet_id="655836697", - ) -] diff --git a/align_data/sources/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py index 42f6dcaa..04bb85b8 100644 --- a/align_data/sources/arxiv_papers/arxiv_papers.py +++ b/align_data/sources/arxiv_papers/arxiv_papers.py @@ -1,75 +1,86 @@ import logging import re -from dataclasses import dataclass +from typing import Dict, Optional import arxiv -from align_data.sources.articles.datasets import SpreadsheetDataset from align_data.sources.articles.pdf import fetch_pdf, parse_vanity +from align_data.sources.articles.html import fetch_element logger = logging.getLogger(__name__) -@dataclass -class ArxivPapers(SpreadsheetDataset): - summary_key: str = "summary" - COOLDOWN: int = 1 - done_key = "url" - batch_size = 1 - - def _get_arxiv_metadata(self, paper_id) -> arxiv.Result: - """ - Get metadata from arxiv - """ - try: - search = arxiv.Search(id_list=[paper_id], max_results=1) - return next(search.results()) - except Exception as e: - logger.error(e) - return None - - def get_id(self, item): - if res := re.search(r"https://arxiv.org/abs/(.*?)/?$", item.url): - return res.group(1) - - def get_contents(self, item) -> dict: - paper_id = self.get_id(item) - for link in [ - f"https://www.arxiv-vanity.com/papers/{paper_id}", - f"https://ar5iv.org/abs/{paper_id}", - ]: - if contents := parse_vanity(link): - return contents - return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf") - - def process_entry(self, item) -> None: - logger.info(f"Processing {item.title}") - - paper = self.get_contents(item) - if not paper or not paper.get("text"): - return None - - metadata = self._get_arxiv_metadata(self.get_id(item)) - if self.maybe(item.authors) and item.authors.strip(): - authors = item.authors.split(',') - elif metadata and metadata.authors: - authors = metadata.authors - else: - authors = paper.get("authors") or [] - authors = [str(a).strip() for a in authors] - - return self.make_data_entry({ - "url": self.get_item_key(item), - "source": self.name, - "source_type": paper['data_source'], - "title": self.maybe(item.title) or paper.get('title'), - "authors": authors, - "date_published": self._get_published_date(self.maybe(item.date_published) or paper.get('date_published')), - "data_last_modified": str(metadata.updated), - "summary": metadata.summary.replace("\n", " "), - "author_comment": metadata.comment, - "journal_ref": metadata.journal_ref, - "doi": metadata.doi, - "primary_category": metadata.primary_category, - "categories": metadata.categories, - "text": paper['text'], - }) +def get_arxiv_metadata(paper_id) -> arxiv.Result: + """ + Get metadata from arxiv + """ + try: + search = arxiv.Search(id_list=[paper_id], max_results=1) + return next(search.results()) + except Exception as e: + logger.error(e) + return None + + +def get_id(url: str) -> Optional[str]: + if res := re.search(r"https?://arxiv.org/(?:abs|pdf)/(.*?)(?:v\d+)?(?:/|\.pdf)?$", url): + return res.group(1) + + +def canonical_url(url: str) -> str: + if paper_id := get_id(url): + return f'https://arxiv.org/abs/{paper_id}' + return url + + +def get_contents(paper_id: str) -> dict: + for link in [ + f"https://www.arxiv-vanity.com/papers/{paper_id}", + f"https://ar5iv.org/abs/{paper_id}", + ]: + if contents := parse_vanity(link): + return contents + return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf") + + +def get_version(id: str) -> Optional[str]: + if res := re.search(r'.*v(\d+)$', id): + return res.group(1) + + +def is_withdrawn(url: str): + if elem := fetch_element(canonical_url(url), '.extra-services .full-text ul'): + return elem.text.strip().lower() == 'withdrawn' + return None + + +def fetch(url) -> Dict: + paper_id = get_id(url) + if not paper_id: + return {'error': 'Could not extract arxiv id'} + + metadata = get_arxiv_metadata(paper_id) + + if is_withdrawn(url): + paper = {'status': 'Withdrawn'} + else: + paper = get_contents(paper_id) + if metadata and metadata.authors: + authors = metadata.authors + else: + authors = paper.get("authors") or [] + authors = [str(a).strip() for a in authors] + + return dict({ + "title": metadata.title, + "url": canonical_url(url), + "authors": authors, + "date_published": metadata.published, + "data_last_modified": metadata.updated.isoformat(), + "summary": metadata.summary.replace("\n", " "), + "comment": metadata.comment, + "journal_ref": metadata.journal_ref, + "doi": metadata.doi, + "primary_category": metadata.primary_category, + "categories": metadata.categories, + "version": get_version(metadata.get_short_id()), + }, **paper) diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index eed98094..11a02816 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -1,14 +1,17 @@ +from datetime import datetime from unittest.mock import Mock, patch import pandas as pd import pytest from align_data.sources.articles.datasets import ( + ArxivPapers, EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, SpreadsheetDataset, + SpecialDocs, XMLArticles, ) @@ -32,6 +35,26 @@ def articles(): return pd.DataFrame(articles) +@pytest.fixture +def mock_arxiv(): + metadata = Mock( + summary="abstract bla bla", + comment="no comment", + categories="wut", + updated=datetime.fromisoformat("2023-01-01T00:00:00"), + authors=[], + doi="123", + journal_ref="sdf", + primary_category="cat", + ) + metadata.get_short_id.return_value = '2001.11038' + arxiv = Mock() + arxiv.Search.return_value.results.return_value = iter([metadata]) + + with patch("align_data.sources.arxiv_papers.arxiv_papers.arxiv", arxiv): + yield + + def test_spreadsheet_dataset_items_list(articles): dataset = SpreadsheetDataset(name="bla", spreadsheet_id="123", sheet_id="456") df = pd.concat( @@ -288,3 +311,149 @@ def test_doc_articles_process_entry(articles): "title": "article no 0", "url": "http://example.com/item/0", } + + +@patch('requests.get', return_value=Mock(content='')) +def test_arxiv_process_entry(_, mock_arxiv): + dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") + item = Mock( + title="this is the title", + url="https://arxiv.org/abs/2001.11038", + authors="", + date_published="2020-01-29", + ) + contents = { + "text": "this is the text", + "date_published": "December 12, 2021", + "authors": ["mr blobby"], + "source_type": "html", + } + with patch( + "align_data.sources.arxiv_papers.arxiv_papers.parse_vanity", return_value=contents + ): + assert dataset.process_entry(item).to_dict() == { + "comment": "no comment", + "authors": ["mr blobby"], + "categories": "wut", + "data_last_modified": "2023-01-01T00:00:00", + "date_published": "2020-01-29T00:00:00Z", + "doi": "123", + "id": None, + "journal_ref": "sdf", + "primary_category": "cat", + "source": "asd", + "source_type": "html", + "summaries": ["abstract bla bla"], + "text": "this is the text", + "title": "this is the title", + "url": "https://arxiv.org/abs/2001.11038", + } + + +def test_arxiv_process_entry_retracted(mock_arxiv): + dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") + item = Mock( + title="this is the title", + url="https://arxiv.org/abs/2001.11038", + authors="", + date_published="2020-01-29", + ) + response = """ +
+
+ + Full-text links: +

Download:

+
  • Withdrawn
+
+
+
+ """ + + with patch('requests.get', return_value=Mock(content=response)): + assert dataset.process_entry(item).to_dict() == { + "comment": "no comment", + "authors": [], + "categories": "wut", + "data_last_modified": "2023-01-01T00:00:00", + "date_published": "2020-01-29T00:00:00Z", + "doi": "123", + "id": None, + "journal_ref": "sdf", + "primary_category": "cat", + "source": "asd", + "source_type": None, + "summaries": ["abstract bla bla"], + "title": "this is the title", + "url": "https://arxiv.org/abs/2001.11038", + "status": "Withdrawn", + "text": None, + } + + +def test_special_docs_process_entry(): + dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") + item = Mock( + title="this is the title", + url="https://bla.bla.bla", + authors="mr. blobby", + date_published="2023-10-02T01:23:45", + source_type=None, + ) + contents = { + "text": "this is the text", + "date_published": "December 12, 2021", + "authors": ["mr blobby"], + "source_type": "html", + } + + with patch("align_data.sources.articles.datasets.item_metadata", return_value=contents): + assert dataset.process_entry(item).to_dict() == { + 'authors': ['mr. blobby'], + 'date_published': '2023-10-02T01:23:45Z', + 'id': None, + 'source': 'html', + 'source_type': None, + 'summaries': [], + 'text': 'this is the text', + 'title': 'this is the title', + 'url': 'https://bla.bla.bla', + } + + +@patch('requests.get', return_value=Mock(content='')) +def test_special_docs_process_entry_arxiv(_, mock_arxiv): + dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") + item = Mock( + title="this is the title", + url="https://arxiv.org/abs/2001.11038", + authors="", + date_published="2020-01-29", + ) + contents = { + "text": "this is the text", + "date_published": "December 12, 2021", + "authors": ["mr blobby"], + "source_type": "pdf", + } + + with patch( + "align_data.sources.arxiv_papers.arxiv_papers.parse_vanity", return_value=contents + ): + assert dataset.process_entry(item).to_dict() == { + "comment": "no comment", + "authors": ["mr blobby"], + "categories": "wut", + "data_last_modified": "2023-01-01T00:00:00", + "date_published": "2020-01-29T00:00:00Z", + "doi": "123", + "id": None, + "journal_ref": "sdf", + "primary_category": "cat", + "source": "arxiv", + "source_type": "pdf", + "summaries": ["abstract bla bla"], + "text": "this is the text", + "title": "this is the title", + "url": "https://arxiv.org/abs/2001.11038", + } diff --git a/tests/align_data/articles/test_google_cloud.py b/tests/align_data/articles/test_google_cloud.py index 3232978a..7b268e43 100644 --- a/tests/align_data/articles/test_google_cloud.py +++ b/tests/align_data/articles/test_google_cloud.py @@ -78,7 +78,7 @@ def test_parse_grobid(): 'authors': ['Cullen Oâ\x80\x99Keefe'], 'text': 'This is the contents', 'title': 'The title!!', - 'data_source': 'xml', + 'source_type': 'xml', } @@ -100,7 +100,7 @@ def test_parse_grobid_no_body(): """ - assert parse_grobid(xml) == {'error': 'No contents in XML file', 'data_source': 'xml'} + assert parse_grobid(xml) == {'error': 'No contents in XML file', 'source_type': 'xml'} @pytest.mark.parametrize('header, expected', ( @@ -160,7 +160,7 @@ def test_extract_gdrive_contents_ebook(header): assert extract_gdrive_contents(url) == { 'downloaded_from': 'google drive', 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'data_source': 'ebook', + 'source_type': 'ebook', } @@ -185,7 +185,7 @@ def test_extract_gdrive_contents_html(): 'downloaded_from': 'google drive', 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', 'text': 'bla bla', - 'data_source': 'html', + 'source_type': 'html', } @@ -207,7 +207,7 @@ def test_extract_gdrive_contents_xml(): 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', 'text': 'This is the contents', 'title': 'The title!!', - 'data_source': 'xml', + 'source_type': 'xml', } @@ -238,7 +238,7 @@ def fetcher(link, *args, **kwargs): 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', 'text': 'This is the contents', 'title': 'The title!!', - 'data_source': 'xml', + 'source_type': 'xml', } diff --git a/tests/align_data/test_arxiv.py b/tests/align_data/test_arxiv.py index 5817fd2c..898fef70 100644 --- a/tests/align_data/test_arxiv.py +++ b/tests/align_data/test_arxiv.py @@ -1,7 +1,5 @@ -from datetime import datetime -from unittest.mock import patch, Mock import pytest -from align_data.sources.arxiv_papers.arxiv_papers import ArxivPapers +from align_data.sources.arxiv_papers.arxiv_papers import get_id, canonical_url, get_version @pytest.mark.parametrize( @@ -13,55 +11,28 @@ ), ) def test_get_id(url, expected): - dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") - assert dataset.get_id(Mock(url="https://arxiv.org/abs/2001.11038")) == "2001.11038" + assert get_id("https://arxiv.org/abs/2001.11038") == "2001.11038" -def test_process_entry(): - dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") - item = Mock( - title="this is the title", - url="https://arxiv.org/abs/2001.11038", - authors="", - date_published="2020-01-29", - ) - contents = { - "text": "this is the text", - "date_published": "December 12, 2021", - "authors": ["mr blobby"], - "data_source": "html", - } - metadata = Mock( - summary="abstract bla bla", - comment="no comment", - categories="wut", - updated="2023-01-01", - authors=[], - doi="123", - journal_ref="sdf", - primary_category="cat", - ) - arxiv = Mock() - arxiv.Search.return_value.results.return_value = iter([metadata]) +@pytest.mark.parametrize('url, expected', ( + ("http://bla.bla", "http://bla.bla"), + ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/math/2001.11038", "https://arxiv.org/abs/math/2001.11038"), +)) +def test_canonical_url(url, expected): + assert canonical_url(url) == expected - with patch( - "align_data.arxiv_papers.arxiv_papers.parse_vanity", return_value=contents - ): - with patch("align_data.arxiv_papers.arxiv_papers.arxiv", arxiv): - assert dataset.process_entry(item).to_dict() == { - "author_comment": "no comment", - "authors": ["mr blobby"], - "categories": "wut", - "data_last_modified": "2023-01-01", - "date_published": "2020-01-29T00:00:00Z", - "doi": "123", - "id": None, - "journal_ref": "sdf", - "primary_category": "cat", - "source": "asd", - "source_type": "html", - "summaries": ["abstract bla bla"], - "text": "this is the text", - "title": "this is the title", - "url": "https://arxiv.org/abs/2001.11038", - } + +@pytest.mark.parametrize('id, version', ( + ('123.123', None), + ('math/312', None), + ('3123123v1', '1'), + ('3123123v123', '123'), +)) +def test_get_version(id, version): + assert get_version(id) == version From 327dd63ba07e4fab323a444988f21272c7b24f86 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 14 Aug 2023 21:49:06 +0200 Subject: [PATCH 5/7] Set status for missing fields (#133) * check for missing fields * Do not accept items with no id keys --- align_data/common/alignment_dataset.py | 2 + align_data/db/models.py | 18 ++- align_data/sources/articles/datasets.py | 3 +- .../sources/arxiv_papers/arxiv_papers.py | 1 + tests/align_data/articles/test_datasets.py | 5 +- .../common/test_alignment_dataset.py | 135 ++++++++---------- 6 files changed, 76 insertions(+), 88 deletions(-) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index b129407a..31f254e7 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -25,6 +25,8 @@ "url": None, "authors": lambda: [], "source_type": None, + "status": None, + "comments": None, } logger = logging.getLogger(__name__) diff --git a/align_data/db/models.py b/align_data/db/models.py index b3da7042..3114a098 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -93,11 +93,8 @@ def generate_id_string(self) -> str: @property def missing_fields(self): - return [field for field in self.__id_fields if not getattr(self, field)] - - def verify_fields(self): - missing = self.missing_fields - assert not missing, f"Entry is missing the following fields: {missing}" + fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'} + return sorted([field for field in fields if not getattr(self, field, None)]) def verify_id(self): assert self.id is not None, "Entry is missing id" @@ -106,7 +103,11 @@ def verify_id(self): id_from_fields = hashlib.md5(id_string).hexdigest() assert ( self.id == id_from_fields - ), f"Entry id {self.id} does not match id from id_fields, {id_from_fields}" + ), f"Entry id {self.id} does not match id from id_fields: {id_from_fields}" + + def verify_id_fields(self): + missing = [field for field in self.__id_fields if not getattr(self, field)] + assert not missing, f"Entry is missing the following fields: {missing}" def update(self, other): for field in self.__table__.columns.keys(): @@ -125,8 +126,11 @@ def _set_id(self): @classmethod def before_write(cls, mapper, connection, target): + target.verify_id_fields() + if not target.status and target.missing_fields: - target.status = f'missing fields: {", ".join(target.missing_fields)}' + target.status = 'Missing fields' + target.comments = f'missing fields: {", ".join(target.missing_fields)}' if target.id: target.verify_id() diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index 6b223b61..20b0a5f6 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -98,7 +98,8 @@ def get_contents(self, item) -> Dict: 'date_published': self._get_published_date(item.date_published) or metadata.get('date_published'), 'authors': self.extract_authors(item) or metadata.get('authors', []), 'text': metadata.get('text'), - 'status': metadata.get('error'), + 'status': 'Invalid' if metadata.get('error') else None, + 'comments': metadata.get('error'), } def process_entry(self, item): diff --git a/align_data/sources/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py index 04bb85b8..3ec3fbe5 100644 --- a/align_data/sources/arxiv_papers/arxiv_papers.py +++ b/align_data/sources/arxiv_papers/arxiv_papers.py @@ -64,6 +64,7 @@ def fetch(url) -> Dict: paper = {'status': 'Withdrawn'} else: paper = get_contents(paper_id) + if metadata and metadata.authors: authors = metadata.authors else: diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index 11a02816..94f26172 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -371,7 +371,9 @@ def test_arxiv_process_entry_retracted(mock_arxiv): """ with patch('requests.get', return_value=Mock(content=response)): - assert dataset.process_entry(item).to_dict() == { + article = dataset.process_entry(item) + assert article.status == 'Withdrawn' + assert article.to_dict() == { "comment": "no comment", "authors": [], "categories": "wut", @@ -386,7 +388,6 @@ def test_arxiv_process_entry_retracted(mock_arxiv): "summaries": ["abstract bla bla"], "title": "this is the title", "url": "https://arxiv.org/abs/2001.11038", - "status": "Withdrawn", "text": None, } diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py index 8d19f66b..879e6295 100644 --- a/tests/align_data/common/test_alignment_dataset.py +++ b/tests/align_data/common/test_alignment_dataset.py @@ -76,25 +76,41 @@ def test_data_entry_id_from_urls_and_title(): @pytest.mark.parametrize('item, error', ( - ({"key1": 12, "key2": 312}, 'missing fields: url, title'), ( - {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines"}, - 'missing fields: url' + {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines", "url": "asd"}, + 'missing fields: date_published, source, text' ), - ({"key1": 12, "key2": 312, "url": None}, 'missing fields: url, title'), ( - {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": None}, - 'missing fields: title' + {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "text": "asdasd", "title": "asdasd"}, + 'missing fields: date_published, source' + ), + ( + { + "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", + "source": "dwe", "date_published": "dwe" + }, + 'missing fields: text' + ), + ( + { + "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", + "text": "asdasd", "date_published": "dwe" + }, + 'missing fields: source' + ), + ( + { + "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", "text": "asdasd", "source": "dwe" + }, + 'missing fields: date_published' ), - ({"key1": 12, "key2": 312, "url": "", "title": ""}, 'missing fields: url, title'), - ({"key1": 12, "key2": 312, "url": "", "title": "once upon a time"}, 'missing fields: url'), - ({"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": ""}, 'missing fields: title'), )) def test_data_entry_missing(item, error): dataset = AlignmentDataset(name="blaa") entry = dataset.make_data_entry(item) Article.before_write(None, None, entry) - assert entry.status == error + assert entry.status == 'Missing fields' + assert entry.comments == error def test_data_entry_verify_id_passes(): @@ -120,10 +136,40 @@ def test_data_entry_verify_id_fails(): "id": "f2b4e02fc1dd8ae43845e4f930f2d84f", } ) - with pytest.raises(AssertionError, match="Entry id does not match id_fields"): + expected = 'Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97' + with pytest.raises(AssertionError, match=expected): entry.verify_id() +@pytest.mark.parametrize( + "data, error", + ( + ({"id": "123"}, "Entry is missing the following fields: \\['url', 'title'\\]"), + ( + {"id": "123", "url": None}, + "Entry is missing the following fields: \\['url', 'title'\\]", + ), + ( + {"id": "123", "url": "www.google.com/"}, + "Entry is missing the following fields: \\['title'\\]", + ), + ( + {"id": "123", "url": "google", "text": None}, + "Entry is missing the following fields: \\['title'\\]", + ), + ( + {"id": "123", "url": "", "title": ""}, + "Entry is missing the following fields: \\['url', 'title'\\]", + ), + ), +) +def test_data_entry_verify_fields_fails(data, error): + dataset = AlignmentDataset(name="blaa", id_fields=["url", "title"]) + entry = dataset.make_data_entry(data) + with pytest.raises(AssertionError, match=error): + entry.verify_id_fields() + + def test_data_entry_id_fields_url(): dataset = AlignmentDataset(name="blaa", id_fields=["url"]) entry = dataset.make_data_entry({"url": "https://www.google.ca/once_upon_a_time"}) @@ -153,73 +199,6 @@ def test_data_entry_different_id_from_different_url(): assert entry1.generate_id_string() != entry2.generate_id_string() -@pytest.mark.parametrize( - "data, error", - ( - ({"text": "bla bla bla"}, "Entry is missing id"), - ({"text": "bla bla bla", "id": None}, "Entry is missing id"), - ( - { - "id": "123", - "url": "www.google.com/winter_wonderland", - "title": "winter wonderland", - }, - "Entry id 123 does not match id from id_fields, [0-9a-fA-F]{32}", - ), - ( - { - "id": "457c21e0ecabebcb85c12022d481d9f4", - "url": "www.google.com", - "title": "winter wonderland", - }, - "Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}", - ), - ( - { - "id": "457c21e0ecabebcb85c12022d481d9f4", - "url": "www.google.com", - "title": "Once upon a time", - }, - "Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}", - ), - ), -) -def test_data_entry_verify_id_fails(data, error): - dataset = AlignmentDataset(name="blaa", id_fields=["url", "title"]) - entry = dataset.make_data_entry(data) - with pytest.raises(AssertionError, match=error): - entry.verify_id() - - -@pytest.mark.parametrize( - "data, error", - ( - ({"id": "123"}, "Entry is missing the following fields: \\['url', 'title'\\]"), - ( - {"id": "123", "url": None}, - "Entry is missing the following fields: \\['url', 'title'\\]", - ), - ( - {"id": "123", "url": "www.google.com/"}, - "Entry is missing the following fields: \\['title'\\]", - ), - ( - {"id": "123", "url": "google", "text": None}, - "Entry is missing the following fields: \\['title'\\]", - ), - ( - {"id": "123", "url": "", "title": ""}, - "Entry is missing the following fields: \\['url', 'title'\\]", - ), - ), -) -def test_data_entry_verify_fields_fails(data, error): - dataset = AlignmentDataset(name="blaa", id_fields=["url", "title"]) - entry = dataset.make_data_entry(data) - with pytest.raises(AssertionError, match=error): - entry.verify_fields() - - @pytest.fixture def numbers_dataset(): """Make a dataset that raises its items to the power of 2.""" From dacbc3432344d50338287a1bdb5e8d69e94763e7 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Mon, 14 Aug 2023 23:21:31 +0200 Subject: [PATCH 6/7] Allow manual batch metadata updates (#129) * Allow manual batch metadata updates * save source url --- .github/workflows/update-metadata.yml | 34 ++++ align_data/common/alignment_dataset.py | 13 +- align_data/db/models.py | 5 + align_data/sources/articles/articles.py | 6 + align_data/sources/articles/datasets.py | 42 +++-- align_data/sources/articles/updater.py | 90 ++++++++++ .../sources/arxiv_papers/arxiv_papers.py | 1 + main.py | 14 +- tests/align_data/articles/test_datasets.py | 10 +- tests/align_data/articles/test_updater.py | 162 ++++++++++++++++++ 10 files changed, 347 insertions(+), 30 deletions(-) create mode 100644 .github/workflows/update-metadata.yml create mode 100644 align_data/sources/articles/updater.py create mode 100644 tests/align_data/articles/test_updater.py diff --git a/.github/workflows/update-metadata.yml b/.github/workflows/update-metadata.yml new file mode 100644 index 00000000..70e2ddd5 --- /dev/null +++ b/.github/workflows/update-metadata.yml @@ -0,0 +1,34 @@ +name: Update metadata +on: + workflow_dispatch: + inputs: + csv_url: + description: 'URL of CSV' + required: true + delimiter: + description: 'The column delimiter' + default: ',' + +jobs: + update: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Download CSV + id: download + run: curl -L "${{ inputs.csv_url }}" -o data.csv + + - name: Run Script + run: python main.py update data.csv ${{ inputs.delimiter }} diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 31f254e7..6e9cd731 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -116,10 +116,10 @@ def _query_items(self): def read_entries(self, sort_by=None): """Iterate through all the saved entries.""" with make_session() as session: - query = self._query_items + query = self._query_items.options(joinedload(Article.summaries)) if sort_by is not None: query = query.order_by(sort_by) - for item in session.scalars(query): + for item in session.scalars(query).unique(): yield item def _add_batch(self, session, batch): @@ -236,13 +236,8 @@ def unprocessed_items(self, items=None) -> Iterable: urls = map(self.get_item_key, items) with make_session() as session: - self.articles = { - a.url: a - for a in session.query(Article) - .options(joinedload(Article.summaries)) - .filter(Article.url.in_(urls)) - if a.url - } + articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls)) + self.articles = {a.url: a for a in articles if a.url} return items diff --git a/align_data/db/models.py b/align_data/db/models.py index 3114a098..157c4340 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -124,6 +124,11 @@ def _set_id(self): id_string = self.generate_id_string() self.id = hashlib.md5(id_string).hexdigest() + def add_meta(self, key, val): + if self.meta is None: + self.meta = {} + self.meta[key] = val + @classmethod def before_write(cls, mapper, connection, target): target.verify_id_fields() diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py index 9f16da77..39cc31b7 100644 --- a/align_data/sources/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -15,6 +15,7 @@ from align_data.sources.articles.parsers import item_metadata, fetch from align_data.sources.articles.indices import fetch_all from align_data.sources.articles.html import with_retry +from align_data.sources.articles.updater import ReplacerDataset from align_data.settings import PDFS_FOLDER_ID @@ -158,3 +159,8 @@ def check_new_articles(source_spreadsheet, source_sheet): updated = res["updates"]["updatedRows"] logger.info("Added %s rows", updated) return updated + + +def update_articles(csv_file, delimiter): + dataset = ReplacerDataset(name='updater', csv_path=csv_file, delimiter=delimiter) + dataset.add_entries(dataset.fetch_entries()) diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index 20b0a5f6..acfc2d85 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -32,20 +32,21 @@ class SpreadsheetDataset(AlignmentDataset): batch_size = 1 @staticmethod - def maybe(val): + def maybe(item, key: str): + val = getattr(item, key, None) if pd.isna(val): return None return val + def get_item_key(self, item): + return self.maybe(item, self.done_key) + @property def items_list(self): url = f'https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}' logger.info(f'Fetching {url}') df = pd.read_csv(url) - return (item for item in df.itertuples() if self.maybe(self.get_item_key(item))) - - def get_item_key(self, item): - return getattr(item, self.done_key) + return (item for item in df.itertuples() if self.get_item_key(item)) @staticmethod def _get_text(item): @@ -53,7 +54,7 @@ def _get_text(item): @staticmethod def extract_authors(item): - if not SpreadsheetDataset.maybe(item.authors): + if not SpreadsheetDataset.maybe(item, "authors"): return [] return [author.strip() for author in item.authors.split(",") if author.strip()] @@ -63,17 +64,21 @@ def process_entry(self, item): logger.error("Could not get text for %s - skipping for now", item.title) return None + url = self.maybe(item, "url") + source_url = self.maybe(item, "source_url") + return self.make_data_entry( { "text": markdownify(text).strip(), - "url": self.maybe(item.url), - "title": self.maybe(item.title), + "url": url, + "source_url": source_url if source_url != url else None, + "title": self.maybe(item, "title"), "source": self.name, - "source_type": self.maybe(item.source_type), + "source_type": self.maybe(item, "source_type"), "source_filetype": self.source_filetype, "date_published": self._get_published_date(item.date_published), "authors": self.extract_authors(item), - "summary": self.maybe(item.summary), + "summary": self.maybe(item, "summary"), } ) @@ -87,14 +92,15 @@ def _query_items(self): def get_contents(self, item) -> Dict: metadata = {} - if url := self.maybe(item.source_url) or self.maybe(item.url): + if url := self.maybe(item, "source_url") or self.maybe(item, "url"): metadata = item_metadata(url) return { - 'url': self.maybe(item.url), - 'title': self.maybe(item.title) or metadata.get('title'), + 'url': self.maybe(item, "url"), + 'title': self.maybe(item, "title") or metadata.get('title'), 'source': metadata.get('source_type') or self.name, - 'source_type': self.maybe(item.source_type), + 'source_url': self.maybe(item, "source_url"), + 'source_type': metadata.get('source_type') or self.maybe(item, "source_type"), 'date_published': self._get_published_date(item.date_published) or metadata.get('date_published'), 'authors': self.extract_authors(item) or metadata.get('authors', []), 'text': metadata.get('text'), @@ -198,13 +204,13 @@ class ArxivPapers(SpreadsheetDataset): def get_contents(cls, item) -> Dict: contents = fetch_arxiv(item.url or item.source_url) - if cls.maybe(item.authors) and item.authors.strip(): + if cls.maybe(item, "authors") and item.authors.strip(): contents['authors'] = [i.strip() for i in item.authors.split(',')] - if cls.maybe(item.title): - contents['title'] = cls.maybe(item.title) + if cls.maybe(item, "title"): + contents['title'] = cls.maybe(item, "title") contents['date_published'] = cls._get_published_date( - cls.maybe(item.date_published) or contents.get('date_published') + cls.maybe(item, "date_published") or contents.get('date_published') ) return contents diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py new file mode 100644 index 00000000..dc6a8e39 --- /dev/null +++ b/align_data/sources/articles/updater.py @@ -0,0 +1,90 @@ +import logging +from collections import namedtuple +from dataclasses import dataclass + +import pandas as pd +from sqlalchemy import select, or_ +from align_data.common.alignment_dataset import AlignmentDataset +from align_data.db.models import Article +from align_data.sources.articles.parsers import item_metadata + +logger = logging.getLogger(__name__) + +Item = namedtuple('Item', ['updates', 'article']) + + +@dataclass +class ReplacerDataset(AlignmentDataset): + csv_path: str + delimiter: str + done_key = "url" + + def get_item_key(self, item): + return None + + @staticmethod + def maybe(item, key): + val = getattr(item, key, None) + if pd.isna(val): + return None + return val + + @property + def items_list(self): + df = pd.read_csv(self.csv_path, delimiter=self.delimiter) + self.csv_items = [ + item for item in df.itertuples() + if self.maybe(item, 'id') or self.maybe(item, 'hash_id') + ] + by_id = {i.id: i for i in self.csv_items if self.maybe(i, 'id')} + by_hash_id = {i.hash_id: i for i in self.csv_items if self.maybe(i, 'hash_id')} + + return [ + Item(by_id.get(a._id) or by_hash_id.get(a.id), a) + for a in self.read_entries() + ] + + @property + def _query_items(self): + ids = [i.id for i in self.csv_items if self.maybe(i, 'id')] + hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, 'hash_id')] + return select(Article).where(or_(Article.id.in_(hash_ids), Article._id.in_(ids))) + + def update_text(self, updates, article): + # If the url is the same as it was before, and there isn't a source url provided, assume that the + # previous text is still valid + if article.url == self.maybe(updates, 'url') and not self.maybe(updates, 'source_url'): + return + + # If no url found, then don't bother fetching the text - assume it was successfully fetched previously + url = self.maybe(updates, 'source_url') or self.maybe(updates, 'url') + if not url: + return + + if article.url != url: + article.add_meta('source_url', url) + + metadata = item_metadata(url) + # Only change the text if it could be fetched - better to have outdated values than none + if metadata.get('text'): + article.text = metadata.get('text') + article.status = metadata.get('error') + + def process_entry(self, item): + updates, article = item + + for key in ['url', 'title', 'source', 'authors', 'comment', 'confidence']: + value = self.maybe(updates, key) + if value and getattr(article, key, None) != value: + setattr(article, key, value) + + if date := getattr(updates, 'date_published', None): + article.date_published = self._get_published_date(date) + + self.update_text(updates, article) + article._set_id() + + return article + + def _add_batch(self, session, batch): + session.add_all(map(session.merge, batch)) diff --git a/align_data/sources/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py index 3ec3fbe5..1671fe5a 100644 --- a/align_data/sources/arxiv_papers/arxiv_papers.py +++ b/align_data/sources/arxiv_papers/arxiv_papers.py @@ -74,6 +74,7 @@ def fetch(url) -> Dict: return dict({ "title": metadata.title, "url": canonical_url(url), + "source_type": paper.get('data_source'), "authors": authors, "date_published": metadata.published, "data_last_modified": metadata.updated.isoformat(), diff --git a/main.py b/main.py index ae11b641..67478a12 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from align_data import ALL_DATASETS, get_dataset from align_data.analysis.count_tokens import count_token -from align_data.sources.articles.articles import update_new_items, check_new_articles +from align_data.sources.articles.articles import update_new_items, check_new_articles, update_articles from align_data.pinecone.update_pinecone import PineconeUpdater from align_data.settings import ( METADATA_OUTPUT_SPREADSHEET, @@ -32,7 +32,7 @@ def fetch(self, *names) -> None: """ > This function takes a dataset name and writes the entries of that dataset to a file - :param str name: The name of the dataset to fetch + :param str name: The name of the dataset to fetch, or 'all' for all of them :return: The path to the file that was written to. """ if names == ("all",): @@ -81,6 +81,14 @@ def count_tokens(self, merged_dataset_path: str) -> None: ), "The path to the merged dataset does not exist" count_token(merged_dataset_path) + def update(self, csv_path, delimiter=','): + """Update all articles in the provided csv files, overwriting the provided values and fetching new text if a different url provided. + + :param str csv_path: The path to the csv file to be processed + :param str delimiter: Specifies what's used as a column delimiter + """ + update_articles(csv_path, delimiter) + def update_metadata( self, source_spreadsheet=METADATA_SOURCE_SPREADSHEET, @@ -110,6 +118,8 @@ def fetch_new_articles( def pinecone_update(self, *names) -> None: """ This function updates the Pinecone vector DB. + + :param List[str] names: The name of the dataset to update, or 'all' for all of them """ if names == ("all",): names = ALL_DATASETS diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index 94f26172..737864b3 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -123,6 +123,7 @@ def test_pdf_articles_process_item(articles): "text": "pdf contents [bla](asd.com)", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -168,6 +169,7 @@ def test_html_articles_process_entry(articles): "text": "html contents with [proper elements](bla.com) ble ble", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -212,6 +214,7 @@ def test_ebook_articles_process_entry(articles): "text": "html contents with [proper elements](bla.com) ble ble", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -244,6 +247,7 @@ def test_xml_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -276,6 +280,7 @@ def test_markdown_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -310,6 +315,7 @@ def test_doc_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", + 'source_url': 'http://example.com/source_url/0', } @@ -400,6 +406,7 @@ def test_special_docs_process_entry(): authors="mr. blobby", date_published="2023-10-02T01:23:45", source_type=None, + source_url="https://ble.ble.com" ) contents = { "text": "this is the text", @@ -414,7 +421,8 @@ def test_special_docs_process_entry(): 'date_published': '2023-10-02T01:23:45Z', 'id': None, 'source': 'html', - 'source_type': None, + 'source_url': "https://ble.ble.com", + 'source_type': 'html', 'summaries': [], 'text': 'this is the text', 'title': 'this is the title', diff --git a/tests/align_data/articles/test_updater.py b/tests/align_data/articles/test_updater.py new file mode 100644 index 00000000..b3b26a27 --- /dev/null +++ b/tests/align_data/articles/test_updater.py @@ -0,0 +1,162 @@ +from unittest.mock import Mock, patch +from csv import DictWriter + +import pandas as pd +import pytest +from align_data.db.models import Article +from align_data.sources.articles.updater import ReplacerDataset, Item + +SAMPLE_UPDATES = [ + {}, + {'title': 'no id - should be ignored'}, + + {'id': '122', 'hash_id': 'deadbeef000'}, + { + 'id': '123', 'hash_id': 'deadbeef001', + 'title': 'bla bla', + 'url': 'http://bla.com', + 'source_url': 'http://bla.bla.com', + 'authors': 'mr. blobby, johnny', + }, { + 'id': '124', + 'title': 'no hash id', + 'url': 'http://bla.com', + 'source_url': 'http://bla.bla.com', + 'authors': 'mr. blobby', + }, { + 'hash_id': 'deadbeef002', + 'title': 'no id', + 'url': 'http://bla.com', + 'source_url': 'http://bla.bla.com', + 'authors': 'mr. blobby', + }, { + 'id': '125', + 'title': 'no hash id, url or title', + 'authors': 'mr. blobby', + } +] + +@pytest.fixture +def csv_file(tmp_path): + filename = tmp_path / 'data.csv' + with open(filename, 'w', newline='') as csvfile: + fieldnames = ['id', 'hash_id', 'title', 'url', 'source_url', 'authors'] + writer = DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + for row in SAMPLE_UPDATES: + writer.writerow(row) + return filename + + +def test_items_list(csv_file): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + def mock_entries(): + return [ + Mock( + _id=dataset.maybe(v, 'id'), + id=dataset.maybe(v, 'hash_id'), + title=dataset.maybe(v, 'title'), + url=dataset.maybe(v, 'url'), + authors=dataset.maybe(v, 'authors') + ) + for v in dataset.csv_items + ] + + with patch.object(dataset, 'read_entries', mock_entries): + items = dataset.items_list + assert len(items) == 5, "items_list should only contain items with valid ids - something is wrong" + for item in items: + assert dataset.maybe(item.updates, 'id') == item.article._id + assert dataset.maybe(item.updates, 'hash_id') == item.article.id + assert dataset.maybe(item.updates, 'title') == item.article.title + assert dataset.maybe(item.updates, 'url') == item.article.url + assert dataset.maybe(item.updates, 'authors') == item.article.authors + + +@pytest.mark.parametrize('updates', ( + Mock(url='http://some.other.url'), + Mock(source_url='http://some.other.url'), + Mock(url='http://some.other.url', source_url='http://another.url'), +)) +def test_update_text(csv_file, updates): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Mock(text='this should be changed', status='as should this', url='http:/bla.bla.com') + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + dataset.update_text(updates, article) + assert article.text == 'bla bla bla' + assert article.status == None + + +@pytest.mark.parametrize('updates', ( + Mock(url='http://some.other.url'), + Mock(source_url='http://some.other.url'), + Mock(url='http://some.other.url', source_url='http://another.url'), +)) +def test_update_text_error(csv_file, updates): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Mock(text='this should not be changed', status='but this should be', url='http:/bla.bla.com') + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'error': 'oh noes!'}): + dataset.update_text(updates, article) + assert article.text == 'this should not be changed' + assert article.status == 'oh noes!' + + +@pytest.mark.parametrize('updates', ( + Mock(url='http://bla.bla.com', source_url=None, comment='Same url as article, no source_url'), + Mock(url='http://bla.bla.com', source_url='', comment='Same url as article, empty source_url'), + Mock(url=None, source_url=None, comment='no urls provided'), +)) +def test_update_text_no_update(csv_file, updates): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Mock(text='this should not be changed', status='as should not this', url='http://bla.bla.com') + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + dataset.update_text(updates, article) + assert article.text == 'this should not be changed' + assert article.status == 'as should not this' + + +def test_process_entry(csv_file): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Article( + _id=123, id='deadbeef0123', + title='this should be changed', + url='this should be changed', + text='this should be changed', + authors='this should be changed', + date_published='this should be changed', + id_fields=['url', 'title'], + ) + + updates = Mock( + id='123', + hash_id='deadbeef001', + title='bla bla', + url='http://bla.com', + source_url='http://bla.bla.com', + source='tests', + authors='mr. blobby, johnny', + date_published='2000-12-23T10:32:43Z', + ) + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + assert dataset.process_entry(Item(updates, article)).to_dict() == { + 'authors': ['mr. blobby', 'johnny'], + 'date_published': '2000-12-23T10:32:43Z', + 'id': 'd8d8cad8d28739a0862654a0e6e8ce6e', + 'source': 'tests', + 'source_type': None, + 'summaries': [], + 'text': 'bla bla bla', + 'title': 'bla bla', + 'url': 'http://bla.com', + 'source_url': 'http://bla.bla.com', + } From fecc5b16d0e6e9a92cfe3af2596cc07ac2931742 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Tue, 15 Aug 2023 00:40:55 +0200 Subject: [PATCH 7/7] handle missing metadata (#136) --- align_data/sources/articles/datasets.py | 4 ++ align_data/sources/articles/updater.py | 2 +- .../sources/arxiv_papers/arxiv_papers.py | 46 ++++++++++--------- tests/align_data/articles/test_updater.py | 42 +++++++++++++++++ 4 files changed, 72 insertions(+), 22 deletions(-) diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index acfc2d85..2cf52ee8 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -115,6 +115,10 @@ def process_entry(self, item): else: contents = self.get_contents(item) + # Skip items that can't be saved because missing fields + if not all(contents.get(key) for key in self.id_fields): + return None + return self.make_data_entry(contents) diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py index dc6a8e39..861d0af6 100644 --- a/align_data/sources/articles/updater.py +++ b/align_data/sources/articles/updater.py @@ -25,7 +25,7 @@ def get_item_key(self, item): @staticmethod def maybe(item, key): val = getattr(item, key, None) - if pd.isna(val): + if pd.isna(val) or (isinstance(val, str) and not val.strip()): return None return val diff --git a/align_data/sources/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py index 1671fe5a..dc551946 100644 --- a/align_data/sources/arxiv_papers/arxiv_papers.py +++ b/align_data/sources/arxiv_papers/arxiv_papers.py @@ -53,29 +53,13 @@ def is_withdrawn(url: str): return None -def fetch(url) -> Dict: - paper_id = get_id(url) - if not paper_id: - return {'error': 'Could not extract arxiv id'} - +def add_metadata(data, paper_id): metadata = get_arxiv_metadata(paper_id) - - if is_withdrawn(url): - paper = {'status': 'Withdrawn'} - else: - paper = get_contents(paper_id) - - if metadata and metadata.authors: - authors = metadata.authors - else: - authors = paper.get("authors") or [] - authors = [str(a).strip() for a in authors] - + if not metadata: + return {} return dict({ + "authors": metadata.authors, "title": metadata.title, - "url": canonical_url(url), - "source_type": paper.get('data_source'), - "authors": authors, "date_published": metadata.published, "data_last_modified": metadata.updated.isoformat(), "summary": metadata.summary.replace("\n", " "), @@ -85,4 +69,24 @@ def fetch(url) -> Dict: "primary_category": metadata.primary_category, "categories": metadata.categories, "version": get_version(metadata.get_short_id()), - }, **paper) + }, **data) + + +def fetch(url) -> Dict: + paper_id = get_id(url) + if not paper_id: + return {'error': 'Could not extract arxiv id'} + + if is_withdrawn(url): + paper = {'status': 'Withdrawn'} + else: + paper = get_contents(paper_id) + + data = add_metadata({ + "url": canonical_url(url), + "source_type": paper.get('data_source'), + }, paper_id) + authors = data.get('authors') or paper.get("authors") or [] + data['authors'] = [str(a).strip() for a in authors] + + return dict(data, **paper) diff --git a/tests/align_data/articles/test_updater.py b/tests/align_data/articles/test_updater.py index b3b26a27..72cf7243 100644 --- a/tests/align_data/articles/test_updater.py +++ b/tests/align_data/articles/test_updater.py @@ -1,5 +1,6 @@ from unittest.mock import Mock, patch from csv import DictWriter +from numpy import source import pandas as pd import pytest @@ -160,3 +161,44 @@ def test_process_entry(csv_file): 'url': 'http://bla.com', 'source_url': 'http://bla.bla.com', } + + +def test_process_entry_empty(csv_file): + dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + + article = Article( + _id=123, id='deadbeef0123', + title='this should not be changed', + url='this should not be changed', + source='this should not be changed', + authors='this should not be changed', + + text='this should be changed', + date_published='this should be changed', + id_fields=['url', 'title'], + ) + + updates = Mock( + id='123', + hash_id='deadbeef001', + title=None, + url='', + source_url='http://bla.bla.com', + source=' ', + authors=' \n \n \t \t ', + date_published='2000-12-23T10:32:43Z', + ) + + with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + assert dataset.process_entry(Item(updates, article)).to_dict() == { + 'authors': ['this should not be changed'], + 'date_published': '2000-12-23T10:32:43Z', + 'id': '606e9224254f508d297bcb17bcc6d104', + 'source': 'this should not be changed', + 'source_type': None, + 'summaries': [], + 'text': 'bla bla bla', + 'title': 'this should not be changed', + 'url': 'this should not be changed', + 'source_url': 'http://bla.bla.com', + }