From 6b024d682431c890dfeff42f99579f240d211033 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Tue, 8 Aug 2023 22:45:09 -0400 Subject: [PATCH] post-black reformatting --- align_data/pinecone/text_splitter.py | 23 ++++++------ align_data/pinecone/update_pinecone.py | 13 +++---- align_data/postprocess/postprocess.py | 32 ++++++----------- align_data/settings.py | 6 ++-- align_data/sources/articles/articles.py | 4 +-- align_data/sources/articles/datasets.py | 15 ++++---- align_data/sources/articles/indices.py | 4 +-- align_data/sources/articles/pdf.py | 11 +++--- .../sources/arxiv_papers/arxiv_papers.py | 36 +++++++++---------- align_data/sources/blogs/blogs.py | 11 +++--- 10 files changed, 70 insertions(+), 85 deletions(-) diff --git a/align_data/pinecone/text_splitter.py b/align_data/pinecone/text_splitter.py index c732c99c..03c74b57 100644 --- a/align_data/pinecone/text_splitter.py +++ b/align_data/pinecone/text_splitter.py @@ -5,6 +5,10 @@ from nltk.tokenize import sent_tokenize +def default_truncate_function(string: str, length: int, from_end: bool = False) -> str: + return string[-length:] if from_end else string[:length] + + class ParagraphSentenceUnitTextSplitter(TextSplitter): """A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc). @@ -17,12 +21,8 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter): DEFAULT_MIN_CHUNK_SIZE = 900 DEFAULT_MAX_CHUNK_SIZE = 1100 DEFAULT_LENGTH_FUNCTION = lambda string: len(string) - DEFAULT_TRUNCATE_FUNCTION = ( - lambda string, length, from_end=False: string[-length:] - if from_end - else string[:length] - ) - + DEFAULT_TRUNCATE_FUNCTION = default_truncate_function + def __init__( self, min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, @@ -47,12 +47,8 @@ def split_text(self, text: str) -> List[str]: current_block += "\n\n" + paragraph block_length = self._length_function(current_block) - if ( - block_length > self.max_chunk_size - ): # current block is too large, truncate it - current_block = self._handle_large_paragraph( - current_block, blocks, paragraph - ) + if block_length > self.max_chunk_size: + current_block = self._handle_large_paragraph(current_block, blocks, paragraph) elif block_length >= self.min_chunk_size: blocks.append(current_block) current_block = "" @@ -65,7 +61,8 @@ def split_text(self, text: str) -> List[str]: def _handle_large_paragraph(self, current_block, blocks, paragraph): # Undo adding the whole paragraph - current_block = current_block[: -(len(paragraph) + 2)] # +2 accounts for "\n\n" + offset = len(paragraph) + 2 # +2 accounts for "\n\n" + current_block = current_block[:-offset] sentences = sent_tokenize(paragraph) for sentence in sentences: diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py index 649a7a2a..e3a91f90 100644 --- a/align_data/pinecone/update_pinecone.py +++ b/align_data/pinecone/update_pinecone.py @@ -27,6 +27,11 @@ logger = logging.getLogger(__name__) +# Define type aliases for the Callables +LengthFunctionType = Callable[[str], int] +TruncateFunctionType = Callable[[str, int], str] + + class PineconeEntry(BaseModel): id: str source: str @@ -65,12 +70,8 @@ def __init__( self, min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE, max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE, - length_function: Callable[ - [str], int - ] = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION, - truncate_function: Callable[ - [str, int], str - ] = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION, + length_function: LengthFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION, + truncate_function: TruncateFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION, ): self.min_chunk_size = min_chunk_size self.max_chunk_size = max_chunk_size diff --git a/align_data/postprocess/postprocess.py b/align_data/postprocess/postprocess.py index 1366bc11..cb16b7a5 100644 --- a/align_data/postprocess/postprocess.py +++ b/align_data/postprocess/postprocess.py @@ -1,4 +1,5 @@ # %% +from collections import defaultdict, Counter from dataclasses import dataclass import jsonlines from tqdm import tqdm @@ -6,8 +7,7 @@ from path import Path import pylab as plt - -# import seaborn as sns +import seaborn as sns import pandas as pd logger = logging.getLogger(__name__) @@ -24,30 +24,20 @@ class PostProcesser: def __init__(self) -> None: self.jsonl_list = sorted(self.jsonl_path.files("*.jsonl")) self.source_list = [path.name.split(".jsonl")[0] for path in self.jsonl_list] + self.all_stats = defaultdict(Counter) def compute_statistics(self) -> None: - self.all_stats = {key: {} for key in self.source_list} for source_name, path in tqdm(zip(self.source_list, self.jsonl_list)): with jsonlines.open(path) as reader: for obj in reader: - self.all_stats[source_name]["num_entries"] = ( - self.all_stats[source_name].get("num_entries", 0) + 1 - ) - self.all_stats[source_name]["num_tokens"] = self.all_stats[ - source_name - ].get("num_tokens", 0) + len(obj["text"].split()) - self.all_stats[source_name]["num_chars"] = self.all_stats[ - source_name - ].get("num_chars", 0) + len(obj["text"]) - self.all_stats[source_name]["num_words"] = self.all_stats[ - source_name - ].get("num_words", 0) + len(obj["text"].split()) - self.all_stats[source_name]["num_sentences"] = self.all_stats[ - source_name - ].get("num_sentences", 0) + len(obj["text"].split(".")) - self.all_stats[source_name]["num_paragraphs"] = self.all_stats[ - source_name - ].get("num_paragraphs", 0) + len(obj["text"].splitlines()) + text = obj['text'] + source_stats = self.all_stats[source_name] + source_stats["num_entries"] += 1 + source_stats["num_tokens"] += len(text.split()) # TODO: Use tokenizer + source_stats["num_chars"] += len(text) + source_stats["num_words"] += len(text.split()) + source_stats["num_sentences"] += len(text.split(".")) # TODO: Use NLTK/Spacy or similar + source_stats["num_paragraphs"] += len(text.splitlines()) def plot_statistics(self) -> None: all_df = pd.DataFrame(self.all_stats).T diff --git a/align_data/settings.py b/align_data/settings.py index 1244861f..07c8e17f 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -9,7 +9,7 @@ ON_SITE_TABLE = os.environ.get("CODA_ON_SITE_TABLE", "table-aOTSHIz_mN") ### GOOGLE DRIVE ### -PDFS_FOLDER_ID = os.environ.get("PDF_FOLDER_ID", "1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN") +PDFS_FOLDER_ID = os.environ.get("PDFS_FOLDER_ID", "1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN") ### GOOGLE SHEETS ### METADATA_SOURCE_SPREADSHEET = os.environ.get( @@ -41,9 +41,7 @@ OPENAI_EMBEDDINGS_DIMS = 1536 OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 -SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = ( - "sentence-transformers/multi-qa-mpnet-base-cos-v1" -) +SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768 ### PINECONE ### diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py index 941de037..3a5163dd 100644 --- a/align_data/sources/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -124,8 +124,8 @@ def check_new_articles(source_spreadsheet, source_sheet): seen_urls = { url for item in current.values() - for url in [item.get("url"), item.get("source_url")] - if url + for key in ("url", "source_url") + if (url := item.get(key)) is not None } indices_items = fetch_all() diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index 2f25a606..1a75db20 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -39,14 +39,15 @@ def is_val(val): @property def items_list(self): - logger.info( - f"Fetching https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=CS&gid={self.sheet_id}" - ) - df = pd.read_csv( - f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}" - ) + fetch_url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}" + log_url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=CS&gid={self.sheet_id}" + + logger.info(f"Fetching {log_url}") + + sheet_data = pd.read_csv(fetch_url) + return ( - item for item in df.itertuples() if not pd.isna(self.get_item_key(item)) + item for item in sheet_data.itertuples() if not pd.isna(self.get_item_key(item)) ) def get_item_key(self, item): diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py index 0cf2c45a..fde99a72 100644 --- a/align_data/sources/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -97,9 +97,7 @@ def format_far_ai(item): return { "title": get_text(item, ".article-title"), "url": f'https://www.safe.ai/research{item.select_one(".article-title a").get("href")}', - "source_url": item.select_one('div.btn-links a:-soup-contains("PDF")').get( - "href" - ), + "source_url": item.select_one('div.btn-links a:-soup-contains("PDF")').get("href"), "authors": ", ".join(i.text for i in item.select(".article-metadata a")), } diff --git a/align_data/sources/articles/pdf.py b/align_data/sources/articles/pdf.py index 2120dc56..c0cc81bf 100644 --- a/align_data/sources/articles/pdf.py +++ b/align_data/sources/articles/pdf.py @@ -98,8 +98,11 @@ def get_arxiv_link(doi): return None vals = [ - i for i in response.json().get("values") if i.get("type", "").upper() == "URL" + val + for val in response.json().get("values") + if val.get("type", "").upper() == "URL" ] + if not vals: return None return vals[0]["data"]["value"].replace("/abs/", "/pdf/") + ".pdf" @@ -197,10 +200,8 @@ def get_first_child(item): date_published = date_published.text.strip("()") text = "\n\n".join( - [ - MarkdownConverter().convert_soup(elem).strip() - for elem in contents.select("section.ltx_section") - ] + MarkdownConverter().convert_soup(elem).strip() + for elem in contents.select("section.ltx_section") ) return { diff --git a/align_data/sources/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py index 1a61ecc6..eef00d84 100644 --- a/align_data/sources/arxiv_papers/arxiv_papers.py +++ b/align_data/sources/arxiv_papers/arxiv_papers.py @@ -56,24 +56,22 @@ def process_entry(self, item) -> None: else: authors = paper.get("authors") or [] authors = [str(a).strip() for a in authors] - + return self.make_data_entry( - { - "url": self.get_item_key(item), - "source": self.name, - "source_type": paper["data_source"], - "title": self.is_val(item.title) or paper.get("title"), - "authors": authors, - "date_published": self._get_published_date( - self.is_val(item.date_published) or paper.get("date_published") - ), - "data_last_modified": str(metadata.updated), - "summary": metadata.summary.replace("\n", " "), - "author_comment": metadata.comment, - "journal_ref": metadata.journal_ref, - "doi": metadata.doi, - "primary_category": metadata.primary_category, - "categories": metadata.categories, - "text": paper["text"], - } + url=self.get_item_key(item), + source=self.name, + source_type=paper["data_source"], + title=self.is_val(item.title) or paper.get("title"), + authors=authors, + date_published=self._get_published_date( + self.is_val(item.date_published) or paper.get("date_published") + ), + data_last_modified=str(metadata.updated), + summary=metadata.summary.replace("\n", " "), + author_comment=metadata.comment, + journal_ref=metadata.journal_ref, + doi=metadata.doi, + primary_category=metadata.primary_category, + categories=metadata.categories, + text=paper["text"], ) diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index ae4439a9..294ee205 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -74,15 +74,16 @@ def _get_text(self, contents): return item_metadata(paper_link.get("href")).get("text") def extract_authors(self, article): - authors = article.select_one( - 'div:-soup-contains("Authors") + div .f-body-1' - ) or article.select_one('div:-soup-contains("Acknowledgments") + div .f-body-1') - if not authors: + author_selector = 'div:-soup-contains("Authors") + div .f-body-1' + ack_selector = 'div:-soup-contains("Acknowledgments") + div .f-body-1' + + authors_div = article.select_one(author_selector) or article.select_one(ack_selector) + if not authors_div: return [] return [ i.split("(")[0].strip() - for i in authors.select_one("p").children + for i in authors_div.select_one("p").children if not i.name ]