Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into finetune-embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
henri123lemoine committed Aug 16, 2023
2 parents 1e444cd + 09e19eb commit 03d0d54
Show file tree
Hide file tree
Showing 28 changed files with 967 additions and 412 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/update-metadata.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Update metadata
on:
workflow_dispatch:
inputs:
csv_url:
description: 'URL of CSV'
required: true
delimiter:
description: 'The column delimiter'
default: ','

jobs:
update:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'

- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Download CSV
id: download
run: curl -L "${{ inputs.csv_url }}" -o data.csv

- name: Run Script
run: python main.py update data.csv ${{ inputs.delimiter }}
5 changes: 1 addition & 4 deletions .github/workflows/upload-to-huggingface.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,19 @@ jobs:
- distill
- eaforum
- eleuther.ai
- gdocs
- generative.ink
- gwern_blog
- html_articles
- importai
- jsteinhardt_blog
- lesswrong
- markdown
- miri
- ml_safety_newsletter
- openai.research
- pdfs
- rob_miles_ai_safety
- special_docs
- vkrakovna_blog
- yudkowsky_blog
- xmls

uses: ./.github/workflows/push-dataset.yml
with:
Expand Down
2 changes: 0 additions & 2 deletions align_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import align_data.sources.articles as articles
import align_data.sources.blogs as blogs
import align_data.sources.ebooks as ebooks
import align_data.sources.arxiv_papers as arxiv_papers
import align_data.sources.greaterwrong as greaterwrong
import align_data.sources.stampy as stampy
import align_data.sources.alignment_newsletter as alignment_newsletter
Expand All @@ -14,7 +13,6 @@
+ articles.ARTICLES_REGISTRY
+ blogs.BLOG_REGISTRY
+ ebooks.EBOOK_REGISTRY
+ arxiv_papers.ARXIV_REGISTRY
+ greaterwrong.GREATERWRONG_REGISTRY
+ stampy.STAMPY_REGISTRY
+ distill.DISTILL_REGISTRY
Expand Down
27 changes: 15 additions & 12 deletions align_data/common/alignment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
"title": None,
"url": None,
"authors": lambda: [],
"source_type": None,
"status": None,
"comments": None,
}

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,7 +88,7 @@ def make_data_entry(self, data, **kwargs) -> Article:

article = Article(
id_fields=self.id_fields,
meta={k: v for k, v in data.items() if k not in INIT_DICT},
meta={k: v for k, v in data.items() if k not in INIT_DICT and v is not None},
**{k: v for k, v in data.items() if k in INIT_DICT},
)
self._add_authors(article, authors)
Expand All @@ -106,13 +109,17 @@ def to_jsonl(self, out_path=None, filename=None) -> Path:
jsonl_writer.write(article.to_dict())
return filename.resolve()

@property
def _query_items(self):
return select(Article).where(Article.source == self.name)

def read_entries(self, sort_by=None):
"""Iterate through all the saved entries."""
with make_session() as session:
query = select(Article).where(Article.source == self.name)
query = self._query_items.options(joinedload(Article.summaries))
if sort_by is not None:
query = query.order_by(sort_by)
for item in session.scalars(query):
for item in session.scalars(query).unique():
yield item

def _add_batch(self, session, batch):
Expand Down Expand Up @@ -204,15 +211,16 @@ def fetch_entries(self):
if self.COOLDOWN:
time.sleep(self.COOLDOWN)

def process_entry(self, entry):
def process_entry(self, entry) -> Optional[Article]:
"""Process a single entry."""
raise NotImplementedError

@staticmethod
def _format_datetime(date) -> str:
return date.strftime("%Y-%m-%dT%H:%M:%SZ")

def _get_published_date(self, date) -> Optional[datetime]:
@staticmethod
def _get_published_date(date) -> Optional[datetime]:
try:
# Totally ignore any timezone info, forcing everything to UTC
return parse(str(date)).replace(tzinfo=pytz.UTC)
Expand All @@ -228,13 +236,8 @@ def unprocessed_items(self, items=None) -> Iterable:

urls = map(self.get_item_key, items)
with make_session() as session:
self.articles = {
a.url: a
for a in session.query(Article)
.options(joinedload(Article.summaries))
.filter(Article.url.in_(urls))
if a.url
}
articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls))
self.articles = {a.url: a for a in articles if a.url}

return items

