Skip to content

Commit

Permalink
post-black reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
henri123lemoine committed Aug 9, 2023
1 parent 0f1b1c8 commit 6b024d6
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 85 deletions.
23 changes: 10 additions & 13 deletions align_data/pinecone/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from nltk.tokenize import sent_tokenize


def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
return string[-length:] if from_end else string[:length]


class ParagraphSentenceUnitTextSplitter(TextSplitter):
"""A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc).
Expand All @@ -17,12 +21,8 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter):
DEFAULT_MIN_CHUNK_SIZE = 900
DEFAULT_MAX_CHUNK_SIZE = 1100
DEFAULT_LENGTH_FUNCTION = lambda string: len(string)
DEFAULT_TRUNCATE_FUNCTION = (
lambda string, length, from_end=False: string[-length:]
if from_end
else string[:length]
)

DEFAULT_TRUNCATE_FUNCTION = default_truncate_function

def __init__(
self,
min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE,
Expand All @@ -47,12 +47,8 @@ def split_text(self, text: str) -> List[str]:
current_block += "\n\n" + paragraph
block_length = self._length_function(current_block)

if (
block_length > self.max_chunk_size
): # current block is too large, truncate it
current_block = self._handle_large_paragraph(
current_block, blocks, paragraph
)
if block_length > self.max_chunk_size:
current_block = self._handle_large_paragraph(current_block, blocks, paragraph)
elif block_length >= self.min_chunk_size:
blocks.append(current_block)
current_block = ""
Expand All @@ -65,7 +61,8 @@ def split_text(self, text: str) -> List[str]:

def _handle_large_paragraph(self, current_block, blocks, paragraph):
# Undo adding the whole paragraph
current_block = current_block[: -(len(paragraph) + 2)] # +2 accounts for "\n\n"
offset = len(paragraph) + 2 # +2 accounts for "\n\n"
current_block = current_block[:-offset]

sentences = sent_tokenize(paragraph)
for sentence in sentences:
Expand Down
13 changes: 7 additions & 6 deletions align_data/pinecone/update_pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
logger = logging.getLogger(__name__)


# Define type aliases for the Callables
LengthFunctionType = Callable[[str], int]
TruncateFunctionType = Callable[[str, int], str]


class PineconeEntry(BaseModel):
id: str
source: str
Expand Down Expand Up @@ -65,12 +70,8 @@ def __init__(
self,
min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE,
max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE,
length_function: Callable[
[str], int
] = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION,
truncate_function: Callable[
[str, int], str
] = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION,
length_function: LengthFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION,
truncate_function: TruncateFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION,
):
self.min_chunk_size = min_chunk_size
self.max_chunk_size = max_chunk_size
Expand Down
32 changes: 11 additions & 21 deletions align_data/postprocess/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# %%
from collections import defaultdict, Counter
from dataclasses import dataclass
import jsonlines
from tqdm import tqdm
import logging
from path import Path

import pylab as plt

# import seaborn as sns
import seaborn as sns
import pandas as pd

logger = logging.getLogger(__name__)
Expand All @@ -24,30 +24,20 @@ class PostProcesser:
def __init__(self) -> None:
self.jsonl_list = sorted(self.jsonl_path.files("*.jsonl"))
self.source_list = [path.name.split(".jsonl")[0] for path in self.jsonl_list]
self.all_stats = defaultdict(Counter)

def compute_statistics(self) -> None:
self.all_stats = {key: {} for key in self.source_list}
for source_name, path in tqdm(zip(self.source_list, self.jsonl_list)):
with jsonlines.open(path) as reader:
for obj in reader:
self.all_stats[source_name]["num_entries"] = (
self.all_stats[source_name].get("num_entries", 0) + 1
)
self.all_stats[source_name]["num_tokens"] = self.all_stats[
source_name
].get("num_tokens", 0) + len(obj["text"].split())
self.all_stats[source_name]["num_chars"] = self.all_stats[
source_name
].get("num_chars", 0) + len(obj["text"])
self.all_stats[source_name]["num_words"] = self.all_stats[
source_name
].get("num_words", 0) + len(obj["text"].split())
self.all_stats[source_name]["num_sentences"] = self.all_stats[
source_name
].get("num_sentences", 0) + len(obj["text"].split("."))
self.all_stats[source_name]["num_paragraphs"] = self.all_stats[
source_name
].get("num_paragraphs", 0) + len(obj["text"].splitlines())
text = obj['text']
source_stats = self.all_stats[source_name]
source_stats["num_entries"] += 1
source_stats["num_tokens"] += len(text.split()) # TODO: Use tokenizer
source_stats["num_chars"] += len(text)
source_stats["num_words"] += len(text.split())
source_stats["num_sentences"] += len(text.split(".")) # TODO: Use NLTK/Spacy or similar
source_stats["num_paragraphs"] += len(text.splitlines())

def plot_statistics(self) -> None:
all_df = pd.DataFrame(self.all_stats).T
Expand Down
6 changes: 2 additions & 4 deletions align_data/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ON_SITE_TABLE = os.environ.get("CODA_ON_SITE_TABLE", "table-aOTSHIz_mN")

### GOOGLE DRIVE ###
PDFS_FOLDER_ID = os.environ.get("PDF_FOLDER_ID", "1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN")
PDFS_FOLDER_ID = os.environ.get("PDFS_FOLDER_ID", "1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN")

### GOOGLE SHEETS ###
METADATA_SOURCE_SPREADSHEET = os.environ.get(
Expand Down Expand Up @@ -41,9 +41,7 @@
OPENAI_EMBEDDINGS_DIMS = 1536
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500

SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = (
"sentence-transformers/multi-qa-mpnet-base-cos-v1"
)
SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768

### PINECONE ###
Expand Down
4 changes: 2 additions & 2 deletions align_data/sources/articles/articles.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def check_new_articles(source_spreadsheet, source_sheet):
seen_urls = {
url
for item in current.values()
for url in [item.get("url"), item.get("source_url")]
if url
for key in ("url", "source_url")
if (url := item.get(key)) is not None
}

indices_items = fetch_all()
Expand Down
15 changes: 8 additions & 7 deletions align_data/sources/articles/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ def is_val(val):

@property
def items_list(self):
logger.info(
f"Fetching https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=CS&gid={self.sheet_id}"
)
df = pd.read_csv(
f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}"
)
fetch_url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}"
log_url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=CS&gid={self.sheet_id}"

logger.info(f"Fetching {log_url}")

sheet_data = pd.read_csv(fetch_url)

return (
item for item in df.itertuples() if not pd.isna(self.get_item_key(item))
item for item in sheet_data.itertuples() if not pd.isna(self.get_item_key(item))
)

def get_item_key(self, item):
Expand Down
4 changes: 1 addition & 3 deletions align_data/sources/articles/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def format_far_ai(item):
return {
"title": get_text(item, ".article-title"),
"url": f'https://www.safe.ai/research{item.select_one(".article-title a").get("href")}',
"source_url": item.select_one('div.btn-links a:-soup-contains("PDF")').get(
"href"
),
"source_url": item.select_one('div.btn-links a:-soup-contains("PDF")').get("href"),
"authors": ", ".join(i.text for i in item.select(".article-metadata a")),
}

Expand Down
11 changes: 6 additions & 5 deletions align_data/sources/articles/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ def get_arxiv_link(doi):
return None

vals = [
i for i in response.json().get("values") if i.get("type", "").upper() == "URL"
val
for val in response.json().get("values")
if val.get("type", "").upper() == "URL"
]

if not vals:
return None
return vals[0]["data"]["value"].replace("/abs/", "/pdf/") + ".pdf"
Expand Down Expand Up @@ -197,10 +200,8 @@ def get_first_child(item):
date_published = date_published.text.strip("()")

text = "\n\n".join(
[
MarkdownConverter().convert_soup(elem).strip()
for elem in contents.select("section.ltx_section")
]
MarkdownConverter().convert_soup(elem).strip()
for elem in contents.select("section.ltx_section")
)

return {
Expand Down
36 changes: 17 additions & 19 deletions align_data/sources/arxiv_papers/arxiv_papers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,22 @@ def process_entry(self, item) -> None:
else:
authors = paper.get("authors") or []
authors = [str(a).strip() for a in authors]

return self.make_data_entry(
{
"url": self.get_item_key(item),
"source": self.name,
"source_type": paper["data_source"],
"title": self.is_val(item.title) or paper.get("title"),
"authors": authors,
"date_published": self._get_published_date(
self.is_val(item.date_published) or paper.get("date_published")
),
"data_last_modified": str(metadata.updated),
"summary": metadata.summary.replace("\n", " "),
"author_comment": metadata.comment,
"journal_ref": metadata.journal_ref,
"doi": metadata.doi,
"primary_category": metadata.primary_category,
"categories": metadata.categories,
"text": paper["text"],
}
url=self.get_item_key(item),
source=self.name,
source_type=paper["data_source"],
title=self.is_val(item.title) or paper.get("title"),
authors=authors,
date_published=self._get_published_date(
self.is_val(item.date_published) or paper.get("date_published")
),
data_last_modified=str(metadata.updated),
summary=metadata.summary.replace("\n", " "),
author_comment=metadata.comment,
journal_ref=metadata.journal_ref,
doi=metadata.doi,
primary_category=metadata.primary_category,
categories=metadata.categories,
text=paper["text"],
)
11 changes: 6 additions & 5 deletions align_data/sources/blogs/blogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@ def _get_text(self, contents):
return item_metadata(paper_link.get("href")).get("text")

def extract_authors(self, article):
authors = article.select_one(
'div:-soup-contains("Authors") + div .f-body-1'
) or article.select_one('div:-soup-contains("Acknowledgments") + div .f-body-1')
if not authors:
author_selector = 'div:-soup-contains("Authors") + div .f-body-1'
ack_selector = 'div:-soup-contains("Acknowledgments") + div .f-body-1'

authors_div = article.select_one(author_selector) or article.select_one(ack_selector)
if not authors_div:
return []

return [
i.split("(")[0].strip()
for i in authors.select_one("p").children
for i in authors_div.select_one("p").children
if not i.name
]

Expand Down

0 comments on commit 6b024d6

Please sign in to comment.