Skip to content

Commit

Permalink
Merge pull request #106 from StampyAI/align-data-structure
Browse files Browse the repository at this point in the history
Restructures align_data by putting sources in sources folder
  • Loading branch information
henri123lemoine authored Jul 25, 2023
2 parents e33e607 + b9a735f commit 0fbb8fb
Show file tree
Hide file tree
Showing 52 changed files with 482 additions and 40 deletions.
10 changes: 10 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CODA_TOKEN=""
ARD_DB_USER="user"
ARD_DB_PASSWORD="we all live in a yellow submarine"
ARD_DB_HOST="127.0.0.1"
ARD_DB_PORT="3306"
ARD_DB_NAME="alignment_research_dataset"
OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
PINECONE_INDEX_NAME="stampy-chat-ard"
PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
PINECONE_ENVIRONMENT="xx-xxxxx-gcp"
24 changes: 12 additions & 12 deletions align_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import align_data.arbital as arbital
import align_data.articles as articles
import align_data.blogs as blogs
import align_data.ebooks as ebooks
import align_data.arxiv_papers as arxiv_papers
import align_data.reports as reports
import align_data.greaterwrong as greaterwrong
import align_data.stampy as stampy
import align_data.audio_transcripts as audio_transcripts
import align_data.alignment_newsletter as alignment_newsletter
import align_data.distill as distill
import align_data.gdocs as gdocs
import align_data.sources.arbital as arbital
import align_data.sources.articles as articles
import align_data.sources.blogs as blogs
import align_data.sources.ebooks as ebooks
import align_data.sources.arxiv_papers as arxiv_papers
import align_data.sources.reports as reports
import align_data.sources.greaterwrong as greaterwrong
import align_data.sources.stampy as stampy
import align_data.sources.audio_transcripts as audio_transcripts
import align_data.sources.alignment_newsletter as alignment_newsletter
import align_data.sources.distill as distill
import align_data.sources.gdocs as gdocs

DATASET_REGISTRY = (
arbital.ARBITAL_REGISTRY
Expand Down
Empty file added align_data/pinecone/__init__.py
Empty file.
106 changes: 106 additions & 0 deletions align_data/pinecone/pinecone_db_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# dataset/pinecone_db_handler.py

import pinecone

from align_data.settings import PINECONE_INDEX_NAME, PINECONE_VALUES_DIMS, PINECONE_METRIC, PINECONE_METADATA_ENTRIES, PINECONE_API_KEY, PINECONE_ENVIRONMENT

import logging
logger = logging.getLogger(__name__)


class PineconeDB:
def __init__(
self,
index_name: str = PINECONE_INDEX_NAME,
values_dims: int = PINECONE_VALUES_DIMS,
metric: str = PINECONE_METRIC,
metadata_entries: list = PINECONE_METADATA_ENTRIES,
create_index: bool = False,
log_index_stats: bool = True,
):
self.index_name = index_name
self.values_dims = values_dims
self.metric = metric
self.metadata_entries = metadata_entries

pinecone.init(
api_key = PINECONE_API_KEY,
environment = PINECONE_ENVIRONMENT,
)

if create_index:
self.create_index()

self.index = pinecone.Index(index_name=self.index_name)

if log_index_stats:
index_stats_response = self.index.describe_index_stats()
logger.info(f"{self.index_name}:\n{index_stats_response}")

def upsert_entry(self, entry, chunks, embeddings, upsert_size=100):
self.index.upsert(
vectors=list(
zip(
[f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))],
embeddings.tolist(),
[
{
'entry_id': entry['id'],
'source': entry['source'],
'title': entry['title'],
'authors': entry['authors'],
'text': chunk,
} for chunk in chunks
]
)
),
batch_size=upsert_size
)

def upsert_entries(self, entries_batch, chunks_batch, chunks_ids_batch, embeddings, upsert_size=100):
self.index.upsert(
vectors=list(
zip(
chunks_ids_batch,
embeddings.tolist(),
[
{
'entry_id': entry['id'],
'source': entry['source'],
'title': entry['title'],
'authors': entry['authors'],
'text': chunk,
}
for entry in entries_batch
for chunk in chunks_batch
]
)
),
batch_size=upsert_size
)