Expand Down
38 changes: 29 additions & 9 deletions align_data/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
String,
Boolean,
Text,
Float,
func,
event,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.dialects.mysql import LONGTEXT
from align_data.settings import PINECONE_METADATA_KEYS

Expand Down Expand Up @@ -58,6 +57,8 @@ class Article(Base):
date_updated: Mapped[Optional[datetime]] = mapped_column(
DateTime, onupdate=func.current_timestamp()
)
status: Mapped[Optional[str]] = mapped_column(String(256))
comments: Mapped[Optional[str]] = mapped_column(LONGTEXT) # Editor comments. Can be anything

pinecone_update_required: Mapped[bool] = mapped_column(Boolean, default=False)

Expand Down Expand Up @@ -90,9 +91,10 @@ def generate_id_string(self) -> str:
"utf-8"
)

def verify_fields(self):
missing = [field for field in self.__id_fields if not getattr(self, field)]
assert not missing, f"Entry is missing the following fields: {missing}"
@property
def missing_fields(self):
fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'}
return sorted([field for field in fields if not getattr(self, field, None)])

def verify_id(self):
assert self.id is not None, "Entry is missing id"
Expand All @@ -101,13 +103,17 @@ def verify_id(self):
id_from_fields = hashlib.md5(id_string).hexdigest()
assert (
self.id == id_from_fields
), f"Entry id {self.id} does not match id from id_fields, {id_from_fields}"
), f"Entry id {self.id} does not match id from id_fields: {id_from_fields}"

def verify_id_fields(self):
missing = [field for field in self.__id_fields if not getattr(self, field)]
assert not missing, f"Entry is missing the following fields: {missing}"

def update(self, other):
for field in self.__table__.columns.keys():
if field not in ["id", "hash_id", "metadata"] and getattr(other, field):
setattr(self, field, getattr(other, field))
self.meta.update({k: v for k, v in other.meta.items() if k and v})
self.meta = dict((self.meta or {}), **{k: v for k, v in other.meta.items() if k and v})

if other._id:
self._id = other._id
Expand All @@ -118,14 +124,28 @@ def _set_id(self):
id_string = self.generate_id_string()
self.id = hashlib.md5(id_string).hexdigest()

def add_meta(self, key, val):
if self.meta is None:
self.meta = {}
self.meta[key] = val

@classmethod
def before_write(cls, mapper, connection, target):
target.verify_fields()
target.verify_id_fields()

if not target.status and target.missing_fields:
target.status = 'Missing fields'
target.comments = f'missing fields: {", ".join(target.missing_fields)}'

if target.id:
target.verify_id()
else:
target._set_id()

# This assumes that status pretty much just notes down that an entry is invalid. If it has
# all fields set and is being written to the database, then it must have been modified, ergo
# should be also updated in pinecone
if not target.status:
target.pinecone_update_required = True

def to_dict(self):
Expand All @@ -147,7 +167,7 @@ def to_dict(self):
"date_published": date,
"authors": authors,
"summaries": [s.text for s in (self.summaries or [])],
**(self.meta or {}),
**(meta or {}),
}


Expand Down
80 changes: 40 additions & 40 deletions align_data/pinecone/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from langchain.text_splitter import TextSplitter
from nltk.tokenize import sent_tokenize

# TODO: Fix this.
# sent_tokenize has strange behavior sometimes: 'The units could be anything (characters, words, sentences, etc.), depending on how you want to chunk your text.'
# splits into ['The units could be anything (characters, words, sentences, etc.', '), depending on how you want to chunk your text.']

StrToIntFunction = Callable[[str], int]
StrIntBoolToStrFunction = Callable[[str, int, bool], str]

def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
return string[-length:] if from_end else string[:length]

