-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from StampyAI/align-data-structure
Restructures align_data by putting sources in sources folder
- Loading branch information
Showing
52 changed files
with
482 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
CODA_TOKEN="" | ||
ARD_DB_USER="user" | ||
ARD_DB_PASSWORD="we all live in a yellow submarine" | ||
ARD_DB_HOST="127.0.0.1" | ||
ARD_DB_PORT="3306" | ||
ARD_DB_NAME="alignment_research_dataset" | ||
OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" | ||
PINECONE_INDEX_NAME="stampy-chat-ard" | ||
PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" | ||
PINECONE_ENVIRONMENT="xx-xxxxx-gcp" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# dataset/pinecone_db_handler.py | ||
|
||
import pinecone | ||
|
||
from align_data.settings import PINECONE_INDEX_NAME, PINECONE_VALUES_DIMS, PINECONE_METRIC, PINECONE_METADATA_ENTRIES, PINECONE_API_KEY, PINECONE_ENVIRONMENT | ||
|
||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PineconeDB: | ||
def __init__( | ||
self, | ||
index_name: str = PINECONE_INDEX_NAME, | ||
values_dims: int = PINECONE_VALUES_DIMS, | ||
metric: str = PINECONE_METRIC, | ||
metadata_entries: list = PINECONE_METADATA_ENTRIES, | ||
create_index: bool = False, | ||
log_index_stats: bool = True, | ||
): | ||
self.index_name = index_name | ||
self.values_dims = values_dims | ||
self.metric = metric | ||
self.metadata_entries = metadata_entries | ||
|
||
pinecone.init( | ||
api_key = PINECONE_API_KEY, | ||
environment = PINECONE_ENVIRONMENT, | ||
) | ||
|
||
if create_index: | ||
self.create_index() | ||
|
||
self.index = pinecone.Index(index_name=self.index_name) | ||
|
||
if log_index_stats: | ||
index_stats_response = self.index.describe_index_stats() | ||
logger.info(f"{self.index_name}:\n{index_stats_response}") | ||
|
||
def upsert_entry(self, entry, chunks, embeddings, upsert_size=100): | ||
self.index.upsert( | ||
vectors=list( | ||
zip( | ||
[f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))], | ||
embeddings.tolist(), | ||
[ | ||
{ | ||
'entry_id': entry['id'], | ||
'source': entry['source'], | ||
'title': entry['title'], | ||
'authors': entry['authors'], | ||
'text': chunk, | ||
} for chunk in chunks | ||
] | ||
) | ||
), | ||
batch_size=upsert_size | ||
) | ||
|
||
def upsert_entries(self, entries_batch, chunks_batch, chunks_ids_batch, embeddings, upsert_size=100): | ||
self.index.upsert( | ||
vectors=list( | ||
zip( | ||
chunks_ids_batch, | ||
embeddings.tolist(), | ||
[ | ||
{ | ||
'entry_id': entry['id'], | ||
'source': entry['source'], | ||
'title': entry['title'], | ||
'authors': entry['authors'], | ||
'text': chunk, | ||
} | ||
for entry in entries_batch | ||
for chunk in chunks_batch | ||
] | ||
) | ||
), | ||
batch_size=upsert_size | ||
) | ||
|
||
def delete_entry(self, id): | ||
self.index.delete( | ||
filter={"entry_id": {"$eq": id}} | ||
) | ||
|
||
def delete_entries(self, ids): | ||
self.index.delete( | ||
filter={"entry_id": {"$in": ids}} | ||
) | ||
|
||
def create_index(self, replace_current_index: bool = True): | ||
if replace_current_index: | ||
self.delete_index() | ||
|
||
pinecone.create_index( | ||
name=self.index_name, | ||
dimension=self.values_dims, | ||
metric=self.metric, | ||
metadata_config = {"indexed": self.metadata_entries}, | ||
) | ||
|
||
def delete_index(self): | ||
if self.index_name in pinecone.list_indexes(): | ||
logger.info(f"Deleting index '{self.index_name}'.") | ||
pinecone.delete_index(self.index_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# dataset/text_splitter.py | ||
|
||
from typing import List, Callable, Any | ||
from langchain.text_splitter import TextSplitter | ||
from nltk.tokenize import sent_tokenize | ||
|
||
|
||
class ParagraphSentenceUnitTextSplitter(TextSplitter): | ||
"""A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc). | ||
@param min_chunk_size: The minimum number of units in a chunk. | ||
@param max_chunk_size: The maximum number of units in a chunk. | ||
@param length_function: A function that returns the length of a string in units. | ||
@param truncate_function: A function that truncates a string to a given unit length. | ||
""" | ||
|
||
DEFAULT_MIN_CHUNK_SIZE = 900 | ||
DEFAULT_MAX_CHUNK_SIZE = 1100 | ||
DEFAULT_TRUNCATE_FUNCTION = lambda string, length, from_end=False: string[-length:] if from_end else string[:length] | ||
|
||
def __init__( | ||
self, | ||
min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, | ||
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, | ||
truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION, | ||
**kwargs: Any | ||
): | ||
super().__init__(**kwargs) | ||
self.min_chunk_size = min_chunk_size | ||
self.max_chunk_size = max_chunk_size | ||
|
||
self._truncate_function = truncate_function | ||
|
||
def split_text(self, text: str) -> List[str]: | ||
blocks = [] | ||
current_block = "" | ||
|
||
paragraphs = text.split("\n\n") | ||
for paragraph in paragraphs: | ||
current_block += "\n\n" + paragraph | ||
block_length = self._length_function(current_block) | ||
|
||
if block_length > self.max_chunk_size: # current block is too large, truncate it | ||
current_block = self._handle_large_paragraph(current_block, blocks, paragraph) | ||
elif block_length >= self.min_chunk_size: | ||
blocks.append(current_block) | ||
current_block = "" | ||
else: # current block is too small, continue appending to it | ||
continue | ||
|
||
blocks = self._handle_remaining_text(current_block, blocks) | ||
|
||
return [block.strip() for block in blocks] | ||
|
||
def _handle_large_paragraph(self, current_block, blocks, paragraph): | ||
# Undo adding the whole paragraph | ||
current_block = current_block[:-(len(paragraph)+2)] # +2 accounts for "\n\n" | ||
|
||
sentences = sent_tokenize(paragraph) | ||
for sentence in sentences: | ||
current_block += f" {sentence}" | ||
|
||
block_length = self._length_function(current_block) | ||
if block_length < self.min_chunk_size: | ||
continue | ||
elif block_length <= self.max_chunk_size: | ||
blocks.append(current_block) | ||
current_block = "" | ||
else: | ||
current_block = self._truncate_large_block(current_block, blocks, sentence) | ||
|
||
return current_block | ||
|
||
def _truncate_large_block(self, current_block, blocks, sentence): | ||
while self._length_function(current_block) > self.max_chunk_size: | ||
# Truncate current_block to max size, set remaining sentence as next sentence | ||
truncated_block = self._truncate_function(current_block, self.max_chunk_size) | ||
blocks.append(truncated_block) | ||
|
||
remaining_sentence = current_block[len(truncated_block):].lstrip() | ||
current_block = sentence = remaining_sentence | ||
|
||
return current_block | ||
|
||
def _handle_remaining_text(self, current_block, blocks): | ||
if blocks == []: # no blocks were added | ||
return [current_block] | ||
elif current_block: # any leftover text | ||
len_current_block = self._length_function(current_block) | ||
if len_current_block < self.min_chunk_size: | ||
# it needs to take the last min_chunk_size-len_current_block units from the previous block | ||
previous_block = blocks[-1] | ||
required_units = self.min_chunk_size - len_current_block # calculate the required units | ||
|
||
part_prev_block = self._truncate_function(previous_block, required_units, from_end=True) # get the required units from the previous block | ||
last_block = part_prev_block + current_block | ||
|
||
blocks.append(last_block) | ||
else: | ||
blocks.append(current_block) | ||
|
||
return blocks |
Oops, something went wrong.