diff --git a/align_data/__init__.py b/align_data/__init__.py index a602f121..9f6c9893 100644 --- a/align_data/__init__.py +++ b/align_data/__init__.py @@ -25,6 +25,7 @@ ALL_DATASETS = sorted([dataset.name for dataset in DATASET_REGISTRY]) DATASET_MAP = {dataset.name: dataset for dataset in DATASET_REGISTRY} + def get_dataset(name): try: return DATASET_MAP[name] diff --git a/align_data/analysis/analyse_jsonl_data.py b/align_data/analysis/analyse_jsonl_data.py index bb38e887..b8c5103a 100644 --- a/align_data/analysis/analyse_jsonl_data.py +++ b/align_data/analysis/analyse_jsonl_data.py @@ -4,6 +4,7 @@ from collections import defaultdict + def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"): """ Checks if the given date string matches the expected format. @@ -15,20 +16,25 @@ def is_valid_date_format(data_dict, format="%Y-%m-%dT%H:%M:%SZ"): except ValueError: return False + def validate_data(data_dict): """ - Processes each dictionary element in the jsonl file. + Processes each dictionary element in the jsonl file. """ if not is_valid_date_format(data_dict): - raise ValueError(f"Invalid date format for source: {data_dict['source']}, title: {data_dict['title'][:30]}, date_pub: {data_dict['date_published']}") + raise ValueError( + f"Invalid date format for source: {data_dict['source']}, title: {data_dict['title'][:30]}, date_pub: {data_dict['date_published']}" + ) # TODO: add more checks here + def check_for_duplicates(data_dict, seen_urls): - id = data_dict.get('id') + id = data_dict.get("id") seen_urls[id].append(data_dict) - #TODO: Add more validation logic here - return seen_urls + # TODO: Add more validation logic here + return seen_urls + def get_data_dict_str(data_dict): """ @@ -36,16 +42,18 @@ def get_data_dict_str(data_dict): """ return f"source: {data_dict['source']}, title: {data_dict['title'][:50]}, date_pub: {data_dict['date_published']}, url: {data_dict['url']}\n" + def files_iterator(data_dir): """ - Goes through the data directory, opens every jsonl file sequentially, + Goes through the data directory, opens every jsonl file sequentially, and yields every element (which is a dictionary) in the jsonl file. """ - for path in Path(data_dir).glob('*.jsonl'): + for path in Path(data_dir).glob("*.jsonl"): with jsonlines.open(path) as f: for line in f: yield line + def process_jsonl_files(data_dir): seen_urls = defaultdict(list) # holds all seen urls for data_dict in files_iterator(data_dir): @@ -57,23 +65,29 @@ def process_jsonl_files(data_dir): except Exception as e: print(f"Unexpected error: {e}") dup_count = 0 - + for id, duplicates in seen_urls.items(): if len(duplicates) > 1: - list_of_duplicates = '\n'.join(get_data_dict_str(duplicate) for duplicate in duplicates) - print(f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n") + list_of_duplicates = "\n".join( + get_data_dict_str(duplicate) for duplicate in duplicates + ) + print( + f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n" + ) dup_count += 1 print(f"Total number of duplicate ids found: {dup_count}") + def delete_all_txt_and_jsonl(data_dir): """ Deletes all txt and jsonl files in the given directory. """ - for path in Path(data_dir).glob('*.txt'): + for path in Path(data_dir).glob("*.txt"): path.unlink() - for path in Path(data_dir).glob('*.jsonl'): + for path in Path(data_dir).glob("*.jsonl"): path.unlink() + if __name__ == "__main__": process_jsonl_files("data/") - #delete_all_txt_and_jsonl("data/") + # delete_all_txt_and_jsonl("data/") diff --git a/align_data/analysis/count_tokens.py b/align_data/analysis/count_tokens.py index a601f7be..cd099c68 100644 --- a/align_data/analysis/count_tokens.py +++ b/align_data/analysis/count_tokens.py @@ -2,11 +2,15 @@ import jsonlines import logging from typing import Tuple + logger = logging.getLogger(__name__) -def count_token(merged_dataset_path : str = "data/merged_dataset/alignment_texts.jsonl") -> Tuple[int , int , int]: + +def count_token( + merged_dataset_path: str = "data/merged_dataset/alignment_texts.jsonl", +) -> Tuple[int, int, int]: tokenizer = AutoTokenizer.from_pretrained("gpt2") - total_token_count , total_word_count , total_character_count = 0 , 0 , 0 + total_token_count, total_word_count, total_character_count = 0, 0, 0 with jsonlines.open(merged_dataset_path) as reader: for obj in reader: @@ -18,6 +22,4 @@ def count_token(merged_dataset_path : str = "data/merged_dataset/alignment_texts logger.info(f"Total token count: {total_token_count}") logger.info(f"Total word count: {total_word_count}") logger.info(f"Total character count: {total_character_count}") - return total_token_count , total_word_count , total_character_count - - + return total_token_count, total_word_count, total_character_count diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 13819fef..d8676b57 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -38,10 +38,10 @@ class AlignmentDataset: _: KW_ONLY - files_path = Path('') + files_path = Path("") """The path where data can be found. Usually a folder""" - done_key = 'id' + done_key = "id" """The key of the entry to use as the id when checking if already processed.""" COOLDOWN = 0 @@ -58,30 +58,30 @@ class AlignmentDataset: _outputted_items = set() """A set of the ids of all previously processed items""" _: KW_ONLY - id_fields: List[str] = field(default_factory=lambda: ['url', 'title']) + id_fields: List[str] = field(default_factory=lambda: ["url", "title"]) """A list of fields to use as the id of the entry. If not set, will use ['url', 'title']""" def __str__(self) -> str: return self.name - def __post_init__(self, data_path=Path(__file__).parent / '../../data/'): + def __post_init__(self, data_path=Path(__file__).parent / "../../data/"): self.data_path = data_path - self.raw_data_path = self.data_path / 'raw' + self.raw_data_path = self.data_path / "raw" # set the default place to look for data self.files_path = self.raw_data_path / self.name def _add_authors(self, article: Article, authors: List[str]) -> Article: # TODO: Don't keep adding the same authors - come up with some way to reuse them - article.authors = ','.join(authors) + article.authors = ",".join(authors) if len(article.authors) > 1024: - article.authors = ','.join(article.authors[:1024].split(',')[:-1]) + article.authors = ",".join(article.authors[:1024].split(",")[:-1]) return article def make_data_entry(self, data, **kwargs) -> Article: data = dict(data, **kwargs) - summary = data.pop('summary', None) - authors = data.pop('authors', []) + summary = data.pop("summary", None) + authors = data.pop("authors", []) article = Article( id_fields=self.id_fields, @@ -95,13 +95,13 @@ def make_data_entry(self, data, **kwargs) -> Article: def to_jsonl(self, out_path=None, filename=None) -> Path: if not out_path: - out_path=Path(__file__).parent / '../../data/' + out_path = Path(__file__).parent / "../../data/" if not filename: filename = f"{self.name}.jsonl" filename = Path(out_path) / filename - with jsonlines.open(filename, 'w') as jsonl_writer: + with jsonlines.open(filename, "w") as jsonl_writer: for article in self.read_entries(): jsonl_writer.write(article.to_dict()) return filename.resolve() @@ -109,7 +109,7 @@ def to_jsonl(self, out_path=None, filename=None) -> Path: def read_entries(self, sort_by=None): """Iterate through all the saved entries.""" with make_session() as session: - query = select(Article).where(Article.source==self.name) + query = select(Article).where(Article.source == self.name) if sort_by is not None: query = query.order_by(sort_by) for item in session.scalars(query): @@ -136,8 +136,8 @@ def commit(): for entry in batch: session.add(entry) if not commit(): - logger.error(f'found duplicate of {entry}') - + logger.error(f"found duplicate of {entry}") + def setup(self): self._outputted_items = self._load_outputted_items() @@ -160,9 +160,14 @@ def _load_outputted_items(self) -> Set[str]: # This doesn't filter by self.name. The good thing about that is that it should handle a lot more # duplicates. The bad thing is that this could potentially return a massive amount of data if there # are lots of items. - return set(session.scalars(select(getattr(Article, self.done_key))).all()) + return set( + session.scalars(select(getattr(Article, self.done_key))).all() + ) # TODO: Properly handle this - it should create a proper SQL JSON select - return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()} + return { + item.get(self.done_key) + for item in session.scalars(select(Article.meta)).all() + } def unprocessed_items(self, items=None) -> Iterable: """Return a list of all items to be processed. @@ -213,7 +218,6 @@ def _get_published_date(self, date) -> Optional[datetime]: class SummaryDataset(AlignmentDataset): - def unprocessed_items(self, items=None) -> Iterable: # This breaks the possible lazy loading of the items. Should be fine... items = list(super().unprocessed_items(items)) @@ -221,7 +225,10 @@ def unprocessed_items(self, items=None) -> Iterable: urls = map(self.get_item_key, items) with make_session() as session: self.articles = { - a.url: a for a in session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls)) + a.url: a + for a in session.query(Article) + .options(joinedload(Article.summaries)) + .filter(Article.url.in_(urls)) if a.url } @@ -230,7 +237,13 @@ def unprocessed_items(self, items=None) -> Iterable: def _load_outputted_items(self) -> Set[str]: """Load the output file (if it exists) in order to know which items have already been output.""" with make_session() as session: - return set(session.scalars(select(Article.url).join(Article.summaries).filter(Summary.source == self.name))) + return set( + session.scalars( + select(Article.url) + .join(Article.summaries) + .filter(Summary.source == self.name) + ) + ) def _add_batch(self, session, batch): def merge(item): diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py index e374dfc9..9e3799f3 100644 --- a/align_data/common/html_dataset.py +++ b/align_data/common/html_dataset.py @@ -16,11 +16,13 @@ logger = logging.getLogger(__name__) + @dataclass class HTMLDataset(AlignmentDataset): """ Fetches articles from a different blog by collecting links to articles from an index page. """ + url: str done_key = "url" @@ -29,9 +31,9 @@ class HTMLDataset(AlignmentDataset): source_key: str = None summary_key: str = None - item_selector = 'article' - title_selector = 'article h1' - text_selector = 'article' + item_selector = "article" + title_selector = "article h1" + text_selector = "article" source_type = "blog" ignored_selectors = [] @@ -64,16 +66,18 @@ def process_entry(self, article): if not text: return None - return self.make_data_entry({ - "text": text, - "url": article_url, - "title": title, - "source": self.name, - "source_type": "blog", - "date_published": date_published, - "authors": self.extract_authors(contents), - **self._extra_values(contents), - }) + return self.make_data_entry( + { + "text": text, + "url": article_url, + "title": title, + "source": self.name, + "source_type": "blog", + "date_published": date_published, + "authors": self.extract_authors(contents), + **self._extra_values(contents), + } + ) def _get_contents(self, url): logger.info("Fetching {}".format(url)) @@ -93,8 +97,8 @@ def _get_text(self, contents): def _find_date(self, items): for i in items: - if re.match('\w+ \d{1,2}, \d{4}', i.text): - return datetime.strptime(i.text, '%b %d, %Y').replace(tzinfo=pytz.UTC) + if re.match("\w+ \d{1,2}, \d{4}", i.text): + return datetime.strptime(i.text, "%b %d, %Y").replace(tzinfo=pytz.UTC) def _extract_markdown(self, element): return element and markdownify(str(element)).strip() @@ -102,35 +106,35 @@ def _extract_markdown(self, element): @dataclass class RSSDataset(HTMLDataset): - date_format = '%a, %d %b %Y %H:%M:%S %z' + date_format = "%a, %d %b %Y %H:%M:%S %z" def get_item_key(self, item): return item @property def feed_url(self): - return f'{self.url}/rss.xml' + return f"{self.url}/rss.xml" def extract_authors(self, item): - if 'authors' in item: - return [a['name'] for a in item['authors'] if a.get('name')] + if "authors" in item: + return [a["name"] for a in item["authors"] if a.get("name")] return self.authors @staticmethod def _get_title(item): - return item['title'] + return item["title"] def _get_published_date(self, item): - date_published = item.get('published') or item.get('pubDate') + date_published = item.get("published") or item.get("pubDate") return super()._get_published_date(date_published) def _get_text(self, item): - text = item.get('content') and item['content'][0].get('value') + text = item.get("content") and item["content"][0].get("value") return self._extract_markdown(text) def _get_contents(self, url): item = self.items[url] - if 'content' in item: + if "content" in item: return item logger.info("Fetching {}".format(url)) @@ -145,5 +149,5 @@ def _get_contents(self, url): def items_list(self): logger.info(f"Fetching entries from {self.feed_url}") feed = feedparser.parse(self.feed_url) - self.items = {item['link']: item for item in feed['entries']} + self.items = {item["link"]: item for item in feed["entries"]} return list(self.items.keys()) diff --git a/align_data/db/models.py b/align_data/db/models.py index 06bee3dd..029c3d1c 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -4,7 +4,17 @@ import hashlib from datetime import datetime from typing import List, Optional -from sqlalchemy import JSON, DateTime, ForeignKey, String, Boolean, Text, Float, func, event +from sqlalchemy import ( + JSON, + DateTime, + ForeignKey, + String, + Boolean, + Text, + Float, + func, + event, +) from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.dialects.mysql import LONGTEXT from align_data.settings import PINECONE_METADATA_KEYS @@ -18,7 +28,6 @@ class Base(DeclarativeBase): class Summary(Base): - __tablename__ = "summaries" id: Mapped[int] = mapped_column(primary_key=True) @@ -30,28 +39,33 @@ class Summary(Base): class Article(Base): - __tablename__ = "articles" - _id: Mapped[int] = mapped_column('id', primary_key=True) - id: Mapped[str] = mapped_column('hash_id', String(32), unique=True, nullable=False) + _id: Mapped[int] = mapped_column("id", primary_key=True) + id: Mapped[str] = mapped_column("hash_id", String(32), unique=True, nullable=False) title: Mapped[Optional[str]] = mapped_column(String(1028)) url: Mapped[Optional[str]] = mapped_column(String(1028)) source: Mapped[Optional[str]] = mapped_column(String(128)) source_type: Mapped[Optional[str]] = mapped_column(String(128)) authors: Mapped[str] = mapped_column(String(1024)) text: Mapped[Optional[str]] = mapped_column(LONGTEXT) - confidence: Mapped[Optional[float]] # Describes the confidence in how good this article is, as a value <0, 1> + confidence: Mapped[ + Optional[float] + ] # Describes the confidence in how good this article is, as a value <0, 1> date_published: Mapped[Optional[datetime]] - meta: Mapped[Optional[JSON]] = mapped_column(JSON, name='metadata', default='{}') + meta: Mapped[Optional[JSON]] = mapped_column(JSON, name="metadata", default="{}") date_created: Mapped[datetime] = mapped_column(DateTime, default=func.now()) - date_updated: Mapped[Optional[datetime]] = mapped_column(DateTime, onupdate=func.current_timestamp()) - + date_updated: Mapped[Optional[datetime]] = mapped_column( + DateTime, onupdate=func.current_timestamp() + ) + pinecone_update_required: Mapped[bool] = mapped_column(Boolean, default=False) - - summaries: Mapped[List["Summary"]] = relationship(back_populates="article", cascade="all, delete-orphan") - __id_fields = ['url', 'title'] + summaries: Mapped[List["Summary"]] = relationship( + back_populates="article", cascade="all, delete-orphan" + ) + + __id_fields = ["url", "title"] def __init__(self, *args, id_fields, **kwargs): self.__id_fields = id_fields @@ -59,32 +73,39 @@ def __init__(self, *args, id_fields, **kwargs): def __repr__(self) -> str: return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})" - + def is_metadata_keys_equal(self, other): if not isinstance(other, Article): - raise TypeError(f"Expected an instance of Article, got {type(other).__name__}") + raise TypeError( + f"Expected an instance of Article, got {type(other).__name__}" + ) return not any( - getattr(self, key, None) != getattr(other, key, None) # entry_id is implicitly ignored + getattr(self, key, None) + != getattr(other, key, None) # entry_id is implicitly ignored for key in PINECONE_METADATA_KEYS ) def generate_id_string(self) -> str: - return ''.join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8") + return "".join(str(getattr(self, field)) for field in self.__id_fields).encode( + "utf-8" + ) def verify_fields(self): missing = [field for field in self.__id_fields if not getattr(self, field)] - assert not missing, f'Entry is missing the following fields: {missing}' - + assert not missing, f"Entry is missing the following fields: {missing}" + def verify_id(self): assert self.id is not None, "Entry is missing id" id_string = self.generate_id_string() id_from_fields = hashlib.md5(id_string).hexdigest() - assert self.id == id_from_fields, f"Entry id {self.id} does not match id from id_fields, {id_from_fields}" + assert ( + self.id == id_from_fields + ), f"Entry id {self.id} does not match id from id_fields, {id_from_fields}" def update(self, other): for field in self.__table__.columns.keys(): - if field not in ['id', 'hash_id', 'metadata'] and getattr(other, field): + if field not in ["id", "hash_id", "metadata"] and getattr(other, field): setattr(self, field, getattr(other, field)) self.meta.update({k: v for k, v in other.meta.items() if k and v}) @@ -114,21 +135,21 @@ def to_dict(self): authors = [] if self.authors and self.authors.strip(): - authors = [i.strip() for i in self.authors.split(',')] + authors = [i.strip() for i in self.authors.split(",")] return { - 'id': self.id, - 'title': self.title, - 'url': self.url, - 'source': self.source, - 'source_type': self.source_type, - 'text': self.text, - 'date_published': date, - 'authors': authors, - 'summaries': [s.text for s in (self.summaries or [])], + "id": self.id, + "title": self.title, + "url": self.url, + "source": self.source, + "source_type": self.source_type, + "text": self.text, + "date_published": date, + "authors": authors, + "summaries": [s.text for s in (self.summaries or [])], **(self.meta or {}), } -event.listen(Article, 'before_insert', Article.before_write) -event.listen(Article, 'before_update', Article.before_write) +event.listen(Article, "before_insert", Article.before_write) +event.listen(Article, "before_update", Article.before_write) diff --git a/align_data/db/session.py b/align_data/db/session.py index f3eb6468..ace0ff8a 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -24,6 +24,4 @@ def stream_pinecone_updates(session, custom_sources: List[str]): """Yield Pinecone entries that require an update.""" yield from session.query(Article).filter( Article.pinecone_update_required.is_(True) - ).filter( - Article.source.in_(custom_sources) - ).yield_per(1000) \ No newline at end of file + ).filter(Article.source.in_(custom_sources)).yield_per(1000) diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py index 4168cb70..d8f565df 100644 --- a/align_data/pinecone/pinecone_db_handler.py +++ b/align_data/pinecone/pinecone_db_handler.py @@ -5,7 +5,14 @@ import pinecone -from align_data.settings import PINECONE_INDEX_NAME, PINECONE_VALUES_DIMS, PINECONE_METRIC, PINECONE_METADATA_KEYS, PINECONE_API_KEY, PINECONE_ENVIRONMENT +from align_data.settings import ( + PINECONE_INDEX_NAME, + PINECONE_VALUES_DIMS, + PINECONE_METRIC, + PINECONE_METADATA_KEYS, + PINECONE_API_KEY, + PINECONE_ENVIRONMENT, +) logger = logging.getLogger(__name__) @@ -25,58 +32,60 @@ def __init__( self.values_dims = values_dims self.metric = metric self.metadata_keys = metadata_keys - + pinecone.init( - api_key = PINECONE_API_KEY, - environment = PINECONE_ENVIRONMENT, + api_key=PINECONE_API_KEY, + environment=PINECONE_ENVIRONMENT, ) - + if create_index: self.create_index() - + self.index = pinecone.Index(index_name=self.index_name) - + if log_index_stats: index_stats_response = self.index.describe_index_stats() logger.info(f"{self.index_name}:\n{index_stats_response}") - + def upsert_entry(self, entry: Dict, upsert_size=100): self.index.upsert( vectors=list( zip( - [f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(entry['text_chunks']))], - entry['embeddings'].tolist(), + [ + f"{entry['id']}_{str(i).zfill(6)}" + for i in range(len(entry["text_chunks"])) + ], + entry["embeddings"].tolist(), [ { - 'entry_id': entry['id'], - 'source': entry['source'], - 'title': entry['title'], - 'authors': entry['authors'], - 'text': text_chunk, - } for text_chunk in entry['text_chunks'] - ] + "entry_id": entry["id"], + "source": entry["source"], + "title": entry["title"], + "authors": entry["authors"], + "text": text_chunk, + } + for text_chunk in entry["text_chunks"] + ], ) ), - batch_size=upsert_size + batch_size=upsert_size, ) - + def delete_entries(self, ids): - self.index.delete( - filter={"entry_id": {"$in": ids}} - ) + self.index.delete(filter={"entry_id": {"$in": ids}}) def create_index(self, replace_current_index: bool = True): if replace_current_index: self.delete_index() - + pinecone.create_index( name=self.index_name, dimension=self.values_dims, metric=self.metric, - metadata_config = {"indexed": self.metadata_keys}, + metadata_config={"indexed": self.metadata_keys}, ) def delete_index(self): if self.index_name in pinecone.list_indexes(): logger.info(f"Deleting index '{self.index_name}'.") - pinecone.delete_index(self.index_name) \ No newline at end of file + pinecone.delete_index(self.index_name) diff --git a/align_data/pinecone/text_splitter.py b/align_data/pinecone/text_splitter.py index 76bb29b8..c732c99c 100644 --- a/align_data/pinecone/text_splitter.py +++ b/align_data/pinecone/text_splitter.py @@ -7,29 +7,33 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter): """A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc). - + @param min_chunk_size: The minimum number of units in a chunk. @param max_chunk_size: The maximum number of units in a chunk. @param length_function: A function that returns the length of a string in units. @param truncate_function: A function that truncates a string to a given unit length. """ - + DEFAULT_MIN_CHUNK_SIZE = 900 DEFAULT_MAX_CHUNK_SIZE = 1100 DEFAULT_LENGTH_FUNCTION = lambda string: len(string) - DEFAULT_TRUNCATE_FUNCTION = lambda string, length, from_end=False: string[-length:] if from_end else string[:length] + DEFAULT_TRUNCATE_FUNCTION = ( + lambda string, length, from_end=False: string[-length:] + if from_end + else string[:length] + ) def __init__( - self, + self, min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, length_function: Callable[[str], int] = DEFAULT_LENGTH_FUNCTION, truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION, - **kwargs: Any + **kwargs: Any, ): super().__init__(**kwargs) self.min_chunk_size = min_chunk_size - self.max_chunk_size = max_chunk_size + self.max_chunk_size = max_chunk_size self._length_function = length_function self._truncate_function = truncate_function @@ -43,26 +47,30 @@ def split_text(self, text: str) -> List[str]: current_block += "\n\n" + paragraph block_length = self._length_function(current_block) - if block_length > self.max_chunk_size: # current block is too large, truncate it - current_block = self._handle_large_paragraph(current_block, blocks, paragraph) + if ( + block_length > self.max_chunk_size + ): # current block is too large, truncate it + current_block = self._handle_large_paragraph( + current_block, blocks, paragraph + ) elif block_length >= self.min_chunk_size: blocks.append(current_block) current_block = "" else: # current block is too small, continue appending to it continue - + blocks = self._handle_remaining_text(current_block, blocks) return [block.strip() for block in blocks] def _handle_large_paragraph(self, current_block, blocks, paragraph): # Undo adding the whole paragraph - current_block = current_block[:-(len(paragraph)+2)] # +2 accounts for "\n\n" + current_block = current_block[: -(len(paragraph) + 2)] # +2 accounts for "\n\n" sentences = sent_tokenize(paragraph) for sentence in sentences: current_block += f" {sentence}" - + block_length = self._length_function(current_block) if block_length < self.min_chunk_size: continue @@ -70,19 +78,23 @@ def _handle_large_paragraph(self, current_block, blocks, paragraph): blocks.append(current_block) current_block = "" else: - current_block = self._truncate_large_block(current_block, blocks, sentence) - + current_block = self._truncate_large_block( + current_block, blocks, sentence + ) + return current_block def _truncate_large_block(self, current_block, blocks, sentence): while self._length_function(current_block) > self.max_chunk_size: # Truncate current_block to max size, set remaining sentence as next sentence - truncated_block = self._truncate_function(current_block, self.max_chunk_size) + truncated_block = self._truncate_function( + current_block, self.max_chunk_size + ) blocks.append(truncated_block) - remaining_sentence = current_block[len(truncated_block):].lstrip() + remaining_sentence = current_block[len(truncated_block) :].lstrip() current_block = sentence = remaining_sentence - + return current_block def _handle_remaining_text(self, current_block, blocks): @@ -93,13 +105,17 @@ def _handle_remaining_text(self, current_block, blocks): if len_current_block < self.min_chunk_size: # it needs to take the last min_chunk_size-len_current_block units from the previous block previous_block = blocks[-1] - required_units = self.min_chunk_size - len_current_block # calculate the required units + required_units = ( + self.min_chunk_size - len_current_block + ) # calculate the required units - part_prev_block = self._truncate_function(previous_block, required_units, from_end=True) # get the required units from the previous block + part_prev_block = self._truncate_function( + previous_block, required_units, from_end=True + ) # get the required units from the previous block last_block = part_prev_block + current_block blocks.append(last_block) else: blocks.append(current_block) - return blocks \ No newline at end of file + return blocks diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py index 9e52276d..649a7a2a 100644 --- a/align_data/pinecone/update_pinecone.py +++ b/align_data/pinecone/update_pinecone.py @@ -11,10 +11,17 @@ from align_data.db.session import make_session, stream_pinecone_updates from align_data.pinecone.pinecone_db_handler import PineconeDB from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter -from align_data.settings import USE_OPENAI_EMBEDDINGS, OPENAI_EMBEDDINGS_MODEL, \ - OPENAI_EMBEDDINGS_DIMS, OPENAI_EMBEDDINGS_RATE_LIMIT, \ - SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, \ - CHUNK_SIZE, MAX_NUM_AUTHORS_IN_SIGNATURE, EMBEDDING_LENGTH_BIAS +from align_data.settings import ( + USE_OPENAI_EMBEDDINGS, + OPENAI_EMBEDDINGS_MODEL, + OPENAI_EMBEDDINGS_DIMS, + OPENAI_EMBEDDINGS_RATE_LIMIT, + SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, + SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, + CHUNK_SIZE, + MAX_NUM_AUTHORS_IN_SIGNATURE, + EMBEDDING_LENGTH_BIAS, +) logger = logging.getLogger(__name__) @@ -29,54 +36,69 @@ class PineconeEntry(BaseModel): authors: List[str] text_chunks: List[str] embeddings: np.ndarray - + class Config: arbitrary_types_allowed = True def __repr__(self): return f"PineconeEntry(id={self.id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={self.text_chunks[:5]!r})" - @validator('id', 'source', 'title', 'url', 'date_published', 'authors', 'text_chunks', pre=True, always=True) + @validator( + "id", + "source", + "title", + "url", + "date_published", + "authors", + "text_chunks", + pre=True, + always=True, + ) def empty_strings_not_allowed(cls, value): if not str(value).strip(): raise ValueError("Attribute should not be empty.") return value - + class PineconeUpdater: def __init__( - self, + self, min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE, max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE, - length_function: Callable[[str], int] = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION, - truncate_function: Callable[[str, int], str] = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION, + length_function: Callable[ + [str], int + ] = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION, + truncate_function: Callable[ + [str, int], str + ] = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION, ): self.min_chunk_size = min_chunk_size self.max_chunk_size = max_chunk_size self.length_function = length_function self.truncate_function = truncate_function - + self.text_splitter = ParagraphSentenceUnitTextSplitter( min_chunk_size=self.min_chunk_size, max_chunk_size=self.max_chunk_size, length_function=self.length_function, - truncate_function=self.truncate_function + truncate_function=self.truncate_function, ) self.pinecone_db = PineconeDB() - + if USE_OPENAI_EMBEDDINGS: import openai - openai.api_key = os.environ['OPENAI_API_KEY'] + + openai.api_key = os.environ["OPENAI_API_KEY"] else: import torch from langchain.embeddings import HuggingFaceEmbeddings - + self.hf_embeddings = HuggingFaceEmbeddings( model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, - model_kwargs={'device': "cuda" if torch.cuda.is_available() else "cpu"}, - encode_kwargs={'show_progress_bar': False} + model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}, + encode_kwargs={"show_progress_bar": False}, ) - + def update(self, custom_sources: List[str]): """ Update the given sources. If no sources are provided, updates all sources. @@ -90,8 +112,10 @@ def update(self, custom_sources: List[str]): article.pinecone_update_required = False session.add(article) session.commit() - - def process_entries(self, article_stream: Generator[Article, None, None]) -> Generator[Tuple[Article, PineconeEntry], None, None]: + + def process_entries( + self, article_stream: Generator[Article, None, None] + ) -> Generator[Tuple[Article, PineconeEntry], None, None]: for article in article_stream: try: text_chunks = self.get_text_chunks(article) @@ -101,45 +125,54 @@ def process_entries(self, article_stream: Generator[Article, None, None]) -> Gen title=article.title, url=article.url, date_published=article.date_published, - authors=[author.strip() for author in article.authors.split(',') if author.strip()], + authors=[ + author.strip() + for author in article.authors.split(",") + if author.strip() + ], text_chunks=text_chunks, - embeddings=self.extract_embeddings(text_chunks, [article.source] * len(text_chunks)) + embeddings=self.extract_embeddings( + text_chunks, [article.source] * len(text_chunks) + ), ) except (ValueError, ValidationError) as e: logger.exception(e) - + def get_text_chunks(self, article: Article) -> List[str]: signature = f"Title: {article.title}, Author(s): {self.get_authors_str(article.authors)}" text_chunks = self.text_splitter.split_text(article.text) text_chunks = [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks] return text_chunks - + def extract_embeddings(self, chunks_batch, sources_batch): if USE_OPENAI_EMBEDDINGS: return self.get_openai_embeddings(chunks_batch, sources_batch) else: - return np.array(self.hf_embeddings.embed_documents(chunks_batch, sources_batch)) + return np.array( + self.hf_embeddings.embed_documents(chunks_batch, sources_batch) + ) @staticmethod - def get_openai_embeddings(chunks, sources=''): + def get_openai_embeddings(chunks, sources=""): embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS)) - + openai_output = openai.Embedding.create( - model=OPENAI_EMBEDDINGS_MODEL, - input=chunks - )['data'] - + model=OPENAI_EMBEDDINGS_MODEL, input=chunks + )["data"] + for i, (embedding, source) in enumerate(zip(openai_output, sources)): bias = EMBEDDING_LENGTH_BIAS.get(source, 1.0) - embeddings[i] = bias * np.array(embedding['embedding']) - + embeddings[i] = bias * np.array(embedding["embedding"]) + return embeddings @staticmethod def get_authors_str(authors_lst: List[str]) -> str: - if authors_lst == []: return 'n/a' - if len(authors_lst) == 1: return authors_lst[0] + if authors_lst == []: + return "n/a" + if len(authors_lst) == 1: + return authors_lst[0] else: authors_lst = authors_lst[:MAX_NUM_AUTHORS_IN_SIGNATURE] authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}" - return authors_str \ No newline at end of file + return authors_str diff --git a/align_data/postprocess/postprocess.py b/align_data/postprocess/postprocess.py index 3f26885e..1366bc11 100644 --- a/align_data/postprocess/postprocess.py +++ b/align_data/postprocess/postprocess.py @@ -1,4 +1,4 @@ -#%% +# %% from dataclasses import dataclass import jsonlines from tqdm import tqdm @@ -6,54 +6,69 @@ from path import Path import pylab as plt + # import seaborn as sns import pandas as pd logger = logging.getLogger(__name__) + @dataclass class PostProcesser: """ This class is used to postprocess the data """ - jsonl_path : Path = Path('../../data/') + + jsonl_path: Path = Path("../../data/") def __init__(self) -> None: - self.jsonl_list = sorted(self.jsonl_path.files('*.jsonl')) - self.source_list = [path.name.split('.jsonl')[0] for path in self.jsonl_list] + self.jsonl_list = sorted(self.jsonl_path.files("*.jsonl")) + self.source_list = [path.name.split(".jsonl")[0] for path in self.jsonl_list] def compute_statistics(self) -> None: - self.all_stats = {key : {} for key in self.source_list} - for source_name , path in tqdm(zip(self.source_list , self.jsonl_list)): + self.all_stats = {key: {} for key in self.source_list} + for source_name, path in tqdm(zip(self.source_list, self.jsonl_list)): with jsonlines.open(path) as reader: for obj in reader: - self.all_stats[source_name]['num_entries'] = self.all_stats[source_name].get('num_entries' , 0) + 1 - self.all_stats[source_name]['num_tokens'] = self.all_stats[source_name].get('num_tokens' , 0) + len(obj['text'].split()) - self.all_stats[source_name]['num_chars'] = self.all_stats[source_name].get('num_chars' , 0) + len(obj['text']) - self.all_stats[source_name]['num_words'] = self.all_stats[source_name].get('num_words' , 0) + len(obj['text'].split()) - self.all_stats[source_name]['num_sentences'] = self.all_stats[source_name].get('num_sentences' , 0) + len(obj['text'].split('.')) - self.all_stats[source_name]['num_paragraphs'] = self.all_stats[source_name].get('num_paragraphs' , 0) + len(obj['text'].splitlines()) - + self.all_stats[source_name]["num_entries"] = ( + self.all_stats[source_name].get("num_entries", 0) + 1 + ) + self.all_stats[source_name]["num_tokens"] = self.all_stats[ + source_name + ].get("num_tokens", 0) + len(obj["text"].split()) + self.all_stats[source_name]["num_chars"] = self.all_stats[ + source_name + ].get("num_chars", 0) + len(obj["text"]) + self.all_stats[source_name]["num_words"] = self.all_stats[ + source_name + ].get("num_words", 0) + len(obj["text"].split()) + self.all_stats[source_name]["num_sentences"] = self.all_stats[ + source_name + ].get("num_sentences", 0) + len(obj["text"].split(".")) + self.all_stats[source_name]["num_paragraphs"] = self.all_stats[ + source_name + ].get("num_paragraphs", 0) + len(obj["text"].splitlines()) + def plot_statistics(self) -> None: all_df = pd.DataFrame(self.all_stats).T - plt.figure(figsize = (5 , 5)) - sns.barplot(x = all_df.index , y = all_df['num_entries']) - + plt.figure(figsize=(5, 5)) + sns.barplot(x=all_df.index, y=all_df["num_entries"]) - def merge_all_files(self , out_dir : str) -> str: + def merge_all_files(self, out_dir: str) -> str: pass def deduplicate(self) -> None: for path in tqdm(self.jsonl_list): - with jsonlines.open(path , 'r') as reader: - all_obj = {obj['id'] : obj for obj in reader} - with jsonlines.open(path , 'w') as writer: + with jsonlines.open(path, "r") as reader: + all_obj = {obj["id"]: obj for obj in reader} + with jsonlines.open(path, "w") as writer: for obj in all_obj.values(): writer.write(obj) - def clean_dataset(self , merged_dataset_path : str) -> str: + def clean_dataset(self, merged_dataset_path: str) -> str: pass + pp = PostProcesser() # %% pp.source_list diff --git a/align_data/settings.py b/align_data/settings.py index 43d62e8a..1244861f 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -1,30 +1,35 @@ import os from dotenv import load_dotenv + load_dotenv() ### CODA ### CODA_TOKEN = os.environ.get("CODA_TOKEN") CODA_DOC_ID = os.environ.get("CODA_DOC_ID", "fau7sl2hmG") -ON_SITE_TABLE = os.environ.get('CODA_ON_SITE_TABLE', 'table-aOTSHIz_mN') +ON_SITE_TABLE = os.environ.get("CODA_ON_SITE_TABLE", "table-aOTSHIz_mN") ### GOOGLE DRIVE ### -PDFS_FOLDER_ID = os.environ.get('PDF_FOLDER_ID', '1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN') +PDFS_FOLDER_ID = os.environ.get("PDF_FOLDER_ID", "1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN") ### GOOGLE SHEETS ### -METADATA_SOURCE_SPREADSHEET = os.environ.get('METADATA_SOURCE_SPREADSHEET', '1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI') -METADATA_SOURCE_SHEET = os.environ.get('METADATA_SOURCE_SHEET', 'special_docs.csv') -METADATA_OUTPUT_SPREADSHEET = os.environ.get('METADATA_OUTPUT_SPREADSHEET', '1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4') +METADATA_SOURCE_SPREADSHEET = os.environ.get( + "METADATA_SOURCE_SPREADSHEET", "1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI" +) +METADATA_SOURCE_SHEET = os.environ.get("METADATA_SOURCE_SHEET", "special_docs.csv") +METADATA_OUTPUT_SPREADSHEET = os.environ.get( + "METADATA_OUTPUT_SPREADSHEET", "1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4" +) ### YouTube ### -YOUTUBE_API_KEY = os.environ.get('YOUTUBE_API_KEY') +YOUTUBE_API_KEY = os.environ.get("YOUTUBE_API_KEY") ### MYSQL ### -user = os.environ.get('ARD_DB_USER', 'user') -password = os.environ.get('ARD_DB_PASSWORD', 'we all live in a yellow submarine') -host = os.environ.get('ARD_DB_HOST', '127.0.0.1') -port = os.environ.get('ARD_DB_PORT', '3306') -db_name = os.environ.get('ARD_DB_NAME', 'alignment_research_dataset') -DB_CONNECTION_URI = f'mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}' +user = os.environ.get("ARD_DB_USER", "user") +password = os.environ.get("ARD_DB_PASSWORD", "we all live in a yellow submarine") +host = os.environ.get("ARD_DB_HOST", "127.0.0.1") +port = os.environ.get("ARD_DB_PORT", "3306") +db_name = os.environ.get("ARD_DB_NAME", "alignment_research_dataset") +DB_CONNECTION_URI = f"mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}" ### EMBEDDINGS ### USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used. @@ -36,14 +41,20 @@ OPENAI_EMBEDDINGS_DIMS = 1536 OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 -SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" +SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = ( + "sentence-transformers/multi-qa-mpnet-base-cos-v1" +) SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768 ### PINECONE ### PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME", "stampy-chat-ard") PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) -PINECONE_VALUES_DIMS = OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS +PINECONE_VALUES_DIMS = ( + OPENAI_EMBEDDINGS_DIMS + if USE_OPENAI_EMBEDDINGS + else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS +) PINECONE_METRIC = "dotproduct" PINECONE_METADATA_KEYS = ["entry_id", "source", "title", "authors", "text", "url"] diff --git a/align_data/sources/alignment_newsletter/__init__.py b/align_data/sources/alignment_newsletter/__init__.py index 770eb5d5..e2d65fb1 100644 --- a/align_data/sources/alignment_newsletter/__init__.py +++ b/align_data/sources/alignment_newsletter/__init__.py @@ -2,7 +2,7 @@ import os ALIGNMENT_NEWSLETTER_REGISTRY = [ - AlignmentNewsletter( - name = "alignment_newsletter" , id_fields=['url', 'title', 'source'] - ), -] \ No newline at end of file + AlignmentNewsletter( + name="alignment_newsletter", id_fields=["url", "title", "source"] + ), +] diff --git a/align_data/sources/alignment_newsletter/alignment_newsletter.py b/align_data/sources/alignment_newsletter/alignment_newsletter.py index fa613640..2b68e32f 100644 --- a/align_data/sources/alignment_newsletter/alignment_newsletter.py +++ b/align_data/sources/alignment_newsletter/alignment_newsletter.py @@ -12,12 +12,11 @@ @dataclass class AlignmentNewsletter(SummaryDataset): - done_key = "url" - def __post_init__(self, data_path=Path(__file__).parent / '../../../data/'): + def __post_init__(self, data_path=Path(__file__).parent / "../../../data/"): self.data_path = data_path - self.raw_data_path = self.data_path / 'raw' + self.raw_data_path = self.data_path / "raw" def setup(self) -> None: super().setup() @@ -54,26 +53,30 @@ def process_entry(self, row): def handle_na(v, cast=None): if not self.maybe(v): - return '' + return "" if cast: return cast(v) return v - return self.make_data_entry({ - "url": handle_na(row.URL), - "source": handle_na(self.name), - "converted_with": "python", - "source_type": "google-sheets", - "venue": handle_na(row.Venue, str), # arXiv, Distill, LessWrong, Alignment Forum, ICML 2018, etc - "newsletter_category": handle_na(row.Category, str), - "highlight": row[2] == "Highlight", - "newsletter_number": handle_na(row.Email, str), - "summarizer": handle_na(row.Summarizer, str), - "opinion": handle_na(row[11], str), - "prerequisites": handle_na(row.Prerequisites, str), - "read_more": handle_na(row[13], str), - "title": handle_na(row.Title, str), - "authors": [i.strip() for i in str(row.Authors).split(',')], - "date_published": self._get_published_date(row.Year), - "summary": handle_na(row.Summary, str), - }) + return self.make_data_entry( + { + "url": handle_na(row.URL), + "source": handle_na(self.name), + "converted_with": "python", + "source_type": "google-sheets", + "venue": handle_na( + row.Venue, str + ), # arXiv, Distill, LessWrong, Alignment Forum, ICML 2018, etc + "newsletter_category": handle_na(row.Category, str), + "highlight": row[2] == "Highlight", + "newsletter_number": handle_na(row.Email, str), + "summarizer": handle_na(row.Summarizer, str), + "opinion": handle_na(row[11], str), + "prerequisites": handle_na(row.Prerequisites, str), + "read_more": handle_na(row[13], str), + "title": handle_na(row.Title, str), + "authors": [i.strip() for i in str(row.Authors).split(",")], + "date_published": self._get_published_date(row.Year), + "summary": handle_na(row.Summary, str), + } + ) diff --git a/align_data/sources/arbital/__init__.py b/align_data/sources/arbital/__init__.py index ad077e15..fda6c3db 100644 --- a/align_data/sources/arbital/__init__.py +++ b/align_data/sources/arbital/__init__.py @@ -1,4 +1,4 @@ from .arbital import Arbital -ARBITAL_REGISTRY = [Arbital(name='arbital')] +ARBITAL_REGISTRY = [Arbital(name="arbital")] diff --git a/align_data/sources/arbital/arbital.py b/align_data/sources/arbital/arbital.py index 4f147e12..40ce16c3 100644 --- a/align_data/sources/arbital/arbital.py +++ b/align_data/sources/arbital/arbital.py @@ -11,13 +11,13 @@ def parse_arbital_link(contents): - text = contents[1].split(' ') - url = f'https://arbital.com/p/{text[0]}' + text = contents[1].split(" ") + url = f"https://arbital.com/p/{text[0]}" if len(text) > 1: - title = ' '.join(text[1:]) + title = " ".join(text[1:]) else: title = url - return f'[{title}]({url})' + return f"[{title}]({url})" def flatten(val): @@ -45,73 +45,78 @@ def markdownify_text(current, view): in_link = False for part, next_part in view: - if part == '[': + if part == "[": # Recursively try to parse this new section - it's probably a link, but can be something else current.append(markdownify_text([part], view)) - elif part == ']' and next_part == '(': + elif part == "]" and next_part == "(": # mark that it's now in the url part of a markdown link - current.append(']') + current.append("]") in_link = True - elif part == ']': + elif part == "]": # this is the arbital summary - just join it for now, but it'll have to be handled later - if current[1].startswith('summary'): - return ''.join(current[1:]) + if current[1].startswith("summary"): + return "".join(current[1:]) # if this was a TODO section, then ignore it - if current[1].startswith('todo'): - return '' + if current[1].startswith("todo"): + return "" # Otherwise it's an arbital link return parse_arbital_link(current) - elif in_link and part == ')': + elif in_link and part == ")": # this is the end of a markdown link - just join the contents, as they're already correct - return ''.join(current + [part]) - elif in_link and current[-1] == '(' and next_part != ')': + return "".join(current + [part]) + elif in_link and current[-1] == "(" and next_part != ")": # This link is strange... looks like it could be malformed? # Assuming that it's malformed and missing a closing `)` # This will remove any additional info in the link, but that seems a reasonable price? - words = part.split(' ') - return ''.join(current + [words[0], ') ', ' '.join(words[1:])]) + words = part.split(" ") + return "".join(current + [words[0], ") ", " ".join(words[1:])]) else: # Just your basic text - add it to the processed parts and go on your merry way current.append(part) # Check if the first item is the summary - if so, extract it - summary = '' - if current[0].startswith('summary'): - _, summary = re.split(r'summary[()\w]*:', current[0], 1) + summary = "" + if current[0].startswith("summary"): + _, summary = re.split(r"summary[()\w]*:", current[0], 1) current = current[1:] # Otherwise just join all the parts back together - return summary.strip(), ''.join(flatten(current)).strip() + return summary.strip(), "".join(flatten(current)).strip() def extract_text(text): - parts = [i for i in re.split('([\[\]()])', text) if i] + parts = [i for i in re.split("([\[\]()])", text) if i] return markdownify_text([], zip(parts, parts[1:] + [None])) + @dataclass class Arbital(AlignmentDataset): - summary_key: str = 'summary' + summary_key: str = "summary" - ARBITAL_SUBSPACES = ['ai_alignment', 'math', 'rationality'] + ARBITAL_SUBSPACES = ["ai_alignment", "math", "rationality"] done_key = "alias" headers = { - 'authority': 'arbital.com', - 'accept': 'application/json, text/plain, */*', - 'content-type': 'application/json;charset=UTF-8', - 'sec-ch-ua-mobile': '?0', - 'origin': 'https://arbital.com', - 'sec-fetch-site': 'same-origin', - 'sec-fetch-mode': 'cors', - 'sec-fetch-dest': 'empty', - 'accept-language': 'en-US,en;q=0.9', + "authority": "arbital.com", + "accept": "application/json, text/plain, */*", + "content-type": "application/json;charset=UTF-8", + "sec-ch-ua-mobile": "?0", + "origin": "https://arbital.com", + "sec-fetch-site": "same-origin", + "sec-fetch-mode": "cors", + "sec-fetch-dest": "empty", + "accept-language": "en-US,en;q=0.9", } titles_map = {} @property def items_list(self): - logger.info('Getting page aliases') - items = [alias for subspace in self.ARBITAL_SUBSPACES for alias in self.get_arbital_page_aliases(subspace)] - logger.info('Got %s page aliases', len(items)) + logger.info("Getting page aliases") + items = [ + alias + for subspace in self.ARBITAL_SUBSPACES + for alias in self.get_arbital_page_aliases(subspace) + ] + logger.info("Got %s page aliases", len(items)) return items def get_item_key(self, item): @@ -120,44 +125,50 @@ def get_item_key(self, item): def process_entry(self, alias): try: page = self.get_page(alias) - summary, text = extract_text(page['text']) - - return self.make_data_entry({ - 'title': page.get('title') or '', - 'text': text, - 'date_published': self._get_published_date(page), - 'url': f'https://arbital.com/p/{page.get("alias") or alias}', - 'source': self.name, - 'source_type': 'text', - 'authors': self.extract_authors(page), - 'alias': alias, - 'tags': list(filter(None, map(self.get_title, page['tagIds']))), - 'summary': summary, - }) + summary, text = extract_text(page["text"]) + + return self.make_data_entry( + { + "title": page.get("title") or "", + "text": text, + "date_published": self._get_published_date(page), + "url": f'https://arbital.com/p/{page.get("alias") or alias}', + "source": self.name, + "source_type": "text", + "authors": self.extract_authors(page), + "alias": alias, + "tags": list(filter(None, map(self.get_title, page["tagIds"]))), + "summary": summary, + } + ) except Exception as e: logger.error(f"Error getting page {alias}: {e}") return None def get_arbital_page_aliases(self, subspace): headers = self.headers.copy() - headers['referer'] = f'https://arbital.com/explore/{subspace}/' + headers["referer"] = f"https://arbital.com/explore/{subspace}/" data = f'{{"pageAlias":"{subspace}"}}' - response = requests.post('https://arbital.com/json/explore/', headers=headers, data=data).json() - return list(response['pages'].keys()) + response = requests.post( + "https://arbital.com/json/explore/", headers=headers, data=data + ).json() + return list(response["pages"].keys()) @staticmethod def _get_published_date(page): - date_published = page.get('editCreatedAt') or page.get('pageCreatedAt') + date_published = page.get("editCreatedAt") or page.get("pageCreatedAt") if date_published: return parse(date_published).astimezone(timezone.utc) return None def get_page(self, alias): headers = self.headers.copy() - headers['referer'] = 'https://arbital.com/' + headers["referer"] = "https://arbital.com/" data = f'{{"pageAlias":"{alias}"}}' - response = requests.post('https://arbital.com/json/primaryPage/', headers=headers, data=data) - return response.json()['pages'][alias] + response = requests.post( + "https://arbital.com/json/primaryPage/", headers=headers, data=data + ) + return response.json()["pages"][alias] def get_title(self, itemId): if title := self.titles_map.get(itemId): @@ -170,7 +181,7 @@ def get_title(self, itemId): logger.error(e) return None - if title := page.get('title'): + if title := page.get("title"): self.titles_map[itemId] = title return title return None @@ -180,6 +191,6 @@ def extract_authors(self, page): This will work faster the more its used, as it only fetches info for authors it hasn't yet seen. """ - authors = {c['userId'] for c in page.get('changeLogs', [])} + authors = {c["userId"] for c in page.get("changeLogs", [])} return list(filter(None, map(self.get_title, authors))) diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py index 6775e496..01c5521f 100644 --- a/align_data/sources/articles/__init__.py +++ b/align_data/sources/articles/__init__.py @@ -1,36 +1,41 @@ from align_data.sources.articles.datasets import ( - EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, XMLArticles + EbookArticles, + DocArticles, + HTMLArticles, + MarkdownArticles, + PDFArticles, + XMLArticles, ) ARTICLES_REGISTRY = [ PDFArticles( - name='pdfs', - spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4', - sheet_id='0' + name="pdfs", + spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4", + sheet_id="0", ), HTMLArticles( - name='html_articles', - spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4', - sheet_id='759210636' + name="html_articles", + spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4", + sheet_id="759210636", ), EbookArticles( - name='ebooks', - spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4', - sheet_id='1800487220' + name="ebooks", + spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4", + sheet_id="1800487220", ), XMLArticles( - name='xmls', - spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4', - sheet_id='823056509' + name="xmls", + spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4", + sheet_id="823056509", ), MarkdownArticles( - name='markdown', - spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4', - sheet_id='1003473759' + name="markdown", + spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4", + sheet_id="1003473759", ), DocArticles( - name='gdocs', - spreadsheet_id='1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4', - sheet_id='1293295703' + name="gdocs", + spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4", + sheet_id="1293295703", ), ] diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py index c7f82e65..941de037 100644 --- a/align_data/sources/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -4,7 +4,14 @@ from tqdm import tqdm import gspread -from align_data.sources.articles.google_cloud import iterate_rows, get_spreadsheet, get_sheet, upload_file, OK, with_retry +from align_data.sources.articles.google_cloud import ( + iterate_rows, + get_spreadsheet, + get_sheet, + upload_file, + OK, + with_retry, +) from align_data.sources.articles.parsers import item_metadata, fetch from align_data.sources.articles.indices import fetch_all from align_data.sources.articles.html import with_retry @@ -15,8 +22,8 @@ # Careful changing these - the sheets assume this ordering -REQUIRED_FIELDS = ['url', 'source_url', 'title', 'source_type', 'date_published'] -OPTIONAL_FIELDS = ['authors', 'summary'] +REQUIRED_FIELDS = ["url", "source_url", "title", "source_type", "date_published"] +OPTIONAL_FIELDS = ["authors", "summary"] def save_pdf(filename, link): @@ -27,47 +34,47 @@ def save_pdf(filename, link): :returns: the google drive id of the resulting pdf file """ res = fetch(link) - if not filename.lower().endswith('.pdf'): - filename += '.pdf' + if not filename.lower().endswith(".pdf"): + filename += ".pdf" return upload_file( filename, bytes_contents=io.BytesIO(res.content), - mimetype=res.headers.get('Content-Type'), - parent_id=PDFS_FOLDER_ID + mimetype=res.headers.get("Content-Type"), + parent_id=PDFS_FOLDER_ID, ) @with_retry(times=3, exceptions=gspread.exceptions.APIError) def process_row(row, sheets): """Check the given `row` and fetch its metadata + optional extra stuff.""" - logger.info('Checking "%s"', row['title']) + logger.info('Checking "%s"', row["title"]) missing = [field for field in REQUIRED_FIELDS if not row.get(field)] if missing: - row.set_status('missing keys: ' + ', '.join(missing)) - logger.error('missing keys: ' + ', '.join(missing)) + row.set_status("missing keys: " + ", ".join(missing)) + logger.error("missing keys: " + ", ".join(missing)) return - source_url = row.get('source_url') + source_url = row.get("source_url") contents = item_metadata(source_url) - if not contents or 'error' in contents: - error = (contents and contents.get('error')) or 'text could not be fetched' + if not contents or "error" in contents: + error = (contents and contents.get("error")) or "text could not be fetched" logger.error(error) row.set_status(error) return - data_source = contents.get('data_source') + data_source = contents.get("data_source") if data_source not in sheets: - error = 'Unhandled data type' + error = "Unhandled data type" logger.error(error) row.set_status(error) return extra_fields = [] - if data_source == 'pdf': - extra_fields = [save_pdf(row['title'], source_url)] + if data_source == "pdf": + extra_fields = [save_pdf(row["title"], source_url)] sheets[data_source].append_row( [row.get(field) for field in REQUIRED_FIELDS + OPTIONAL_FIELDS] + extra_fields @@ -83,19 +90,19 @@ def process_spreadsheets(source_sheet, output_sheets): :param Worksheet source_sheet: the worksheet to be processed - each row should be a separate entry :param Dict[str, Worksheet] output_sheets: a dict of per data type worksheets to be updated """ - logger.info('fetching seen urls') + logger.info("fetching seen urls") seen = { url for sheet in output_sheets.values() for record in sheet.get_all_records() - for url in [record.get('url'), record.get('source_url')] + for url in [record.get("url"), record.get("source_url")] if url } for row in tqdm(iterate_rows(source_sheet)): - if not row.get('source_url'): - row['source_url'] = row['url'] - if row.get('source_url') in seen: - title = row.get('title') + if not row.get("source_url"): + row["source_url"] = row["url"] + if row.get("source_url") in seen: + title = row.get("title") logger.info(f'skipping "{title}", as it has already been seen') else: process_row(row, output_sheets) @@ -104,31 +111,48 @@ def process_spreadsheets(source_sheet, output_sheets): def update_new_items(source_spreadsheet, source_sheet, output_spreadsheet): """Go through all unprocessed items from the source worksheet, updating the appropriate metadata in the output one.""" source_sheet = get_sheet(source_spreadsheet, source_sheet) - sheets = {sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets()} + sheets = { + sheet.title: sheet for sheet in get_spreadsheet(output_spreadsheet).worksheets() + } return process_spreadsheets(source_sheet, sheets) def check_new_articles(source_spreadsheet, source_sheet): """Goes through the special indices looking for unseen articles.""" source_sheet = get_sheet(source_spreadsheet, source_sheet) - current = {row.get('title'): row for row in iterate_rows(source_sheet)} - seen_urls = {url for item in current.values() for url in [item.get('url'), item.get('source_url')] if url} + current = {row.get("title"): row for row in iterate_rows(source_sheet)} + seen_urls = { + url + for item in current.values() + for url in [item.get("url"), item.get("source_url")] + if url + } indices_items = fetch_all() missing = [ - item for title, item in indices_items.items() - if title not in current and not {item.get('url'), item.get('source_url')} & seen_urls + item + for title, item in indices_items.items() + if title not in current + and not {item.get("url"), item.get("source_url")} & seen_urls ] if not missing: - logger.info('No new articles found') + logger.info("No new articles found") return 0 - columns = ['status', 'source_url', 'url', 'title', 'date_published', 'authors', 'publication_title', 'source_type'] - res = source_sheet.append_rows([ - [item.get(col) for col in columns] - for item in missing - ]) - updated = res['updates']['updatedRows'] - logger.info('Added %s rows', updated) + columns = [ + "status", + "source_url", + "url", + "title", + "date_published", + "authors", + "publication_title", + "source_type", + ] + res = source_sheet.append_rows( + [[item.get(col) for col in columns] for item in missing] + ) + updated = res["updates"]["updatedRows"] + logger.info("Added %s rows", updated) return updated diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index a6328f42..2f25a606 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -19,7 +19,6 @@ @dataclass class SpreadsheetDataset(AlignmentDataset): - spreadsheet_id: str sheet_id: str done_key = "title" @@ -40,9 +39,15 @@ def is_val(val): @property def items_list(self): - logger.info(f'Fetching https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=CS&gid={self.sheet_id}') - df = pd.read_csv(f'https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}') - return (item for item in df.itertuples() if not pd.isna(self.get_item_key(item))) + logger.info( + f"Fetching https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=CS&gid={self.sheet_id}" + ) + df = pd.read_csv( + f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}" + ) + return ( + item for item in df.itertuples() if not pd.isna(self.get_item_key(item)) + ) def get_item_key(self, item): return getattr(item, self.done_key) @@ -55,30 +60,31 @@ def _get_text(item): def extract_authors(item): if not SpreadsheetDataset.maybe(item.authors): return [] - return [author.strip() for author in item.authors.split(',') if author.strip()] + return [author.strip() for author in item.authors.split(",") if author.strip()] def process_entry(self, item): text = self._get_text(item) if not text: - logger.error('Could not get text for %s - skipping for now', item.title) + logger.error("Could not get text for %s - skipping for now", item.title) return None - return self.make_data_entry({ - 'text': markdownify(text).strip(), - 'url': self.maybe(item.url), - 'title': self.maybe(item.title), - 'source': self.name, - 'source_type': self.maybe(item.source_type), - 'source_filetype': self.source_filetype, - 'date_published': self._get_published_date(item.date_published), - 'authors': self.extract_authors(item), - 'summary': self.maybe(item.summary), - }) + return self.make_data_entry( + { + "text": markdownify(text).strip(), + "url": self.maybe(item.url), + "title": self.maybe(item.title), + "source": self.name, + "source_type": self.maybe(item.source_type), + "source_filetype": self.source_filetype, + "date_published": self._get_published_date(item.date_published), + "authors": self.extract_authors(item), + "summary": self.maybe(item.summary), + } + ) class PDFArticles(SpreadsheetDataset): - - source_filetype = 'pdf' + source_filetype = "pdf" COOLDOWN = 1 batch_size = 1 @@ -87,28 +93,26 @@ def setup(self): self.files_path.mkdir(exist_ok=True, parents=True) def _get_text(self, item): - url = f'https://drive.google.com/uc?id={item.file_id}' + url = f"https://drive.google.com/uc?id={item.file_id}" - filename = self.files_path / f'{item.title}.pdf' + filename = self.files_path / f"{item.title}.pdf" if download(output=str(filename), id=item.file_id): return read_pdf(filename) class HTMLArticles(SpreadsheetDataset): - - source_filetype = 'html' + source_filetype = "html" @staticmethod def _get_text(item): - domain = urlparse(item.source_url).netloc.lstrip('www.') + domain = urlparse(item.source_url).netloc.lstrip("www.") if parser := HTML_PARSERS.get(domain): return parser(item.source_url) class EbookArticles(SpreadsheetDataset): - - source_filetype = 'epub' - COOLDOWN = 10 # Add a large cooldown, as google complains a lot + source_filetype = "epub" + COOLDOWN = 10 # Add a large cooldown, as google complains a lot batch_size = 1 def setup(self): @@ -116,44 +120,43 @@ def setup(self): self.files_path.mkdir(exist_ok=True, parents=True) def _get_text(self, item): - file_id = item.source_url.split('/')[-2] - filename = download(output=str(self.files_path / f'{item.title}.epub'), id=file_id) - return convert_file(filename, "plain",'epub', extra_args=['--wrap=none']) + file_id = item.source_url.split("/")[-2] + filename = download( + output=str(self.files_path / f"{item.title}.epub"), id=file_id + ) + return convert_file(filename, "plain", "epub", extra_args=["--wrap=none"]) class XMLArticles(SpreadsheetDataset): - - source_filetype = 'xml' + source_filetype = "xml" def _get_text(self, item): vals = extract_gdrive_contents(item.source_url) - return vals['text'] + return vals["text"] class MarkdownArticles(SpreadsheetDataset): - - source_filetype = 'md' + source_filetype = "md" def _get_text(self, item): - file_id = item.source_url.split('/')[-2] + file_id = item.source_url.split("/")[-2] vals = fetch_markdown(file_id) - return vals['text'] + return vals["text"] class DocArticles(SpreadsheetDataset): - - source_filetype = 'docx' + source_filetype = "docx" def setup(self): super().setup() self.files_path.mkdir(exist_ok=True, parents=True) def _get_text(self, item): - pandoc_path = Path('data/raw/pandoc/pandoc/') + pandoc_path = Path("data/raw/pandoc/pandoc/") if pandoc_path.exists(): logger.info("Make sure pandoc is configured correctly.") os.environ.setdefault("PYPANDOC_PANDOC", str(pandoc_path)) - file_id = item.source_url.split('/')[-2] + file_id = item.source_url.split("/")[-2] file_name = fetch_file(file_id) - return convert_file(file_name, "md", format='docx', extra_args=['--wrap=none']) + return convert_file(file_name, "md", format="docx", extra_args=["--wrap=none"]) diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py index 36946e89..d0f8646a 100644 --- a/align_data/sources/articles/google_cloud.py +++ b/align_data/sources/articles/google_cloud.py @@ -13,17 +13,17 @@ SCOPES = [ - 'https://www.googleapis.com/auth/spreadsheets', - 'https://www.googleapis.com/auth/drive' + "https://www.googleapis.com/auth/spreadsheets", + "https://www.googleapis.com/auth/drive", ] -OK = 'ok' -OUTPUT_SPREADSHEET_ID = '1bg-6vL-I82CBRkxvWQs1-Ao0nTvHyfn4yns5MdlbCmY' -sheet_name = 'Sheet1' +OK = "ok" +OUTPUT_SPREADSHEET_ID = "1bg-6vL-I82CBRkxvWQs1-Ao0nTvHyfn4yns5MdlbCmY" +sheet_name = "Sheet1" -def get_credentials(credentials_file='credentials.json'): +def get_credentials(credentials_file="credentials.json"): return Credentials.from_service_account_file(credentials_file, scopes=SCOPES) @@ -53,20 +53,20 @@ def update_value(self, col, value): self.sheet.update_cell(self.row_id, self.columns.index(col) + 1, value) def update_colour(self, col, colour): - col_letter = chr(ord('A') + self.columns.index(col)) - self.sheet.format(f'{col_letter}{self.row_id}', {"backgroundColor": colour}) + col_letter = chr(ord("A") + self.columns.index(col)) + self.sheet.format(f"{col_letter}{self.row_id}", {"backgroundColor": colour}) - def set_status(self, status, status_col='status'): + def set_status(self, status, status_col="status"): if self.get(status_col) == status: # Don't update anything if the status is the same - this saves on gdocs calls return if status == OK: - colour = {'red': 0, 'green': 1, 'blue': 0} - elif status == '': - colour = {'red': 1, 'green': 1, 'blue': 1} + colour = {"red": 0, "green": 1, "blue": 0} + elif status == "": + colour = {"red": 1, "green": 1, "blue": 1} else: - colour = {'red': 1, 'green': 0, 'blue': 0} + colour = {"red": 1, "green": 0, "blue": 0} self.update_value(status_col, status) self.update_colour(status_col, colour) @@ -91,37 +91,41 @@ def upload_file(filename, bytes_contents, mimetype, parent_id=None): """ credentials = get_credentials() - drive_service = build('drive', 'v3', credentials=credentials) + drive_service = build("drive", "v3", credentials=credentials) - file_metadata = { - 'name': filename, - 'parents': parent_id and [parent_id] - } - media = drive_service.files().create( - body=file_metadata, - media_body=MediaIoBaseUpload(bytes_contents, mimetype=mimetype) - ).execute() - return media.get('id') + file_metadata = {"name": filename, "parents": parent_id and [parent_id]} + media = ( + drive_service.files() + .create( + body=file_metadata, + media_body=MediaIoBaseUpload(bytes_contents, mimetype=mimetype), + ) + .execute() + ) + return media.get("id") def with_retry(times=3): """A decorator that will retry the wrapped function up to `times` times in case of google sheets errors.""" + def wrapper(f): def retrier(*args, **kwargs): for i in range(times): try: return f(*args, **kwargs) except gspread.exceptions.APIError as e: - logger.error(f'{e} - retrying up to {times - i} times') + logger.error(f"{e} - retrying up to {times - i} times") # Do a logarithmic backoff time.sleep((i + 1) ** 2) - raise ValueError(f'Gave up after {times} tries') + raise ValueError(f"Gave up after {times} tries") + return retrier + return wrapper def fetch_file(file_id): - data_path = Path('data/raw/') + data_path = Path("data/raw/") data_path.mkdir(parents=True, exist_ok=True) file_name = data_path / file_id return gdown.download(id=file_id, output=str(file_name), quiet=False) @@ -131,8 +135,8 @@ def fetch_markdown(file_id): try: file_name = fetch_file(file_id) return { - 'text': Path(file_name).read_text(), - 'data_source': 'markdown', + "text": Path(file_name).read_text(), + "data_source": "markdown", } except Exception as e: - return {'error': str(e)} + return {"error": str(e)} diff --git a/align_data/sources/articles/html.py b/align_data/sources/articles/html.py index 152ea8dc..80b5dc8b 100644 --- a/align_data/sources/articles/html.py +++ b/align_data/sources/articles/html.py @@ -10,27 +10,30 @@ DEFAULT_HEADERS = { - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0', + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0", } def with_retry(times=3, exceptions=requests.exceptions.RequestException): """A decorator that will retry the wrapped function up to `times` times in case of google sheets errors.""" + def wrapper(f): def retrier(*args, **kwargs): for i in range(times): try: return f(*args, **kwargs) except exceptions as e: - logger.error(f'{e} - retrying up to {times - i} times') + logger.error(f"{e} - retrying up to {times - i} times") # Do a logarithmic backoff time.sleep((i + 1) ** 2) - raise ValueError(f'Gave up after {times} tries') + raise ValueError(f"Gave up after {times} tries") + return retrier + return wrapper -def fetch(url, method='get', headers=DEFAULT_HEADERS): +def fetch(url, method="get", headers=DEFAULT_HEADERS): """Fetch the given `url`. This function is to have a single place to manage headers etc. @@ -43,7 +46,7 @@ def fetch_element(url: str, selector: str, headers=DEFAULT_HEADERS) -> Union[Tag try: resp = fetch(url, headers=headers) except requests.exceptions.ConnectionError: - logger.error('Could not connect to %s', url) + logger.error("Could not connect to %s", url) return None soup = BeautifulSoup(resp.content, "html.parser") @@ -57,6 +60,7 @@ def element_extractor(selector, remove=[]): :param List[str] remove: An optional list of selectors to be removed from the resulting HTML. Useful for removing footers etc. :returns: A function that expects to get an URL, and which will then return the contents of the selected HTML element as markdown. """ + def getter(url): elem = fetch_element(url, selector) if not elem: diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py index 6eb5761c..0cf2c45a 100644 --- a/align_data/sources/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -11,7 +11,7 @@ def get_text(tag, selector: str) -> str: if item := tag.select_one(selector): return item.text - return '' + return "" def indice_fetcher(url, main_selector, item_selector, formatter): @@ -19,107 +19,114 @@ def fetcher(): if contents := fetch_element(url, main_selector): return list(filter(None, map(formatter, contents.select(item_selector)))) return [] + return fetcher def reading_what_we_can_items(): - res = fetch('https://readingwhatwecan.com/books.js') + res = fetch("https://readingwhatwecan.com/books.js") items = { item - for section in re.findall(r'\[(.*?)\]', res.text, re.DOTALL) - for item in re.findall(r'Name: "(.*?)",.*?Link: "(.*?)",.*?Author: "(.*?)"', section, re.DOTALL) + for section in re.findall(r"\[(.*?)\]", res.text, re.DOTALL) + for item in re.findall( + r'Name: "(.*?)",.*?Link: "(.*?)",.*?Author: "(.*?)"', section, re.DOTALL + ) } - return [{ - 'title': item[0], - 'url': item[1], - 'authors': item[2] - } for item in items] + return [{"title": item[0], "url": item[1], "authors": item[2]} for item in items] def aisafetysupport(): - contents = fetch_element('https://www.aisafetysupport.org/resources/lots-of-links', 'header + div') - sections = ['Research Maps and Reviews', 'Research Agendas', 'Books, papers, podcasts, videos'] - sections = [s for s in contents.select('section') if get_text(s, 'h2') in sections] + contents = fetch_element( + "https://www.aisafetysupport.org/resources/lots-of-links", "header + div" + ) + sections = [ + "Research Maps and Reviews", + "Research Agendas", + "Books, papers, podcasts, videos", + ] + sections = [s for s in contents.select("section") if get_text(s, "h2") in sections] return [ - {'title': a.text, 'url': a.get('href')} + {"title": a.text, "url": a.get("href")} for section in sections - for a in section.select('a') - if a.text and a.get('href').startswith('http') + for a in section.select("a") + if a.text and a.get("href").startswith("http") ] def format_mlsafety_course(a): - if (a.get('href') or '').startswith('http'): - return {'title': a.text, 'url': a.get('href')} + if (a.get("href") or "").startswith("http"): + return {"title": a.text, "url": a.get("href")} def format_anthropic(post): - if date_published := parse(get_text(post, 'div.post-date')): + if date_published := parse(get_text(post, "div.post-date")): date_published = AlignmentDataset._format_datetime(date_published) - url = post.get('href') + url = post.get("href") - if source_url := fetch_element(url, 'article .post-heading a.btn-primary'): - source_url = source_url.get('href') + if source_url := fetch_element(url, "article .post-heading a.btn-primary"): + source_url = source_url.get("href") return { - 'title': get_text(post, 'div.post-heading'), - 'url': url, - 'source_url': source_url, - 'date_published': date_published, + "title": get_text(post, "div.post-heading"), + "url": url, + "source_url": source_url, + "date_published": date_published, } def format_transformer_circuits(item): - if not item.get('href').startswith('http'): + if not item.get("href").startswith("http"): url = f'https://transformer-circuits.pub/{item.get("href")}' return { - 'title': get_text(item, 'h3'), - 'url': url, - 'source_url': url, + "title": get_text(item, "h3"), + "url": url, + "source_url": url, } def format_safe_ai(item): return { - 'title': get_text(item, 'h4'), - 'url': item.find('a').get('href'), - 'source_url': item.find('a').get('href'), - 'authors': get_text(item, 'h4 ~ p') + "title": get_text(item, "h4"), + "url": item.find("a").get("href"), + "source_url": item.find("a").get("href"), + "authors": get_text(item, "h4 ~ p"), } def format_far_ai(item): return { - 'title': get_text(item, '.article-title'), - 'url': f'https://www.safe.ai/research{item.select_one(".article-title a").get("href")}', - 'source_url': item.select_one('div.btn-links a:-soup-contains("PDF")').get('href'), - 'authors': ', '.join(i.text for i in item.select('.article-metadata a')), + "title": get_text(item, ".article-title"), + "url": f'https://www.safe.ai/research{item.select_one(".article-title a").get("href")}', + "source_url": item.select_one('div.btn-links a:-soup-contains("PDF")').get( + "href" + ), + "authors": ", ".join(i.text for i in item.select(".article-metadata a")), } def format_redwoodresearch(item): - url = item.select_one('.list-item-content__button-container a').get('href') - authors = get_text(item, 'em') + url = item.select_one(".list-item-content__button-container a").get("href") + authors = get_text(item, "em") try: - parts = authors.split(', ') + parts = authors.split(", ") date_published = parse(parts[-1]) date_published = AlignmentDataset._format_datetime(date_published) - authors = ', '.join(parts[:-1]) + authors = ", ".join(parts[:-1]) except ParserError: date_published = None return { - 'title': get_text(item, 'h2'), - 'url': url, - 'source_url': url, - 'authors': authors, - 'date_published': date_published, + "title": get_text(item, "h2"), + "url": url, + "source_url": url, + "authors": authors, + "date_published": date_published, } def format_chai_research(item): - author_block = next(item.children).strip().strip('.') - authors = parts = author_block.split('.') + author_block = next(item.children).strip().strip(".") + authors = parts = author_block.split(".") try: int(parts[-1].strip()) date_published = parts[-1].strip() @@ -127,47 +134,47 @@ def format_chai_research(item): except ValueError: date_published = None - url = item.select_one('a').get('href') + url = item.select_one("a").get("href") return { - 'title': get_text(item, 'a'), - 'url': url, - 'source_url': url, - 'authors': ', '.join(authors), - 'date_published': date_published, + "title": get_text(item, "a"), + "url": url, + "source_url": url, + "authors": ", ".join(authors), + "date_published": date_published, } def format_chai_bibliography(item): - return { - 'title': get_text(item, '.bib-entry-title a'), - 'url': item.select_one('.bib-entry-title a').get('href'), - 'authors': item.select_one('.bib-entry-title a').next_sibling.strip(',. ') - } + return { + "title": get_text(item, ".bib-entry-title a"), + "url": item.select_one(".bib-entry-title a").get("href"), + "authors": item.select_one(".bib-entry-title a").next_sibling.strip(",. "), + } def format_chai_newsletter(item): - if item.text.strip().startswith('CHAI Newsletter'): + if item.text.strip().startswith("CHAI Newsletter"): return { - 'title': item.text, - 'url': item.get('href'), - 'source_url': item.get('href'), + "title": item.text, + "url": item.get("href"), + "source_url": item.get("href"), } def format_neel_nanda_fav(item): - url = item.find('a').get('href').strip() - if not url.startswith('http'): + url = item.find("a").get("href").strip() + if not url.startswith("http"): return None try: - title = item.find('p').extract().text + title = item.find("p").extract().text except: - title = get_text(item, 'a') + title = get_text(item, "a") return { - 'title': title.replace('\n', ' '), - 'url': url, - 'summary': MarkdownConverter().convert_soup(item).strip() + "title": title.replace("\n", " "), + "url": url, + "summary": MarkdownConverter().convert_soup(item).strip(), } @@ -175,20 +182,70 @@ def fetch_all(): fetchers = [ reading_what_we_can_items, aisafetysupport, - indice_fetcher('https://www.neelnanda.io/mechanistic-interpretability/favourite-papers', 'article', 'div > ul > li', format_neel_nanda_fav), - indice_fetcher('https://course.mlsafety.org/readings/', 'div.main-content', 'a', format_mlsafety_course), - indice_fetcher('https://www.anthropic.com/research', 'div.b-postList', 'a', format_anthropic), - indice_fetcher('https://transformer-circuits.pub/', 'div.toc', 'a', format_transformer_circuits), - indice_fetcher('https://www.safe.ai/research', '#guiding-principles', 'div.card.is-document', format_safe_ai), - indice_fetcher('https://far.ai/publication/', '#container-publications', 'div.media-body', format_far_ai), - indice_fetcher('https://www.redwoodresearch.org/research', 'article', '.list-item', format_redwoodresearch), - indice_fetcher('https://humancompatible.ai/research', 'article', '.publications li', format_chai_research), - indice_fetcher('https://humancompatible.ai/bibliography', '#content', '.bib-entry', format_chai_bibliography), - indice_fetcher('https://humancompatible.ai/newsletter/', 'article', 'a', format_chai_newsletter), + indice_fetcher( + "https://www.neelnanda.io/mechanistic-interpretability/favourite-papers", + "article", + "div > ul > li", + format_neel_nanda_fav, + ), + indice_fetcher( + "https://course.mlsafety.org/readings/", + "div.main-content", + "a", + format_mlsafety_course, + ), + indice_fetcher( + "https://www.anthropic.com/research", + "div.b-postList", + "a", + format_anthropic, + ), + indice_fetcher( + "https://transformer-circuits.pub/", + "div.toc", + "a", + format_transformer_circuits, + ), + indice_fetcher( + "https://www.safe.ai/research", + "#guiding-principles", + "div.card.is-document", + format_safe_ai, + ), + indice_fetcher( + "https://far.ai/publication/", + "#container-publications", + "div.media-body", + format_far_ai, + ), + indice_fetcher( + "https://www.redwoodresearch.org/research", + "article", + ".list-item", + format_redwoodresearch, + ), + indice_fetcher( + "https://humancompatible.ai/research", + "article", + ".publications li", + format_chai_research, + ), + indice_fetcher( + "https://humancompatible.ai/bibliography", + "#content", + ".bib-entry", + format_chai_bibliography, + ), + indice_fetcher( + "https://humancompatible.ai/newsletter/", + "article", + "a", + format_chai_newsletter, + ), ] articles = defaultdict(dict) for func in tqdm(fetchers): for item in func(): - articles[item['title']].update(item) + articles[item["title"]].update(item) return articles diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py index 3e8c60bc..f8708b5f 100644 --- a/align_data/sources/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -5,7 +5,13 @@ import grobid_tei_xml import regex as re from align_data.sources.articles.html import element_extractor, fetch, fetch_element -from align_data.sources.articles.pdf import doi_getter, fetch_pdf, get_pdf_from_page, get_arxiv_pdf, parse_vanity +from align_data.sources.articles.pdf import ( + doi_getter, + fetch_pdf, + get_pdf_from_page, + get_arxiv_pdf, + parse_vanity, +) from align_data.sources.articles.google_cloud import fetch_markdown from markdownify import MarkdownConverter from bs4 import BeautifulSoup @@ -15,12 +21,14 @@ def google_doc(url: str) -> str: """Fetch the contents of the given gdoc url as markdown.""" - res = re.search(r'https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/', url) + res = re.search(r"https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/", url) if not res: return None doc_id = res.group(1) - body = fetch_element(f'https://docs.google.com/document/d/{doc_id}/export?format=html', 'body') + body = fetch_element( + f"https://docs.google.com/document/d/{doc_id}/export?format=html", "body" + ) if body: return MarkdownConverter().convert_soup(body).strip() @@ -28,12 +36,12 @@ def google_doc(url: str) -> str: def medium_blog(url): """Return the contents of the medium article at the given URL as markdown.""" # Medium does some magic redirects if it detects that the request is from firefox - article = fetch_element(url, 'article', headers=None) + article = fetch_element(url, "article", headers=None) if not article: return None # remove the header - if title := article.find('h1'): + if title := article.find("h1"): title.parent.extract() return MarkdownConverter().convert_soup(article).strip() @@ -41,12 +49,15 @@ def medium_blog(url): def parse_grobid(contents): doc_dict = grobid_tei_xml.parse_document_xml(contents).to_dict() - authors = [xx["full_name"].strip(' !') for xx in doc_dict.get("header", {}).get("authors", [])] + authors = [ + xx["full_name"].strip(" !") + for xx in doc_dict.get("header", {}).get("authors", []) + ] - if not doc_dict.get('body'): + if not doc_dict.get("body"): return { - 'error': 'No contents in XML file', - 'data_source': 'xml', + "error": "No contents in XML file", + "data_source": "xml", } return { @@ -59,218 +70,268 @@ def parse_grobid(contents): def get_content_type(res): - header = res.headers.get('Content-Type') or '' - parts = [c_type.strip().lower() for c_type in header.split(';')] + header = res.headers.get("Content-Type") or "" + parts = [c_type.strip().lower() for c_type in header.split(";")] return set(filter(None, parts)) def extract_gdrive_contents(link): - file_id = link.split('/')[-2] - url = f'https://drive.google.com/uc?id={file_id}' - res = fetch(url, 'head') + file_id = link.split("/")[-2] + url = f"https://drive.google.com/uc?id={file_id}" + res = fetch(url, "head") if res.status_code == 403: - logger.error('Could not fetch the file at %s - 403 returned', link) - return {'error': 'Could not read file from google drive - forbidden'} + logger.error("Could not fetch the file at %s - 403 returned", link) + return {"error": "Could not read file from google drive - forbidden"} if res.status_code >= 400: - logger.error('Could not fetch the file at %s - are you sure that link is correct?', link) - return {'error': 'Could not read file from google drive'} + logger.error( + "Could not fetch the file at %s - are you sure that link is correct?", link + ) + return {"error": "Could not read file from google drive"} result = { - 'source_url': link, - 'downloaded_from': 'google drive', + "source_url": link, + "downloaded_from": "google drive", } content_type = get_content_type(res) if not content_type: - result['error'] = 'no content type' - elif content_type & {'application/octet-stream', 'application/pdf'}: + result["error"] = "no content type" + elif content_type & {"application/octet-stream", "application/pdf"}: result.update(fetch_pdf(url)) - elif content_type & {'text/markdown'}: + elif content_type & {"text/markdown"}: result.update(fetch_markdown(file_id)) - elif content_type & {'application/epub+zip', 'application/epub'}: - result['data_source'] = 'ebook' - elif content_type & {'text/html'}: + elif content_type & {"application/epub+zip", "application/epub"}: + result["data_source"] = "ebook" + elif content_type & {"text/html"}: res = fetch(url) - if 'Google Drive - Virus scan warning' in res.text: - element_extractor('form') + if "Google Drive - Virus scan warning" in res.text: + element_extractor("form") soup = BeautifulSoup(res.content, "html.parser") - res = fetch(soup.select_one('form').get('action')) + res = fetch(soup.select_one("form").get("action")) content_type = get_content_type(res) - if content_type & {'text/xml'}: + if content_type & {"text/xml"}: result.update(parse_grobid(res.content)) - elif content_type & {'text/html'}: + elif content_type & {"text/html"}: soup = BeautifulSoup(res.content, "html.parser") - result.update({ - 'text': MarkdownConverter().convert_soup(soup.select_one('body')).strip(), - 'data_source': 'html', - }) + result.update( + { + "text": MarkdownConverter() + .convert_soup(soup.select_one("body")) + .strip(), + "data_source": "html", + } + ) else: - result['error'] = f'unknown content type: {content_type}' + result["error"] = f"unknown content type: {content_type}" else: - result['error'] = f'unknown content type: {content_type}' + result["error"] = f"unknown content type: {content_type}" return result def error(error_msg): """Returns a url handler function that just logs the provided `error` string.""" + def func(url): if error_msg: logger.error(error_msg) return error_msg + return func def multistrategy(*funcs): """Merges multiple getter functions, returning the result of the first function call to succeed.""" + def getter(url): for func in funcs: res = func(url) - if res and 'error' not in res: + if res and "error" not in res: return res + return getter UNIMPLEMENTED_PARSERS = { # Unhandled items that will be caught later. Though it would be good for them also to be done properly - 'oxford.universitypressscholarship.com': error(''), - + "oxford.universitypressscholarship.com": error(""), # Paywalled journal - 'linkinghub.elsevier.com': error('Elsevier is a known parasite - no point in looking to them for content'), - 'link.springer.com': error('This article looks paywalled'), - 'www.dl.begellhouse.com': error('This article is paywalled'), - + "linkinghub.elsevier.com": error( + "Elsevier is a known parasite - no point in looking to them for content" + ), + "link.springer.com": error("This article looks paywalled"), + "www.dl.begellhouse.com": error("This article is paywalled"), # To be implemented - 'goodreads.com': error('Ebooks are not yet handled'), - 'judiciary.senate.gov': error(''), - 'taylorfrancis.com': error('Ebooks are not yet handled'), - 'YouTube.com': error('Youtube videos are not yet handled'), - 'researchgate.net': error('Researchgate makes it hard to auto download pdf - please provide a DOI or a different url to the contents'), - 'repository.cam.ac.uk': error(''), + "goodreads.com": error("Ebooks are not yet handled"), + "judiciary.senate.gov": error(""), + "taylorfrancis.com": error("Ebooks are not yet handled"), + "YouTube.com": error("Youtube videos are not yet handled"), + "researchgate.net": error( + "Researchgate makes it hard to auto download pdf - please provide a DOI or a different url to the contents" + ), + "repository.cam.ac.uk": error(""), } HTML_PARSERS = { - 'academic.oup.com': element_extractor('#ContentTab'), - 'ai.googleblog.com': element_extractor('div.post-body.entry-content'), - 'arxiv-vanity.com': parse_vanity, - 'ar5iv.labs.arxiv.org': parse_vanity, - 'bair.berkeley.edu': element_extractor('article'), - 'mediangroup.org': element_extractor('div.entry-content'), - 'www.alexirpan.com': element_extractor('article'), - 'www.incompleteideas.net': element_extractor('body'), - 'ai-alignment.com': medium_blog, - 'aisrp.org': element_extractor('article'), - 'bounded-regret.ghost.io': element_extractor('div.post-content'), - 'carnegieendowment.org': element_extractor('div.article-body', remove=['.no-print', '.related-pubs']), - 'casparoesterheld.com': element_extractor('.entry-content', remove=['div.sharedaddy']), - 'cullenokeefe.com': element_extractor('div.sqs-block-content'), - 'deepmindsafetyresearch.medium.com': medium_blog, - 'docs.google.com': google_doc, - 'docs.microsoft.com': element_extractor('div.content'), - 'digichina.stanford.edu': element_extractor('div.h_editor-content'), - 'en.wikipedia.org': element_extractor('main.mw-body'), - 'eng.uber.com': element_extractor('div.article-body'), - 'futureoflife.org': multistrategy( - element_extractor('div.body-content'), - element_extractor('#main-content'), + "academic.oup.com": element_extractor("#ContentTab"), + "ai.googleblog.com": element_extractor("div.post-body.entry-content"), + "arxiv-vanity.com": parse_vanity, + "ar5iv.labs.arxiv.org": parse_vanity, + "bair.berkeley.edu": element_extractor("article"), + "mediangroup.org": element_extractor("div.entry-content"), + "www.alexirpan.com": element_extractor("article"), + "www.incompleteideas.net": element_extractor("body"), + "ai-alignment.com": medium_blog, + "aisrp.org": element_extractor("article"), + "bounded-regret.ghost.io": element_extractor("div.post-content"), + "carnegieendowment.org": element_extractor( + "div.article-body", remove=[".no-print", ".related-pubs"] + ), + "casparoesterheld.com": element_extractor( + ".entry-content", remove=["div.sharedaddy"] + ), + "cullenokeefe.com": element_extractor("div.sqs-block-content"), + "deepmindsafetyresearch.medium.com": medium_blog, + "docs.google.com": google_doc, + "docs.microsoft.com": element_extractor("div.content"), + "digichina.stanford.edu": element_extractor("div.h_editor-content"), + "en.wikipedia.org": element_extractor("main.mw-body"), + "eng.uber.com": element_extractor("div.article-body"), + "futureoflife.org": multistrategy( + element_extractor("div.body-content"), + element_extractor("#main-content"), ), - 'gcrinstitute.org': element_extractor('div.blog-content'), - 'jbkjr.me': element_extractor('section.page__content'), - 'link.springer.com': element_extractor('article.c-article-body'), - 'longtermrisk.org': element_extractor('div.entry-content'), - 'lukemuehlhauser.com': element_extractor('div.entry-content'), - 'medium.com': medium_blog, - 'openai.com': element_extractor('#content'), - 'ought.org': element_extractor('div.BlogPostBodyContainer'), - 'sideways-view.com': element_extractor('article', remove=['header']), - 'slatestarcodex.com': element_extractor('div.pjgm-postcontent'), - 'techpolicy.press': element_extractor('div.post-content', remove=['div.before_content', '.sabox-guest-authors-container', '.jp-relatedposts']), - 'theconversation.com': element_extractor('div.content-body'), - 'thegradient.pub': element_extractor('div.c-content'), - 'towardsdatascience.com': medium_blog, - 'unstableontology.com': element_extractor('.entry-content', remove=['div.sharedaddy']), - 'waitbutwhy.com': element_extractor('article', remove=['.entry-header']), - 'weightagnostic.github.io': element_extractor('dt-article', remove=['#authors_section', 'dt-byline']), - 'cnas.org': element_extractor('#mainbar-toc'), - 'econlib.org': element_extractor('div.post-content'), - 'humanityplus.org': element_extractor('div.content'), - 'gleech.org': element_extractor('article.post-content', remove=['center', 'div.accordion']), - 'ibm.com': element_extractor('div:has(> p)'), # IBM's HTML is really ugly... - 'microsoft.com': element_extractor('div.content-container'), - 'mdpi.com': element_extractor( - 'article', remove=[ - '.article-icons', '.title', '.art-authors', '.art-affiliations', '.bib-identity', - '.pubhistory', '.belongsTo', '.highlight-box1', '.additional-content' - ] + "gcrinstitute.org": element_extractor("div.blog-content"), + "jbkjr.me": element_extractor("section.page__content"), + "link.springer.com": element_extractor("article.c-article-body"), + "longtermrisk.org": element_extractor("div.entry-content"), + "lukemuehlhauser.com": element_extractor("div.entry-content"), + "medium.com": medium_blog, + "openai.com": element_extractor("#content"), + "ought.org": element_extractor("div.BlogPostBodyContainer"), + "sideways-view.com": element_extractor("article", remove=["header"]), + "slatestarcodex.com": element_extractor("div.pjgm-postcontent"), + "techpolicy.press": element_extractor( + "div.post-content", + remove=[ + "div.before_content", + ".sabox-guest-authors-container", + ".jp-relatedposts", + ], + ), + "theconversation.com": element_extractor("div.content-body"), + "thegradient.pub": element_extractor("div.c-content"), + "towardsdatascience.com": medium_blog, + "unstableontology.com": element_extractor( + ".entry-content", remove=["div.sharedaddy"] + ), + "waitbutwhy.com": element_extractor("article", remove=[".entry-header"]), + "weightagnostic.github.io": element_extractor( + "dt-article", remove=["#authors_section", "dt-byline"] + ), + "cnas.org": element_extractor("#mainbar-toc"), + "econlib.org": element_extractor("div.post-content"), + "humanityplus.org": element_extractor("div.content"), + "gleech.org": element_extractor( + "article.post-content", remove=["center", "div.accordion"] + ), + "ibm.com": element_extractor("div:has(> p)"), # IBM's HTML is really ugly... + "microsoft.com": element_extractor("div.content-container"), + "mdpi.com": element_extractor( + "article", + remove=[ + ".article-icons", + ".title", + ".art-authors", + ".art-affiliations", + ".bib-identity", + ".pubhistory", + ".belongsTo", + ".highlight-box1", + ".additional-content", + ], + ), + "nature.com": element_extractor( + "article", remove=["header", "#rights link-section", "#article-info-section"] ), - 'nature.com': element_extractor('article', remove=['header', '#rights link-section', '#article-info-section']), - 'ncbi.nlm.nih.gov': element_extractor('div.article'), - 'openphilanthropy.org': element_extractor('div.pagenav-content'), - 'safe.ai': element_extractor('#open-letter'), - 'sciencedirect.com': element_extractor( - 'article', + "ncbi.nlm.nih.gov": element_extractor("div.article"), + "openphilanthropy.org": element_extractor("div.pagenav-content"), + "safe.ai": element_extractor("#open-letter"), + "sciencedirect.com": element_extractor( + "article", remove=[ - '#section-cited-by', '.Copyright', '.issue-navigation', '.ReferencedArticles', - '.LicenseInfo', '.ArticleIdentifierLinks', '.Banner', '.screen-reader-main-title', '.Publication' - ] + "#section-cited-by", + ".Copyright", + ".issue-navigation", + ".ReferencedArticles", + ".LicenseInfo", + ".ArticleIdentifierLinks", + ".Banner", + ".screen-reader-main-title", + ".Publication", + ], ), - 'transformer-circuits.pub': error('not handled yet - same codebase as distill'), - 'vox.com': element_extractor('did.c-entry-content', remove=['c-article-footer']), - 'weforum.org': element_extractor('div.wef-0'), - 'www6.inrae.fr': element_extractor('div.ArticleContent'), - 'aleph.se': element_extractor('body'), - 'yoshuabengio.org': element_extractor('div.post-content'), + "transformer-circuits.pub": error("not handled yet - same codebase as distill"), + "vox.com": element_extractor("did.c-entry-content", remove=["c-article-footer"]), + "weforum.org": element_extractor("div.wef-0"), + "www6.inrae.fr": element_extractor("div.ArticleContent"), + "aleph.se": element_extractor("body"), + "yoshuabengio.org": element_extractor("div.post-content"), } PDF_PARSERS = { # Domain sepecific handlers - 'apcz.umk.pl': get_pdf_from_page('.galleys_links a.pdf', 'a.download'), - 'arxiv.org': get_arxiv_pdf, - 'academic.oup.com': get_pdf_from_page('a.article-pdfLink'), - 'cset.georgetown.edu': get_pdf_from_page('a:-soup-contains("Download Full")'), - 'drive.google.com': extract_gdrive_contents, - 'doi.org': doi_getter, - 'dl.acm.org': fetch_pdf, - 'dspace.mit.edu': get_pdf_from_page('a.btn-primary.download-button'), - 'globalprioritiesinstitute.org': get_pdf_from_page('a:-soup-contains("PDF")'), - 'link.springer.com': multistrategy( - get_pdf_from_page('div.c-pdf-download a'), + "apcz.umk.pl": get_pdf_from_page(".galleys_links a.pdf", "a.download"), + "arxiv.org": get_arxiv_pdf, + "academic.oup.com": get_pdf_from_page("a.article-pdfLink"), + "cset.georgetown.edu": get_pdf_from_page('a:-soup-contains("Download Full")'), + "drive.google.com": extract_gdrive_contents, + "doi.org": doi_getter, + "dl.acm.org": fetch_pdf, + "dspace.mit.edu": get_pdf_from_page("a.btn-primary.download-button"), + "globalprioritiesinstitute.org": get_pdf_from_page('a:-soup-contains("PDF")'), + "link.springer.com": multistrategy( + get_pdf_from_page("div.c-pdf-download a"), doi_getter, ), - 'openaccess.thecvf.com': get_pdf_from_page('a:-soup-contains("pdf")'), - 'openreview.net': get_pdf_from_page('a.note_content_pdf'), - 'ora.ox.ac.uk': fetch_pdf, - 'papers.nips.cc': get_pdf_from_page('a:-soup-contains("Paper")'), - 'papers.ssrn.com': get_pdf_from_page('.abstract-buttons a.button-link:-soup-contains("Download")'), - 'par.nsf.gov': get_pdf_from_page('a:-soup-contains("Accepted Manuscript")'), - 'proceedings.neurips.cc': get_pdf_from_page('a:-soup-contains("Paper")'), - 'psyarxiv.com': lambda url: fetch_pdf(url.rstrip('/') + '/download'), - 'rowanzellers.com': get_pdf_from_page('main a:-soup-contains("Paper")'), - 'governance.ai': get_pdf_from_page('a.read-paper-button:not([href="#"])'), - 'ijcai.org': get_pdf_from_page('a.btn-download:-soup-contains("PDF")'), - 'jair.org': get_pdf_from_page('div.download a.pdf', 'a.download'), - 'jstor.org': doi_getter, - 'ri.cmu.edu': get_pdf_from_page('a.pub-link'), - 'risksciences.ucla.edu': get_pdf_from_page('a:-soup-contains("Download")'), - 'ssrn.com': get_pdf_from_page('.abstract-buttons a.button-link:-soup-contains("Download")'), - 'yjolt.org': get_pdf_from_page('span.file a'), + "openaccess.thecvf.com": get_pdf_from_page('a:-soup-contains("pdf")'), + "openreview.net": get_pdf_from_page("a.note_content_pdf"), + "ora.ox.ac.uk": fetch_pdf, + "papers.nips.cc": get_pdf_from_page('a:-soup-contains("Paper")'), + "papers.ssrn.com": get_pdf_from_page( + '.abstract-buttons a.button-link:-soup-contains("Download")' + ), + "par.nsf.gov": get_pdf_from_page('a:-soup-contains("Accepted Manuscript")'), + "proceedings.neurips.cc": get_pdf_from_page('a:-soup-contains("Paper")'), + "psyarxiv.com": lambda url: fetch_pdf(url.rstrip("/") + "/download"), + "rowanzellers.com": get_pdf_from_page('main a:-soup-contains("Paper")'), + "governance.ai": get_pdf_from_page('a.read-paper-button:not([href="#"])'), + "ijcai.org": get_pdf_from_page('a.btn-download:-soup-contains("PDF")'), + "jair.org": get_pdf_from_page("div.download a.pdf", "a.download"), + "jstor.org": doi_getter, + "ri.cmu.edu": get_pdf_from_page("a.pub-link"), + "risksciences.ucla.edu": get_pdf_from_page('a:-soup-contains("Download")'), + "ssrn.com": get_pdf_from_page( + '.abstract-buttons a.button-link:-soup-contains("Download")' + ), + "yjolt.org": get_pdf_from_page("span.file a"), } def item_metadata(url) -> Dict[str, str]: - domain = urlparse(url).netloc.lstrip('www.') - res = fetch(url, 'head') - content_type = {item.strip() for item in res.headers.get('Content-Type').split(';')} + domain = urlparse(url).netloc.lstrip("www.") + res = fetch(url, "head") + content_type = {item.strip() for item in res.headers.get("Content-Type").split(";")} - if content_type & {'text/html', 'text/xml'}: + if content_type & {"text/html", "text/xml"}: # If the url points to a html webpage, then it either contains the text as html, or # there is a link to a pdf on it if parser := HTML_PARSERS.get(domain): if res := parser(url): # Proper contents were found on the page, so use them - return {'source_url': url, 'data_source': 'html'} + return {"source_url": url, "data_source": "html"} if parser := PDF_PARSERS.get(domain): if res := parser(url): @@ -278,17 +339,19 @@ def item_metadata(url) -> Dict[str, str]: return res if parser := UNIMPLEMENTED_PARSERS.get(domain): - return {'error': parser(url)} - - if domain not in (HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys()): - return {'error': 'No domain handler defined'} - return {'error': 'could not parse url'} - elif content_type & {'application/octet-stream', 'application/pdf'}: + return {"error": parser(url)} + + if domain not in ( + HTML_PARSERS.keys() | PDF_PARSERS.keys() | UNIMPLEMENTED_PARSERS.keys() + ): + return {"error": "No domain handler defined"} + return {"error": "could not parse url"} + elif content_type & {"application/octet-stream", "application/pdf"}: # this looks like it could be a pdf - try to download it as one return fetch_pdf(url) - elif content_type & {'application/epub+zip', 'application/epub'}: + elif content_type & {"application/epub+zip", "application/epub"}: # it looks like an ebook. Assume it's fine. # TODO: validate that the ebook is readable - return {'source_url': url, 'data_source': 'ebook'} + return {"source_url": url, "data_source": "ebook"} else: - return {'error': f'Unhandled content type: {content_type}'} + return {"error": f"Unhandled content type: {content_type}"} diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py index f6a020ce..2120dc56 100644 --- a/align_data/sources/articles/pdf.py +++ b/align_data/sources/articles/pdf.py @@ -22,21 +22,21 @@ def sci_hub_pdf(identifier): large file containing multiple articles, e.g. a whole journal or book, in which case this function will ignore the result. """ - elem = fetch_element(f'https://sci-hub.st/{identifier}', 'embed') + elem = fetch_element(f"https://sci-hub.st/{identifier}", "embed") if not elem: return None - src = elem.get('src').strip() - if src.startswith('//'): - src = 'https:' + src - elif src.startswith('/'): - src = f'https://sci-hub.st/{src}' + src = elem.get("src").strip() + if src.startswith("//"): + src = "https:" + src + elif src.startswith("/"): + src = f"https://sci-hub.st/{src}" return src def read_pdf(filename): try: pdf_reader = PdfReader(filename) - return '\n'.join(page.extract_text() for page in pdf_reader.pages) + return "\n".join(page.extract_text() for page in pdf_reader.pages) except PdfReadError as e: logger.error(e) return None @@ -50,36 +50,45 @@ def fetch_pdf(link): :returns: the contents of the pdf file as markdown.""" res = fetch(link) if res.status_code >= 400: - logger.error('Could not fetch the pdf file at %s - are you sure that link is correct?', link) + logger.error( + "Could not fetch the pdf file at %s - are you sure that link is correct?", + link, + ) - content_type = {c_type.strip().lower() for c_type in res.headers.get('Content-Type').split(';')} - if not content_type & {'application/octet-stream', 'application/pdf'}: + content_type = { + c_type.strip().lower() for c_type in res.headers.get("Content-Type").split(";") + } + if not content_type & {"application/octet-stream", "application/pdf"}: return { - 'error': f'Wrong content type retrieved: {content_type} - {link}', - 'contents': res.content, + "error": f"Wrong content type retrieved: {content_type} - {link}", + "contents": res.content, } try: pdf_reader = PdfReader(io.BytesIO(res.content)) return { - 'source_url': link, - 'text': '\n'.join(page.extract_text() for page in pdf_reader.pages), - 'data_source': 'pdf', + "source_url": link, + "text": "\n".join(page.extract_text() for page in pdf_reader.pages), + "data_source": "pdf", } except PdfReadError as e: - logger.error('Could not read PDF file: %s', e) - return {'error': str(e)} + logger.error("Could not read PDF file: %s", e) + return {"error": str(e)} filenames = [ - i.strip().split('=')[1] - for i in res.headers.get('Content-Disposition', '').split(';') - if 'filename' in i + i.strip().split("=")[1] + for i in res.headers.get("Content-Disposition", "").split(";") + if "filename" in i ] - if filenames and 'pdf' not in filenames[0].lower(): - logger.error('Are you sure %s points to a pdf file? The response says the file should be called %s', link, filenames[0]) - error = f'Probably bad file type: {filenames[0]} - {link}' + if filenames and "pdf" not in filenames[0].lower(): + logger.error( + "Are you sure %s points to a pdf file? The response says the file should be called %s", + link, + filenames[0], + ) + error = f"Probably bad file type: {filenames[0]} - {link}" - return {'error': error} + return {"error": error} def get_arxiv_link(doi): @@ -88,14 +97,16 @@ def get_arxiv_link(doi): if res.status_code != 200: return None - vals = [i for i in response.json().get('values') if i.get('type', '').upper() == 'URL'] + vals = [ + i for i in response.json().get("values") if i.get("type", "").upper() == "URL" + ] if not vals: return None return vals[0]["data"]["value"].replace("/abs/", "/pdf/") + ".pdf" def get_arxiv_pdf(link): - return fetch_pdf(link.replace('/abs/', '/pdf/')) + return fetch_pdf(link.replace("/abs/", "/pdf/")) def get_doi(doi): @@ -104,23 +115,23 @@ def get_doi(doi): This will look for it in sci-hub and arxiv (if applicable), as those are likely the most comprehensive sources of pdfs. """ - if 'arXiv' in doi: + if "arXiv" in doi: link = get_arxiv_link(doi) - pdf = (link and fetch_pdf(link)) - if pdf and 'text' in pdf: - pdf['downloaded_from'] = 'arxiv' + pdf = link and fetch_pdf(link) + if pdf and "text" in pdf: + pdf["downloaded_from"] = "arxiv" return pdf if link := sci_hub_pdf(doi): if pdf := fetch_pdf(link): - pdf['downloaded_from'] = 'scihub' + pdf["downloaded_from"] = "scihub" return pdf - return {'error': 'Could not find pdf of article by DOI'} + return {"error": "Could not find pdf of article by DOI"} def doi_getter(url): """Extract the DOI from the given `url` and fetch the contents of its article.""" - return get_doi(urlparse(url).path.lstrip('/')) + return get_doi(urlparse(url).path.lstrip("/")) def get_pdf_from_page(*link_selectors): @@ -133,34 +144,38 @@ def get_pdf_from_page(*link_selectors): :param List[str] link_selectors: CSS selector used to find the final download link :returns: the contents of the pdf file as a string """ + def getter(url): link = url for selector in link_selectors: elem = fetch_element(link, selector) if not elem: - return {'error': f'Could not find pdf download link for {link} using \'{selector}\''} + return { + "error": f"Could not find pdf download link for {link} using '{selector}'" + } - link = elem.get('href') - if not link.startswith('http') or not link.startswith('//'): + link = elem.get("href") + if not link.startswith("http") or not link.startswith("//"): link = urljoin(url, link) # Some pages keep link to google drive previews of pdf files, which need to be # mangled to get the URL of the actual pdf file - if 'drive.google.com' in link and '/view' in link: + if "drive.google.com" in link and "/view" in link: return extract_gdrive_contents(link) if pdf := fetch_pdf(link): return pdf - return {'error': f'Could not fetch pdf from {link}'} + return {"error": f"Could not fetch pdf from {link}"} + return getter def parse_vanity(url): - contents = fetch_element(url, 'article') + contents = fetch_element(url, "article") if not contents: return None - if title := contents.select_one('h1.ltx_title'): + if title := contents.select_one("h1.ltx_title"): title = title.text def get_first_child(item): @@ -170,24 +185,28 @@ def get_first_child(item): if not isinstance(child, str): child = child.text - return child.split(',') + return child.split(",") authors = [ - a.strip() for item in contents.select('div.ltx_authors .ltx_personname') for a in get_first_child(item) + a.strip() + for item in contents.select("div.ltx_authors .ltx_personname") + for a in get_first_child(item) ] - if date_published := contents.select_one('div.ltx_dates'): - date_published = date_published.text.strip('()') + if date_published := contents.select_one("div.ltx_dates"): + date_published = date_published.text.strip("()") - text = '\n\n'.join([ - MarkdownConverter().convert_soup(elem).strip() - for elem in contents.select('section.ltx_section') - ]) + text = "\n\n".join( + [ + MarkdownConverter().convert_soup(elem).strip() + for elem in contents.select("section.ltx_section") + ] + ) return { - 'title': title, - 'authors': authors, - 'text': text, - 'date_published': date_published, - 'data_source': 'html', + "title": title, + "authors": authors, + "text": text, + "date_published": date_published, + "data_source": "html", } diff --git a/align_data/sources/arxiv_papers/__init__.py b/align_data/sources/arxiv_papers/__init__.py index f9bc2080..29258480 100644 --- a/align_data/sources/arxiv_papers/__init__.py +++ b/align_data/sources/arxiv_papers/__init__.py @@ -3,7 +3,7 @@ ARXIV_REGISTRY = [ ArxivPapers( name="arxiv", - spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI', - sheet_id='655836697' + spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", + sheet_id="655836697", ) ] diff --git a/align_data/sources/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py index ae9b7cb9..1a61ecc6 100644 --- a/align_data/sources/arxiv_papers/arxiv_papers.py +++ b/align_data/sources/arxiv_papers/arxiv_papers.py @@ -11,7 +11,7 @@ @dataclass class ArxivPapers(SpreadsheetDataset): - summary_key: str = 'summary' + summary_key: str = "summary" COOLDOWN: int = 1 done_key = "url" batch_size = 1 @@ -28,45 +28,52 @@ def _get_arxiv_metadata(self, paper_id) -> arxiv.Result: return None def get_id(self, item): - if res := re.search(r'https://arxiv.org/abs/(.*?)/?$', item.url): + if res := re.search(r"https://arxiv.org/abs/(.*?)/?$", item.url): return res.group(1) def get_contents(self, item) -> dict: paper_id = self.get_id(item) - for link in [f"https://www.arxiv-vanity.com/papers/{paper_id}", f"https://ar5iv.org/abs/{paper_id}"]: + for link in [ + f"https://www.arxiv-vanity.com/papers/{paper_id}", + f"https://ar5iv.org/abs/{paper_id}", + ]: if contents := parse_vanity(link): return contents - return fetch_pdf(f'https://arxiv.org/pdf/{paper_id}.pdf') + return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf") def process_entry(self, item) -> None: logger.info(f"Processing {item.title}") paper = self.get_contents(item) - if not paper or not paper.get('text'): + if not paper or not paper.get("text"): return None metadata = self._get_arxiv_metadata(self.get_id(item)) if self.is_val(item.authors) and item.authors.strip(): - authors = item.authors.split(',') + authors = item.authors.split(",") elif metadata and metadata.authors: authors = metadata.authors else: - authors = paper.get('authors') or [] + authors = paper.get("authors") or [] authors = [str(a).strip() for a in authors] - return self.make_data_entry({ - "url": self.get_item_key(item), - "source": self.name, - "source_type": paper['data_source'], - "title": self.is_val(item.title) or paper.get('title'), - "authors": authors, - "date_published": self._get_published_date(self.is_val(item.date_published) or paper.get('date_published')), - "data_last_modified": str(metadata.updated), - "summary": metadata.summary.replace("\n", " "), - "author_comment": metadata.comment, - "journal_ref": metadata.journal_ref, - "doi": metadata.doi, - "primary_category": metadata.primary_category, - "categories": metadata.categories, - "text": paper['text'], - }) + return self.make_data_entry( + { + "url": self.get_item_key(item), + "source": self.name, + "source_type": paper["data_source"], + "title": self.is_val(item.title) or paper.get("title"), + "authors": authors, + "date_published": self._get_published_date( + self.is_val(item.date_published) or paper.get("date_published") + ), + "data_last_modified": str(metadata.updated), + "summary": metadata.summary.replace("\n", " "), + "author_comment": metadata.comment, + "journal_ref": metadata.journal_ref, + "doi": metadata.doi, + "primary_category": metadata.primary_category, + "categories": metadata.categories, + "text": paper["text"], + } + ) diff --git a/align_data/sources/blogs/__init__.py b/align_data/sources/blogs/__init__.py index 7021c994..ed55dc81 100644 --- a/align_data/sources/blogs/__init__.py +++ b/align_data/sources/blogs/__init__.py @@ -2,7 +2,12 @@ from align_data.sources.blogs.medium_blog import MediumBlog from align_data.sources.blogs.gwern_blog import GwernBlog from align_data.sources.blogs.blogs import ( - ColdTakes, GenerativeInk, CaradoMoe, EleutherAI, OpenAIResearch, DeepMindTechnicalBlog + ColdTakes, + GenerativeInk, + CaradoMoe, + EleutherAI, + OpenAIResearch, + DeepMindTechnicalBlog, ) from align_data.sources.blogs.substack_blog import SubstackBlog @@ -14,34 +19,43 @@ WordpressBlog(name="jsteinhardt_blog", url="https://jsteinhardt.wordpress.com"), WordpressBlog(name="vkrakovna_blog", url="https://vkrakovna.wordpress.com"), WordpressBlog(name="yudkowsky_blog", url="https://yudkowsky.net"), - MediumBlog(name="deepmind_blog", url="https://deepmindsafetyresearch.medium.com/", authors=["DeepMind Safety Research"]), - GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]), + MediumBlog( + name="deepmind_blog", + url="https://deepmindsafetyresearch.medium.com/", + authors=["DeepMind Safety Research"], + ), + GwernBlog( + name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] + ), ColdTakes( name="cold_takes", url="https://www.cold-takes.com/", - authors=['Holden Karnofsky'], + authors=["Holden Karnofsky"], ), GenerativeInk( name="generative.ink", url="https://generative.ink/posts/", - authors=['janus'], + authors=["janus"], ), CaradoMoe( name="carado.moe", - url='https://carado.moe', - authors=['Tamsin Leake'], + url="https://carado.moe", + authors=["Tamsin Leake"], ), SubstackBlog( name="importai", url="https://importai.substack.com", - id_fields=['url', 'title', 'source'] + id_fields=["url", "title", "source"], ), SubstackBlog( name="ml_safety_newsletter", url="https://newsletter.mlsafety.org", - id_fields=['url', 'title', 'source'] + id_fields=["url", "title", "source"], + ), + EleutherAI(name="eleuther.ai", url="https://blog.eleuther.ai/"), + OpenAIResearch(name="openai.research", url="https://openai.com/research"), + DeepMindTechnicalBlog( + name="deepmind_technical_blog", + url="https://www.deepmind.com/blog-categories/technical-blogs", ), - EleutherAI(name='eleuther.ai', url='https://blog.eleuther.ai/'), - OpenAIResearch(name='openai.research', url='https://openai.com/research'), - DeepMindTechnicalBlog(name='deepmind_technical_blog', url='https://www.deepmind.com/blog-categories/technical-blogs'), ] diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index 65ad95e6..ae4439a9 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -9,82 +9,89 @@ logger = logging.getLogger(__name__) + class ColdTakes(HTMLDataset): - item_selector = 'div.post-feed article' + item_selector = "div.post-feed article" - ignored_selectors = ['center', 'div[style*="display:flex"]', 'footer'] + ignored_selectors = ["center", 'div[style*="display:flex"]', "footer"] def _get_published_date(self, contents): - header = contents.select_one('article header').extract() - date = header.find('time').get('datetime') + header = contents.select_one("article header").extract() + date = header.find("time").get("datetime") return super()._get_published_date(date) class GenerativeInk(HTMLDataset): - item_selector = 'div.post.on-list' + item_selector = "div.post.on-list" def _get_published_date(self, contents): possible_date_elements = [ - elem for info in contents.select('div.post-info') - for elem in info.children + elem for info in contents.select("div.post-info") for elem in info.children ] return self._find_date(possible_date_elements) class CaradoMoe(RSSDataset): def _get_text(self, item): - contents = item['soup'] - meta = contents.find('p', {'class': 'postmeta'}) - return self._extract_markdown(meta.find_next_sibling('div')) + contents = item["soup"] + meta = contents.find("p", {"class": "postmeta"}) + return self._extract_markdown(meta.find_next_sibling("div")) class EleutherAI(HTMLDataset): - - item_selector = 'div.archive-entry' - text_selector = 'div.post-content' + item_selector = "div.archive-entry" + text_selector = "div.post-content" def _get_published_date(self, contents): try: - date = contents.select_one('header .post-meta').text.split('·')[0].strip() + date = contents.select_one("header .post-meta").text.split("·")[0].strip() return super()._get_published_date(date) except (ValueError, ParserError): - return '' + return "" def extract_authors(self, article): - return article.select_one('header .post-meta').text.split('·')[1].strip().split(', ') + return ( + article.select_one("header .post-meta") + .text.split("·")[1] + .strip() + .split(", ") + ) class OpenAIResearch(HTMLDataset): - - item_selector = 'li.group-item' - title_selector = '.container h1' + item_selector = "li.group-item" + title_selector = ".container h1" def _get_published_date(self, contents): - if date := contents.select_one('.container .f-meta-2'): + if date := contents.select_one(".container .f-meta-2"): return super()._get_published_date(date.text) - return '' + return "" def _get_text(self, contents): - if paper_link := contents.select_one('.container .cols-container a.ui-link:-soup-contains("Read paper")'): - return item_metadata(paper_link.get('href')).get('text') + if paper_link := contents.select_one( + '.container .cols-container a.ui-link:-soup-contains("Read paper")' + ): + return item_metadata(paper_link.get("href")).get("text") def extract_authors(self, article): - authors = ( - article.select_one('div:-soup-contains("Authors") + div .f-body-1') or - article.select_one('div:-soup-contains("Acknowledgments") + div .f-body-1') - ) + authors = article.select_one( + 'div:-soup-contains("Authors") + div .f-body-1' + ) or article.select_one('div:-soup-contains("Acknowledgments") + div .f-body-1') if not authors: return [] - return [i.split('(')[0].strip() for i in authors.select_one('p').children if not i.name] + return [ + i.split("(")[0].strip() + for i in authors.select_one("p").children + if not i.name + ] class DeepMindTechnicalBlog(HTMLDataset): - - item_selector = 'div.w-dyn-item .c_card_list__item__blog' - title_selector = '.c_banner__blog__card h2' - text_selector = '.c_rich-text__cms' - ignored_selectors = ['.article-gtag-buttons'] + item_selector = "div.w-dyn-item .c_card_list__item__blog" + title_selector = ".c_banner__blog__card h2" + text_selector = ".c_rich-text__cms" + ignored_selectors = [".article-gtag-buttons"] @property def items_list(self): @@ -93,7 +100,9 @@ def items_list(self): with tqdm(desc=f"Loading {self.name} pages") as pbar: while True: logger.info(f"Fetching entries from {self.url}") - response = requests.get(self.url, allow_redirects=True, params={'73df3071_page': page}) + response = requests.get( + self.url, allow_redirects=True, params={"73df3071_page": page} + ) soup = BeautifulSoup(response.content, "html.parser") items = soup.select(self.item_selector) if not items: @@ -103,18 +112,22 @@ def items_list(self): page += 1 # update the tqdm progress bar - pbar.set_postfix_str(f"page {page}", refresh=True) # Set postfix to "page X" + pbar.set_postfix_str( + f"page {page}", refresh=True + ) # Set postfix to "page X" pbar.update() # Here we increment the progress bar by 1 - logger.info('Got %s pages', len(articles)) + logger.info("Got %s pages", len(articles)) return articles def _get_published_date(self, contents): - if date := contents.select_one('.c_banner__blog__card__meta'): + if date := contents.select_one(".c_banner__blog__card__meta"): return super()._get_published_date(date.text) - return '' + return "" def extract_authors(self, article): - if div := article.select_one('.c_cms_content__meta__wrapper div:-soup-contains("Authors") + div'): - return [author.strip() for author in div.text.split(',')] + if div := article.select_one( + '.c_cms_content__meta__wrapper div:-soup-contains("Authors") + div' + ): + return [author.strip() for author in div.text.split(",")] return [] diff --git a/align_data/sources/blogs/gwern_blog.py b/align_data/sources/blogs/gwern_blog.py index 325bd7d3..9328d874 100644 --- a/align_data/sources/blogs/gwern_blog.py +++ b/align_data/sources/blogs/gwern_blog.py @@ -22,48 +22,50 @@ def get_item_key(self, item): @property def items_list(self): return [ - 'https://www.gwern.net/Scaling-hypothesis.page', - 'https://www.gwern.net/Tanks.page', - 'https://www.gwern.net/Clippy.page', - 'https://www.gwern.net/complexity.page', - 'https://www.gwern.net/Tool-AI.page', - 'https://www.gwern.net/Backstop.page', - 'https://www.gwern.net/Hyperbolic-Time-Chamber.page' + "https://www.gwern.net/Scaling-hypothesis.page", + "https://www.gwern.net/Tanks.page", + "https://www.gwern.net/Clippy.page", + "https://www.gwern.net/complexity.page", + "https://www.gwern.net/Tool-AI.page", + "https://www.gwern.net/Backstop.page", + "https://www.gwern.net/Hyperbolic-Time-Chamber.page", ] def process_entry(self, post_href): article = self._get_article(post_href) if article.status_code != 200: - logger.error(f'Could not fetch {post_href}') + logger.error(f"Could not fetch {post_href}") return None # Some pages are returned as markdown, some as HTML, so handle both - if 'text/html' in article.headers.get('Content-Type', ''): + if "text/html" in article.headers.get("Content-Type", ""): return super().process_entry(post_href) return self._process_markdown(post_href, article) def _process_markdown(self, post_href, article): - parts = article.text.split('...') + parts = article.text.split("...") metadata = self._get_metadata(parts[0]) - text = self._extract_markdown('...'.join(parts[1:])) - - return self.make_data_entry({ - "source": self.name, - "source_type": self.source_type, - "url": post_href, - "title": metadata.get('title'), - "authors": self.authors, - "date_published": self._get_published_date(metadata), - "text": text, - }) + text = self._extract_markdown("...".join(parts[1:])) + + return self.make_data_entry( + { + "source": self.name, + "source_type": self.source_type, + "url": post_href, + "title": metadata.get("title"), + "authors": self.authors, + "date_published": self._get_published_date(metadata), + "text": text, + } + ) @staticmethod def _get_metadata(header): def extract(item): - parts = item.split(': ') + parts = item.split(": ") if len(parts) > 1: - return (parts[0].strip(), ': '.join(parts[1:])) + return (parts[0].strip(), ": ".join(parts[1:])) return None return dict(filter(None, map(extract, header.splitlines()))) @@ -74,17 +76,17 @@ def _get_article(self, url): @staticmethod def _get_title(contents): - return contents.find('header').find('h1').text + return contents.find("header").find("h1").text def _get_published_date(self, contents): if isinstance(contents, dict): - date_published = contents.get('modified') or contents.get('created') + date_published = contents.get("modified") or contents.get("created") else: date_published = ( - contents.select_one('.page-date-range .page-modified') or - contents.select_one('.page-date-range .page-created') + contents.select_one(".page-date-range .page-modified") + or contents.select_one(".page-date-range .page-created") ).text.strip() return super()._get_published_date(date_published) def _get_text(self, contents): - return self._extract_markdown(contents.select_one('div#markdownBody')) + return self._extract_markdown(contents.select_one("div#markdownBody")) diff --git a/align_data/sources/blogs/medium_blog.py b/align_data/sources/blogs/medium_blog.py index 9d57e2ab..5d80dfee 100644 --- a/align_data/sources/blogs/medium_blog.py +++ b/align_data/sources/blogs/medium_blog.py @@ -5,6 +5,7 @@ logger = logging.getLogger(__name__) + @dataclass class MediumBlog(HTMLDataset): """ @@ -27,8 +28,8 @@ class MediumBlog(HTMLDataset): """ source_type = "medium_blog" - ignored_selectors = ['div:first-child span'] + ignored_selectors = ["div:first-child span"] def _get_published_date(self, contents): - possible_date_elements = contents.select('article div:first-child span') + possible_date_elements = contents.select("article div:first-child span") return self._find_date(possible_date_elements) diff --git a/align_data/sources/blogs/substack_blog.py b/align_data/sources/blogs/substack_blog.py index ec6aa481..526bbb4f 100644 --- a/align_data/sources/blogs/substack_blog.py +++ b/align_data/sources/blogs/substack_blog.py @@ -3,8 +3,8 @@ class SubstackBlog(RSSDataset): source_type = "substack" - date_format = '%a, %d %b %Y %H:%M:%S %Z' + date_format = "%a, %d %b %Y %H:%M:%S %Z" @property def feed_url(self): - return self.url + '/feed' + return self.url + "/feed" diff --git a/align_data/sources/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py index 197cc078..c0132301 100644 --- a/align_data/sources/blogs/wp_blog.py +++ b/align_data/sources/blogs/wp_blog.py @@ -11,7 +11,7 @@ @dataclass class WordpressBlog(RSSDataset): - summary_key = 'summary' + summary_key = "summary" @property def feed_url(self): @@ -31,19 +31,21 @@ def items_list(self): logging.info(f"Fetching {paged_url}") feed = feedparser.parse(paged_url) - title = feed.get('feed', {}).get('title') + title = feed.get("feed", {}).get("title") if not title or title == prev_title: break prev_title = feed["feed"]["title"] page_number += 1 - for item in feed['entries']: - self.items[item['link']] = item + for item in feed["entries"]: + self.items[item["link"]] = item # update the tqdm progress bar - pbar.set_postfix_str(f"page {page_number}", refresh=True) # Set postfix to "page X" + pbar.set_postfix_str( + f"page {page_number}", refresh=True + ) # Set postfix to "page X" pbar.update() # Here we increment the progress bar by 1 - logger.info(f'Got {len(self.items)} pages') + logger.info(f"Got {len(self.items)} pages") return list(self.items.keys()) diff --git a/align_data/sources/distill/__init__.py b/align_data/sources/distill/__init__.py index 80e40f24..66684922 100644 --- a/align_data/sources/distill/__init__.py +++ b/align_data/sources/distill/__init__.py @@ -3,7 +3,7 @@ DISTILL_REGISTRY = [ Distill( - name = "distill", - url='https://distill.pub', + name="distill", + url="https://distill.pub", ), ] diff --git a/align_data/sources/distill/distill.py b/align_data/sources/distill/distill.py index 4a9ea388..f54fb554 100644 --- a/align_data/sources/distill/distill.py +++ b/align_data/sources/distill/distill.py @@ -5,29 +5,30 @@ @dataclass class Distill(RSSDataset): - source_type = 'html' - done_key = 'url' - summary_key = 'summary' + source_type = "html" + done_key = "url" + summary_key = "summary" def extract_authors(self, item): - return [a.text for a in item['soup'].select('.authors-affiliations p.author a')] + return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] def _get_text(self, item): - article = item['soup'].find('d-article') or item['soup'].find('dt-article') + article = item["soup"].find("d-article") or item["soup"].find("dt-article") return self._extract_markdown(article) def _extra_values(self, item): - soup = item['soup'] + soup = item["soup"] - doi_elem = soup.find('h3', string='DOI') - doi_elem = doi_elem and doi_elem.find_next_sibling('p') + doi_elem = soup.find("h3", string="DOI") + doi_elem = doi_elem and doi_elem.find_next_sibling("p") return { - 'doi': doi_elem and doi_elem.text, - 'summary': item['summary'], - 'journal_ref': 'distill-pub', - 'bibliography': [ - {'title': el.find('span').text, 'link': el.find('a').get('href')} - for el in soup.select('.references li') if el.find('a') - ] + "doi": doi_elem and doi_elem.text, + "summary": item["summary"], + "journal_ref": "distill-pub", + "bibliography": [ + {"title": el.find("span").text, "link": el.find("a").get("href")} + for el in soup.select(".references li") + if el.find("a") + ], } diff --git a/align_data/sources/ebooks/__init__.py b/align_data/sources/ebooks/__init__.py index 7334c938..0055f5e0 100644 --- a/align_data/sources/ebooks/__init__.py +++ b/align_data/sources/ebooks/__init__.py @@ -2,7 +2,6 @@ EBOOK_REGISTRY = [ AgentModels( - name='agentmodels', - repo='https://github.com/agentmodels/agentmodels.org.git' + name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git" ), ] diff --git a/align_data/sources/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py index ee9593bb..cfd68a79 100644 --- a/align_data/sources/ebooks/agentmodels.py +++ b/align_data/sources/ebooks/agentmodels.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + @dataclass class AgentModels(AlignmentDataset): """ @@ -13,30 +14,39 @@ class AgentModels(AlignmentDataset): John Salvatier, and Daniel Filan as .md from GitHub """ - repo: str = 'https://github.com/agentmodels/agentmodels.org.git' + repo: str = "https://github.com/agentmodels/agentmodels.org.git" done_key = "filename" def setup(self): super().setup() - self.base_dir = self.raw_data_path / 'agentmodels.org' - if not self.base_dir.exists() or not list(self.base_dir.glob('*')): + self.base_dir = self.raw_data_path / "agentmodels.org" + if not self.base_dir.exists() or not list(self.base_dir.glob("*")): logger.info("Cloning repo") Repo.clone_from(self.repo, self.base_dir) self.repository = Repo(self.base_dir) - self.files_path = self.base_dir / 'chapters' + self.files_path = self.base_dir / "chapters" def _get_published_date(self, filename): - last_commit = next(self.repository.iter_commits(paths=f'chapters/{filename.name}')) + last_commit = next( + self.repository.iter_commits(paths=f"chapters/{filename.name}") + ) return last_commit.committed_datetime.astimezone(timezone.utc) def process_entry(self, filename): - return self.make_data_entry({ - 'source': self.name, - 'source_type': 'markdown', - 'authors': ['Owain Evans', 'Andreas Stuhlmüller', 'John Salvatier', 'Daniel Filan'], - 'date_published': self._get_published_date(filename), - 'title': 'Modeling Agents with Probabilistic Programs', - 'url': f'https://agentmodels.org/chapters/{filename.stem}.html', - 'filename': filename.name, - 'text': filename.read_text(encoding='utf-8'), - }) + return self.make_data_entry( + { + "source": self.name, + "source_type": "markdown", + "authors": [ + "Owain Evans", + "Andreas Stuhlmüller", + "John Salvatier", + "Daniel Filan", + ], + "date_published": self._get_published_date(filename), + "title": "Modeling Agents with Probabilistic Programs", + "url": f"https://agentmodels.org/chapters/{filename.stem}.html", + "filename": filename.name, + "text": filename.read_text(encoding="utf-8"), + } + ) diff --git a/align_data/sources/greaterwrong/__init__.py b/align_data/sources/greaterwrong/__init__.py index 8f4a079c..300fd00b 100644 --- a/align_data/sources/greaterwrong/__init__.py +++ b/align_data/sources/greaterwrong/__init__.py @@ -3,23 +3,23 @@ GREATERWRONG_REGISTRY = [ GreaterWrong( name="lesswrong", - base_url='https://www.lesswrong.com', + base_url="https://www.lesswrong.com", start_year=2005, min_karma=1, af=False, ), GreaterWrong( name="alignmentforum", - base_url='https://www.alignmentforum.org', + base_url="https://www.alignmentforum.org", start_year=2009, min_karma=1, af=True, ), GreaterWrong( name="eaforum", - base_url='https://forum.effectivealtruism.org', + base_url="https://forum.effectivealtruism.org", start_year=2011, min_karma=1, af=False, - ) + ), ] diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py index f00e9d06..f746e552 100644 --- a/align_data/sources/greaterwrong/greaterwrong.py +++ b/align_data/sources/greaterwrong/greaterwrong.py @@ -16,33 +16,37 @@ def fetch_LW_tags(url): res = requests.get( - url + '/tag/ai', - headers={'User-Agent': 'Mozilla /5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0'}, + url + "/tag/ai", + headers={ + "User-Agent": "Mozilla /5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0" + }, ) soup = BeautifulSoup(res.content, "html.parser") - tags = soup.select('div.TagPage-description .table a') - return {a.text.strip() for a in tags if '/tag/' in a.get('href')} + tags = soup.select("div.TagPage-description .table a") + return {a.text.strip() for a in tags if "/tag/" in a.get("href")} def fetch_ea_forum_topics(url): - res = requests.get(url + '/topics/ai-safety') + res = requests.get(url + "/topics/ai-safety") soup = BeautifulSoup(res.content, "html.parser") - links = soup.select('div.SidebarSubtagsBox-root a') - return {a.text.strip() for a in links if '/topics/' in a.get('href', '')} + links = soup.select("div.SidebarSubtagsBox-root a") + return {a.text.strip() for a in links if "/topics/" in a.get("href", "")} def get_allowed_tags(url, name): - if name == 'alignmentforum': + if name == "alignmentforum": return set() try: - if name == 'lesswrong': + if name == "lesswrong": return fetch_LW_tags(url) - if name == 'eaforum': + if name == "eaforum": return fetch_ea_forum_topics(url) except Exception: - raise ValueError('Could not fetch tags! Please retry') + raise ValueError("Could not fetch tags! Please retry") - raise ValueError(f'Could not fetch tags for unknown datasource: "{name}". Must be one of alignmentforum|lesswrong|eaforum') + raise ValueError( + f'Could not fetch tags for unknown datasource: "{name}". Must be one of alignmentforum|lesswrong|eaforum' + ) @dataclass @@ -61,8 +65,8 @@ class GreaterWrong(AlignmentDataset): """Whether alignment forum posts should be returned""" limit = 50 - COOLDOWN_TIME : float = 0.5 - summary_key: str = 'summary' + COOLDOWN_TIME: float = 0.5 + summary_key: str = "summary" done_key = "url" lazy_eval = True @@ -73,26 +77,30 @@ def setup(self): self.ai_tags = get_allowed_tags(self.base_url, self.name) def tags_ok(self, post): - return not self.ai_tags or {t['name'] for t in post['tags'] if t.get('name')} & self.ai_tags + return ( + not self.ai_tags + or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags + ) def get_item_key(self, item): - return item['pageUrl'] + return item["pageUrl"] def _get_published_date(self, item): - return super()._get_published_date(item.get('postedAt')) + return super()._get_published_date(item.get("postedAt")) def make_query(self, after: str): - return """{ + return ( + """{ posts(input: { terms: { excludeEvents: true view: "old" - """ \ - f" af: {self.af}\n" \ - f" limit: {self.limit}\n" \ - f" karmaThreshold: {self.min_karma}\n" \ - f' after: "{after}"\n' \ - """ filter: "tagged" + """ + f" af: {self.af}\n" + f" limit: {self.limit}\n" + f" karmaThreshold: {self.min_karma}\n" + f' after: "{after}"\n' + """ filter: "tagged" } }) { totalCount @@ -123,60 +131,65 @@ def make_query(self, after: str): } } }""" + ) def fetch_posts(self, query: str): res = requests.post( - f'{self.base_url}/graphql', + f"{self.base_url}/graphql", # The GraphQL endpoint returns a 403 if the user agent isn't set... Makes sense, but is annoying - headers={'User-Agent': 'Mozilla /5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0'}, - json={'query': query} + headers={ + "User-Agent": "Mozilla /5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) Gecko/20100101 Firefox/113.0" + }, + json={"query": query}, ) - return res.json()['data']['posts'] + return res.json()["data"]["posts"] @property def last_date_published(self): try: prev_item = next(self.read_entries(sort_by=Article.date_published.desc())) if prev_item and prev_item.date_published: - return prev_item.date_published.isoformat() + 'Z' + return prev_item.date_published.isoformat() + "Z" except StopIteration: pass - return datetime(self.start_year, 1, 1).isoformat() + 'Z' + return datetime(self.start_year, 1, 1).isoformat() + "Z" @property def items_list(self): next_date = self.last_date_published - logger.info('Starting from %s', next_date) + logger.info("Starting from %s", next_date) while next_date: posts = self.fetch_posts(self.make_query(next_date)) - if not posts['results']: + if not posts["results"]: return - for post in posts['results']: - if post['htmlBody'] and self.tags_ok(post): + for post in posts["results"]: + if post["htmlBody"] and self.tags_ok(post): yield post - next_date = posts['results'][-1]['postedAt'] + next_date = posts["results"][-1]["postedAt"] time.sleep(self.COOLDOWN) def process_entry(self, item): - authors = item['coauthors'] - if item['user']: - authors = [item['user']] + authors - authors = [a['displayName'] for a in authors] - return self.make_data_entry({ - 'title': item['title'], - 'text': markdownify(item['htmlBody']).strip(), - 'url': item['pageUrl'], - 'date_published': self._get_published_date(item), - 'modified_at': item['modifiedAt'], - "source": self.name, - "source_type": "GreaterWrong", - 'votes': item['voteCount'], - 'karma': item['baseScore'], - 'tags': [t['name'] for t in item['tags']], - 'words': item['wordCount'], - 'comment_count': item['commentCount'], - # Some posts don't have authors, for some reaason - 'authors': authors, - }) + authors = item["coauthors"] + if item["user"]: + authors = [item["user"]] + authors + authors = [a["displayName"] for a in authors] + return self.make_data_entry( + { + "title": item["title"], + "text": markdownify(item["htmlBody"]).strip(), + "url": item["pageUrl"], + "date_published": self._get_published_date(item), + "modified_at": item["modifiedAt"], + "source": self.name, + "source_type": "GreaterWrong", + "votes": item["voteCount"], + "karma": item["baseScore"], + "tags": [t["name"] for t in item["tags"]], + "words": item["wordCount"], + "comment_count": item["commentCount"], + # Some posts don't have authors, for some reaason + "authors": authors, + } + ) diff --git a/align_data/sources/stampy/__init__.py b/align_data/sources/stampy/__init__.py index a2a31645..cb0ce2d8 100644 --- a/align_data/sources/stampy/__init__.py +++ b/align_data/sources/stampy/__init__.py @@ -1,5 +1,5 @@ from .stampy import Stampy STAMPY_REGISTRY = [ - Stampy(name='aisafety.info', id_fields=['url']), + Stampy(name="aisafety.info", id_fields=["url"]), ] diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py index 025e4658..88a7149e 100644 --- a/align_data/sources/stampy/stampy.py +++ b/align_data/sources/stampy/stampy.py @@ -16,12 +16,13 @@ @dataclass class Stampy(AlignmentDataset): - done_key = "title" def setup(self): if not CODA_TOKEN: - print(f'No CODA_TOKEN found! Please provide a valid Read token for the {CODA_DOC_ID} table') + print( + f"No CODA_TOKEN found! Please provide a valid Read token for the {CODA_DOC_ID} table" + ) sys.exit(1) super().setup() @@ -30,34 +31,40 @@ def setup(self): def items_list(self): coda = Coda(CODA_TOKEN) doc = Document(CODA_DOC_ID, coda=coda) - logger.info('Fetching table: %s', CODA_DOC_ID) + logger.info("Fetching table: %s", CODA_DOC_ID) table = doc.get_table(ON_SITE_TABLE) - return table.to_dict() # a list of dicts + return table.to_dict() # a list of dicts def get_item_key(self, entry): - return html.unescape(entry['Question']) + return html.unescape(entry["Question"]) def _get_published_date(self, entry): - date_published = entry['Doc Last Edited'] + date_published = entry["Doc Last Edited"] return super()._get_published_date(date_published) def process_entry(self, entry): def clean_text(text): text = html.unescape(text) - return re.sub(r'\(/\?state=(\w+)\)', r'(http://aisafety.info?state=\1)', text) + return re.sub( + r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text + ) - question = clean_text(entry['Question']) # raise an error if the entry has no question - answer = clean_text(entry['Rich Text']) - url = 'https://aisafety.info?state=' + entry['UI ID'] + question = clean_text( + entry["Question"] + ) # raise an error if the entry has no question + answer = clean_text(entry["Rich Text"]) + url = "https://aisafety.info?state=" + entry["UI ID"] logger.info(f"Processing {question}") - return self.make_data_entry({ - "source": self.name, - "source_type": "markdown", - "url": url, - "title": question, - "authors": ['Stampy aisafety.info'], - "date_published": self._get_published_date(entry), - "text": answer, - }) + return self.make_data_entry( + { + "source": self.name, + "source_type": "markdown", + "url": url, + "title": question, + "authors": ["Stampy aisafety.info"], + "date_published": self._get_published_date(entry), + "text": answer, + } + ) diff --git a/align_data/sources/youtube/__init__.py b/align_data/sources/youtube/__init__.py index fd393cc5..06c8defe 100644 --- a/align_data/sources/youtube/__init__.py +++ b/align_data/sources/youtube/__init__.py @@ -1,39 +1,42 @@ -from align_data.sources.youtube.youtube import YouTubeChannelDataset, YouTubePlaylistDataset +from align_data.sources.youtube.youtube import ( + YouTubeChannelDataset, + YouTubePlaylistDataset, +) YOUTUBE_REGISTRY = [ YouTubeChannelDataset( - name='rob_miles_ai_safety', - channel_id='UCLB7AzTwc6VFZrBsO2ucBMg', - authors=['Rob Miles'], + name="rob_miles_ai_safety", + channel_id="UCLB7AzTwc6VFZrBsO2ucBMg", + authors=["Rob Miles"], ), YouTubeChannelDataset( - name='ai_safety_talks', - channel_id='UCXowyqjXvFS-tMKF1GwhpkA', - authors=['Evan Hubinger'], + name="ai_safety_talks", + channel_id="UCXowyqjXvFS-tMKF1GwhpkA", + authors=["Evan Hubinger"], ), YouTubeChannelDataset( - name='ai_safety_reading_group', - channel_id='UC-C23F-9rK2gtRiJZMWsTzQ', + name="ai_safety_reading_group", + channel_id="UC-C23F-9rK2gtRiJZMWsTzQ", authors=[], ), YouTubeChannelDataset( - name='ai_tech_tu_delft', - channel_id='UCPK-Ell2WYxyfP5UYzRzjAA', + name="ai_tech_tu_delft", + channel_id="UCPK-Ell2WYxyfP5UYzRzjAA", authors=[], ), YouTubeChannelDataset( - name='ai_explained', - channel_id='UCNJ1Ymd5yFuUPtn21xtRbbw', + name="ai_explained", + channel_id="UCNJ1Ymd5yFuUPtn21xtRbbw", authors=[], ), YouTubePlaylistDataset( - name='ai_alignment_playlist', + name="ai_alignment_playlist", playlist_ids=[ - 'PLqYmG7hTraZCRwoyGxvQkqVrZgDQi4m-5', - 'PLqYmG7hTraZBiUr6_Qf8YTS2Oqy3OGZEj', - 'PLAPVC5uNprwY0q4_nyeeHqIT07wZqwjGO', - 'PLCRVRLd2RhZTpdUdEzJjo3qhmX3y3skWA', - 'PLTYHZYmxohXpn5uf8JZ2OouB1PsDJAk-x', - ] + "PLqYmG7hTraZCRwoyGxvQkqVrZgDQi4m-5", + "PLqYmG7hTraZBiUr6_Qf8YTS2Oqy3OGZEj", + "PLAPVC5uNprwY0q4_nyeeHqIT07wZqwjGO", + "PLCRVRLd2RhZTpdUdEzJjo3qhmX3y3skWA", + "PLTYHZYmxohXpn5uf8JZ2OouB1PsDJAk-x", + ], ), ] diff --git a/align_data/sources/youtube/youtube.py b/align_data/sources/youtube/youtube.py index e5912dc2..8670b691 100644 --- a/align_data/sources/youtube/youtube.py +++ b/align_data/sources/youtube/youtube.py @@ -5,7 +5,11 @@ from googleapiclient.discovery import build from youtube_transcript_api import YouTubeTranscriptApi -from youtube_transcript_api._errors import NoTranscriptFound, VideoUnavailable, TranscriptsDisabled +from youtube_transcript_api._errors import ( + NoTranscriptFound, + VideoUnavailable, + TranscriptsDisabled, +) from align_data.settings import YOUTUBE_API_KEY from align_data.common.alignment_dataset import AlignmentDataset @@ -15,8 +19,7 @@ class YouTubeDataset(AlignmentDataset): - - done_key = 'url' + done_key = "url" batch_size = 1 # COOLDOWN = 2 authors = None @@ -25,34 +28,34 @@ class YouTubeDataset(AlignmentDataset): def setup(self): super().setup() if not YOUTUBE_API_KEY: - raise ValueError('No YOUTUBE_API_KEY provided!') - self.youtube = build('youtube', 'v3', developerKey=YOUTUBE_API_KEY) + raise ValueError("No YOUTUBE_API_KEY provided!") + self.youtube = build("youtube", "v3", developerKey=YOUTUBE_API_KEY) def next_page(self, collection_id, next_page_token): - return {'items': []} + return {"items": []} @staticmethod def _get_id(item): - if item.get('kind') == 'youtube#searchResult': - resource = item['id'] - elif item.get('kind') == 'youtube#playlistItem': - resource = item['snippet']['resourceId'] + if item.get("kind") == "youtube#searchResult": + resource = item["id"] + elif item.get("kind") == "youtube#playlistItem": + resource = item["snippet"]["resourceId"] else: return None - if resource['kind'] == 'youtube#video': - return resource['videoId'] + if resource["kind"] == "youtube#video": + return resource["videoId"] def fetch_videos(self, collection_id): next_page_token = None while True: videos_response = self.next_page(collection_id, next_page_token) - for item in videos_response.get('items'): + for item in videos_response.get("items"): if self._get_id(item): yield item - next_page_token = videos_response.get('nextPageToken') + next_page_token = videos_response.get("nextPageToken") if not next_page_token: return @@ -66,23 +69,29 @@ def items_list(self): def get_item_key(self, item): video_id = self._get_id(item) - return f'https://www.youtube.com/watch?v={video_id}' + return f"https://www.youtube.com/watch?v={video_id}" def _get_contents(self, video): video_id = self._get_id(video) try: - transcript = YouTubeTranscriptApi.list_transcripts(video_id).find_transcript(['en', 'en-GB']).fetch() - return '\n'.join([i['text'] for i in transcript]) + transcript = ( + YouTubeTranscriptApi.list_transcripts(video_id) + .find_transcript(["en", "en-GB"]) + .fetch() + ) + return "\n".join([i["text"] for i in transcript]) except (NoTranscriptFound, VideoUnavailable): return None except TranscriptsDisabled: - logger.error(f'Transcripts disabled for https://www.youtube.com/watch?v={video_id} - skipping') + logger.error( + f"Transcripts disabled for https://www.youtube.com/watch?v={video_id} - skipping" + ) return None def extract_authors(self, video): if self.authors: return self.authors - return [video['snippet']['channelTitle'].strip()] + return [video["snippet"]["channelTitle"].strip()] def process_entry(self, video): video_url = self.get_item_key(video) @@ -91,20 +100,21 @@ def process_entry(self, video): if not contents: return None - return self.make_data_entry({ - "text": contents, - "url": video_url, - "title": video['snippet']['title'], - "source": self.name, - "source_type": "youtube", - "date_published": self._get_published_date(video), - "authors": self.extract_authors(video), - }) + return self.make_data_entry( + { + "text": contents, + "url": video_url, + "title": video["snippet"]["title"], + "source": self.name, + "source_type": "youtube", + "date_published": self._get_published_date(video), + "authors": self.extract_authors(video), + } + ) @dataclass class YouTubeChannelDataset(YouTubeDataset): - channel_id: str authors: List[str] @@ -113,20 +123,23 @@ def collection_ids(self): return [self.channel_id] def next_page(self, collection_id, next_page_token): - return self.youtube.search().list( - part='snippet', - channelId=collection_id, - maxResults=50, - pageToken=next_page_token - ).execute() + return ( + self.youtube.search() + .list( + part="snippet", + channelId=collection_id, + maxResults=50, + pageToken=next_page_token, + ) + .execute() + ) def _get_published_date(self, video): - return super()._get_published_date(video['snippet']['publishTime']) + return super()._get_published_date(video["snippet"]["publishTime"]) @dataclass class YouTubePlaylistDataset(YouTubeDataset): - playlist_ids: str @property @@ -134,12 +147,16 @@ def collection_ids(self): return self.playlist_ids def next_page(self, collection_id, next_page_token): - return self.youtube.playlistItems().list( - part='snippet', - playlistId=collection_id, - maxResults=50, - pageToken=next_page_token, - ).execute() + return ( + self.youtube.playlistItems() + .list( + part="snippet", + playlistId=collection_id, + maxResults=50, + pageToken=next_page_token, + ) + .execute() + ) def _get_published_date(self, video): - return super()._get_published_date(video['snippet']['publishedAt']) + return super()._get_published_date(video["snippet"]["publishedAt"]) diff --git a/main.py b/main.py index 4ea6bf3d..ae11b641 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,9 @@ from align_data.sources.articles.articles import update_new_items, check_new_articles from align_data.pinecone.update_pinecone import PineconeUpdater from align_data.settings import ( - METADATA_OUTPUT_SPREADSHEET, METADATA_SOURCE_SHEET, METADATA_SOURCE_SPREADSHEET + METADATA_OUTPUT_SPREADSHEET, + METADATA_SOURCE_SHEET, + METADATA_SOURCE_SPREADSHEET, ) @@ -19,7 +21,6 @@ @dataclass class AlignmentDataset: - out_path: str = "data" """The path to the directory where the data will be downloaded, defaults to data""" @@ -34,13 +35,13 @@ def fetch(self, *names) -> None: :param str name: The name of the dataset to fetch :return: The path to the file that was written to. """ - if names == ('all',): + if names == ("all",): names = ALL_DATASETS missing = {name for name in names if name not in ALL_DATASETS} assert not missing, f"{missing} are not valid dataset names" for name in names: dataset = get_dataset(name) - + dataset.add_entries(dataset.fetch_entries()) def fetch_all(self, *skip) -> None: @@ -62,7 +63,7 @@ def generate_jsonl_files(self, *names): :param List[str] names: The names of the datasets to generate """ - if names == ('all',): + if names == ("all",): names = ALL_DATASETS missing = {name for name in names if name not in ALL_DATASETS} assert not missing, f"{missing} are not valid dataset names" @@ -75,12 +76,16 @@ def count_tokens(self, merged_dataset_path: str) -> None: This function counts the number of tokens, words, and characters in the dataset :return: None """ - assert os.path.exists(merged_dataset_path), "The path to the merged dataset does not exist" + assert os.path.exists( + merged_dataset_path + ), "The path to the merged dataset does not exist" count_token(merged_dataset_path) def update_metadata( - self, source_spreadsheet=METADATA_SOURCE_SPREADSHEET, - source_sheet=METADATA_SOURCE_SHEET, output_spreadsheet=METADATA_OUTPUT_SPREADSHEET + self, + source_spreadsheet=METADATA_SOURCE_SPREADSHEET, + source_sheet=METADATA_SOURCE_SHEET, + output_spreadsheet=METADATA_OUTPUT_SPREADSHEET, ): """Go through all unprocessed items from the source worksheet, updating the appropriate metadata in the output one. @@ -90,7 +95,11 @@ def update_metadata( """ return update_new_items(source_spreadsheet, source_sheet, output_spreadsheet) - def fetch_new_articles(self, source_spreadsheet=METADATA_SOURCE_SPREADSHEET, source_sheet=METADATA_SOURCE_SHEET): + def fetch_new_articles( + self, + source_spreadsheet=METADATA_SOURCE_SPREADSHEET, + source_sheet=METADATA_SOURCE_SHEET, + ): """Look for unseen articles in the special indices, adding any that are found to the provided spreadsheet. :param str source_spreadsheet: The id of the google docs spreadsheet containing the items to be processed @@ -102,12 +111,12 @@ def pinecone_update(self, *names) -> None: """ This function updates the Pinecone vector DB. """ - if names == ('all',): + if names == ("all",): names = ALL_DATASETS missing = {name for name in names if name not in ALL_DATASETS} assert not missing, f"{missing} are not valid dataset names" PineconeUpdater().update(names) - + def pinecone_update_all(self, *skip) -> None: """ This function updates the Pinecone vector DB. @@ -117,4 +126,4 @@ def pinecone_update_all(self, *skip) -> None: if __name__ == "__main__": - fire.Fire(AlignmentDataset) \ No newline at end of file + fire.Fire(AlignmentDataset) diff --git a/migrations/env.py b/migrations/env.py index 838bfb97..1d07d5d2 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -15,13 +15,15 @@ fileConfig(config.config_file_name) from align_data.settings import DB_CONNECTION_URI -config.set_main_option('sqlalchemy.url', DB_CONNECTION_URI) + +config.set_main_option("sqlalchemy.url", DB_CONNECTION_URI) # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata from align_data.db.models import Base + target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, @@ -68,9 +70,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/migrations/versions/0a0041c28458_confidence_column.py b/migrations/versions/0a0041c28458_confidence_column.py index 1c39b30f..53c6eeb9 100644 --- a/migrations/versions/0a0041c28458_confidence_column.py +++ b/migrations/versions/0a0041c28458_confidence_column.py @@ -10,15 +10,15 @@ # revision identifiers, used by Alembic. -revision = '0a0041c28458' -down_revision = '983b5bdef5f6' +revision = "0a0041c28458" +down_revision = "983b5bdef5f6" branch_labels = None depends_on = None def upgrade() -> None: - op.add_column('articles', sa.Column('confidence', sa.Float(), nullable=True)) + op.add_column("articles", sa.Column("confidence", sa.Float(), nullable=True)) def downgrade() -> None: - op.drop_column('articles', 'confidence') + op.drop_column("articles", "confidence") diff --git a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py index 74792469..7a8485fe 100644 --- a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py +++ b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py @@ -10,19 +10,21 @@ # revision identifiers, used by Alembic. -revision = '59ac3cb671e3' -down_revision = '0a0041c28458' +revision = "59ac3cb671e3" +down_revision = "0a0041c28458" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column('articles', sa.Column('pinecone_update_required', sa.Boolean(), nullable=False)) + op.add_column( + "articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False) + ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('articles', 'pinecone_update_required') + op.drop_column("articles", "pinecone_update_required") # ### end Alembic commands ### diff --git a/migrations/versions/983b5bdef5f6_initial_structure.py b/migrations/versions/983b5bdef5f6_initial_structure.py index ff1ef321..947a37c1 100644 --- a/migrations/versions/983b5bdef5f6_initial_structure.py +++ b/migrations/versions/983b5bdef5f6_initial_structure.py @@ -10,7 +10,7 @@ from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = '983b5bdef5f6' +revision = "983b5bdef5f6" down_revision = None branch_labels = None depends_on = None @@ -18,33 +18,36 @@ def upgrade() -> None: op.create_table( - 'articles', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('hash_id', sa.String(length=32), nullable=False), - sa.Column('title', sa.String(length=1028), nullable=True), - sa.Column('url', sa.String(length=1028), nullable=True), - sa.Column('source', sa.String(length=128), nullable=True), - sa.Column('source_type', sa.String(length=128), nullable=True), - sa.Column('authors', sa.String(length=1024), nullable=False), - sa.Column('text', mysql.LONGTEXT(), nullable=True), - sa.Column('date_published', sa.DateTime(), nullable=True), - sa.Column('metadata', sa.JSON(), nullable=True), - sa.Column('date_created', sa.DateTime(), nullable=False), - sa.Column('date_updated', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('hash_id') + "articles", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("hash_id", sa.String(length=32), nullable=False), + sa.Column("title", sa.String(length=1028), nullable=True), + sa.Column("url", sa.String(length=1028), nullable=True), + sa.Column("source", sa.String(length=128), nullable=True), + sa.Column("source_type", sa.String(length=128), nullable=True), + sa.Column("authors", sa.String(length=1024), nullable=False), + sa.Column("text", mysql.LONGTEXT(), nullable=True), + sa.Column("date_published", sa.DateTime(), nullable=True), + sa.Column("metadata", sa.JSON(), nullable=True), + sa.Column("date_created", sa.DateTime(), nullable=False), + sa.Column("date_updated", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("hash_id"), ) op.create_table( - 'summaries', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('text', sa.Text(), nullable=False), - sa.Column('source', sa.String(length=256), nullable=True), - sa.Column('article_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['article_id'], ['articles.id'], ), - sa.PrimaryKeyConstraint('id') + "summaries", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("text", sa.Text(), nullable=False), + sa.Column("source", sa.String(length=256), nullable=True), + sa.Column("article_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["article_id"], + ["articles.id"], + ), + sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: - op.drop_table('summaries') - op.drop_table('articles') + op.drop_table("summaries") + op.drop_table("articles") diff --git a/setup.py b/setup.py index a1106c58..871c1d32 100644 --- a/setup.py +++ b/setup.py @@ -4,13 +4,13 @@ long_description = fh.read() setuptools.setup( - name='align_data', - version='0.0.1', + name="align_data", + version="0.0.1", description="A framework for constructing a dataset for alignment research", long_description=long_description, long_description_content_type="text/markdown", packages=setuptools.find_packages(), - python_requires='>=3.6', + python_requires=">=3.6", install_requires=[ "bs4==0.0.1", "python-dateutil==2.8.2", @@ -21,5 +21,5 @@ "GitPython", "gdown", "pypandoc", - ] + ], ) diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index 48340000..36c47071 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -3,245 +3,288 @@ import pandas as pd import pytest from align_data.sources.articles.datasets import ( - EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, SpreadsheetDataset, XMLArticles + EbookArticles, + DocArticles, + HTMLArticles, + MarkdownArticles, + PDFArticles, + SpreadsheetDataset, + XMLArticles, ) @pytest.fixture def articles(): - source_type = 'something' + source_type = "something" articles = [ { - 'source_url': f'http://example.com/source_url/{i}', - 'url': f'http://example.com/item/{i}', - 'title': f'article no {i}', - 'source_type': source_type, - 'date_published': f'2023/01/0{i + 1} 12:32:11', - 'authors': f'John Snow, mr Blobby', - 'summary': f'the summary of article {i}', - 'file_id': str(i), - } for i in range(5) + "source_url": f"http://example.com/source_url/{i}", + "url": f"http://example.com/item/{i}", + "title": f"article no {i}", + "source_type": source_type, + "date_published": f"2023/01/0{i + 1} 12:32:11", + "authors": f"John Snow, mr Blobby", + "summary": f"the summary of article {i}", + "file_id": str(i), + } + for i in range(5) ] return pd.DataFrame(articles) def test_spreadsheet_dataset_items_list(articles): - dataset = SpreadsheetDataset(name='bla', spreadsheet_id='123', sheet_id='456') + dataset = SpreadsheetDataset(name="bla", spreadsheet_id="123", sheet_id="456") df = pd.concat( - [articles, pd.DataFrame([{'title': None}, {'summary': 'bla'}])], - ignore_index=True + [articles, pd.DataFrame([{"title": None}, {"summary": "bla"}])], + ignore_index=True, ) - with patch('pandas.read_csv', return_value=df): + with patch("pandas.read_csv", return_value=df): assert list(dataset.items_list) == list(pd.DataFrame(articles).itertuples()) def test_spreadsheet_dataset_get_item_key(): - dataset = SpreadsheetDataset(name='bla', spreadsheet_id='123', sheet_id='456') - assert dataset.get_item_key(Mock(bla='ble', title='the key')) == 'the key' - - -@pytest.mark.parametrize('authors, expected', ( - ('', []), - (' \n \n \t', []), - ('John Snow', ['John Snow']), - ('John Snow, mr. Blobby', ['John Snow', 'mr. Blobby']), -)) + dataset = SpreadsheetDataset(name="bla", spreadsheet_id="123", sheet_id="456") + assert dataset.get_item_key(Mock(bla="ble", title="the key")) == "the key" + + +@pytest.mark.parametrize( + "authors, expected", + ( + ("", []), + (" \n \n \t", []), + ("John Snow", ["John Snow"]), + ("John Snow, mr. Blobby", ["John Snow", "mr. Blobby"]), + ), +) def test_spreadsheet_dataset_extract_authors(authors, expected): - dataset = SpreadsheetDataset(name='bla', spreadsheet_id='123', sheet_id='456') + dataset = SpreadsheetDataset(name="bla", spreadsheet_id="123", sheet_id="456") assert dataset.extract_authors(Mock(authors=authors)) == expected def test_pdf_articles_get_text(): - dataset = PDFArticles(name='bla', spreadsheet_id='123', sheet_id='456') - item = Mock(file_id='23423', title='bla bla bla') + dataset = PDFArticles(name="bla", spreadsheet_id="123", sheet_id="456") + item = Mock(file_id="23423", title="bla bla bla") def check_downloads(output, id): - assert output == str(dataset.files_path / 'bla bla bla.pdf') - assert id == '23423' + assert output == str(dataset.files_path / "bla bla bla.pdf") + assert id == "23423" return output def read_pdf(filename): - assert filename == dataset.files_path / 'bla bla bla.pdf' - return 'pdf contents' + assert filename == dataset.files_path / "bla bla bla.pdf" + return "pdf contents" - with patch('align_data.sources.articles.datasets.download', check_downloads): - with patch('align_data.sources.articles.datasets.read_pdf', read_pdf): - assert dataset._get_text(item) == 'pdf contents' + with patch("align_data.sources.articles.datasets.download", check_downloads): + with patch("align_data.sources.articles.datasets.read_pdf", read_pdf): + assert dataset._get_text(item) == "pdf contents" def test_pdf_articles_process_item(articles): - dataset = PDFArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('pandas.read_csv', return_value=articles): + dataset = PDFArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch("pandas.read_csv", return_value=articles): item = list(dataset.items_list)[0] - with patch('align_data.sources.articles.datasets.download'): - with patch('align_data.sources.articles.datasets.read_pdf', return_value='pdf contents bla'): + with patch("align_data.sources.articles.datasets.download"): + with patch( + "align_data.sources.articles.datasets.read_pdf", + return_value='pdf contents bla', + ): assert dataset.process_entry(item).to_dict() == { - 'authors': ['John Snow', 'mr Blobby'], - 'date_published': '2023-01-01T12:32:11Z', - 'id': None, - 'source': 'bla', - 'source_filetype': 'pdf', - 'source_type': 'something', - 'summaries': ['the summary of article 0'], - 'text': 'pdf contents [bla](asd.com)', - 'title': 'article no 0', - 'url': 'http://example.com/item/0', + "authors": ["John Snow", "mr Blobby"], + "date_published": "2023-01-01T12:32:11Z", + "id": None, + "source": "bla", + "source_filetype": "pdf", + "source_type": "something", + "summaries": ["the summary of article 0"], + "text": "pdf contents [bla](asd.com)", + "title": "article no 0", + "url": "http://example.com/item/0", } def test_html_articles_get_text(): def parser(url): - assert url == 'http://example.org/bla.bla' - return 'html contents' + assert url == "http://example.org/bla.bla" + return "html contents" - with patch('align_data.sources.articles.datasets.HTML_PARSERS', {'example.org': parser}): - assert HTMLArticles._get_text(Mock(source_url='http://example.org/bla.bla')) == 'html contents' + with patch( + "align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser} + ): + assert ( + HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) + == "html contents" + ) def test_html_articles_get_text_no_parser(): - with patch('align_data.sources.articles.datasets.HTML_PARSERS', {}): - assert HTMLArticles._get_text(Mock(source_url='http://example.org/bla.bla')) is None + with patch("align_data.sources.articles.datasets.HTML_PARSERS", {}): + assert ( + HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) + is None + ) def test_html_articles_process_entry(articles): - dataset = HTMLArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('pandas.read_csv', return_value=articles): + dataset = HTMLArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch("pandas.read_csv", return_value=articles): item = list(dataset.items_list)[0] - parsers = {'example.com': lambda _: ' html contents with proper elements ble ble '} - with patch('align_data.sources.articles.datasets.HTML_PARSERS', parsers): + parsers = { + "example.com": lambda _: ' html contents with proper elements ble ble ' + } + with patch("align_data.sources.articles.datasets.HTML_PARSERS", parsers): assert dataset.process_entry(item).to_dict() == { - 'authors': ['John Snow', 'mr Blobby'], - 'date_published': '2023-01-01T12:32:11Z', - 'id': None, - 'source': 'bla', - 'source_filetype': 'html', - 'source_type': 'something', - 'summaries': ['the summary of article 0'], - 'text': 'html contents with [proper elements](bla.com) ble ble', - 'title': 'article no 0', - 'url': 'http://example.com/item/0', + "authors": ["John Snow", "mr Blobby"], + "date_published": "2023-01-01T12:32:11Z", + "id": None, + "source": "bla", + "source_filetype": "html", + "source_type": "something", + "summaries": ["the summary of article 0"], + "text": "html contents with [proper elements](bla.com) ble ble", + "title": "article no 0", + "url": "http://example.com/item/0", } def test_ebook_articles_get_text(): - dataset = EbookArticles(name='bla', spreadsheet_id='123', sheet_id='456') + dataset = EbookArticles(name="bla", spreadsheet_id="123", sheet_id="456") item = Mock( - source_url='https://drive.google.com/file/d/123456/view?usp=drive_link', - title='bla bla bla' + source_url="https://drive.google.com/file/d/123456/view?usp=drive_link", + title="bla bla bla", ) def check_downloads(output, id): - assert output == str(dataset.files_path / 'bla bla bla.epub') - assert id == '123456' + assert output == str(dataset.files_path / "bla bla bla.epub") + assert id == "123456" return output def read_ebook(filename, *args, **kwargs): - return 'ebook contents' + return "ebook contents" - with patch('align_data.sources.articles.datasets.download', check_downloads): - with patch('align_data.sources.articles.datasets.convert_file', read_ebook): - assert dataset._get_text(item) == 'ebook contents' + with patch("align_data.sources.articles.datasets.download", check_downloads): + with patch("align_data.sources.articles.datasets.convert_file", read_ebook): + assert dataset._get_text(item) == "ebook contents" def test_ebook_articles_process_entry(articles): - dataset = EbookArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('pandas.read_csv', return_value=articles): + dataset = EbookArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch("pandas.read_csv", return_value=articles): item = list(dataset.items_list)[0] contents = ' html contents with proper elements ble ble ' - with patch('align_data.sources.articles.datasets.download'): - with patch('align_data.sources.articles.datasets.convert_file', return_value=contents): + with patch("align_data.sources.articles.datasets.download"): + with patch( + "align_data.sources.articles.datasets.convert_file", return_value=contents + ): assert dataset.process_entry(item).to_dict() == { - 'authors': ['John Snow', 'mr Blobby'], - 'date_published': '2023-01-01T12:32:11Z', - 'id': None, - 'source': 'bla', - 'source_filetype': 'epub', - 'source_type': 'something', - 'summaries': ['the summary of article 0'], - 'text': 'html contents with [proper elements](bla.com) ble ble', - 'title': 'article no 0', - 'url': 'http://example.com/item/0', + "authors": ["John Snow", "mr Blobby"], + "date_published": "2023-01-01T12:32:11Z", + "id": None, + "source": "bla", + "source_filetype": "epub", + "source_type": "something", + "summaries": ["the summary of article 0"], + "text": "html contents with [proper elements](bla.com) ble ble", + "title": "article no 0", + "url": "http://example.com/item/0", } def test_xml_articles_get_text(): - dataset = XMLArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('align_data.sources.articles.datasets.extract_gdrive_contents', return_value={'text': 'bla bla'}): - assert dataset._get_text(Mock(source_url='bla.com')) == 'bla bla' + dataset = XMLArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch( + "align_data.sources.articles.datasets.extract_gdrive_contents", + return_value={"text": "bla bla"}, + ): + assert dataset._get_text(Mock(source_url="bla.com")) == "bla bla" def test_xml_articles_process_entry(articles): - dataset = XMLArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('pandas.read_csv', return_value=articles): + dataset = XMLArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch("pandas.read_csv", return_value=articles): item = list(dataset.items_list)[0] - with patch('align_data.sources.articles.datasets.extract_gdrive_contents', return_value={'text': 'bla bla'}): + with patch( + "align_data.sources.articles.datasets.extract_gdrive_contents", + return_value={"text": "bla bla"}, + ): assert dataset.process_entry(item).to_dict() == { - 'authors': ['John Snow', 'mr Blobby'], - 'date_published': '2023-01-01T12:32:11Z', - 'id': None, - 'source': 'bla', - 'source_filetype': 'xml', - 'source_type': 'something', - 'summaries': ['the summary of article 0'], - 'text': 'bla bla', - 'title': 'article no 0', - 'url': 'http://example.com/item/0', + "authors": ["John Snow", "mr Blobby"], + "date_published": "2023-01-01T12:32:11Z", + "id": None, + "source": "bla", + "source_filetype": "xml", + "source_type": "something", + "summaries": ["the summary of article 0"], + "text": "bla bla", + "title": "article no 0", + "url": "http://example.com/item/0", } def test_markdown_articles_get_text(): - dataset = MarkdownArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('align_data.sources.articles.datasets.fetch_markdown', return_value={'text': 'bla bla'}): - assert dataset._get_text(Mock(source_url='bla.com/bla/123/bla')) == 'bla bla' + dataset = MarkdownArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch( + "align_data.sources.articles.datasets.fetch_markdown", + return_value={"text": "bla bla"}, + ): + assert dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla" def test_markdown_articles_process_entry(articles): - dataset = MarkdownArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('pandas.read_csv', return_value=articles): + dataset = MarkdownArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch("pandas.read_csv", return_value=articles): item = list(dataset.items_list)[0] - with patch('align_data.sources.articles.datasets.fetch_markdown', return_value={'text': 'bla bla'}): + with patch( + "align_data.sources.articles.datasets.fetch_markdown", + return_value={"text": "bla bla"}, + ): assert dataset.process_entry(item).to_dict() == { - 'authors': ['John Snow', 'mr Blobby'], - 'date_published': '2023-01-01T12:32:11Z', - 'id': None, - 'source': 'bla', - 'source_filetype': 'md', - 'source_type': 'something', - 'summaries': ['the summary of article 0'], - 'text': 'bla bla', - 'title': 'article no 0', - 'url': 'http://example.com/item/0', + "authors": ["John Snow", "mr Blobby"], + "date_published": "2023-01-01T12:32:11Z", + "id": None, + "source": "bla", + "source_filetype": "md", + "source_type": "something", + "summaries": ["the summary of article 0"], + "text": "bla bla", + "title": "article no 0", + "url": "http://example.com/item/0", } def test_doc_articles_get_text(): - dataset = DocArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('align_data.sources.articles.datasets.fetch_file'): - with patch('align_data.sources.articles.datasets.convert_file', return_value='bla bla'): - assert dataset._get_text(Mock(source_url='bla.com/bla/123/bla')) == 'bla bla' + dataset = DocArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch("align_data.sources.articles.datasets.fetch_file"): + with patch( + "align_data.sources.articles.datasets.convert_file", return_value="bla bla" + ): + assert ( + dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla" + ) def test_doc_articles_process_entry(articles): - dataset = DocArticles(name='bla', spreadsheet_id='123', sheet_id='456') - with patch('pandas.read_csv', return_value=articles): + dataset = DocArticles(name="bla", spreadsheet_id="123", sheet_id="456") + with patch("pandas.read_csv", return_value=articles): item = list(dataset.items_list)[0] - with patch('align_data.sources.articles.datasets.fetch_file'): - with patch('align_data.sources.articles.datasets.convert_file', return_value='bla bla'): + with patch("align_data.sources.articles.datasets.fetch_file"): + with patch( + "align_data.sources.articles.datasets.convert_file", return_value="bla bla" + ): assert dataset.process_entry(item).to_dict() == { - 'authors': ['John Snow', 'mr Blobby'], - 'date_published': '2023-01-01T12:32:11Z', - 'id': None, - 'source': 'bla', - 'source_filetype': 'docx', - 'source_type': 'something', - 'summaries': ['the summary of article 0'], - 'text': 'bla bla', - 'title': 'article no 0', - 'url': 'http://example.com/item/0', + "authors": ["John Snow", "mr Blobby"], + "date_published": "2023-01-01T12:32:11Z", + "id": None, + "source": "bla", + "source_filetype": "docx", + "source_type": "something", + "summaries": ["the summary of article 0"], + "text": "bla bla", + "title": "article no 0", + "url": "http://example.com/item/0", } diff --git a/tests/align_data/articles/test_parsers.py b/tests/align_data/articles/test_parsers.py index 9abdb43b..9f43c231 100644 --- a/tests/align_data/articles/test_parsers.py +++ b/tests/align_data/articles/test_parsers.py @@ -5,7 +5,11 @@ from bs4 import BeautifulSoup from align_data.sources.articles.parsers import ( - google_doc, medium_blog, parse_grobid, get_content_type, extract_gdrive_contents + google_doc, + medium_blog, + parse_grobid, + get_content_type, + extract_gdrive_contents, ) @@ -47,10 +51,15 @@ """ + def test_google_doc(): def fetcher(url, *args, **kwargs): - assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html' - return Mock(content=""" + assert ( + url + == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html" + ) + return Mock( + content="""
bla bla bla
@@ -58,23 +67,37 @@ def fetcher(url, *args, **kwargs): - """) + """ + ) - with patch('requests.get', fetcher): - assert google_doc('https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit') == 'ble ble [a link](bla.com)' + with patch("requests.get", fetcher): + assert ( + google_doc( + "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit" + ) + == "ble ble [a link](bla.com)" + ) def test_google_doc_no_body(): def fetcher(url, *args, **kwargs): - assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html' + assert ( + url + == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html" + ) return Mock(content="
bla bla bla
") - with patch('requests.get', fetcher): - assert google_doc('https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit') is None + with patch("requests.get", fetcher): + assert ( + google_doc( + "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit" + ) + is None + ) def test_google_doc_bad_url(): - assert google_doc('https://docs.google.com/bla/bla') is None + assert google_doc("https://docs.google.com/bla/bla") is None def test_medium_blog(): @@ -94,8 +117,8 @@ def test_medium_blog(): """ - with patch('requests.get', return_value=Mock(content=html)): - assert medium_blog('bla.com') == "bla bla bla [a link](http://ble.com) bla bla" + with patch("requests.get", return_value=Mock(content=html)): + assert medium_blog("bla.com") == "bla bla bla [a link](http://ble.com) bla bla" def test_medium_blog_no_title(): @@ -109,8 +132,11 @@ def test_medium_blog_no_title(): """ - with patch('requests.get', return_value=Mock(content=html)): - assert medium_blog('bla.com') == "Some random thing\n\n\n bla bla bla [a link](http://ble.com) bla bla" + with patch("requests.get", return_value=Mock(content=html)): + assert ( + medium_blog("bla.com") + == "Some random thing\n\n\n bla bla bla [a link](http://ble.com) bla bla" + ) def test_medium_blog_no_contents(): @@ -124,17 +150,17 @@ def test_medium_blog_no_contents(): """ - with patch('requests.get', return_value=Mock(content=html)): - assert medium_blog('bla.com') is None + with patch("requests.get", return_value=Mock(content=html)): + assert medium_blog("bla.com") is None def test_parse_grobid(): assert parse_grobid(SAMPLE_XML) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'text': 'This is the contents', - 'title': 'The title!!', - 'data_source': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "text": "This is the contents", + "title": "The title!!", + "data_source": "xml", } @@ -156,74 +182,94 @@ def test_parse_grobid_no_body(): """ - assert parse_grobid(xml) == {'error': 'No contents in XML file', 'data_source': 'xml'} - + assert parse_grobid(xml) == { + "error": "No contents in XML file", + "data_source": "xml", + } -@pytest.mark.parametrize('header, expected', ( - (None, set()), - ('', set()), - ('text/html', {'text/html'}), - ('text/html; bla=asdas; fewwe=fe', {'text/html', 'bla=asdas', 'fewwe=fe'}), -)) +@pytest.mark.parametrize( + "header, expected", + ( + (None, set()), + ("", set()), + ("text/html", {"text/html"}), + ("text/html; bla=asdas; fewwe=fe", {"text/html", "bla=asdas", "fewwe=fe"}), + ), +) def test_get_content_type(header, expected): - assert get_content_type(Mock(headers={'Content-Type': header})) == expected - - -@pytest.mark.parametrize('headers', ( - {}, - {'Content-Type': None}, - {'Content-Type': ''}, - {'Content-Type': ' '}, - {'Content-Type': ' ; ;; '}, -)) + assert get_content_type(Mock(headers={"Content-Type": header})) == expected + + +@pytest.mark.parametrize( + "headers", + ( + {}, + {"Content-Type": None}, + {"Content-Type": ""}, + {"Content-Type": " "}, + {"Content-Type": " ; ;; "}, + ), +) def test_extract_gdrive_contents_no_contents(headers): - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers=headers, status_code=200)): + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=Mock(headers=headers, status_code=200)): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'error': 'no content type' + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "error": "no content type", } -@pytest.mark.parametrize('header', ( - 'application/octet-stream', - 'application/pdf', - 'application/pdf; filename=bla.pdf' -)) +@pytest.mark.parametrize( + "header", + ( + "application/octet-stream", + "application/pdf", + "application/pdf; filename=bla.pdf", + ), +) def test_extract_gdrive_contents_pdf(header): - res = Mock(headers={'Content-Type': header}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): - with patch('align_data.sources.articles.parsers.fetch_pdf', return_value={'text': 'bla'}): + res = Mock(headers={"Content-Type": header}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): + with patch( + "align_data.sources.articles.parsers.fetch_pdf", + return_value={"text": "bla"}, + ): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'bla', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "bla", } -@pytest.mark.parametrize('header', ( - 'application/epub', - 'application/epub+zip', - 'application/epub; filename=bla.epub', -)) +@pytest.mark.parametrize( + "header", + ( + "application/epub", + "application/epub+zip", + "application/epub; filename=bla.epub", + ), +) def test_extract_gdrive_contents_ebook(header): - res = Mock(headers={'Content-Type': header}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): + res = Mock(headers={"Content-Type": header}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'data_source': 'ebook', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "data_source": "ebook", } def test_extract_gdrive_contents_html(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): html = """
bleee
@@ -231,45 +277,48 @@ def test_extract_gdrive_contents_html(): """ res = Mock( - headers={'Content-Type': 'text/html'}, + headers={"Content-Type": "text/html"}, status_code=200, content=html, text=html, ) - with patch('requests.get', return_value=res): + with patch("requests.get", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'bla bla', - 'data_source': 'html', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "bla bla", + "data_source": "html", } def test_extract_gdrive_contents_xml(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): res = Mock( - headers={'Content-Type': 'text/xml'}, + headers={"Content-Type": "text/xml"}, status_code=200, content=SAMPLE_XML, text=SAMPLE_XML, ) - with patch('requests.get', return_value=res): + with patch("requests.get", return_value=res): assert extract_gdrive_contents(url) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'This is the contents', - 'title': 'The title!!', - 'data_source': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "This is the contents", + "title": "The title!!", + "data_source": "xml", } def test_extract_gdrive_contents_xml_with_confirm(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" def fetcher(link, *args, **kwargs): # The first request should get the google drive warning page @@ -280,27 +329,37 @@ def fetcher(link, *args, **kwargs):
""" - return Mock(headers={'Content-Type': 'text/html'}, status_code=200, text=html, content=html) + return Mock( + headers={"Content-Type": "text/html"}, + status_code=200, + text=html, + content=html, + ) # The second one returns the actual contents - return Mock(headers={'Content-Type': 'text/xml'}, status_code=200, content=SAMPLE_XML) + return Mock( + headers={"Content-Type": "text/xml"}, status_code=200, content=SAMPLE_XML + ) - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): - with patch('requests.get', fetcher): + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): + with patch("requests.get", fetcher): assert extract_gdrive_contents(url) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'This is the contents', - 'title': 'The title!!', - 'data_source': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "This is the contents", + "title": "The title!!", + "data_source": "xml", } def test_extract_gdrive_contents_warning_with_unknown(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" def fetcher(link, *args, **kwargs): # The first request should get the google drive warning page @@ -311,26 +370,34 @@ def fetcher(link, *args, **kwargs):
""" - return Mock(headers={'Content-Type': 'text/html'}, status_code=200, text=html, content=html) + return Mock( + headers={"Content-Type": "text/html"}, + status_code=200, + text=html, + content=html, + ) # The second one returns the actual contents, with an unhandled content type - return Mock(headers={'Content-Type': 'text/bla bla'}, status_code=200) + return Mock(headers={"Content-Type": "text/bla bla"}, status_code=200) - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): - with patch('requests.get', fetcher): + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): + with patch("requests.get", fetcher): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'error': "unknown content type: {'text/bla bla'}", - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', + "downloaded_from": "google drive", + "error": "unknown content type: {'text/bla bla'}", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", } def test_extract_gdrive_contents_unknown_content_type(): - res = Mock(headers={'Content-Type': 'bla bla'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): + res = Mock(headers={"Content-Type": "bla bla"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'error': "unknown content type: {'bla bla'}", + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "error": "unknown content type: {'bla bla'}", } diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py index b04a5429..fcefeacd 100644 --- a/tests/align_data/common/test_alignment_dataset.py +++ b/tests/align_data/common/test_alignment_dataset.py @@ -11,15 +11,18 @@ @pytest.fixture def data_entries(): - dataset = AlignmentDataset(name='blaa') + dataset = AlignmentDataset(name="blaa") entries = [ - dataset.make_data_entry({ - 'text': f'line {i}', - 'date_published': f'day {i}', - 'source': f'source {i}', - 'title': str(i), - 'url': f'http://bla.bla.bla?page={i}', - }) for i in range(5) + dataset.make_data_entry( + { + "text": f"line {i}", + "date_published": f"day {i}", + "source": f"source {i}", + "title": str(i), + "url": f"http://bla.bla.bla?page={i}", + } + ) + for i in range(5) ] for entry in entries: Article.before_write(None, None, entry) @@ -28,163 +31,260 @@ def data_entries(): @pytest.fixture def dataset(): - return AlignmentDataset(name='blaa') + return AlignmentDataset(name="blaa") def test_data_entry_default_fields(): - dataset = AlignmentDataset(name='blaa') + dataset = AlignmentDataset(name="blaa") entry = dataset.make_data_entry({}) assert entry.to_dict() == { - 'date_published': None, - 'source': None, - 'source_type': None, - 'title': None, - 'url': None, - 'id': None, - 'text': None, - 'summaries': [], - 'authors': [], + "date_published": None, + "source": None, + "source_type": None, + "title": None, + "url": None, + "id": None, + "text": None, + "summaries": [], + "authors": [], } + def test_data_entry_id_from_urls_and_title(): - data = {'key1': 12, 'key2': 312, 'url': 'www.arbital.org', 'title': 'once upon a time'} - dataset = AlignmentDataset(name='blaa') + data = { + "key1": 12, + "key2": 312, + "url": "www.arbital.org", + "title": "once upon a time", + } + dataset = AlignmentDataset(name="blaa") entry = dataset.make_data_entry(data) Article.before_write(None, None, entry) - assert entry.to_dict() == dict({ - 'date_published': None, - 'id': '770fe57c8c2130eda08dc392b8696f97', - 'source': None, - 'source_type': None, - 'text': None, - 'summaries': [], - 'authors': [], - }, **data + assert entry.to_dict() == dict( + { + "date_published": None, + "id": "770fe57c8c2130eda08dc392b8696f97", + "source": None, + "source_type": None, + "text": None, + "summaries": [], + "authors": [], + }, + **data, ) def test_data_entry_no_url_and_title(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'key1': 12, 'key2': 312}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url', 'title'\\]"): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry({"key1": 12, "key2": 312}) + with pytest.raises( + AssertionError, + match="Entry is missing the following fields: \\['url', 'title'\\]", + ): Article.before_write(None, None, entry) def test_data_entry_no_url(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'title': 'wikipedia goes to war on porcupines'}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url'\\]"): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry( + {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines"} + ) + with pytest.raises( + AssertionError, match="Entry is missing the following fields: \\['url'\\]" + ): Article.before_write(None, None, entry) def test_data_entry_none_url(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': None}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url', 'title'\\]"): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry({"key1": 12, "key2": 312, "url": None}) + with pytest.raises( + AssertionError, + match="Entry is missing the following fields: \\['url', 'title'\\]", + ): Article.before_write(None, None, entry) def test_data_entry_none_title(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': 'www.wikipedia.org', 'title': None}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['title'\\]"): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry( + {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": None} + ) + with pytest.raises( + AssertionError, match="Entry is missing the following fields: \\['title'\\]" + ): Article.before_write(None, None, entry) def test_data_entry_empty_url_and_title(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': '', 'title': ''}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url', 'title'\\]"): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry({"key1": 12, "key2": 312, "url": "", "title": ""}) + with pytest.raises( + AssertionError, + match="Entry is missing the following fields: \\['url', 'title'\\]", + ): Article.before_write(None, None, entry) def test_data_entry_empty_url_only(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': '', 'title': 'once upon a time'}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url'\\]"): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry( + {"key1": 12, "key2": 312, "url": "", "title": "once upon a time"} + ) + with pytest.raises( + AssertionError, match="Entry is missing the following fields: \\['url'\\]" + ): Article.before_write(None, None, entry) def test_data_entry_empty_title_only(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': 'www.wikipedia.org', 'title':''}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['title'\\]"): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry( + {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": ""} + ) + with pytest.raises( + AssertionError, match="Entry is missing the following fields: \\['title'\\]" + ): Article.before_write(None, None, entry) def test_data_entry_verify_id_passes(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'source': 'arbital', 'text': 'once upon a time', 'url': 'www.arbital.org', 'title': 'once upon a time', 'id': '770fe57c8c2130eda08dc392b8696f97'}) + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry( + { + "source": "arbital", + "text": "once upon a time", + "url": "www.arbital.org", + "title": "once upon a time", + "id": "770fe57c8c2130eda08dc392b8696f97", + } + ) entry.verify_id() def test_data_entry_verify_id_fails(): - dataset = AlignmentDataset(name='blaa') - entry = dataset.make_data_entry({'url': 'www.arbital.org', 'title': 'once upon a time', 'id': 'f2b4e02fc1dd8ae43845e4f930f2d84f'}) - with pytest.raises(AssertionError, match='Entry id does not match id_fields'): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry( + { + "url": "www.arbital.org", + "title": "once upon a time", + "id": "f2b4e02fc1dd8ae43845e4f930f2d84f", + } + ) + with pytest.raises(AssertionError, match="Entry id does not match id_fields"): entry.verify_id() def test_data_entry_id_fields_url_no_url(): - dataset = AlignmentDataset(name='blaa', id_fields=['url']) - entry = dataset.make_data_entry({'source': 'arbital', 'text': 'once upon a time'}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url'\\]"): + dataset = AlignmentDataset(name="blaa", id_fields=["url"]) + entry = dataset.make_data_entry({"source": "arbital", "text": "once upon a time"}) + with pytest.raises( + AssertionError, match="Entry is missing the following fields: \\['url'\\]" + ): Article.before_write(None, None, entry) def test_data_entry_id_fields_url_empty_url(): - dataset = AlignmentDataset(name='blaa', id_fields=['url']) - entry = dataset.make_data_entry({'url': ''}) - with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url'\\]"): + dataset = AlignmentDataset(name="blaa", id_fields=["url"]) + entry = dataset.make_data_entry({"url": ""}) + with pytest.raises( + AssertionError, match="Entry is missing the following fields: \\['url'\\]" + ): Article.before_write(None, None, entry) def test_data_entry_id_fields_url(): - dataset = AlignmentDataset(name='blaa', id_fields=['url']) - entry = dataset.make_data_entry({'url': 'https://www.google.ca/once_upon_a_time'}) + dataset = AlignmentDataset(name="blaa", id_fields=["url"]) + entry = dataset.make_data_entry({"url": "https://www.google.ca/once_upon_a_time"}) Article.before_write(None, None, entry) assert entry.id def test_data_entry_id_fields_url_verify_id_passes(): - dataset = AlignmentDataset(name='blaa', id_fields=['url']) - entry = dataset.make_data_entry({'url': 'arbitalonce upon a time', 'id':'809d336a0b9b38c4f585e862317e667d'}) + dataset = AlignmentDataset(name="blaa", id_fields=["url"]) + entry = dataset.make_data_entry( + {"url": "arbitalonce upon a time", "id": "809d336a0b9b38c4f585e862317e667d"} + ) entry.verify_id() def test_data_entry_different_id_from_different_url(): - dataset = AlignmentDataset(name='blaa', id_fields=['url']) - entry1 = dataset.make_data_entry({'url': ' https://aisafety.info?state=6478'}) - entry2 = dataset.make_data_entry({'source': 'arbital', 'text': 'once upon a time', 'url': ' https://aisafety.info?state=6479'}) + dataset = AlignmentDataset(name="blaa", id_fields=["url"]) + entry1 = dataset.make_data_entry({"url": " https://aisafety.info?state=6478"}) + entry2 = dataset.make_data_entry( + { + "source": "arbital", + "text": "once upon a time", + "url": " https://aisafety.info?state=6479", + } + ) assert entry1.generate_id_string() != entry2.generate_id_string() -@pytest.mark.parametrize('data, error', ( - ({'text': 'bla bla bla'}, "Entry is missing id"), - ({'text': 'bla bla bla', 'id': None}, "Entry is missing id"), - - ({'id': '123', 'url':'www.google.com/winter_wonderland','title': 'winter wonderland'}, 'Entry id 123 does not match id from id_fields, [0-9a-fA-F]{32}'), - ({'id': '457c21e0ecabebcb85c12022d481d9f4', 'url':'www.google.com', 'title': 'winter wonderland'}, 'Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}'), - ({'id': '457c21e0ecabebcb85c12022d481d9f4', 'url':'www.google.com', 'title': 'Once upon a time'}, 'Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}'), -)) +@pytest.mark.parametrize( + "data, error", + ( + ({"text": "bla bla bla"}, "Entry is missing id"), + ({"text": "bla bla bla", "id": None}, "Entry is missing id"), + ( + { + "id": "123", + "url": "www.google.com/winter_wonderland", + "title": "winter wonderland", + }, + "Entry id 123 does not match id from id_fields, [0-9a-fA-F]{32}", + ), + ( + { + "id": "457c21e0ecabebcb85c12022d481d9f4", + "url": "www.google.com", + "title": "winter wonderland", + }, + "Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}", + ), + ( + { + "id": "457c21e0ecabebcb85c12022d481d9f4", + "url": "www.google.com", + "title": "Once upon a time", + }, + "Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}", + ), + ), +) def test_data_entry_verify_id_fails(data, error): - dataset = AlignmentDataset(name='blaa', id_fields=['url', 'title']) + dataset = AlignmentDataset(name="blaa", id_fields=["url", "title"]) entry = dataset.make_data_entry(data) with pytest.raises(AssertionError, match=error): entry.verify_id() -@pytest.mark.parametrize('data, error', ( - ({'id': '123'}, "Entry is missing the following fields: \\['url', 'title'\\]"), - ({'id': '123', 'url': None}, "Entry is missing the following fields: \\['url', 'title'\\]"), - ({'id': '123', 'url': 'www.google.com/'}, "Entry is missing the following fields: \\['title'\\]"), - ({'id': '123', 'url': 'google', 'text': None}, "Entry is missing the following fields: \\['title'\\]"), - ({'id': '123', 'url': '', 'title': ''}, "Entry is missing the following fields: \\['url', 'title'\\]"), -)) +@pytest.mark.parametrize( + "data, error", + ( + ({"id": "123"}, "Entry is missing the following fields: \\['url', 'title'\\]"), + ( + {"id": "123", "url": None}, + "Entry is missing the following fields: \\['url', 'title'\\]", + ), + ( + {"id": "123", "url": "www.google.com/"}, + "Entry is missing the following fields: \\['title'\\]", + ), + ( + {"id": "123", "url": "google", "text": None}, + "Entry is missing the following fields: \\['title'\\]", + ), + ( + {"id": "123", "url": "", "title": ""}, + "Entry is missing the following fields: \\['url', 'title'\\]", + ), + ), +) def test_data_entry_verify_fields_fails(data, error): - dataset = AlignmentDataset(name='blaa', id_fields=['url', 'title']) + dataset = AlignmentDataset(name="blaa", id_fields=["url", "title"]) entry = dataset.make_data_entry(data) with pytest.raises(AssertionError, match=error): entry.verify_fields() @@ -193,10 +293,11 @@ def test_data_entry_verify_fields_fails(data, error): @pytest.fixture def numbers_dataset(): """Make a dataset that raises its items to the power of 2.""" + @dataclass class NumbersDataset(AlignmentDataset): nums: List[int] - done_key = 'number' + done_key = "number" @property def items_list(self): @@ -206,49 +307,56 @@ def get_item_key(self, item): return item def process_entry(self, item): - return self.make_data_entry({ - 'text': f'line {item}', - 'date_published': f'day {item}', - 'source': f'source {item}', - 'title': str(item), - 'url': f'http://bla.bla.bla?page={item}', - 'number': item, - 'value': item ** 2, - 'authors': [], - }) - - return NumbersDataset(name='numbers', nums=list(range(10))) + return self.make_data_entry( + { + "text": f"line {item}", + "date_published": f"day {item}", + "source": f"source {item}", + "title": str(item), + "url": f"http://bla.bla.bla?page={item}", + "number": item, + "value": item**2, + "authors": [], + } + ) + + return NumbersDataset(name="numbers", nums=list(range(10))) def test_unprocessed_items_fresh(numbers_dataset): """Getting the unprocessed items from a dataset that hasn't written anything should get all items.""" seen = set() - with patch.object(numbers_dataset, '_load_outputted_items', return_value=seen): + with patch.object(numbers_dataset, "_load_outputted_items", return_value=seen): assert list(numbers_dataset.unprocessed_items()) == list(range(10)) def test_unprocessed_items_all_done(numbers_dataset): """Getting the unprocessed items from a dataset that has already processed everything should return an empty list.""" seen = set(range(0, 10)) - with patch.object(numbers_dataset, '_load_outputted_items', return_value=seen): + with patch.object(numbers_dataset, "_load_outputted_items", return_value=seen): assert list(numbers_dataset.unprocessed_items()) == [] def test_unprocessed_items_some_done(numbers_dataset): """Getting the uprocessed items from a dataset that has partially completed should return the items that haven't been processed.""" seen = set(range(0, 10, 2)) - with patch.object(numbers_dataset, '_load_outputted_items', return_value=seen): + with patch.object(numbers_dataset, "_load_outputted_items", return_value=seen): assert list(numbers_dataset.unprocessed_items()) == list(range(1, 10, 2)) def test_fetch_entries(numbers_dataset): - assert [i.meta['value'] for i in numbers_dataset.fetch_entries()] == [i**2 for i in range(10)] + assert [i.meta["value"] for i in numbers_dataset.fetch_entries()] == [ + i**2 for i in range(10) + ] def test_format_datatime(dataset): - assert dataset._format_datetime(datetime(2022, 1, 1, 12, 23, 43)) == '2022-01-01T12:23:43Z' + assert ( + dataset._format_datetime(datetime(2022, 1, 1, 12, 23, 43)) + == "2022-01-01T12:23:43Z" + ) def test_format_datatime_ignore_timezone(dataset): - dt = datetime.fromisoformat('2022-01-01T00:00:00+04:00') - assert dataset._format_datetime(dt) == '2022-01-01T00:00:00Z' + dt = datetime.fromisoformat("2022-01-01T00:00:00+04:00") + assert dataset._format_datetime(dt) == "2022-01-01T00:00:00Z" diff --git a/tests/align_data/common/test_html_dataset.py b/tests/align_data/common/test_html_dataset.py index e279873e..a362124a 100644 --- a/tests/align_data/common/test_html_dataset.py +++ b/tests/align_data/common/test_html_dataset.py @@ -9,7 +9,9 @@ @pytest.fixture def html_dataset(): - dataset = HTMLDataset(name='bla', url='http://example.com', authors=['John Smith', 'Your momma']) + dataset = HTMLDataset( + name="bla", url="http://example.com", authors=["John Smith", "Your momma"] + ) return dataset @@ -31,8 +33,12 @@ def html_dataset(): """ + def test_html_dataset_extract_authors(html_dataset): - assert html_dataset.extract_authors('dummy variable') == ['John Smith', 'Your momma'] + assert html_dataset.extract_authors("dummy variable") == [ + "John Smith", + "Your momma", + ] def test_html_dataset_get_title(html_dataset): @@ -44,11 +50,11 @@ def test_html_dataset_get_title(html_dataset): """ soup = BeautifulSoup(item, "html.parser") - assert html_dataset._get_title(soup) == 'This is the title' + assert html_dataset._get_title(soup) == "This is the title" def test_html_dataset_get_title_missing(html_dataset): - soup = BeautifulSoup('', "html.parser") + soup = BeautifulSoup("", "html.parser") assert html_dataset._get_title(soup) is None @@ -60,7 +66,7 @@ def test_html_dataset_get_item_key(html_dataset): """ soup = BeautifulSoup(item, "html.parser") - assert html_dataset.get_item_key(soup) == 'http://example.com/path/to/article' + assert html_dataset.get_item_key(soup) == "http://example.com/path/to/article" def test_html_dataset_items_list(html_dataset): @@ -73,24 +79,28 @@ def test_html_dataset_items_list(html_dataset):
article 5
""" - with patch('requests.get', return_value=Mock(content=text)): + with patch("requests.get", return_value=Mock(content=text)): assert [i.text for i in html_dataset.items_list] == [ - 'article 1', - 'article 2', - 'article 3', - 'article 4', - 'article 5', + "article 1", + "article 2", + "article 3", + "article 4", + "article 5", ] def test_html_dataset_get_contents(html_dataset): - with patch('requests.get', return_value=Mock(content=SAMPLE_HTML)): - assert html_dataset._get_contents('url') == BeautifulSoup(SAMPLE_HTML, "html.parser") + with patch("requests.get", return_value=Mock(content=SAMPLE_HTML)): + assert html_dataset._get_contents("url") == BeautifulSoup( + SAMPLE_HTML, "html.parser" + ) def test_html_dataset_get_text(html_dataset): - soup = BeautifulSoup(f'
{SAMPLE_CONTENTS}
', "html.parser") - assert html_dataset._get_text(soup) == 'bla bla bla [a link](http://ble.com) bla bla' + soup = BeautifulSoup(f"
{SAMPLE_CONTENTS}
", "html.parser") + assert ( + html_dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" + ) def test_html_dataset_find_date(html_dataset): @@ -104,15 +114,21 @@ def test_html_dataset_find_date(html_dataset): """ soup = BeautifulSoup(text, "html.parser") - assert html_dataset._find_date(soup.select('span')) == parse('2023-10-07T00:00:00Z') + assert html_dataset._find_date(soup.select("span")) == parse("2023-10-07T00:00:00Z") -@pytest.mark.parametrize('text', ( - SAMPLE_CONTENTS, - BeautifulSoup(SAMPLE_CONTENTS, "html.parser"), -)) +@pytest.mark.parametrize( + "text", + ( + SAMPLE_CONTENTS, + BeautifulSoup(SAMPLE_CONTENTS, "html.parser"), + ), +) def test_html_dataset_extract_metadata(html_dataset, text): - assert html_dataset._extract_markdown(text) == 'bla bla bla [a link](http://ble.com) bla bla' + assert ( + html_dataset._extract_markdown(text) + == "bla bla bla [a link](http://ble.com) bla bla" + ) def test_html_dataset_process_entry(html_dataset): @@ -124,17 +140,17 @@ def test_html_dataset_process_entry(html_dataset): """ article = BeautifulSoup(item, "html.parser") - with patch('requests.get', return_value=Mock(content=SAMPLE_HTML)): + with patch("requests.get", return_value=Mock(content=SAMPLE_HTML)): assert html_dataset.process_entry(article).to_dict() == { - 'authors': ['John Smith', 'Your momma'], - 'date_published': None, - 'id': None, - 'source': 'bla', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': 'This is the title', - 'url': 'http://example.com/path/to/article', + "authors": ["John Smith", "Your momma"], + "date_published": None, + "id": None, + "source": "bla", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": "This is the title", + "url": "http://example.com/path/to/article", } @@ -142,94 +158,111 @@ def test_html_dataset_process_entry_no_text(html_dataset): item = f'
click to read more
' article = BeautifulSoup(item, "html.parser") - with patch('requests.get', return_value=Mock(content='')): + with patch("requests.get", return_value=Mock(content="")): assert html_dataset.process_entry(article) is None -@pytest.mark.parametrize('item, authors', ( - ({}, ['default author']), - ({'bla': 123123}, ['default author']), - - ({'authors': []}, []), - ({'authors': [{}, {'bla': 'asd'}, {'name': None}, {'name': ''}]}, []), - - ({'authors': [{'name': 'John Smith'}, {'name': 'your momma'}]}, ['John Smith', 'your momma']), -)) +@pytest.mark.parametrize( + "item, authors", + ( + ({}, ["default author"]), + ({"bla": 123123}, ["default author"]), + ({"authors": []}, []), + ({"authors": [{}, {"bla": "asd"}, {"name": None}, {"name": ""}]}, []), + ( + {"authors": [{"name": "John Smith"}, {"name": "your momma"}]}, + ["John Smith", "your momma"], + ), + ), +) def test_rss_dataset_extract_authors(item, authors): - dataset = RSSDataset(name='bla', url='http://example.org', authors=['default author']) + dataset = RSSDataset( + name="bla", url="http://example.org", authors=["default author"] + ) assert dataset.extract_authors(item) == authors def test_rss_dataset_get_title(): - assert RSSDataset._get_title({'title': 'title'}) == 'title' + assert RSSDataset._get_title({"title": "title"}) == "title" -@pytest.mark.parametrize('item, date', ( - ({'published': '2012/01/02 12:32'}, parse('2012-01-02T12:32:00Z')), - ({'pubDate': '2012/01/02 12:32'}, parse('2012-01-02T12:32:00Z')), - ({ - 'pubDate': '2032/01/02 12:32', - 'published': '2012/01/02 12:32', - }, parse('2012-01-02T12:32:00Z')), - - ({'bla': 'bla'}, None), -)) +@pytest.mark.parametrize( + "item, date", + ( + ({"published": "2012/01/02 12:32"}, parse("2012-01-02T12:32:00Z")), + ({"pubDate": "2012/01/02 12:32"}, parse("2012-01-02T12:32:00Z")), + ( + { + "pubDate": "2032/01/02 12:32", + "published": "2012/01/02 12:32", + }, + parse("2012-01-02T12:32:00Z"), + ), + ({"bla": "bla"}, None), + ), +) def test_rss_dataset_get_published_date(item, date): - dataset = RSSDataset(name='bla', url='http://example.org', authors=['default author']) + dataset = RSSDataset( + name="bla", url="http://example.org", authors=["default author"] + ) assert dataset._get_published_date(item) == date -@pytest.mark.parametrize('item', ( - {}, - {'content': None}, - {'content': ''}, - - {'content': []}, - {'content': [{}]}, - {'content': [{'bla': 'asd'}]}, -)) +@pytest.mark.parametrize( + "item", + ( + {}, + {"content": None}, + {"content": ""}, + {"content": []}, + {"content": [{}]}, + {"content": [{"bla": "asd"}]}, + ), +) def test_rss_dataset_get_text_missing(item): - dataset = RSSDataset(name='bla', url='http://example.org') + dataset = RSSDataset(name="bla", url="http://example.org") assert not dataset._get_text(item) def test_rss_dataset_get_text(): - dataset = RSSDataset(name='bla', url='http://example.org') - assert dataset._get_text({'content': [{'value': SAMPLE_CONTENTS}]}) == 'bla bla bla [a link](http://ble.com) bla bla' + dataset = RSSDataset(name="bla", url="http://example.org") + assert ( + dataset._get_text({"content": [{"value": SAMPLE_CONTENTS}]}) + == "bla bla bla [a link](http://ble.com) bla bla" + ) def test_rss_dataset_get_contents_with_contents(): - dataset = RSSDataset(name='bla', url='http://example.org') - dataset.items = { - 'http://bla.bla': { - 'content': 'contents' - } - } + dataset = RSSDataset(name="bla", url="http://example.org") + dataset.items = {"http://bla.bla": {"content": "contents"}} - assert dataset._get_contents('http://bla.bla') == {'content': 'contents'} + assert dataset._get_contents("http://bla.bla") == {"content": "contents"} def test_rss_dataset_get_contents_no_contents(): - dataset = RSSDataset(name='bla', url='http://example.org') - dataset.items = {'http://bla.bla': {}} + dataset = RSSDataset(name="bla", url="http://example.org") + dataset.items = {"http://bla.bla": {}} - contents = '
bla
' - with patch('requests.get', return_value=Mock(content=contents)): - assert dataset._get_contents('http://bla.bla') == { - 'soup': BeautifulSoup(contents, "html.parser") + contents = "
bla
" + with patch("requests.get", return_value=Mock(content=contents)): + assert dataset._get_contents("http://bla.bla") == { + "soup": BeautifulSoup(contents, "html.parser") } def test_rss_dataset_items_list(): - dataset = RSSDataset(name='bla', url='http://example.org') + dataset = RSSDataset(name="bla", url="http://example.org") contents = { - 'entries': [ + "entries": [ { - 'link': f'http://example.org/article-{i}', - 'title': f'Article no {i}', - } for i in range(5) + "link": f"http://example.org/article-{i}", + "title": f"Article no {i}", + } + for i in range(5) ] } - with patch('feedparser.parse', return_value=contents): - assert dataset.items_list == [f'http://example.org/article-{i}' for i in range(5)] + with patch("feedparser.parse", return_value=contents): + assert dataset.items_list == [ + f"http://example.org/article-{i}" for i in range(5) + ] diff --git a/tests/align_data/test_alignment_newsletter.py b/tests/align_data/test_alignment_newsletter.py index 249f77e9..ffd5e31b 100644 --- a/tests/align_data/test_alignment_newsletter.py +++ b/tests/align_data/test_alignment_newsletter.py @@ -7,7 +7,7 @@ @pytest.fixture(scope="module") def dataset(): - dataset = AlignmentNewsletter(name='text') + dataset = AlignmentNewsletter(name="text") dataset.setup() return dataset @@ -19,21 +19,33 @@ def test_xlsx_file_loaded(dataset): def test_get_item_key(dataset): items = list(dataset.items_list) - assert dataset.get_item_key(items[0]) == 'http://gradientscience.org/adv/' + assert dataset.get_item_key(items[0]) == "http://gradientscience.org/adv/" def test_process_entry_no_summary(dataset): - items = pd.DataFrame([ - {'Url': 'http://bla.bla/3', 'Title': 'An item without a summary field'}, - {'Url': 'http://bla.bla/2', 'Title': 'An item with a None summary field', 'Summary': None}, - {'Url': 'http://bla.bla/1', 'Title': 'An item with an invalid summary field', 'Summary': pd.NA}, - ]) + items = pd.DataFrame( + [ + {"Url": "http://bla.bla/3", "Title": "An item without a summary field"}, + { + "Url": "http://bla.bla/2", + "Title": "An item with a None summary field", + "Summary": None, + }, + { + "Url": "http://bla.bla/1", + "Title": "An item with an invalid summary field", + "Summary": pd.NA, + }, + ] + ) for item in items.itertuples(): assert dataset.process_entry(item) is None def test_format_datatime(dataset): - assert dataset._get_published_date(2022) == datetime(2022, 1, 1, tzinfo=timezone.utc) + assert dataset._get_published_date(2022) == datetime( + 2022, 1, 1, tzinfo=timezone.utc + ) def test_process_entry(dataset): @@ -41,81 +53,85 @@ def test_process_entry(dataset): # of a bother to keep up to date, then it can be deleted items = list(dataset.items_list) assert dataset.process_entry(items[0]).to_dict() == { - 'authors': ['Andrew Ilyas*', - 'Shibani Santurkar*', - 'Dimitris Tsipras*', - 'Logan Engstrom*', - 'Brandon Tran', - 'Aleksander Madry'], - 'converted_with': 'python', - 'date_published': '2019-01-01T00:00:00Z', - 'highlight': True, - 'id': None, - 'newsletter_category': 'Adversarial examples', - 'newsletter_number': 'AN #62', - 'opinion': ( - 'I buy this hypothesis. It explains why adversarial examples occur ' + "authors": [ + "Andrew Ilyas*", + "Shibani Santurkar*", + "Dimitris Tsipras*", + "Logan Engstrom*", + "Brandon Tran", + "Aleksander Madry", + ], + "converted_with": "python", + "date_published": "2019-01-01T00:00:00Z", + "highlight": True, + "id": None, + "newsletter_category": "Adversarial examples", + "newsletter_number": "AN #62", + "opinion": ( + "I buy this hypothesis. It explains why adversarial examples occur " '("because they are useful to reduce loss"), and why they transfer ' 'across models ("because different models can learn the same ' 'non-robust features"). In fact, the paper shows that ' - 'architectures that did worse in ExpWrongLabels (and so presumably ' - 'are bad at learning non-robust features) are also the ones to ' + "architectures that did worse in ExpWrongLabels (and so presumably " + "are bad at learning non-robust features) are also the ones to " "which adversarial examples transfer the least. I'll leave the " - 'rest of my opinion to the opinions on the responses.' + "rest of my opinion to the opinions on the responses." ), - 'prerequisites': '', - 'read_more': '[Paper](https://arxiv.org/abs/1905.02175) and [Author response](https://distill.pub/2019/advex-bugs-discussion/original-authors/)', - 'source': 'text', - 'source_type': 'google-sheets', - 'summarizer': 'Rohin', - 'summaries': [( - '_Distill published a discussion of this paper. This highlights ' - 'section will cover the full discussion; all of these summaries and ' - 'opinions are meant to be read together._\n' - '\n' - 'Consider two possible explanations of adversarial examples. First, ' - 'they could be caused because the model "hallucinates" a signal that ' - 'is not useful for classification, and it becomes very sensitive to ' - 'this feature. We could call these "bugs", since they don\'t ' - 'generalize well. Second, they could be caused by features that _do_ ' - 'generalize to the test set, but _can_ be modified by an adversarial ' - 'perturbation. We could call these "non-robust features" (as opposed ' - 'to "robust features", which can\'t be changed by an adversarial ' - 'perturbation). The authors argue that at least some adversarial ' - 'perturbations fall into the second category of being informative but ' - 'sensitive features, based on two experiments.\n' - '\n' - 'If the "hallucination" explanation were true, the hallucinations ' - 'would presumably be caused by the training process, the choice of ' - 'architecture, the size of the dataset, **but not by the type of ' - 'data**. So one thing to do would be to see if we can construct a ' - 'dataset such that a model trained on that dataset is _already_ ' - 'robust, without adversarial training. The authors do this in the ' - 'first experiment. They take an adversarially trained robust ' - 'classifier, and create images whose features (final-layer ' - 'activations of the robust classifier) match the features of some ' - 'unmodified input. The generated images only have robust features ' - 'because the original classifier was robust, and in fact models ' - 'trained on this dataset are automatically robust.\n' - '\n' - 'If the "non-robust features" explanation were true, then it should ' - 'be possible for a model to learn on a dataset containing only ' - 'non-robust features (which will look nonsensical to humans) and ' - '**still generalize to a normal-looking test set**. In the second ' - 'experiment (henceforth WrongLabels), the authors construct such a ' - 'dataset. Their hypothesis is that adversarial perturbations work by ' - 'introducing non-robust features of the target class. So, to ' - 'construct their dataset, they take an image x with original label y, ' - "adversarially perturb it towards some class y' to get image x', and " - "then add (x', y') to their dataset (even though to a human x' looks " - 'like class y). They have two versions of this: in RandLabels, the ' - "target class y' is chosen randomly, whereas in DetLabels, y' is " - 'chosen to be y + 1. For both datasets, if you train a new model on ' - 'the dataset, you get good performance **on the original test set**, ' - 'showing that the "non-robust features" do generalize.' - )], - 'title': 'Adversarial Examples Are Not Bugs, They Are Features', - 'url': 'http://gradientscience.org/adv/', - 'venue': 'arXiv', - 'text': None, -} + "prerequisites": "", + "read_more": "[Paper](https://arxiv.org/abs/1905.02175) and [Author response](https://distill.pub/2019/advex-bugs-discussion/original-authors/)", + "source": "text", + "source_type": "google-sheets", + "summarizer": "Rohin", + "summaries": [ + ( + "_Distill published a discussion of this paper. This highlights " + "section will cover the full discussion; all of these summaries and " + "opinions are meant to be read together._\n" + "\n" + "Consider two possible explanations of adversarial examples. First, " + 'they could be caused because the model "hallucinates" a signal that ' + "is not useful for classification, and it becomes very sensitive to " + 'this feature. We could call these "bugs", since they don\'t ' + "generalize well. Second, they could be caused by features that _do_ " + "generalize to the test set, but _can_ be modified by an adversarial " + 'perturbation. We could call these "non-robust features" (as opposed ' + 'to "robust features", which can\'t be changed by an adversarial ' + "perturbation). The authors argue that at least some adversarial " + "perturbations fall into the second category of being informative but " + "sensitive features, based on two experiments.\n" + "\n" + 'If the "hallucination" explanation were true, the hallucinations ' + "would presumably be caused by the training process, the choice of " + "architecture, the size of the dataset, **but not by the type of " + "data**. So one thing to do would be to see if we can construct a " + "dataset such that a model trained on that dataset is _already_ " + "robust, without adversarial training. The authors do this in the " + "first experiment. They take an adversarially trained robust " + "classifier, and create images whose features (final-layer " + "activations of the robust classifier) match the features of some " + "unmodified input. The generated images only have robust features " + "because the original classifier was robust, and in fact models " + "trained on this dataset are automatically robust.\n" + "\n" + 'If the "non-robust features" explanation were true, then it should ' + "be possible for a model to learn on a dataset containing only " + "non-robust features (which will look nonsensical to humans) and " + "**still generalize to a normal-looking test set**. In the second " + "experiment (henceforth WrongLabels), the authors construct such a " + "dataset. Their hypothesis is that adversarial perturbations work by " + "introducing non-robust features of the target class. So, to " + "construct their dataset, they take an image x with original label y, " + "adversarially perturb it towards some class y' to get image x', and " + "then add (x', y') to their dataset (even though to a human x' looks " + "like class y). They have two versions of this: in RandLabels, the " + "target class y' is chosen randomly, whereas in DetLabels, y' is " + "chosen to be y + 1. For both datasets, if you train a new model on " + "the dataset, you get good performance **on the original test set**, " + 'showing that the "non-robust features" do generalize.' + ) + ], + "title": "Adversarial Examples Are Not Bugs, They Are Features", + "url": "http://gradientscience.org/adv/", + "venue": "arXiv", + "text": None, + } diff --git a/tests/align_data/test_arbital.py b/tests/align_data/test_arbital.py index d1f6724b..304c5398 100644 --- a/tests/align_data/test_arbital.py +++ b/tests/align_data/test_arbital.py @@ -4,65 +4,114 @@ import pytest from dateutil.parser import parse -from align_data.sources.arbital.arbital import Arbital, extract_text, flatten, parse_arbital_link - - -@pytest.mark.parametrize('contents, expected', ( - (['[', '123'], '[https://arbital.com/p/123](https://arbital.com/p/123)'), - (['[', '123 Some title'], '[Some title](https://arbital.com/p/123)'), - (['[', '123 Some title with multiple words'], '[Some title with multiple words](https://arbital.com/p/123)'), -)) +from align_data.sources.arbital.arbital import ( + Arbital, + extract_text, + flatten, + parse_arbital_link, +) + + +@pytest.mark.parametrize( + "contents, expected", + ( + (["[", "123"], "[https://arbital.com/p/123](https://arbital.com/p/123)"), + (["[", "123 Some title"], "[Some title](https://arbital.com/p/123)"), + ( + ["[", "123 Some title with multiple words"], + "[Some title with multiple words](https://arbital.com/p/123)", + ), + ), +) def test_parse_arbital_link(contents, expected): - assert parse_arbital_link(contents) == expected - - -@pytest.mark.parametrize('input, expected', ( - ([1, 2, 3], [1, 2, 3]), - ([1, [2, [3], 4]], [1, 2, 3, 4]), - ((1, (2, 3), 4), [1, 2, 3, 4]), - ([], []), - ([5], [5]), - ([1, 'a', [2, ['b'], 3]], [1, 'a', 2, 'b', 3]), - ([1, None, [2, [3], None]], [1, None, 2, 3, None]), -)) + assert parse_arbital_link(contents) == expected + + +@pytest.mark.parametrize( + "input, expected", + ( + ([1, 2, 3], [1, 2, 3]), + ([1, [2, [3], 4]], [1, 2, 3, 4]), + ((1, (2, 3), 4), [1, 2, 3, 4]), + ([], []), + ([5], [5]), + ([1, "a", [2, ["b"], 3]], [1, "a", 2, "b", 3]), + ([1, None, [2, [3], None]], [1, None, 2, 3, None]), + ), +) def test_flatten(input, expected): assert flatten(input) == expected -@pytest.mark.parametrize('text', ( - '' - 'asdasd asd asd as', - 'Stuff that is in parenthesizes (like this) should be left alone' - 'Markdown links [like this](https://bla.bla.com) should not be changed', -)) +@pytest.mark.parametrize( + "text", + ( + "" "asdasd asd asd as", + "Stuff that is in parenthesizes (like this) should be left alone" + "Markdown links [like this](https://bla.bla.com) should not be changed", + ), +) def test_markdownify_text_contents_basic_markdown(text): _, result = extract_text(text) assert result == text -@pytest.mark.parametrize('text, expected', ( - ('Arbital links [123 like this] should be transformed', 'Arbital links [like this](https://arbital.com/p/123) should be transformed'), - ('[summary: summaries should be removed] bla bla bla', 'bla bla bla'), - - (' \n \t \n contents get stripped of whitespace \t \n', 'contents get stripped of whitespace'), - ('malformed [links](http://bla.bla are handled somewhat', 'malformed [links](http://bla.bla) are handled somewhat') -)) +@pytest.mark.parametrize( + "text, expected", + ( + ( + "Arbital links [123 like this] should be transformed", + "Arbital links [like this](https://arbital.com/p/123) should be transformed", + ), + ("[summary: summaries should be removed] bla bla bla", "bla bla bla"), + ( + " \n \t \n contents get stripped of whitespace \t \n", + "contents get stripped of whitespace", + ), + ( + "malformed [links](http://bla.bla are handled somewhat", + "malformed [links](http://bla.bla) are handled somewhat", + ), + ), +) def test_markdownify_text_contents_arbital_markdown(text, expected): _, result = extract_text(text) assert result == expected -@pytest.mark.parametrize('text, expected', ( - ('[summary: summaries should be extracted] bla bla bla', 'summaries should be extracted'), - ('[summary: \n whitespace should be stripped \n] bla bla bla', 'whitespace should be stripped'), - - ('[summary(Bold): special summaries should be extracted] bla bla bla', 'special summaries should be extracted'), - ('[summary(Markdown): special summaries should be extracted] bla bla bla', 'special summaries should be extracted'), - ('[summary(BLEEEE): special summaries should be extracted] bla bla bla', 'special summaries should be extracted'), - - ('[summary: markdown is handled: [bla](https://bla.bla)] bla bla bla', 'markdown is handled: [bla](https://bla.bla)'), - ('[summary: markdown is handled: [123 ble ble]] bla bla bla', 'markdown is handled: [ble ble](https://arbital.com/p/123)'), -)) +@pytest.mark.parametrize( + "text, expected", + ( + ( + "[summary: summaries should be extracted] bla bla bla", + "summaries should be extracted", + ), + ( + "[summary: \n whitespace should be stripped \n] bla bla bla", + "whitespace should be stripped", + ), + ( + "[summary(Bold): special summaries should be extracted] bla bla bla", + "special summaries should be extracted", + ), + ( + "[summary(Markdown): special summaries should be extracted] bla bla bla", + "special summaries should be extracted", + ), + ( + "[summary(BLEEEE): special summaries should be extracted] bla bla bla", + "special summaries should be extracted", + ), + ( + "[summary: markdown is handled: [bla](https://bla.bla)] bla bla bla", + "markdown is handled: [bla](https://bla.bla)", + ), + ( + "[summary: markdown is handled: [123 ble ble]] bla bla bla", + "markdown is handled: [ble ble](https://arbital.com/p/123)", + ), + ), +) def test_markdownify_text_summary(text, expected): summary, _ = extract_text(text) assert summary == expected @@ -70,129 +119,132 @@ def test_markdownify_text_summary(text, expected): @pytest.fixture def dataset(): - dataset = Arbital(name='arbital') - dataset.titles_map = {} - - def post(url, *args, **kwargs): - response = Mock() - page = json.loads(kwargs.get('data', '{}')).get('pageAlias') - - if 'json/explore' in url: - response.json.return_value = {'pages': {f'{page}-{i}': i for i in range(10)}} - elif 'json/primaryPage' in url: - response.json.return_value = { - 'pages': { - page: { - 'title': f'{page}-title', - } + dataset = Arbital(name="arbital") + dataset.titles_map = {} + + def post(url, *args, **kwargs): + response = Mock() + page = json.loads(kwargs.get("data", "{}")).get("pageAlias") + + if "json/explore" in url: + response.json.return_value = { + "pages": {f"{page}-{i}": i for i in range(10)} + } + elif "json/primaryPage" in url: + response.json.return_value = { + "pages": { + page: { + "title": f"{page}-title", + } + } } - } - else: - response.json.return_value = {} - return response + else: + response.json.return_value = {} + return response - with patch('requests.post', post): - yield dataset + with patch("requests.post", post): + yield dataset def test_items_list(dataset): - assert dataset.items_list == [ - f'{page}-{i}' for page in dataset.ARBITAL_SUBSPACES for i in range(10) - ] + assert dataset.items_list == [ + f"{page}-{i}" for page in dataset.ARBITAL_SUBSPACES for i in range(10) + ] def test_get_title_no_items(dataset): - assert dataset.get_title('bla') == 'bla-title' + assert dataset.get_title("bla") == "bla-title" def test_get_title_cached(dataset): - dataset.titles_map['bla'] = 'ble ble ble' - assert dataset.get_title('bla') == 'ble ble ble' - - -@pytest.mark.parametrize('side_effect, return_value', ( - # The request was successful but no title present - (None, {'pages': {'bla': {}}}), - # The request was successful but the title is empty - (None, {'pages': {'bla': {'title': ''}}}), - - # The request failed - (ValueError('Oh noes!!'), None) -)) + dataset.titles_map["bla"] = "ble ble ble" + assert dataset.get_title("bla") == "ble ble ble" + + +@pytest.mark.parametrize( + "side_effect, return_value", + ( + # The request was successful but no title present + (None, {"pages": {"bla": {}}}), + # The request was successful but the title is empty + (None, {"pages": {"bla": {"title": ""}}}), + # The request failed + (ValueError("Oh noes!!"), None), + ), +) def test_get_title_error(dataset, side_effect, return_value): - titles_map = { - 'a random entry': 'to check that nothing gets changed' - } - dataset.titles_map = titles_map + titles_map = {"a random entry": "to check that nothing gets changed"} + dataset.titles_map = titles_map - with patch('requests.post', return_value=return_value, side_effect=side_effect): - assert dataset.get_title('bla') is None - # Make sure that errors don't change the titles map - assert dataset.titles_map == titles_map + with patch("requests.post", return_value=return_value, side_effect=side_effect): + assert dataset.get_title("bla") is None + # Make sure that errors don't change the titles map + assert dataset.titles_map == titles_map def test_extract_authors(dataset): - authors = ['John Snow', 'mr. blobby'] + authors = ["John Snow", "mr. blobby"] - def post(url, data, **kwargs): - pageAlias = json.loads(data).get('pageAlias') - resp = Mock() - resp.json.return_value = { - 'pages': { - pageAlias: {'title': pageAlias} - } - } - return resp + def post(url, data, **kwargs): + pageAlias = json.loads(data).get("pageAlias") + resp = Mock() + resp.json.return_value = {"pages": {pageAlias: {"title": pageAlias}}} + return resp - with patch('requests.post', post): - page = {'changeLogs': [{'userId': author} for author in authors]} - assert sorted(dataset.extract_authors(page)) == sorted(authors) + with patch("requests.post", post): + page = {"changeLogs": [{"userId": author} for author in authors]} + assert sorted(dataset.extract_authors(page)) == sorted(authors) def test_extract_authors_ignore_missing(dataset): - authors = [ - '', None, 'John Snow', None, None, 'mr. blobby', '', '' - ] - page = {'changeLogs': [{'userId': author} for author in authors]} - - with patch.object(dataset, 'get_title', lambda author: author): - assert sorted(dataset.extract_authors(page)) == sorted(['John Snow', 'mr. blobby']) - - -@pytest.mark.parametrize('page, expected', ( - ({'editCreatedAt': '2021-02-01T01:23:45Z'}, parse('2021-02-01T01:23:45Z')), - ({'pageCreatedAt': '2021-02-01T01:23:45Z'}, parse('2021-02-01T01:23:45Z')), - ({ - 'editCreatedAt': '2021-02-01T01:23:45Z', - 'pageCreatedAt': '2024-02-01T01:23:45Z', - }, parse('2021-02-01T01:23:45Z')), - - ({}, None), - ({'bla': 'asdasd'}, None), -)) + authors = ["", None, "John Snow", None, None, "mr. blobby", "", ""] + page = {"changeLogs": [{"userId": author} for author in authors]} + + with patch.object(dataset, "get_title", lambda author: author): + assert sorted(dataset.extract_authors(page)) == sorted( + ["John Snow", "mr. blobby"] + ) + + +@pytest.mark.parametrize( + "page, expected", + ( + ({"editCreatedAt": "2021-02-01T01:23:45Z"}, parse("2021-02-01T01:23:45Z")), + ({"pageCreatedAt": "2021-02-01T01:23:45Z"}, parse("2021-02-01T01:23:45Z")), + ( + { + "editCreatedAt": "2021-02-01T01:23:45Z", + "pageCreatedAt": "2024-02-01T01:23:45Z", + }, + parse("2021-02-01T01:23:45Z"), + ), + ({}, None), + ({"bla": "asdasd"}, None), + ), +) def test_get_published_date(dataset, page, expected): - assert dataset._get_published_date(page) == expected + assert dataset._get_published_date(page) == expected def test_process_entry(dataset): - page = { - 'title': 'test article', - 'text': 'bla bla bla', - 'editCreatedAt': '2001-02-03T12:34:45Z', - 'alias': 'blee', - 'tagIds': [], - } - with patch.object(dataset, 'get_page', return_value=page): - assert dataset.process_entry('bla').to_dict() == { - 'alias': 'bla', - 'authors': [], - 'date_published': '2001-02-03T12:34:45Z', - 'id': None, - 'source': 'arbital', - 'source_type': 'text', - 'summaries': [], - 'tags': [], - 'text': 'bla bla bla', - 'title': 'test article', - 'url': 'https://arbital.com/p/blee', - } + page = { + "title": "test article", + "text": "bla bla bla", + "editCreatedAt": "2001-02-03T12:34:45Z", + "alias": "blee", + "tagIds": [], + } + with patch.object(dataset, "get_page", return_value=page): + assert dataset.process_entry("bla").to_dict() == { + "alias": "bla", + "authors": [], + "date_published": "2001-02-03T12:34:45Z", + "id": None, + "source": "arbital", + "source_type": "text", + "summaries": [], + "tags": [], + "text": "bla bla bla", + "title": "test article", + "url": "https://arbital.com/p/blee", + } diff --git a/tests/align_data/test_arxiv.py b/tests/align_data/test_arxiv.py index 30717d9e..5817fd2c 100644 --- a/tests/align_data/test_arxiv.py +++ b/tests/align_data/test_arxiv.py @@ -4,29 +4,32 @@ from align_data.sources.arxiv_papers.arxiv_papers import ArxivPapers -@pytest.mark.parametrize('url, expected', ( - ('https://arxiv.org/abs/2001.11038', '2001.11038'), - ('https://arxiv.org/abs/2001.11038/', '2001.11038'), - ('https://bla.bla/2001.11038/', None), -)) +@pytest.mark.parametrize( + "url, expected", + ( + ("https://arxiv.org/abs/2001.11038", "2001.11038"), + ("https://arxiv.org/abs/2001.11038/", "2001.11038"), + ("https://bla.bla/2001.11038/", None), + ), +) def test_get_id(url, expected): - dataset = ArxivPapers(name='asd', spreadsheet_id='ad', sheet_id='da') - assert dataset.get_id(Mock(url='https://arxiv.org/abs/2001.11038')) == '2001.11038' + dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") + assert dataset.get_id(Mock(url="https://arxiv.org/abs/2001.11038")) == "2001.11038" def test_process_entry(): - dataset = ArxivPapers(name='asd', spreadsheet_id='ad', sheet_id='da') + dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") item = Mock( - title='this is the title', - url='https://arxiv.org/abs/2001.11038', - authors='', - date_published='2020-01-29', + title="this is the title", + url="https://arxiv.org/abs/2001.11038", + authors="", + date_published="2020-01-29", ) contents = { - 'text': 'this is the text', - 'date_published': "December 12, 2021", - 'authors': ['mr blobby'], - 'data_source': 'html', + "text": "this is the text", + "date_published": "December 12, 2021", + "authors": ["mr blobby"], + "data_source": "html", } metadata = Mock( summary="abstract bla bla", @@ -34,29 +37,31 @@ def test_process_entry(): categories="wut", updated="2023-01-01", authors=[], - doi='123', - journal_ref='sdf', - primary_category='cat', + doi="123", + journal_ref="sdf", + primary_category="cat", ) arxiv = Mock() arxiv.Search.return_value.results.return_value = iter([metadata]) - with patch('align_data.arxiv_papers.arxiv_papers.parse_vanity', return_value=contents): - with patch('align_data.arxiv_papers.arxiv_papers.arxiv', arxiv): + with patch( + "align_data.arxiv_papers.arxiv_papers.parse_vanity", return_value=contents + ): + with patch("align_data.arxiv_papers.arxiv_papers.arxiv", arxiv): assert dataset.process_entry(item).to_dict() == { - 'author_comment': 'no comment', - 'authors': ['mr blobby'], - 'categories': 'wut', - 'data_last_modified': '2023-01-01', - 'date_published': '2020-01-29T00:00:00Z', - 'doi': '123', - 'id': None, - 'journal_ref': 'sdf', - 'primary_category': 'cat', - 'source': 'asd', - 'source_type': 'html', - 'summaries': ['abstract bla bla'], - 'text': 'this is the text', - 'title': 'this is the title', - 'url': 'https://arxiv.org/abs/2001.11038', + "author_comment": "no comment", + "authors": ["mr blobby"], + "categories": "wut", + "data_last_modified": "2023-01-01", + "date_published": "2020-01-29T00:00:00Z", + "doi": "123", + "id": None, + "journal_ref": "sdf", + "primary_category": "cat", + "source": "asd", + "source_type": "html", + "summaries": ["abstract bla bla"], + "text": "this is the text", + "title": "this is the title", + "url": "https://arxiv.org/abs/2001.11038", } diff --git a/tests/align_data/test_blogs.py b/tests/align_data/test_blogs.py index f2dc7d21..d09edd16 100644 --- a/tests/align_data/test_blogs.py +++ b/tests/align_data/test_blogs.py @@ -5,8 +5,15 @@ from dateutil.parser import parse from align_data.sources.blogs import ( - CaradoMoe, ColdTakes, GenerativeInk, GwernBlog, MediumBlog, SubstackBlog, WordpressBlog, - OpenAIResearch, DeepMindTechnicalBlog + CaradoMoe, + ColdTakes, + GenerativeInk, + GwernBlog, + MediumBlog, + SubstackBlog, + WordpressBlog, + OpenAIResearch, + DeepMindTechnicalBlog, ) from align_data.sources.blogs.blogs import EleutherAI @@ -17,11 +24,12 @@ """ + def test_cold_takes_published_date(): dataset = ColdTakes( name="cold_takes", url="https://www.cold-takes.com/", - authors=['Holden Karnofsky'], + authors=["Holden Karnofsky"], ) contents = """ @@ -32,14 +40,14 @@ def test_cold_takes_published_date(): """ soup = BeautifulSoup(contents, "html.parser") - assert dataset._get_published_date(soup) == parse('2001-02-03T00:00:00Z') + assert dataset._get_published_date(soup) == parse("2001-02-03T00:00:00Z") def test_cold_takes_process_entry(): dataset = ColdTakes( name="cold_takes", url="https://www.cold-takes.com/", - authors=['Holden Karnofsky'], + authors=["Holden Karnofsky"], ) item = """ @@ -70,17 +78,17 @@ def test_cold_takes_process_entry(): """ - with patch('requests.get', return_value=Mock(content=article)): + with patch("requests.get", return_value=Mock(content=article)): assert dataset.process_entry(BeautifulSoup(item, "html.parser")).to_dict() == { - 'authors': ['Holden Karnofsky'], - 'date_published': '2023-02-28T00:00:00Z', - 'id': None, - 'source': 'cold_takes', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'What does Bing Chat tell us about AI risk?', - 'url': 'https://www.cold-takes.com/how-governments-can-help-with-the-most-important-century/', + "authors": ["Holden Karnofsky"], + "date_published": "2023-02-28T00:00:00Z", + "id": None, + "source": "cold_takes", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla", + "title": "What does Bing Chat tell us about AI risk?", + "url": "https://www.cold-takes.com/how-governments-can-help-with-the-most-important-century/", } @@ -99,22 +107,23 @@ def test_cold_takes_process_entry(): """ + def test_generative_ink_published_date(): dataset = GenerativeInk( name="generative.ink", url="https://generative.ink/posts/", - authors=['janus'], + authors=["janus"], ) soup = BeautifulSoup(GENERITIVE_INK_HTML, "html.parser") - assert dataset._get_published_date(soup) == parse('2023-02-09T00:00:00Z') + assert dataset._get_published_date(soup) == parse("2023-02-09T00:00:00Z") def test_generative_ink_process_entry(): dataset = GenerativeInk( name="generative.ink", url="https://generative.ink/posts/", - authors=['janus'], + authors=["janus"], ) item = """ @@ -125,25 +134,25 @@ def test_generative_ink_process_entry(): """ - with patch('requests.get', return_value=Mock(content=GENERITIVE_INK_HTML)): + with patch("requests.get", return_value=Mock(content=GENERITIVE_INK_HTML)): assert dataset.process_entry(BeautifulSoup(item, "html.parser")).to_dict() == { - 'authors': ['janus'], - 'date_published': '2023-02-09T00:00:00Z', - 'id': None, - 'source': 'generative.ink', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'Anomalous tokens reveal the original identities of Instruct models', - 'url': 'https://generative.ink/posts/simulators/', + "authors": ["janus"], + "date_published": "2023-02-09T00:00:00Z", + "id": None, + "source": "generative.ink", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla", + "title": "Anomalous tokens reveal the original identities of Instruct models", + "url": "https://generative.ink/posts/simulators/", } def test_caradomoe_text(): dataset = CaradoMoe( name="carado.moe", - url='https://carado.moe', - authors=['Tamsin Leake'], + url="https://carado.moe", + authors=["Tamsin Leake"], ) contents = f"""
@@ -152,38 +161,41 @@ def test_caradomoe_text():
""" soup = BeautifulSoup(contents, "html.parser") - assert dataset._get_text({'soup': soup}) == 'bla bla bla [a link](http://ble.com) bla bla' + assert ( + dataset._get_text({"soup": soup}) + == "bla bla bla [a link](http://ble.com) bla bla" + ) def test_caradomoe_process_entry(): dataset = CaradoMoe( name="carado.moe", - url='https://carado.moe', - authors=['Tamsin Leake'], + url="https://carado.moe", + authors=["Tamsin Leake"], ) item = { - 'pubDate': 'Sat, 10 Jun 2023 07:00:00 -0000', - 'title': 'the title', - 'link': 'http://example.com/bla' + "pubDate": "Sat, 10 Jun 2023 07:00:00 -0000", + "title": "the title", + "link": "http://example.com/bla", } - dataset.items = {item['link']: item} + dataset.items = {item["link"]: item} contents = f"""

" {SAMPLE_HTML}
""" - with patch('requests.get', return_value=Mock(content=contents)): - assert dataset.process_entry(item['link']).to_dict() == { - 'authors': ['Tamsin Leake'], - 'date_published': '2023-06-10T07:00:00Z', - 'id': None, - 'source': 'carado.moe', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': 'the title', - 'url': 'http://example.com/bla' + with patch("requests.get", return_value=Mock(content=contents)): + assert dataset.process_entry(item["link"]).to_dict() == { + "authors": ["Tamsin Leake"], + "date_published": "2023-06-10T07:00:00Z", + "id": None, + "source": "carado.moe", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": "the title", + "url": "http://example.com/bla", } @@ -214,31 +226,44 @@ def test_caradomoe_process_entry(): """ + + def test_gwern_get_text(): - dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) + dataset = GwernBlog( + name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] + ) soup = BeautifulSoup(GWERN_CONTENTS, "html.parser") - assert dataset._get_text(soup) == 'bla bla bla [a link](http://ble.com) bla bla' - + assert dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" -@pytest.mark.parametrize('metadata, date', ( - ({'modified': '2022-01-02'}, parse('2022-01-02T00:00:00Z')), - ({'created': '2022-01-02'}, parse('2022-01-02T00:00:00Z')), - ({'created': '2000-01-01', 'modified': '2022-01-02'}, parse('2022-01-02T00:00:00Z')), - ({}, None), - ({'bla': 'asda'}, None) -)) +@pytest.mark.parametrize( + "metadata, date", + ( + ({"modified": "2022-01-02"}, parse("2022-01-02T00:00:00Z")), + ({"created": "2022-01-02"}, parse("2022-01-02T00:00:00Z")), + ( + {"created": "2000-01-01", "modified": "2022-01-02"}, + parse("2022-01-02T00:00:00Z"), + ), + ({}, None), + ({"bla": "asda"}, None), + ), +) def test_gwern_get_published_date(metadata, date): - dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) + dataset = GwernBlog( + name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] + ) assert dataset._get_published_date(metadata) == date def test_gwern_get_article(): - dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) - with patch('requests.get', return_value='article contents'): - assert dataset._get_article('http://bla.com') == 'article contents' + dataset = GwernBlog( + name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] + ) + with patch("requests.get", return_value="article contents"): + assert dataset._get_article("http://bla.com") == "article contents" def test_gwern_get_metadata(): @@ -256,16 +281,16 @@ def test_gwern_get_metadata(): cssExtension: drop-caps-kanzlei """ assert GwernBlog._get_metadata(text) == { - 'confidence': 'likely', - 'created': '2020-05-28', - 'cssExtension': 'drop-caps-kanzlei', - 'importance': '10', - 'modified': '2022-01-02', - 'next': '/fiction/clippy', - 'previous': '/newsletter/2020/05', - 'status': 'finished', - 'thumbnail': '/doc/ai/nn/transformer/gpt/2020-brown-gpt3-figure13-meanperformancescalingcurve.png', - 'title': '"The Scaling Hypothesis"', + "confidence": "likely", + "created": "2020-05-28", + "cssExtension": "drop-caps-kanzlei", + "importance": "10", + "modified": "2022-01-02", + "next": "/fiction/clippy", + "previous": "/newsletter/2020/05", + "status": "finished", + "thumbnail": "/doc/ai/nn/transformer/gpt/2020-brown-gpt3-figure13-meanperformancescalingcurve.png", + "title": '"The Scaling Hypothesis"', } @@ -277,18 +302,22 @@ def test_gwern_process_markdown(): ... {SAMPLE_HTML} """ - dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) - - assert dataset._process_markdown('http://article.url', Mock(text=text)).to_dict() == { - 'authors': ['Gwern Branwen'], - 'date_published': '2020-05-28T00:00:00Z', - 'id': None, - 'source': 'gwern_blog', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': '"The Scaling Hypothesis"', - 'url': 'http://article.url', + dataset = GwernBlog( + name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] + ) + + assert dataset._process_markdown( + "http://article.url", Mock(text=text) + ).to_dict() == { + "authors": ["Gwern Branwen"], + "date_published": "2020-05-28T00:00:00Z", + "id": None, + "source": "gwern_blog", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": '"The Scaling Hypothesis"', + "url": "http://article.url", } @@ -300,44 +329,59 @@ def test_gwern_process_entry_markdown(): ... {SAMPLE_HTML} """ - dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) - - with patch('requests.get', return_value=Mock(text=text, status_code=200, headers={})): - assert dataset.process_entry('http://article.url').to_dict() == { - 'authors': ['Gwern Branwen'], - 'date_published': '2020-05-28T00:00:00Z', - 'id': None, - 'source': 'gwern_blog', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': '"The Scaling Hypothesis"', - 'url': 'http://article.url', + dataset = GwernBlog( + name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] + ) + + with patch( + "requests.get", return_value=Mock(text=text, status_code=200, headers={}) + ): + assert dataset.process_entry("http://article.url").to_dict() == { + "authors": ["Gwern Branwen"], + "date_published": "2020-05-28T00:00:00Z", + "id": None, + "source": "gwern_blog", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": '"The Scaling Hypothesis"', + "url": "http://article.url", } def test_gwern_process_entry_html(): - dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) - - with patch('requests.get', return_value=Mock(content=GWERN_CONTENTS, status_code=200, headers={'Content-Type': 'text/html'})): - assert dataset.process_entry('http://article.url').to_dict() == { - 'authors': ['Gwern Branwen'], - 'date_published': '2023-01-01T00:00:00Z', - 'id': None, - 'source': 'gwern_blog', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': 'The title of the article', - 'url': 'http://article.url', + dataset = GwernBlog( + name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] + ) + + with patch( + "requests.get", + return_value=Mock( + content=GWERN_CONTENTS, + status_code=200, + headers={"Content-Type": "text/html"}, + ), + ): + assert dataset.process_entry("http://article.url").to_dict() == { + "authors": ["Gwern Branwen"], + "date_published": "2023-01-01T00:00:00Z", + "id": None, + "source": "gwern_blog", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": "The title of the article", + "url": "http://article.url", } def test_gwern_process_entry_erro(): - dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) + dataset = GwernBlog( + name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] + ) - with patch('requests.get', return_value=Mock(status_code=404)): - assert dataset.process_entry('http://article.url') is None + with patch("requests.get", return_value=Mock(status_code=404)): + assert dataset.process_entry("http://article.url") is None MEDIUM_HTML = f""" @@ -354,23 +398,31 @@ def test_gwern_process_entry_erro(): {SAMPLE_HTML} """ + + def test_medium_get_published_date(): - dataset = MediumBlog(name="deepmind_blog", url="https://bla.medium.com/", authors=["mr Blobby"]) + dataset = MediumBlog( + name="deepmind_blog", url="https://bla.medium.com/", authors=["mr Blobby"] + ) soup = BeautifulSoup(MEDIUM_HTML, "html.parser") - assert dataset._get_published_date(soup) == parse('2023-10-07T00:00:00Z') + assert dataset._get_published_date(soup) == parse("2023-10-07T00:00:00Z") def test_medium_get_text(): - dataset = MediumBlog(name="deepmind_blog", url="https://bla.medium.com/", authors=["mr Blobby"]) + dataset = MediumBlog( + name="deepmind_blog", url="https://bla.medium.com/", authors=["mr Blobby"] + ) soup = BeautifulSoup(MEDIUM_HTML, "html.parser") - soup.find('h1').extract() - assert dataset._get_text(soup) == 'bla bla bla [a link](http://ble.com) bla bla' + soup.find("h1").extract() + assert dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" def test_medium_process_entry(): - dataset = MediumBlog(name="deepmind_blog", url="https://bla.medium.com/", authors=["mr Blobby"]) + dataset = MediumBlog( + name="deepmind_blog", url="https://bla.medium.com/", authors=["mr Blobby"] + ) item = """
@@ -378,50 +430,51 @@ def test_medium_process_entry():