def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
return string[-length:] if from_end else string[:length]
Expand All @@ -14,22 +23,22 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter):
@param min_chunk_size: The minimum number of units in a chunk.
@param max_chunk_size: The maximum number of units in a chunk.
@param length_function: A function that returns the length of a string in units.
@param length_function: A function that returns the length of a string in units. Defaults to len().
@param truncate_function: A function that truncates a string to a given unit length.
"""

DEFAULT_MIN_CHUNK_SIZE = 900
DEFAULT_MAX_CHUNK_SIZE = 1100
DEFAULT_LENGTH_FUNCTION = lambda string: len(string)
DEFAULT_TRUNCATE_FUNCTION = default_truncate_function

DEFAULT_MIN_CHUNK_SIZE: int = 900
DEFAULT_MAX_CHUNK_SIZE: int = 1100
DEFAULT_LENGTH_FUNCTION: StrToIntFunction = len
DEFAULT_TRUNCATE_FUNCTION: StrIntBoolToStrFunction = default_truncate_function

def __init__(
self,
min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE,
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE,
length_function: Callable[[str], int] = DEFAULT_LENGTH_FUNCTION,
truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION,
**kwargs: Any,
length_function: StrToIntFunction = DEFAULT_LENGTH_FUNCTION,
truncate_function: StrIntBoolToStrFunction = DEFAULT_TRUNCATE_FUNCTION,
**kwargs: Any
):
super().__init__(**kwargs)
self.min_chunk_size = min_chunk_size
Expand All @@ -39,8 +48,9 @@ def __init__(
self._truncate_function = truncate_function

def split_text(self, text: str) -> List[str]:
blocks = []
current_block = ""
"""Split text into chunks of length between min_chunk_size and max_chunk_size."""
blocks: List[str] = []
current_block: str = ""

paragraphs = text.split("\n\n")
for paragraph in paragraphs:
Expand All @@ -56,10 +66,9 @@ def split_text(self, text: str) -> List[str]:
continue

blocks = self._handle_remaining_text(current_block, blocks)

return [block.strip() for block in blocks]

def _handle_large_paragraph(self, current_block, blocks, paragraph):
def _handle_large_paragraph(self, current_block: str, blocks: List[str], paragraph: str) -> str:
# Undo adding the whole paragraph
offset = len(paragraph) + 2 # +2 accounts for "\n\n"
current_block = current_block[:-offset]
Expand All @@ -75,44 +84,35 @@ def _handle_large_paragraph(self, current_block, blocks, paragraph):
blocks.append(current_block)
current_block = ""
else:
current_block = self._truncate_large_block(
current_block, blocks, sentence
)

current_block = self._truncate_large_block(current_block, blocks)
return current_block

def _truncate_large_block(self, current_block, blocks, sentence):
def _truncate_large_block(self, current_block: str, blocks: List[str]) -> str:
while self._length_function(current_block) > self.max_chunk_size:
# Truncate current_block to max size, set remaining sentence as next sentence
# Truncate current_block to max size, set remaining text as current_block
truncated_block = self._truncate_function(
current_block, self.max_chunk_size
current_block, self.max_chunk_size
)
blocks.append(truncated_block)

remaining_sentence = current_block[len(truncated_block) :].lstrip()
current_block = sentence = remaining_sentence

current_block = current_block[len(truncated_block):].lstrip()

return current_block

def _handle_remaining_text(self, current_block, blocks):
def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str]:
if blocks == []: # no blocks were added
return [current_block]
elif current_block: # any leftover text
len_current_block = self._length_function(current_block)
if len_current_block < self.min_chunk_size:
# it needs to take the last min_chunk_size-len_current_block units from the previous block
previous_block = blocks[-1]
required_units = (
self.min_chunk_size - len_current_block
) # calculate the required units

return [last_block]
elif last_block: # any leftover text
len_last_block = self._length_function(last_block)
if self.min_chunk_size - len_last_block > 0:
# Add text from previous block to last block if last_block is too short
part_prev_block = self._truncate_function(
previous_block, required_units, from_end=True
) # get the required units from the previous block
last_block = part_prev_block + current_block
string=blocks[-1],
length=self.min_chunk_size - len_last_block,
from_end=True
)
last_block = part_prev_block + last_block

blocks.append(last_block)
else:
blocks.append(current_block)
blocks.append(last_block)

return blocks
8 changes: 7 additions & 1 deletion align_data/sources/articles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from align_data.sources.articles.datasets import (
EbookArticles, DocArticles, HTMLArticles, MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles
ArxivPapers, EbookArticles, DocArticles, HTMLArticles,
MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles
)
from align_data.sources.articles.indices import IndicesDataset

Expand Down Expand Up @@ -39,5 +40,10 @@
spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI',
sheet_id='980957638',
),
ArxivPapers(
name="arxiv",
spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI",
sheet_id="655836697",
),
IndicesDataset('indices'),
]
Loading

0 comments on commit 03d0d54

Please sign in to comment.