def delete_entry(self, id):
self.index.delete(
filter={"entry_id": {"$eq": id}}
)

def delete_entries(self, ids):
self.index.delete(
filter={"entry_id": {"$in": ids}}
)

def create_index(self, replace_current_index: bool = True):
if replace_current_index:
self.delete_index()

pinecone.create_index(
name=self.index_name,
dimension=self.values_dims,
metric=self.metric,
metadata_config = {"indexed": self.metadata_entries},
)

def delete_index(self):
if self.index_name in pinecone.list_indexes():
logger.info(f"Deleting index '{self.index_name}'.")
pinecone.delete_index(self.index_name)
102 changes: 102 additions & 0 deletions align_data/pinecone/text_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# dataset/text_splitter.py

from typing import List, Callable, Any
from langchain.text_splitter import TextSplitter
from nltk.tokenize import sent_tokenize


class ParagraphSentenceUnitTextSplitter(TextSplitter):
"""A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc).
@param min_chunk_size: The minimum number of units in a chunk.
@param max_chunk_size: The maximum number of units in a chunk.
@param length_function: A function that returns the length of a string in units.
@param truncate_function: A function that truncates a string to a given unit length.
"""

DEFAULT_MIN_CHUNK_SIZE = 900
DEFAULT_MAX_CHUNK_SIZE = 1100
DEFAULT_TRUNCATE_FUNCTION = lambda string, length, from_end=False: string[-length:] if from_end else string[:length]

def __init__(
self,
min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE,
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE,
truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION,
**kwargs: Any
):
super().__init__(**kwargs)
self.min_chunk_size = min_chunk_size
self.max_chunk_size = max_chunk_size

self._truncate_function = truncate_function

def split_text(self, text: str) -> List[str]:
blocks = []
current_block = ""

paragraphs = text.split("\n\n")
for paragraph in paragraphs:
current_block += "\n\n" + paragraph
block_length = self._length_function(current_block)

if block_length > self.max_chunk_size: # current block is too large, truncate it
current_block = self._handle_large_paragraph(current_block, blocks, paragraph)
elif block_length >= self.min_chunk_size:
blocks.append(current_block)
current_block = ""
else: # current block is too small, continue appending to it
continue

blocks = self._handle_remaining_text(current_block, blocks)

return [block.strip() for block in blocks]

def _handle_large_paragraph(self, current_block, blocks, paragraph):
# Undo adding the whole paragraph
current_block = current_block[:-(len(paragraph)+2)] # +2 accounts for "\n\n"

sentences = sent_tokenize(paragraph)
for sentence in sentences:
current_block += f" {sentence}"

block_length = self._length_function(current_block)
if block_length < self.min_chunk_size:
continue
elif block_length <= self.max_chunk_size:
blocks.append(current_block)
current_block = ""
else:
current_block = self._truncate_large_block(current_block, blocks, sentence)

return current_block

def _truncate_large_block(self, current_block, blocks, sentence):
while self._length_function(current_block) > self.max_chunk_size:
# Truncate current_block to max size, set remaining sentence as next sentence
truncated_block = self._truncate_function(current_block, self.max_chunk_size)
blocks.append(truncated_block)

remaining_sentence = current_block[len(truncated_block):].lstrip()
current_block = sentence = remaining_sentence

return current_block

def _handle_remaining_text(self, current_block, blocks):
if blocks == []: # no blocks were added
return [current_block]
elif current_block: # any leftover text
len_current_block = self._length_function(current_block)
if len_current_block < self.min_chunk_size:
# it needs to take the last min_chunk_size-len_current_block units from the previous block
previous_block = blocks[-1]
required_units = self.min_chunk_size - len_current_block # calculate the required units

part_prev_block = self._truncate_function(previous_block, required_units, from_end=True) # get the required units from the previous block
last_block = part_prev_block + current_block

blocks.append(last_block)
else:
blocks.append(current_block)

return blocks
Loading

0 comments on commit 0fbb8fb

Please sign in to comment.