Discovering when an agent is present in a system

""" - with patch('requests.get', return_value=Mock(content=MEDIUM_HTML)): + with patch("requests.get", return_value=Mock(content=MEDIUM_HTML)): assert dataset.process_entry(BeautifulSoup(item, "html.parser")).to_dict() == { - 'authors': ['mr Blobby'], - 'date_published': '2023-10-07T00:00:00Z', - 'id': None, - 'source': 'deepmind_blog', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': 'This is the title', - 'url': 'https://bla.medium.com/discovering-when-an-agent-is-present-in-a-system-41154de11e7b', + "authors": ["mr Blobby"], + "date_published": "2023-10-07T00:00:00Z", + "id": None, + "source": "deepmind_blog", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": "This is the title", + "url": "https://bla.medium.com/discovering-when-an-agent-is-present-in-a-system-41154de11e7b", } def test_substack_blog_process_entry(): dataset = SubstackBlog(name="blog", url="https://blog.substack.com") contents = { - 'entries': [ + "entries": [ { - 'link': 'http://example.org/bla', - 'title': 'the article title', - 'pubDate': 'Mon, 26 Jun 2023 13:40:01 GMT', - 'description': 'the articles description', - 'content': [{'value': SAMPLE_HTML}], - 'authors': [{'name': 'mr Blobby'}], + "link": "http://example.org/bla", + "title": "the article title", + "pubDate": "Mon, 26 Jun 2023 13:40:01 GMT", + "description": "the articles description", + "content": [{"value": SAMPLE_HTML}], + "authors": [{"name": "mr Blobby"}], } ] } # Setup the items list contents - with patch('feedparser.parse', return_value=contents): + with patch("feedparser.parse", return_value=contents): dataset.items_list - assert dataset.process_entry('http://example.org/bla').to_dict() == { - 'authors': ['mr Blobby'], - 'date_published': '2023-06-26T13:40:01Z', - 'id': None, - 'source': 'blog', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': 'the article title', - 'url': 'http://example.org/bla', + assert dataset.process_entry("http://example.org/bla").to_dict() == { + "authors": ["mr Blobby"], + "date_published": "2023-06-26T13:40:01Z", + "id": None, + "source": "blog", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": "the article title", + "url": "http://example.org/bla", } + WORDPRESS_FEED = { "entries": [ { @@ -438,63 +491,68 @@ def test_substack_blog_process_entry(): "link": "https://www.yudkowsky.net", }, "headers": { - "link": "; rel=\"https://api.w.org/\"" - } + "link": '; rel="https://api.w.org/"' + }, } def test_wordpress_blog_setup(): blog = WordpressBlog( - name='blog_name', + name="blog_name", url="https://www.bla.yudkowsky.net", ) - assert blog.feed_url == 'https://www.bla.yudkowsky.net/feed' + assert blog.feed_url == "https://www.bla.yudkowsky.net/feed" assert blog.name == "blog_name" - -@patch('feedparser.parse', return_value=WORDPRESS_FEED) +@patch("feedparser.parse", return_value=WORDPRESS_FEED) def test_wordpress_blog_items_list(feedparser_parse): - blog = WordpressBlog(name='blog', url="https://www.bla.yudkowsky.net") - assert blog.items_list == ['https://www.yudkowsky.net/other/fiction/prospiracy-theory'] + blog = WordpressBlog(name="blog", url="https://www.bla.yudkowsky.net") + assert blog.items_list == [ + "https://www.yudkowsky.net/other/fiction/prospiracy-theory" + ] def test_wordpress_blog_get_item_key(): blog = WordpressBlog( - name='blog', + name="blog", url="https://www.bla.yudkowsky.net", ) - item = {'title': 'Test Entry'} - assert item == blog.get_item_key(item) + item = {"title": "Test Entry"} + assert item == blog.get_item_key(item) def test_wordpress_blog_get_published_date(): blog = WordpressBlog( - name='blog', + name="blog", url="https://www.bla.yudkowsky.net", ) - date_published = blog._get_published_date({'published': "Mon, 26 Jun 2023 13:40:01 +0000"}) - assert date_published == parse('2023-06-26T13:40:01Z') + date_published = blog._get_published_date( + {"published": "Mon, 26 Jun 2023 13:40:01 +0000"} + ) + assert date_published == parse("2023-06-26T13:40:01Z") -@patch('feedparser.parse', return_value=WORDPRESS_FEED) +@patch("feedparser.parse", return_value=WORDPRESS_FEED) def test_wordpress_blog_process_entry(feedparser_parse): blog = WordpressBlog( - name='blog_name', + name="blog_name", url="https://www.bla.yudkowsky.net", ) - blog.items = {i['link']: i for i in WORDPRESS_FEED['entries']} - entry = blog.process_entry('https://www.yudkowsky.net/other/fiction/prospiracy-theory') + blog.items = {i["link"]: i for i in WORDPRESS_FEED["entries"]} + entry = blog.process_entry( + "https://www.yudkowsky.net/other/fiction/prospiracy-theory" + ) assert entry.to_dict() == { - 'authors': ['Eliezer S. Yudkowsky'], - 'date_published': '2020-09-04T04:11:23Z', - 'id': None, - 'source': 'blog_name', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': 'Prospiracy Theory', - 'url': 'https://www.yudkowsky.net/other/fiction/prospiracy-theory', + "authors": ["Eliezer S. Yudkowsky"], + "date_published": "2020-09-04T04:11:23Z", + "id": None, + "source": "blog_name", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": "Prospiracy Theory", + "url": "https://www.yudkowsky.net/other/fiction/prospiracy-theory", } @@ -512,35 +570,48 @@ def test_wordpress_blog_process_entry(feedparser_parse): """ + def test_eleutherai_get_published_date(): - dataset = EleutherAI(name='eleuther', url='http://bla.bla') + dataset = EleutherAI(name="eleuther", url="http://bla.bla") soup = BeautifulSoup(ELEUTHER_HTML, "html.parser") assert dataset._get_published_date(soup) == parse("2023-07-08T00:00:00Z") def test_eleutherai_extract_authors(): - dataset = EleutherAI(name='eleuther', url='http://bla.bla') + dataset = EleutherAI(name="eleuther", url="http://bla.bla") soup = BeautifulSoup(ELEUTHER_HTML, "html.parser") - assert dataset.extract_authors(soup) == ['Curtis Huebner', 'Robert Klassert', 'Stepan Shabalin', 'Edwin Fennell', 'Delta Hessler'] + assert dataset.extract_authors(soup) == [ + "Curtis Huebner", + "Robert Klassert", + "Stepan Shabalin", + "Edwin Fennell", + "Delta Hessler", + ] def test_eleutherai_process_entry(): - dataset = EleutherAI(name='eleuther', url='http://bla.bla') + dataset = EleutherAI(name="eleuther", url="http://bla.bla") article = BeautifulSoup('', "html.parser") - with patch('requests.get', return_value=Mock(content=ELEUTHER_HTML)): + with patch("requests.get", return_value=Mock(content=ELEUTHER_HTML)): assert dataset.process_entry(article).to_dict() == { - 'authors': ['Curtis Huebner', 'Robert Klassert', 'Stepan Shabalin', 'Edwin Fennell', 'Delta Hessler'], - 'date_published': '2023-07-08T00:00:00Z', - 'id': None, - 'source': 'eleuther', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'Minetester: A fully open RL environment built on Minetest', - 'url': 'http://bla.bla/bla.bla', + "authors": [ + "Curtis Huebner", + "Robert Klassert", + "Stepan Shabalin", + "Edwin Fennell", + "Delta Hessler", + ], + "date_published": "2023-07-08T00:00:00Z", + "id": None, + "source": "eleuther", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla", + "title": "Minetester: A fully open RL environment built on Minetest", + "url": "http://bla.bla/bla.bla", } @@ -558,95 +629,109 @@ def test_eleutherai_process_entry(): """ + + def test_openai_research_get_published_date(): - dataset = OpenAIResearch(name='openai', url='bla.bla') + dataset = OpenAIResearch(name="openai", url="bla.bla") soup = BeautifulSoup(OPENAI_HTML, "html.parser") - assert dataset._get_published_date(soup) == parse('2023-07-06T00:00:00Z') + assert dataset._get_published_date(soup) == parse("2023-07-06T00:00:00Z") def test_openai_research_get_text(): - dataset = OpenAIResearch(name='openai', url='bla.bla') + dataset = OpenAIResearch(name="openai", url="bla.bla") soup = BeautifulSoup(OPENAI_HTML, "html.parser") - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'})): - with patch('align_data.articles.pdf.fetch_pdf', return_value={'text': 'bla bla bla'}): - assert dataset._get_text(soup) == 'bla bla bla' + with patch( + "requests.head", return_value=Mock(headers={"Content-Type": "text/html"}) + ): + with patch( + "align_data.articles.pdf.fetch_pdf", return_value={"text": "bla bla bla"} + ): + assert dataset._get_text(soup) == "bla bla bla" -@pytest.mark.parametrize('html, expected', ( +@pytest.mark.parametrize( + "html, expected", ( - """
+ ( + """
Authors

Mr. Blobby
John Snow (Westeros)

""", - ["Mr. Blobby", "John Snow"] - ), - ( - """
+ ["Mr. Blobby", "John Snow"], + ), + ( + """
Acknowledgments

Mr. Blobby
John Snow (Westeros)

""", - ["Mr. Blobby", "John Snow"] - ), - ( - """
+ ["Mr. Blobby", "John Snow"], + ), + ( + """
Bla Bla Bla

Mr. Blobby
John Snow (Westeros)

""", - [] + [], + ), ), -)) +) def test_openai_research_extract_authors(html, expected): - dataset = OpenAIResearch(name='openai', url='bla.bla') + dataset = OpenAIResearch(name="openai", url="bla.bla") soup = BeautifulSoup(html, "html.parser") assert dataset.extract_authors(soup) == expected def test_openai_research_process_entry(): - dataset = OpenAIResearch(name='openai', url='bla.bla') + dataset = OpenAIResearch(name="openai", url="bla.bla") soup = BeautifulSoup(OPENAI_HTML, "html.parser") - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'})): - with patch('requests.get', return_value=Mock(content=OPENAI_HTML)): - with patch('align_data.articles.pdf.fetch_pdf', return_value={'text': 'bla bla bla'}): + with patch( + "requests.head", return_value=Mock(headers={"Content-Type": "text/html"}) + ): + with patch("requests.get", return_value=Mock(content=OPENAI_HTML)): + with patch( + "align_data.articles.pdf.fetch_pdf", + return_value={"text": "bla bla bla"}, + ): assert dataset.process_entry(soup).to_dict() == { - 'authors': ['Mr. Blobby', 'John Snow'], - 'date_published': '2023-07-06T00:00:00Z', - 'id': None, - 'source': 'openai', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla', - 'title': None, - 'url': 'https://arxiv.org', + "authors": ["Mr. Blobby", "John Snow"], + "date_published": "2023-07-06T00:00:00Z", + "id": None, + "source": "openai", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla", + "title": None, + "url": "https://arxiv.org", } def test_deepmind_technical_items_list(): - dataset = DeepMindTechnicalBlog(name='bla', url='http://bla.com') + dataset = DeepMindTechnicalBlog(name="bla", url="http://bla.com") def getter(url, *args, **params): - page = params.get('params')['73df3071_page'] + page = params.get("params")["73df3071_page"] if page < 3: - html = ''.join( + html = "".join( f'
{i}
' for i in range(page * 10 - 10, page * 10) ) - return Mock(content=f'
{html}
') - return Mock(content='') + return Mock(content=f"
{html}
") + return Mock(content="") - with patch('requests.get', getter): + with patch("requests.get", getter): assert [str(i) for i in dataset.items_list] == [ f'
{i}
' for i in range(0, 20) ] @@ -670,30 +755,32 @@ def getter(url, *args, **params):
""" + + def test_deepmind_technical_get_published_date(): - dataset = DeepMindTechnicalBlog(name='bla', url='http://bla.com') + dataset = DeepMindTechnicalBlog(name="bla", url="http://bla.com") soup = BeautifulSoup(DEEPMIND_HTML, "html.parser") - assert dataset._get_published_date(soup) == parse('2023-07-11T00:00:00Z') + assert dataset._get_published_date(soup) == parse("2023-07-11T00:00:00Z") def test_deepmind_technical_extract_authors(): - dataset = DeepMindTechnicalBlog(name='bla', url='http://bla.com') + dataset = DeepMindTechnicalBlog(name="bla", url="http://bla.com") soup = BeautifulSoup(DEEPMIND_HTML, "html.parser") - assert dataset.extract_authors(soup) == ['Mr. Blobby', 'John Snow'] + assert dataset.extract_authors(soup) == ["Mr. Blobby", "John Snow"] def test_deepmind_technical_proces_entry(): - dataset = DeepMindTechnicalBlog(name='bla', url='http://bla.com') + dataset = DeepMindTechnicalBlog(name="bla", url="http://bla.com") soup = BeautifulSoup('
', "html.parser") - with patch('requests.get', return_value=Mock(content=DEEPMIND_HTML)): + with patch("requests.get", return_value=Mock(content=DEEPMIND_HTML)): assert dataset.process_entry(soup).to_dict() == { - 'authors': ['Mr. Blobby', 'John Snow'], - 'date_published': '2023-07-11T00:00:00Z', - 'id': None, - 'source': 'bla', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'title!', - 'url': 'http://bla.bl', + "authors": ["Mr. Blobby", "John Snow"], + "date_published": "2023-07-11T00:00:00Z", + "id": None, + "source": "bla", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla", + "title": "title!", + "url": "http://bla.bl", } diff --git a/tests/align_data/test_distill.py b/tests/align_data/test_distill.py index b94b5bda..02a45d98 100644 --- a/tests/align_data/test_distill.py +++ b/tests/align_data/test_distill.py @@ -7,7 +7,7 @@ def test_extract_authors(): - dataset = Distill(name='distill', url='bla.bla') + dataset = Distill(name="distill", url="bla.bla") contents = """
@@ -23,22 +23,29 @@ def test_extract_authors():
""" soup = BeautifulSoup(contents, "html.parser") - assert dataset.extract_authors({'soup': soup}) == ['Ameya Daigavane', 'Balaraman Ravindran', 'Gaurav Aggarwal'] - - -@pytest.mark.parametrize('text', ( - ' bla bla a link ble \n\n', - ' bla bla a link ble \n\n', -)) + assert dataset.extract_authors({"soup": soup}) == [ + "Ameya Daigavane", + "Balaraman Ravindran", + "Gaurav Aggarwal", + ] + + +@pytest.mark.parametrize( + "text", + ( + ' bla bla a link ble \n\n', + ' bla bla a link ble \n\n', + ), +) def test_get_text(text): - dataset = Distill(name='distill', url='bla.bla') + dataset = Distill(name="distill", url="bla.bla") soup = BeautifulSoup(text, "html.parser") - assert dataset._get_text({'soup': soup}) == "bla bla [a link](bla.com) ble" + assert dataset._get_text({"soup": soup}) == "bla bla [a link](bla.com) ble" def test_extra_values(): - dataset = Distill(name='distill', url='bla.bla') + dataset = Distill(name="distill", url="bla.bla") contents = """
@@ -62,24 +69,27 @@ def test_extra_values(): """ soup = BeautifulSoup(contents, "html.parser") - assert dataset._extra_values({'soup': soup, 'summary': 'A wild summary has appeared!'}) == { - 'bibliography': [ + assert dataset._extra_values( + {"soup": soup, "summary": "A wild summary has appeared!"} + ) == { + "bibliography": [ { - 'link': 'https://doi.org/10.23915/distill.00033', - 'title': 'A Gentle Introduction to Graph Neural Networks' - }, { - 'link': 'http://jmlr.org/papers/v11/vishwanathan10a.html', - 'title': 'Graph Kernels' - } + "link": "https://doi.org/10.23915/distill.00033", + "title": "A Gentle Introduction to Graph Neural Networks", + }, + { + "link": "http://jmlr.org/papers/v11/vishwanathan10a.html", + "title": "Graph Kernels", + }, ], - 'doi': '10.23915/distill.00032', - 'journal_ref': 'distill-pub', - 'summary': 'A wild summary has appeared!', + "doi": "10.23915/distill.00032", + "journal_ref": "distill-pub", + "summary": "A wild summary has appeared!", } def test_process_entry(): - dataset = Distill(name='distill', url='bla.bla') + dataset = Distill(name="distill", url="bla.bla") contents = """
@@ -114,39 +124,40 @@ def test_process_entry(): """ items = { - 'entries': [ + "entries": [ { - 'link': 'http://example.org/bla', - 'title': 'the article title', - 'pubDate': 'Mon, 26 Jun 2023 13:40:01 GMT', - 'summary': 'A wild summary has appeared!', + "link": "http://example.org/bla", + "title": "the article title", + "pubDate": "Mon, 26 Jun 2023 13:40:01 GMT", + "summary": "A wild summary has appeared!", } ] } # Setup the items list contents - with patch('feedparser.parse', return_value=items): + with patch("feedparser.parse", return_value=items): dataset.items_list - with patch('requests.get', return_value=Mock(content=contents)): - assert dataset.process_entry('http://example.org/bla').to_dict() == { - 'authors': ['Ameya Daigavane', 'Balaraman Ravindran', 'Gaurav Aggarwal'], - 'bibliography': [ + with patch("requests.get", return_value=Mock(content=contents)): + assert dataset.process_entry("http://example.org/bla").to_dict() == { + "authors": ["Ameya Daigavane", "Balaraman Ravindran", "Gaurav Aggarwal"], + "bibliography": [ + { + "link": "https://doi.org/10.23915/distill.00033", + "title": "A Gentle Introduction to Graph Neural Networks", + }, { - 'link': 'https://doi.org/10.23915/distill.00033', - 'title': 'A Gentle Introduction to Graph Neural Networks' - }, { - 'link': 'http://jmlr.org/papers/v11/vishwanathan10a.html', - 'title': 'Graph Kernels' - } + "link": "http://jmlr.org/papers/v11/vishwanathan10a.html", + "title": "Graph Kernels", + }, ], - 'date_published': '2023-06-26T13:40:01Z', - 'doi': '10.23915/distill.00032', - 'id': None, - 'journal_ref': 'distill-pub', - 'source': 'distill', - 'source_type': 'blog', - 'summaries': ['A wild summary has appeared!'], - 'text': 'bla bla [a link](bla.com) ble', - 'title': 'the article title', - 'url': 'http://example.org/bla', + "date_published": "2023-06-26T13:40:01Z", + "doi": "10.23915/distill.00032", + "id": None, + "journal_ref": "distill-pub", + "source": "distill", + "source_type": "blog", + "summaries": ["A wild summary has appeared!"], + "text": "bla bla [a link](bla.com) ble", + "title": "the article title", + "url": "http://example.org/bla", } diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/test_greater_wrong.py index e9f2dc85..2f89bac0 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/test_greater_wrong.py @@ -6,7 +6,9 @@ import pytest from align_data.sources.greaterwrong.greaterwrong import ( - fetch_LW_tags, fetch_ea_forum_topics, GreaterWrong + fetch_LW_tags, + fetch_ea_forum_topics, + GreaterWrong, ) @@ -21,8 +23,8 @@ def test_fetch_LW_tags():
""" - with patch('requests.get', return_value=Mock(content=contents)): - assert fetch_LW_tags('http://url.com') == {'tag3', 'tag2', 'tag1'} + with patch("requests.get", return_value=Mock(content=contents)): + assert fetch_LW_tags("http://url.com") == {"tag3", "tag2", "tag1"} def test_fetch_ea_forum_topics(): @@ -34,43 +36,57 @@ def test_fetch_ea_forum_topics(): ignored
""" - with patch('requests.get', return_value=Mock(content=contents)): - assert fetch_ea_forum_topics('http://url.com') == {'tag3', 'tag2', 'tag1'} + with patch("requests.get", return_value=Mock(content=contents)): + assert fetch_ea_forum_topics("http://url.com") == {"tag3", "tag2", "tag1"} @pytest.fixture def dataset(tmp_path): - return GreaterWrong(name='bla', base_url='http://example.com', start_year=2013, min_karma=0, af=False) - - -@pytest.mark.parametrize('tags', ( - [{'name': 'tag1'}], - [{'name': 'tag1'}, {'name': 'other tag'}], - [{'name': 'tag1'}, {'name': 'tag2'}], - [{'name': 'tag2'}, {'name': 'bla'}], -)) + return GreaterWrong( + name="bla", + base_url="http://example.com", + start_year=2013, + min_karma=0, + af=False, + ) + + +@pytest.mark.parametrize( + "tags", + ( + [{"name": "tag1"}], + [{"name": "tag1"}, {"name": "other tag"}], + [{"name": "tag1"}, {"name": "tag2"}], + [{"name": "tag2"}, {"name": "bla"}], + ), +) def test_greaterwrong_tags_ok(dataset, tags): - dataset.ai_tags = {'tag1', 'tag2'} - assert dataset.tags_ok({'tags': tags}) - - -@pytest.mark.parametrize('tags', ( - [], - [{'title': 'tag1'}], - [{'name': 'tag3'}, {'name': 'tag5'}], - [{'name': 'bla'}], -)) + dataset.ai_tags = {"tag1", "tag2"} + assert dataset.tags_ok({"tags": tags}) + + +@pytest.mark.parametrize( + "tags", + ( + [], + [{"title": "tag1"}], + [{"name": "tag3"}, {"name": "tag5"}], + [{"name": "bla"}], + ), +) def test_greaterwrong_tags_ok_missing(dataset, tags): - dataset.ai_tags = {'tag1', 'tag2'} - assert not dataset.tags_ok({'tags': tags}) + dataset.ai_tags = {"tag1", "tag2"} + assert not dataset.tags_ok({"tags": tags}) def test_greaterwrong_get_item_key(dataset): - assert dataset.get_item_key({'pageUrl': 'item key'}) == 'item key' + assert dataset.get_item_key({"pageUrl": "item key"}) == "item key" def test_greaterwrong_get_published_date(dataset): - assert dataset._get_published_date({'postedAt': '2021/02/01'}) == parse('2021-02-01T00:00:00Z') + assert dataset._get_published_date({"postedAt": "2021/02/01"}) == parse( + "2021-02-01T00:00:00Z" + ) def test_greaterwrong_get_published_date_missing(dataset): @@ -78,12 +94,13 @@ def test_greaterwrong_get_published_date_missing(dataset): def test_items_list_no_previous(dataset): - dataset.ai_tags = {'tag1', 'tag2'} + dataset.ai_tags = {"tag1", "tag2"} def make_item(date): return { - 'htmlBody': f'body {date.isoformat()}', 'tags': [{'name': 'tag1'}], - 'postedAt': date.isoformat() + "htmlBody": f"body {date.isoformat()}", + "tags": [{"name": "tag1"}], + "postedAt": date.isoformat(), } # Pretend that a new post drops every month @@ -91,30 +108,34 @@ def fetcher(next_date): results = [] date = parse(next_date).replace(tzinfo=pytz.UTC) - if date < parse('2015-01-01 00:00:00+00:00'): + if date < parse("2015-01-01 00:00:00+00:00"): # Pretend that graphql returns 3 items at once results = [ make_item(date + timedelta(days=30)), make_item(date + timedelta(days=60)), make_item(date + timedelta(days=90)), ] - return {'results': results} + return {"results": results} - with patch.object(dataset, 'fetch_posts', fetcher): - with patch.object(dataset, 'make_query', lambda next_date: next_date): + with patch.object(dataset, "fetch_posts", fetcher): + with patch.object(dataset, "make_query", lambda next_date: next_date): assert list(dataset.items_list) == [ - make_item(datetime(dataset.start_year, 1, 1).replace(tzinfo=pytz.UTC) + timedelta(days=i*30)) + make_item( + datetime(dataset.start_year, 1, 1).replace(tzinfo=pytz.UTC) + + timedelta(days=i * 30) + ) for i in range(1, 28) ] def test_items_list_with_previous_items(dataset): - dataset.ai_tags = {'tag1', 'tag2'} + dataset.ai_tags = {"tag1", "tag2"} def make_item(date): return { - 'htmlBody': f'body {date.isoformat()}', 'tags': [{'name': 'tag1'}], - 'postedAt': date.isoformat() + "htmlBody": f"body {date.isoformat()}", + "tags": [{"name": "tag1"}], + "postedAt": date.isoformat(), } # Pretend that a new post drops every month @@ -122,55 +143,60 @@ def fetcher(next_date): results = [] date = parse(next_date).replace(tzinfo=pytz.UTC) - if date < parse('2015-01-01 00:00:00+00:00'): + if date < parse("2015-01-01 00:00:00+00:00"): # Pretend that graphql returns 3 items at once results = [ make_item(date + timedelta(days=30)), make_item(date + timedelta(days=60)), make_item(date + timedelta(days=90)), ] - return {'results': results} - - mock_items = (i for i in [Mock(date_published=datetime.fromisoformat('2014-12-12T01:23:45'))]) - with patch.object(dataset, 'fetch_posts', fetcher): - with patch.object(dataset, 'make_query', lambda next_date: next_date): - with patch.object(dataset, 'read_entries', return_value=mock_items): + return {"results": results} + + mock_items = ( + i for i in [Mock(date_published=datetime.fromisoformat("2014-12-12T01:23:45"))] + ) + with patch.object(dataset, "fetch_posts", fetcher): + with patch.object(dataset, "make_query", lambda next_date: next_date): + with patch.object(dataset, "read_entries", return_value=mock_items): # All items that are older than the newest item in the jsonl file are ignored assert list(dataset.items_list) == [ - make_item(datetime(2014, 12, 12, 1, 23, 45).replace(tzinfo=pytz.UTC) + timedelta(days=i*30)) + make_item( + datetime(2014, 12, 12, 1, 23, 45).replace(tzinfo=pytz.UTC) + + timedelta(days=i * 30) + ) for i in range(1, 4) - ] + ] def test_process_entry(dataset): entry = { - 'coauthors': [{'displayName': 'John Snow'}, {'displayName': 'Mr Blobby'}], - 'user': {'displayName': 'Me'}, - 'title': 'The title', - 'pageUrl': 'http://example.com/bla', - 'modifiedAt': '2001-02-10', - 'postedAt': '2012/02/01 12:23:34', - 'htmlBody': '\n\n bla bla a link ', - 'voteCount': 12, - 'baseScore': 32, - 'tags': [{'name': 'tag1'}, {'name': 'tag2'}], - 'wordCount': 123, - 'commentCount': 423, + "coauthors": [{"displayName": "John Snow"}, {"displayName": "Mr Blobby"}], + "user": {"displayName": "Me"}, + "title": "The title", + "pageUrl": "http://example.com/bla", + "modifiedAt": "2001-02-10", + "postedAt": "2012/02/01 12:23:34", + "htmlBody": '\n\n bla bla a link ', + "voteCount": 12, + "baseScore": 32, + "tags": [{"name": "tag1"}, {"name": "tag2"}], + "wordCount": 123, + "commentCount": 423, } assert dataset.process_entry(entry).to_dict() == { - 'authors': ['Me', 'John Snow', 'Mr Blobby'], - 'comment_count': 423, - 'date_published': '2012-02-01T12:23:34Z', - 'id': None, - 'karma': 32, - 'modified_at': '2001-02-10', - 'source': 'bla', - 'source_type': 'GreaterWrong', - 'summaries': [], - 'tags': ['tag1', 'tag2'], - 'text': 'bla bla [a link](bla.com)', - 'title': 'The title', - 'url': 'http://example.com/bla', - 'votes': 12, - 'words': 123, + "authors": ["Me", "John Snow", "Mr Blobby"], + "comment_count": 423, + "date_published": "2012-02-01T12:23:34Z", + "id": None, + "karma": 32, + "modified_at": "2001-02-10", + "source": "bla", + "source_type": "GreaterWrong", + "summaries": [], + "tags": ["tag1", "tag2"], + "text": "bla bla [a link](bla.com)", + "title": "The title", + "url": "http://example.com/bla", + "votes": 12, + "words": 123, } diff --git a/tests/align_data/test_stampy.py b/tests/align_data/test_stampy.py index c3694086..5d4500b5 100644 --- a/tests/align_data/test_stampy.py +++ b/tests/align_data/test_stampy.py @@ -5,44 +5,49 @@ def test_validate_coda_token(): - dataset = Stampy(name='bla') - with patch('align_data.stampy.stampy.CODA_TOKEN', None): - with patch('sys.exit') as mock: + dataset = Stampy(name="bla") + with patch("align_data.stampy.stampy.CODA_TOKEN", None): + with patch("sys.exit") as mock: dataset.setup() assert mock.called_once_with(1) def test_get_item_key(): - dataset = Stampy(name='bla') - assert dataset.get_item_key({'Question': 'Why not just?'}) == 'Why\nnot just?' + dataset = Stampy(name="bla") + assert ( + dataset.get_item_key({"Question": "Why not just?"}) + == "Why\nnot just?" + ) def test_get_published_date(): - dataset = Stampy(name='bla') - assert dataset._get_published_date({'Doc Last Edited': '2012/01/03 12:23:32'}) == parse('2012-01-03T12:23:32Z') + dataset = Stampy(name="bla") + assert dataset._get_published_date( + {"Doc Last Edited": "2012/01/03 12:23:32"} + ) == parse("2012-01-03T12:23:32Z") def test_get_published_date_missing(): - dataset = Stampy(name='bla') - assert dataset._get_published_date({'Doc Last Edited': ''}) == None + dataset = Stampy(name="bla") + assert dataset._get_published_date({"Doc Last Edited": ""}) == None def test_process_entry(): - dataset = Stampy(name='bla') + dataset = Stampy(name="bla") entry = { - 'Question': 'Why not just?', - 'Rich Text': 'bla bla bla', - 'UI ID': '1234', - 'Doc Last Edited': '2012-02-03', + "Question": "Why not just?", + "Rich Text": "bla bla bla", + "UI ID": "1234", + "Doc Last Edited": "2012-02-03", } assert dataset.process_entry(entry).to_dict() == { - 'authors': ['Stampy aisafety.info'], - 'date_published': '2012-02-03T00:00:00Z', - 'id': None, - 'source': 'bla', - 'source_type': 'markdown', - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'Why\nnot just?', - 'url': 'https://aisafety.info?state=1234', + "authors": ["Stampy aisafety.info"], + "date_published": "2012-02-03T00:00:00Z", + "id": None, + "source": "bla", + "source_type": "markdown", + "summaries": [], + "text": "bla bla bla", + "title": "Why\nnot just?", + "url": "https://aisafety.info?state=1234", } diff --git a/tests/align_data/test_youtube.py b/tests/align_data/test_youtube.py index 093b0f2c..bcb720e8 100644 --- a/tests/align_data/test_youtube.py +++ b/tests/align_data/test_youtube.py @@ -1,84 +1,83 @@ from datetime import datetime from unittest.mock import patch, Mock import pytest -from align_data.sources.youtube.youtube import YouTubeDataset, YouTubeChannelDataset, YouTubePlaylistDataset -from youtube_transcript_api._errors import NoTranscriptFound, VideoUnavailable, TranscriptsDisabled +from align_data.sources.youtube.youtube import ( + YouTubeDataset, + YouTubeChannelDataset, + YouTubePlaylistDataset, +) +from youtube_transcript_api._errors import ( + NoTranscriptFound, + VideoUnavailable, + TranscriptsDisabled, +) @pytest.fixture def transcriber(): transcriber = Mock() transcriber.list_transcripts.return_value.find_transcript.return_value.fetch.return_value = [ - {'text': 'bla bla'}, - {'text': 'second line'}, - {'text': 'ble ble'}, + {"text": "bla bla"}, + {"text": "second line"}, + {"text": "ble ble"}, ] - with patch('align_data.sources.youtube.youtube.YouTubeTranscriptApi', transcriber): + with patch("align_data.sources.youtube.youtube.YouTubeTranscriptApi", transcriber): yield def test_key_required(): - dataset = YouTubeDataset(name='asd') + dataset = YouTubeDataset(name="asd") with pytest.raises(ValueError, match="No YOUTUBE_API_KEY provided"): dataset.setup() def test_next_page_empty_by_default(): - dataset = YouTubeDataset(name='asd') - assert not dataset.next_page('collection id', 'token')['items'] - - -@pytest.mark.parametrize('item', ( - { - 'kind': 'youtube#searchResult', - 'id': { - 'kind': 'youtube#video', - 'videoId': 'your_video_id' - } - }, - { - 'kind': 'youtube#playlistItem', - 'snippet': { - 'resourceId': { - 'kind': 'youtube#video', - 'videoId': 'your_video_id' - } - } - } -)) + dataset = YouTubeDataset(name="asd") + assert not dataset.next_page("collection id", "token")["items"] + + +@pytest.mark.parametrize( + "item", + ( + { + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": "your_video_id"}, + }, + { + "kind": "youtube#playlistItem", + "snippet": { + "resourceId": {"kind": "youtube#video", "videoId": "your_video_id"} + }, + }, + ), +) def test_get_id_with_id(item): dataset = YouTubeDataset(name="bla") result = dataset._get_id(item) - assert result == 'your_video_id' - - -@pytest.mark.parametrize('item', ( - {'bla': 'bla'}, - { - 'kind': 'invalid_kind', - 'id': { - 'kind': 'youtube#video', - 'videoId': 'your_video_id' - } - }, - { - 'kind': 'youtube#searchResult', - 'id': { - 'kind': 'bla bla bla', - 'videoId': 'your_video_id' - } - }, - { - 'kind': 'youtube#playlistItem', - 'snippet': { - 'resourceId': { - 'kind': 'invalid_kind', - 'videoId': 'your_video_id' - } - } - } -)) + assert result == "your_video_id" + + +@pytest.mark.parametrize( + "item", + ( + {"bla": "bla"}, + { + "kind": "invalid_kind", + "id": {"kind": "youtube#video", "videoId": "your_video_id"}, + }, + { + "kind": "youtube#searchResult", + "id": {"kind": "bla bla bla", "videoId": "your_video_id"}, + }, + { + "kind": "youtube#playlistItem", + "snippet": { + "resourceId": {"kind": "invalid_kind", "videoId": "your_video_id"} + }, + }, + ), +) def test_get_id_with_invalid_id(item): dataset = YouTubeDataset(name="bla") result = dataset._get_id(item) @@ -87,7 +86,7 @@ def test_get_id_with_invalid_id(item): def test_fetch_videos_default(): dataset = YouTubeDataset(name="bla") - assert list(dataset.fetch_videos('collection')) == [] + assert list(dataset.fetch_videos("collection")) == [] def test_fetch_videos_with_next_page_token(): @@ -95,20 +94,43 @@ def test_fetch_videos_with_next_page_token(): items = [ { - 'items': [{'kind': 'youtube#searchResult', 'id': {'kind': 'youtube#video', 'videoId': str(i)}} for i in range(0, 3)], - 'nextPageToken': "1" - }, { - 'items': [{'kind': 'youtube#searchResult', 'id': {'kind': 'youtube#video', 'videoId': str(i)}} for i in range(3, 6)], - 'nextPageToken': "2" - }, { - 'items': [{'kind': 'youtube#searchResult', 'id': {'kind': 'youtube#video', 'videoId': str(i)}} for i in range(6, 9)], - 'nextPageToken': None + "items": [ + { + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": str(i)}, + } + for i in range(0, 3) + ], + "nextPageToken": "1", + }, + { + "items": [ + { + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": str(i)}, + } + for i in range(3, 6) + ], + "nextPageToken": "2", + }, + { + "items": [ + { + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": str(i)}, + } + for i in range(6, 9) + ], + "nextPageToken": None, }, ] - with patch.object(dataset, 'next_page', side_effect=items): - assert list(dataset.fetch_videos('collection_id')) == [ - {'id': {'kind': 'youtube#video', 'videoId': str(i)}, 'kind': 'youtube#searchResult'} + with patch.object(dataset, "next_page", side_effect=items): + assert list(dataset.fetch_videos("collection_id")) == [ + { + "id": {"kind": "youtube#video", "videoId": str(i)}, + "kind": "youtube#searchResult", + } for i in range(9) ] @@ -118,107 +140,143 @@ def test_fetch_videos_stops_when_no_next_page_token(): items = [ { - 'items': [{'kind': 'youtube#searchResult', 'id': {'kind': 'youtube#video', 'videoId': str(i)}} for i in range(0, 3)], - 'nextPageToken': None - }, { - 'items': [{'kind': 'youtube#searchResult', 'id': {'kind': 'youtube#video', 'videoId': str(i)}} for i in range(3, 6)], - 'nextPageToken': "2" - }, { - 'items': [{'kind': 'youtube#searchResult', 'id': {'kind': 'youtube#video', 'videoId': str(i)}} for i in range(6, 9)], - 'nextPageToken': "3" + "items": [ + { + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": str(i)}, + } + for i in range(0, 3) + ], + "nextPageToken": None, + }, + { + "items": [ + { + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": str(i)}, + } + for i in range(3, 6) + ], + "nextPageToken": "2", + }, + { + "items": [ + { + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": str(i)}, + } + for i in range(6, 9) + ], + "nextPageToken": "3", }, ] - with patch.object(dataset, 'next_page', side_effect=items): - assert list(dataset.fetch_videos('collection_id')) == [ - {'id': {'kind': 'youtube#video', 'videoId': str(i)}, 'kind': 'youtube#searchResult'} + with patch.object(dataset, "next_page", side_effect=items): + assert list(dataset.fetch_videos("collection_id")) == [ + { + "id": {"kind": "youtube#video", "videoId": str(i)}, + "kind": "youtube#searchResult", + } for i in range(3) ] def test_items_list(): dataset = YouTubeDataset(name="bla") - dataset.collection_ids = ['collection_id_1', 'collection_id_2'] + dataset.collection_ids = ["collection_id_1", "collection_id_2"] def fetcher(collection_id): return [ - {'id': {'kind': 'youtube#video', 'videoId': f'{collection_id}_{i}'}} + {"id": {"kind": "youtube#video", "videoId": f"{collection_id}_{i}"}} for i in range(3) ] - with patch.object(dataset, 'fetch_videos', fetcher): + with patch.object(dataset, "fetch_videos", fetcher): assert list(dataset.items_list) == [ - {'id': {'kind': 'youtube#video', 'videoId': f'collection_id_1_0'}}, - {'id': {'kind': 'youtube#video', 'videoId': f'collection_id_1_1'}}, - {'id': {'kind': 'youtube#video', 'videoId': f'collection_id_1_2'}}, - {'id': {'kind': 'youtube#video', 'videoId': f'collection_id_2_0'}}, - {'id': {'kind': 'youtube#video', 'videoId': f'collection_id_2_1'}}, - {'id': {'kind': 'youtube#video', 'videoId': f'collection_id_2_2'}}, + {"id": {"kind": "youtube#video", "videoId": f"collection_id_1_0"}}, + {"id": {"kind": "youtube#video", "videoId": f"collection_id_1_1"}}, + {"id": {"kind": "youtube#video", "videoId": f"collection_id_1_2"}}, + {"id": {"kind": "youtube#video", "videoId": f"collection_id_2_0"}}, + {"id": {"kind": "youtube#video", "videoId": f"collection_id_2_1"}}, + {"id": {"kind": "youtube#video", "videoId": f"collection_id_2_2"}}, ] def test_get_item_key(): dataset = YouTubeDataset(name="bla") - video = {'id': {'kind': 'youtube#video', 'videoId': 'your_video_id'}, 'kind': 'youtube#searchResult'} - assert dataset.get_item_key(video) == 'https://www.youtube.com/watch?v=your_video_id' - - -@pytest.mark.parametrize('error', ( - NoTranscriptFound('video_id', 'language_codes', 'transcript_data'), - VideoUnavailable('video_id'), - TranscriptsDisabled('video_id'), -)) + video = { + "id": {"kind": "youtube#video", "videoId": "your_video_id"}, + "kind": "youtube#searchResult", + } + assert ( + dataset.get_item_key(video) == "https://www.youtube.com/watch?v=your_video_id" + ) + + +@pytest.mark.parametrize( + "error", + ( + NoTranscriptFound("video_id", "language_codes", "transcript_data"), + VideoUnavailable("video_id"), + TranscriptsDisabled("video_id"), + ), +) def test_get_contents_with_no_transcript_found(error): dataset = YouTubeDataset(name="bla") - video = {'id': {'kind': 'youtube#video', 'videoId': "bla_bla"}, 'kind': 'youtube#searchResult'} + video = { + "id": {"kind": "youtube#video", "videoId": "bla_bla"}, + "kind": "youtube#searchResult", + } transcriber = Mock() - transcriber.list_transcripts.return_value.find_transcript.return_value.fetch.side_effect = error + transcriber.list_transcripts.return_value.find_transcript.return_value.fetch.side_effect = ( + error + ) - with patch('align_data.sources.youtube.youtube.YouTubeTranscriptApi', transcriber): + with patch("align_data.sources.youtube.youtube.YouTubeTranscriptApi", transcriber): assert dataset._get_contents(video) is None def test_get_contents(): dataset = YouTubeDataset(name="bla") - video = {'id': {'kind': 'youtube#video', 'videoId': "bla_bla"}, 'kind': 'youtube#searchResult'} + video = { + "id": {"kind": "youtube#video", "videoId": "bla_bla"}, + "kind": "youtube#searchResult", + } transcriber = Mock() transcriber.list_transcripts.return_value.find_transcript.return_value.fetch.return_value = [ - {'text': 'bla bla'}, - {'text': 'second line'}, - {'text': 'ble ble'}, + {"text": "bla bla"}, + {"text": "second line"}, + {"text": "ble ble"}, ] - with patch('align_data.sources.youtube.youtube.YouTubeTranscriptApi', transcriber): - assert dataset._get_contents(video) == 'bla bla\nsecond line\nble ble' + with patch("align_data.sources.youtube.youtube.YouTubeTranscriptApi", transcriber): + assert dataset._get_contents(video) == "bla bla\nsecond line\nble ble" def test_extract_authors_with_authors_defined(): dataset = YouTubeDataset(name="bla") - video = {'snippet': {'channelTitle': 'channel_title'}} + video = {"snippet": {"channelTitle": "channel_title"}} - dataset.authors = ['author_1', 'author_2'] - assert dataset.extract_authors(video) == ['author_1', 'author_2'] + dataset.authors = ["author_1", "author_2"] + assert dataset.extract_authors(video) == ["author_1", "author_2"] def test_extract_authors_with_no_authors_defined(): dataset = YouTubeDataset(name="bla") - video = {'snippet': {'channelTitle': 'channel title'}} + video = {"snippet": {"channelTitle": "channel title"}} dataset.authors = [] - assert dataset.extract_authors(video) == ['channel title'] + assert dataset.extract_authors(video) == ["channel title"] def test_process_entry_with_valid_entry(transcriber): dataset = YouTubeDataset(name="bla") video = { - 'kind': 'youtube#searchResult', - 'id': {'kind': 'youtube#video', 'videoId': "bla_bla"}, - 'snippet': { - 'title': 'bla bla title', - 'channelTitle': 'This is a pen!' - } + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": "bla_bla"}, + "snippet": {"title": "bla bla title", "channelTitle": "This is a pen!"}, } assert dataset.process_entry(video).to_dict() == { @@ -235,86 +293,92 @@ def test_process_entry_with_valid_entry(transcriber): def test_channel_collection_ids(): - dataset = YouTubeChannelDataset(name='bla', channel_id='a channel id') - assert dataset.collection_ids == ['a channel id'] + dataset = YouTubeChannelDataset(name="bla", channel_id="a channel id") + assert dataset.collection_ids == ["a channel id"] def test_channel_published_date(): - dataset = YouTubeChannelDataset(name='bla', channel_id='a channel id') + dataset = YouTubeChannelDataset(name="bla", channel_id="a channel id") video = { - 'kind': 'youtube#searchResult', - 'id': {'kind': 'youtube#video', 'videoId': "bla_bla"}, - 'snippet': { - 'title': 'bla bla title', - 'channelTitle': 'This is a pen!', - 'publishTime': '2022-01-02T03:04:05Z', - } + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": "bla_bla"}, + "snippet": { + "title": "bla bla title", + "channelTitle": "This is a pen!", + "publishTime": "2022-01-02T03:04:05Z", + }, } - assert dataset._get_published_date(video).isoformat() == '2022-01-02T03:04:05+00:00' + assert dataset._get_published_date(video).isoformat() == "2022-01-02T03:04:05+00:00" def test_channel_process_item(transcriber): - dataset = YouTubeChannelDataset(name='bla', channel_id='a channel id') + dataset = YouTubeChannelDataset(name="bla", channel_id="a channel id") video = { - 'kind': 'youtube#searchResult', - 'id': {'kind': 'youtube#video', 'videoId': "bla_bla"}, - 'snippet': { - 'title': 'bla bla title', - 'channelTitle': 'This is a pen!', - 'publishTime': '2022-01-02T03:04:05Z', - } + "kind": "youtube#searchResult", + "id": {"kind": "youtube#video", "videoId": "bla_bla"}, + "snippet": { + "title": "bla bla title", + "channelTitle": "This is a pen!", + "publishTime": "2022-01-02T03:04:05Z", + }, } assert dataset.process_entry(video).to_dict() == { - 'authors': ['This is a pen!'], - 'date_published': '2022-01-02T03:04:05Z', - 'id': None, - 'source': 'bla', - 'source_type': 'youtube', - 'summaries': [], - 'text': 'bla bla\nsecond line\nble ble', - 'title': 'bla bla title', - 'url': 'https://www.youtube.com/watch?v=bla_bla' + "authors": ["This is a pen!"], + "date_published": "2022-01-02T03:04:05Z", + "id": None, + "source": "bla", + "source_type": "youtube", + "summaries": [], + "text": "bla bla\nsecond line\nble ble", + "title": "bla bla title", + "url": "https://www.youtube.com/watch?v=bla_bla", } def test_playlist_collection_ids(): - dataset = YouTubePlaylistDataset(name='bla', playlist_ids=['a list id', 'another id']) - assert dataset.collection_ids == ['a list id', 'another id'] + dataset = YouTubePlaylistDataset( + name="bla", playlist_ids=["a list id", "another id"] + ) + assert dataset.collection_ids == ["a list id", "another id"] def test_playlist_published_date(): - dataset = YouTubePlaylistDataset(name='bla', playlist_ids=['a list id', 'another id']) + dataset = YouTubePlaylistDataset( + name="bla", playlist_ids=["a list id", "another id"] + ) video = { - 'kind': 'youtube#playlistItem', - 'snippet': { - 'resourceId': {'kind': 'youtube#video', 'videoId': "bla_bla"}, - 'title': 'bla bla title', - 'channelTitle': 'This is a pen!', - 'publishedAt': '2022-01-02T03:04:05Z', - } + "kind": "youtube#playlistItem", + "snippet": { + "resourceId": {"kind": "youtube#video", "videoId": "bla_bla"}, + "title": "bla bla title", + "channelTitle": "This is a pen!", + "publishedAt": "2022-01-02T03:04:05Z", + }, } - assert dataset._get_published_date(video).isoformat() == '2022-01-02T03:04:05+00:00' + assert dataset._get_published_date(video).isoformat() == "2022-01-02T03:04:05+00:00" def test_channel_process_item(transcriber): - dataset = YouTubePlaylistDataset(name='bla', playlist_ids=['a list id', 'another id']) + dataset = YouTubePlaylistDataset( + name="bla", playlist_ids=["a list id", "another id"] + ) video = { - 'kind': 'youtube#playlistItem', - 'snippet': { - 'resourceId': {'kind': 'youtube#video', 'videoId': "bla_bla"}, - 'title': 'bla bla title', - 'channelTitle': 'This is a pen!', - 'publishedAt': '2022-01-02T03:04:05Z', - } + "kind": "youtube#playlistItem", + "snippet": { + "resourceId": {"kind": "youtube#video", "videoId": "bla_bla"}, + "title": "bla bla title", + "channelTitle": "This is a pen!", + "publishedAt": "2022-01-02T03:04:05Z", + }, } assert dataset.process_entry(video).to_dict() == { - 'authors': ['This is a pen!'], - 'date_published': '2022-01-02T03:04:05Z', - 'id': None, - 'source': 'bla', - 'source_type': 'youtube', - 'summaries': [], - 'text': 'bla bla\nsecond line\nble ble', - 'title': 'bla bla title', - 'url': 'https://www.youtube.com/watch?v=bla_bla' + "authors": ["This is a pen!"], + "date_published": "2022-01-02T03:04:05Z", + "id": None, + "source": "bla", + "source_type": "youtube", + "summaries": [], + "text": "bla bla\nsecond line\nble ble", + "title": "bla bla title", + "url": "https://www.youtube.com/watch?v=bla_bla", } diff --git a/tests/conftest.py b/tests/conftest.py index 5373b4ec..43545757 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,8 @@ from align_data.common.alignment_dataset import make_session -@pytest.fixture(autouse=True, scope='session') +@pytest.fixture(autouse=True, scope="session") def mock_db(): # This just mocks out all db calls, nothing more - with patch('align_data.common.alignment_dataset.make_session'): + with patch("align_data.common.alignment_dataset.make_session"): yield diff --git a/tests/print_date_published.py b/tests/print_date_published.py index deed00f8..d3ba46fd 100644 --- a/tests/print_date_published.py +++ b/tests/print_date_published.py @@ -1,32 +1,37 @@ import json from dateutil.parser import parse + def print_date_published(file_path, n=10): - with open(file_path, 'r') as file: + with open(file_path, "r") as file: for i, line in enumerate(file): if i >= n: break entry = json.loads(line) - print(entry.get('date_published')) + print(entry.get("date_published")) + def validate_date_format(file_path, keys_to_print): - with open(file_path, 'r') as file: + with open(file_path, "r") as file: for i, line in enumerate(file): entry = json.loads(line) - date_published = entry.get('date_published') + date_published = entry.get("date_published") try: # Try to parse the date_published string into a datetime object parse(date_published) except ValueError: - print(f'Row {i}: date_published is NOT in a valid format: {date_published}') + print( + f"Row {i}: date_published is NOT in a valid format: {date_published}" + ) for key in keys_to_print: - print(f' {key}: {entry.get(key)}') + print(f" {key}: {entry.get(key)}") + # replace with your file path file_path = "data/distill.jsonl" # list of keys to print when an invalid date format is found -keys_to_print = ['url', 'title', 'id'] +keys_to_print = ["url", "title", "id"] # uncomment to print date_published for the first 10 entries print_date_published(file_path) diff --git a/upload_to_huggingface.py b/upload_to_huggingface.py index 86fb6c48..c8711406 100644 --- a/upload_to_huggingface.py +++ b/upload_to_huggingface.py @@ -10,17 +10,19 @@ from huggingface_hub import HfApi -GDOCS_FOLDER = 'https://drive.google.com/drive/folders/1n4i0J4CuSfNmrUkKPyTFKJU0XWYLtRF8' -PRIVATE_FILES = ['ebooks.jsonl'] +GDOCS_FOLDER = ( + "https://drive.google.com/drive/folders/1n4i0J4CuSfNmrUkKPyTFKJU0XWYLtRF8" +) +PRIVATE_FILES = ["ebooks.jsonl"] def upload(api, filename, repo_name): - print(f'Uploading {filename} as {repo_name}/{filename.name}') + print(f"Uploading {filename} as {repo_name}/{filename.name}") api.upload_file( path_or_fileobj=filename, path_in_repo=filename.name, - repo_id=f'StampyAI/{repo_name}', - repo_type='dataset' + repo_id=f"StampyAI/{repo_name}", + repo_type="dataset", ) @@ -36,7 +38,11 @@ def get_gdoc_names(url): return None _, id_name_type_iter = _parse_google_drive_file(url=url, content=res.text) - return [(id, name) for id, name, filetype in id_name_type_iter if name.endswith('.jsonl')] + return [ + (id, name) + for id, name, filetype in id_name_type_iter + if name.endswith(".jsonl") + ] def upload_data_file(api, name, id, repo_name): @@ -44,14 +50,16 @@ def upload_data_file(api, name, id, repo_name): If the file already exists locally, it will be used. Otherwise it will first be fetched from the GDrive. """ - data = Path('data/') + data = Path("data/") filename = data / name # Don't download it if it exists locally if not filename.exists(): - gdown.download(f'https://drive.google.com/uc?id={id}', str(filename), quiet=False) + gdown.download( + f"https://drive.google.com/uc?id={id}", str(filename), quiet=False + ) else: - print(f'Using local file at {filename}') + print(f"Using local file at {filename}") try: # Check that the dowloaded file really contains json lines @@ -64,12 +72,14 @@ def upload_data_file(api, name, id, repo_name): def download_file(repo_name, filename, api): - headers = {'Authorization': f'Bearer {api.token}'} - url = f'https://huggingface.co/datasets/StampyAI/{repo_name}/raw/main/{filename.name}' + headers = {"Authorization": f"Bearer {api.token}"} + url = ( + f"https://huggingface.co/datasets/StampyAI/{repo_name}/raw/main/{filename.name}" + ) response = requests.get(url, headers=headers) if response.status_code == 200: - with open(filename, 'wb') as file: + with open(filename, "wb") as file: file.write(response.content) @@ -84,25 +94,25 @@ def update_readme(api, files, repo_name): repo.mkdir(exist_ok=True) # Fetch the current README and dataset script - for filename in ['README.md', f'{repo_name}.py']: + for filename in ["README.md", f"{repo_name}.py"]: download_file(repo_name, repo / filename, api) # Copy over all jsonl files that have been updated, and update the README to have the # current metadata for filename in files: - target = Path('data') / filename + target = Path("data") / filename (repo / filename).write_text(target.read_text()) - output = subprocess.check_output([ - 'datasets-cli', 'test', repo_name, '--save_info', f'--name={target.stem}' - ]) + output = subprocess.check_output( + ["datasets-cli", "test", repo_name, "--save_info", f"--name={target.stem}"] + ) # Now upload the updated README - upload(api, repo / 'README.md', repo_name) + upload(api, repo / "README.md", repo_name) if __name__ == "__main__": if len(sys.argv) < 2 or not sys.argv[1]: - print('Usage: python upload_to_huggingface ') + print("Usage: python upload_to_huggingface ") sys.exit(2) token = sys.argv[1] @@ -110,19 +120,20 @@ def update_readme(api, files, repo_name): api = HfApi(token=token) files = get_gdoc_names(GDOCS_FOLDER) - if len(sys.argv) > 2 and sys.argv[2] != 'all': - files = [item for item in files if item[1] == sys.argv[2] + '.jsonl'] + if len(sys.argv) > 2 and sys.argv[2] != "all": + files = [item for item in files if item[1] == sys.argv[2] + ".jsonl"] - data = Path('data/') + data = Path("data/") for id, name in files: - upload_data_file(api, name, id, 'ard-private') + upload_data_file(api, name, id, "ard-private") if name not in PRIVATE_FILES: - upload_data_file(api, name, id, 'alignment-research-dataset') + upload_data_file(api, name, id, "alignment-research-dataset") update_readme( - api, [name for _, name in files if name not in PRIVATE_FILES], - 'alignment-research-dataset' + api, + [name for _, name in files if name not in PRIVATE_FILES], + "alignment-research-dataset", ) - update_readme(api, [name for _, name in files], 'ard-private') + update_readme(api, [name for _, name in files], "ard-private") - print('done') + print("done")