diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..7636d2da --- /dev/null +++ b/.env.example @@ -0,0 +1,10 @@ +CODA_TOKEN="" +ARD_DB_USER="user" +ARD_DB_PASSWORD="we all live in a yellow submarine" +ARD_DB_HOST="127.0.0.1" +ARD_DB_PORT="3306" +ARD_DB_NAME="alignment_research_dataset" +OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" +PINECONE_INDEX_NAME="stampy-chat-ard" +PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" +PINECONE_ENVIRONMENT="xx-xxxxx-gcp" diff --git a/align_data/__init__.py b/align_data/__init__.py index 7c2b55e2..4028ab5b 100644 --- a/align_data/__init__.py +++ b/align_data/__init__.py @@ -1,15 +1,15 @@ -import align_data.arbital as arbital -import align_data.articles as articles -import align_data.blogs as blogs -import align_data.ebooks as ebooks -import align_data.arxiv_papers as arxiv_papers -import align_data.reports as reports -import align_data.greaterwrong as greaterwrong -import align_data.stampy as stampy -import align_data.audio_transcripts as audio_transcripts -import align_data.alignment_newsletter as alignment_newsletter -import align_data.distill as distill -import align_data.gdocs as gdocs +import align_data.sources.arbital as arbital +import align_data.sources.articles as articles +import align_data.sources.blogs as blogs +import align_data.sources.ebooks as ebooks +import align_data.sources.arxiv_papers as arxiv_papers +import align_data.sources.reports as reports +import align_data.sources.greaterwrong as greaterwrong +import align_data.sources.stampy as stampy +import align_data.sources.audio_transcripts as audio_transcripts +import align_data.sources.alignment_newsletter as alignment_newsletter +import align_data.sources.distill as distill +import align_data.sources.gdocs as gdocs DATASET_REGISTRY = ( arbital.ARBITAL_REGISTRY diff --git a/align_data/pinecone/__init__.py b/align_data/pinecone/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py new file mode 100644 index 00000000..4ee2ce7e --- /dev/null +++ b/align_data/pinecone/pinecone_db_handler.py @@ -0,0 +1,106 @@ +# dataset/pinecone_db_handler.py + +import pinecone + +from align_data.settings import PINECONE_INDEX_NAME, PINECONE_VALUES_DIMS, PINECONE_METRIC, PINECONE_METADATA_ENTRIES, PINECONE_API_KEY, PINECONE_ENVIRONMENT + +import logging +logger = logging.getLogger(__name__) + + +class PineconeDB: + def __init__( + self, + index_name: str = PINECONE_INDEX_NAME, + values_dims: int = PINECONE_VALUES_DIMS, + metric: str = PINECONE_METRIC, + metadata_entries: list = PINECONE_METADATA_ENTRIES, + create_index: bool = False, + log_index_stats: bool = True, + ): + self.index_name = index_name + self.values_dims = values_dims + self.metric = metric + self.metadata_entries = metadata_entries + + pinecone.init( + api_key = PINECONE_API_KEY, + environment = PINECONE_ENVIRONMENT, + ) + + if create_index: + self.create_index() + + self.index = pinecone.Index(index_name=self.index_name) + + if log_index_stats: + index_stats_response = self.index.describe_index_stats() + logger.info(f"{self.index_name}:\n{index_stats_response}") + + def upsert_entry(self, entry, chunks, embeddings, upsert_size=100): + self.index.upsert( + vectors=list( + zip( + [f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))], + embeddings.tolist(), + [ + { + 'entry_id': entry['id'], + 'source': entry['source'], + 'title': entry['title'], + 'authors': entry['authors'], + 'text': chunk, + } for chunk in chunks + ] + ) + ), + batch_size=upsert_size + ) + + def upsert_entries(self, entries_batch, chunks_batch, chunks_ids_batch, embeddings, upsert_size=100): + self.index.upsert( + vectors=list( + zip( + chunks_ids_batch, + embeddings.tolist(), + [ + { + 'entry_id': entry['id'], + 'source': entry['source'], + 'title': entry['title'], + 'authors': entry['authors'], + 'text': chunk, + } + for entry in entries_batch + for chunk in chunks_batch + ] + ) + ), + batch_size=upsert_size + ) + + def delete_entry(self, id): + self.index.delete( + filter={"entry_id": {"$eq": id}} + ) + + def delete_entries(self, ids): + self.index.delete( + filter={"entry_id": {"$in": ids}} + ) + + def create_index(self, replace_current_index: bool = True): + if replace_current_index: + self.delete_index() + + pinecone.create_index( + name=self.index_name, + dimension=self.values_dims, + metric=self.metric, + metadata_config = {"indexed": self.metadata_entries}, + ) + + def delete_index(self): + if self.index_name in pinecone.list_indexes(): + logger.info(f"Deleting index '{self.index_name}'.") + pinecone.delete_index(self.index_name) \ No newline at end of file diff --git a/align_data/pinecone/text_splitter.py b/align_data/pinecone/text_splitter.py new file mode 100644 index 00000000..8e5dc0b8 --- /dev/null +++ b/align_data/pinecone/text_splitter.py @@ -0,0 +1,102 @@ +# dataset/text_splitter.py + +from typing import List, Callable, Any +from langchain.text_splitter import TextSplitter +from nltk.tokenize import sent_tokenize + + +class ParagraphSentenceUnitTextSplitter(TextSplitter): + """A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc). + + @param min_chunk_size: The minimum number of units in a chunk. + @param max_chunk_size: The maximum number of units in a chunk. + @param length_function: A function that returns the length of a string in units. + @param truncate_function: A function that truncates a string to a given unit length. + """ + + DEFAULT_MIN_CHUNK_SIZE = 900 + DEFAULT_MAX_CHUNK_SIZE = 1100 + DEFAULT_TRUNCATE_FUNCTION = lambda string, length, from_end=False: string[-length:] if from_end else string[:length] + + def __init__( + self, + min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, + max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, + truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION, + **kwargs: Any + ): + super().__init__(**kwargs) + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + + self._truncate_function = truncate_function + + def split_text(self, text: str) -> List[str]: + blocks = [] + current_block = "" + + paragraphs = text.split("\n\n") + for paragraph in paragraphs: + current_block += "\n\n" + paragraph + block_length = self._length_function(current_block) + + if block_length > self.max_chunk_size: # current block is too large, truncate it + current_block = self._handle_large_paragraph(current_block, blocks, paragraph) + elif block_length >= self.min_chunk_size: + blocks.append(current_block) + current_block = "" + else: # current block is too small, continue appending to it + continue + + blocks = self._handle_remaining_text(current_block, blocks) + + return [block.strip() for block in blocks] + + def _handle_large_paragraph(self, current_block, blocks, paragraph): + # Undo adding the whole paragraph + current_block = current_block[:-(len(paragraph)+2)] # +2 accounts for "\n\n" + + sentences = sent_tokenize(paragraph) + for sentence in sentences: + current_block += f" {sentence}" + + block_length = self._length_function(current_block) + if block_length < self.min_chunk_size: + continue + elif block_length <= self.max_chunk_size: + blocks.append(current_block) + current_block = "" + else: + current_block = self._truncate_large_block(current_block, blocks, sentence) + + return current_block + + def _truncate_large_block(self, current_block, blocks, sentence): + while self._length_function(current_block) > self.max_chunk_size: + # Truncate current_block to max size, set remaining sentence as next sentence + truncated_block = self._truncate_function(current_block, self.max_chunk_size) + blocks.append(truncated_block) + + remaining_sentence = current_block[len(truncated_block):].lstrip() + current_block = sentence = remaining_sentence + + return current_block + + def _handle_remaining_text(self, current_block, blocks): + if blocks == []: # no blocks were added + return [current_block] + elif current_block: # any leftover text + len_current_block = self._length_function(current_block) + if len_current_block < self.min_chunk_size: + # it needs to take the last min_chunk_size-len_current_block units from the previous block + previous_block = blocks[-1] + required_units = self.min_chunk_size - len_current_block # calculate the required units + + part_prev_block = self._truncate_function(previous_block, required_units, from_end=True) # get the required units from the previous block + last_block = part_prev_block + current_block + + blocks.append(last_block) + else: + blocks.append(current_block) + + return blocks \ No newline at end of file diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py new file mode 100644 index 00000000..e3a27540 --- /dev/null +++ b/align_data/pinecone/update_pinecone.py @@ -0,0 +1,190 @@ +import os +from typing import Dict, List, Union +import numpy as np +import openai + +from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter +from align_data.pinecone.pinecone_db_handler import PineconeDB + +from align_data.settings import USE_OPENAI_EMBEDDINGS, OPENAI_EMBEDDINGS_MODEL, \ + OPENAI_EMBEDDINGS_DIMS, OPENAI_EMBEDDINGS_RATE_LIMIT, \ + SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, \ + CHUNK_SIZE, MAX_NUM_AUTHORS_IN_SIGNATURE, EMBEDDING_LENGTH_BIAS + +import logging +logger = logging.getLogger(__name__) + + +class ARDUpdater: + def __init__( + self, + min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE, + max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE, + ): + self.text_splitter = ParagraphSentenceUnitTextSplitter( + min_chunk_size=min_chunk_size, + max_chunk_size=max_chunk_size, + ) + + self.pinecone_db = PineconeDB() + + if USE_OPENAI_EMBEDDINGS: + import openai + openai.api_key = os.environ['OPENAI_API_KEY'] + else: + import torch + from langchain.embeddings import HuggingFaceEmbeddings + + self.hf_embeddings = HuggingFaceEmbeddings( + model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, + model_kwargs={'device': "cuda" if torch.cuda.is_available() else "cpu"}, + encode_kwargs={'show_progress_bar': False} + ) + + def update(self, custom_sources: List[str] = ['all']): + """ + Update the given sources. If no sources are provided, updates all sources. + + :param custom_sources: List of sources to update. + """ + + for source in custom_sources: + self.update_source(source) + + def update_source(self, source: str): + """ + Updates the entries from the given source. + + :param source: The name of the source to update. + """ + + logger.info(f"Updating {source} entries...") + + # TODO: loop through mysql entries and update the pinecone db + + logger.info(f"Successfully updated {source} entries.") + + def batchify(self, iterable): + """ + Divides the iterable into batches of size ~CHUNK_SIZE. + + :param iterable: The iterable to divide into batches. + :returns: A generator that yields batches from the iterable. + """ + + entries_batch = [] + chunks_batch = [] + chunks_ids_batch = [] + sources_batch = [] + + for entry in iterable: + chunks, chunks_ids = self.create_chunk_ids_and_authors(entry) + + entries_batch.append(entry) + chunks_batch.extend(chunks) + chunks_ids_batch.extend(chunks_ids) + sources_batch.extend([entry['source']] * len(chunks)) + + # If this batch is large enough, yield it and start a new one. + if len(chunks_batch) >= CHUNK_SIZE: + yield self._create_batch(entries_batch, chunks_batch, chunks_ids_batch, sources_batch) + + entries_batch = [] + chunks_batch = [] + chunks_ids_batch = [] + sources_batch = [] + + # Yield any remaining items. + if entries_batch: + yield self._create_batch(entries_batch, chunks_batch, chunks_ids_batch, sources_batch) + + def create_chunk_ids_and_authors(self, entry): + signature = f"Title: {entry['title']}, Author(s): {self.get_authors_str(entry['authors'])}" + chunks = self.text_splitter.split_text(entry['text']) + chunks = [f"- {signature}\n\n{chunk}" for chunk in chunks] + chunks_ids = [f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))] + return chunks, chunks_ids + + def _create_batch(self, entries_batch, chunks_batch, chunks_ids_batch, sources_batch): + return {'entries_batch': entries_batch, 'chunks_batch': chunks_batch, 'chunks_ids_batch': chunks_ids_batch, 'sources_batch': sources_batch} + + def is_sql_entry_upserted(self, entry): + """Upserts an entry to the SQL database and returns the success status""" + return self.sql_db.upsert_entry(entry) + + def extract_embeddings(self, chunks_batch, sources_batch): + if USE_OPENAI_EMBEDDINGS: + return self.get_openai_embeddings(chunks_batch, sources_batch) + else: + return np.array(self.hf_embeddings.embed_documents(chunks_batch, sources_batch)) + + def reset_dbs(self): + self.sql_db.create_tables(True) + self.pinecone_db.create_index(True) + + @staticmethod + def preprocess_and_validate(entry): + """Preprocesses and validates the entry data""" + try: + ARDUpdater.validate_entry(entry) + + return { + 'id': entry['id'], + 'source': entry['source'], + 'title': entry['title'], + 'text': entry['text'], + 'url': entry['url'], + 'date_published': entry['date_published'], + 'authors': entry['authors'] + } + except ValueError as e: + logger.error(f"Entry validation failed: {str(e)}", exc_info=True) + return None + + @staticmethod + def validate_entry(entry: Dict[str, Union[str, list]], char_len_lower_limit: int = 0): + metadata_types = { + 'id': str, + 'source': str, + 'title': str, + 'text': str, + 'url': str, + 'date_published': str, + 'authors': list + } + + for metadata_type, metadata_type_type in metadata_types.items(): + if not isinstance(entry.get(metadata_type), metadata_type_type): + raise ValueError(f"Entry metadata '{metadata_type}' is not of type '{metadata_type_type}' or is missing.") + + if len(entry['text']) < char_len_lower_limit: + raise ValueError(f"Entry text is too short (< {char_len_lower_limit} characters).") + + @staticmethod + def is_valid_entry(entry): + """Checks if the entry is valid""" + return entry is not None + + @staticmethod + def get_openai_embeddings(chunks, sources=''): + embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS)) + + openai_output = openai.Embedding.create( + model=OPENAI_EMBEDDINGS_MODEL, + input=chunks + )['data'] + + for i, (embedding, source) in enumerate(zip(openai_output, sources)): + bias = EMBEDDING_LENGTH_BIAS.get(source, 1.0) + embeddings[i] = bias * np.array(embedding['embedding']) + + return embeddings + + @staticmethod + def get_authors_str(authors_lst: List[str]) -> str: + if authors_lst == []: return 'n/a' + if len(authors_lst) == 1: return authors_lst[0] + else: + authors_lst = authors_lst[:MAX_NUM_AUTHORS_IN_SIGNATURE] + authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}" + return authors_str \ No newline at end of file diff --git a/align_data/settings.py b/align_data/settings.py index 49be5460..8e0da373 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -2,21 +2,48 @@ from dotenv import load_dotenv load_dotenv() - +### CODA ### CODA_TOKEN = os.environ.get("CODA_TOKEN") CODA_DOC_ID = os.environ.get("CODA_DOC_ID", "fau7sl2hmG") ON_SITE_TABLE = os.environ.get('CODA_ON_SITE_TABLE', 'table-aOTSHIz_mN') +### GOOGLE DRIVE ### PDFS_FOLDER_ID = os.environ.get('PDF_FOLDER_ID', '1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN') +### GOOGLE SHEETS ### METADATA_SOURCE_SPREADSHEET = os.environ.get('METADATA_SOURCE_SPREADSHEET', '1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI') METADATA_SOURCE_SHEET = os.environ.get('METADATA_SOURCE_SHEET', 'special_docs.csv') METADATA_OUTPUT_SPREADSHEET = os.environ.get('METADATA_OUTPUT_SPREADSHEET', '1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4') - +### MYSQL ### user = os.environ.get('ARD_DB_USER', 'user') password = os.environ.get('ARD_DB_PASSWORD', 'we all live in a yellow submarine') host = os.environ.get('ARD_DB_HOST', '127.0.0.1') port = os.environ.get('ARD_DB_PORT', '3306') db_name = os.environ.get('ARD_DB_NAME', 'alignment_research_dataset') DB_CONNECTION_URI = f'mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}' + +### EMBEDDINGS ### +USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used. +EMBEDDING_LENGTH_BIAS = { + "aisafety.info": 1.05, # In search, favor AISafety.info entries. +} + +OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002" +OPENAI_EMBEDDINGS_DIMS = 1536 +OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 + +SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" +SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768 + +### PINECONE ### +PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME", "stampy-chat-ard") +PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) +PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) +PINECONE_VALUES_DIMS = OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS +PINECONE_METRIC = "dotproduct" +PINECONE_METADATA_ENTRIES = ["entry_id", "source", "title", "authors", "text"] + +### MISCELLANEOUS ### +CHUNK_SIZE = 1750 +MAX_NUM_AUTHORS_IN_SIGNATURE = 3 \ No newline at end of file diff --git a/align_data/alignment_newsletter/__init__.py b/align_data/sources/alignment_newsletter/__init__.py similarity index 100% rename from align_data/alignment_newsletter/__init__.py rename to align_data/sources/alignment_newsletter/__init__.py diff --git a/align_data/alignment_newsletter/alignment_newsletter.py b/align_data/sources/alignment_newsletter/alignment_newsletter.py similarity index 100% rename from align_data/alignment_newsletter/alignment_newsletter.py rename to align_data/sources/alignment_newsletter/alignment_newsletter.py diff --git a/align_data/arbital/__init__.py b/align_data/sources/arbital/__init__.py similarity index 100% rename from align_data/arbital/__init__.py rename to align_data/sources/arbital/__init__.py diff --git a/align_data/arbital/arbital.py b/align_data/sources/arbital/arbital.py similarity index 100% rename from align_data/arbital/arbital.py rename to align_data/sources/arbital/arbital.py diff --git a/align_data/articles/__init__.py b/align_data/sources/articles/__init__.py similarity index 85% rename from align_data/articles/__init__.py rename to align_data/sources/articles/__init__.py index cd93345b..04664fd6 100644 --- a/align_data/articles/__init__.py +++ b/align_data/sources/articles/__init__.py @@ -1,4 +1,4 @@ -from align_data.articles.datasets import PDFArticles, HTMLArticles, EbookArticles, XMLArticles +from align_data.sources.articles.datasets import PDFArticles, HTMLArticles, EbookArticles, XMLArticles ARTICLES_REGISTRY = [ PDFArticles( diff --git a/align_data/articles/articles.py b/align_data/sources/articles/articles.py similarity index 94% rename from align_data/articles/articles.py rename to align_data/sources/articles/articles.py index 9b670e74..f32c3e64 100644 --- a/align_data/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -3,9 +3,9 @@ from tqdm import tqdm -from align_data.articles.google_cloud import iterate_rows, get_spreadsheet, get_sheet, upload_file, OK, with_retry -from align_data.articles.parsers import item_metadata, fetch -from align_data.articles.indices import fetch_all +from align_data.sources.articles.google_cloud import iterate_rows, get_spreadsheet, get_sheet, upload_file, OK, with_retry +from align_data.sources.articles.parsers import item_metadata, fetch +from align_data.sources.articles.indices import fetch_all from align_data.settings import PDFS_FOLDER_ID diff --git a/align_data/articles/datasets.py b/align_data/sources/articles/datasets.py similarity index 95% rename from align_data/articles/datasets.py rename to align_data/sources/articles/datasets.py index aba2a8e8..fff7f051 100644 --- a/align_data/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -10,8 +10,8 @@ from gdown.download import download from markdownify import markdownify -from align_data.articles.pdf import fetch_pdf, read_pdf, fetch -from align_data.articles.parsers import HTML_PARSERS, extract_gdrive_contents +from align_data.sources.articles.pdf import fetch_pdf, read_pdf, fetch +from align_data.sources.articles.parsers import HTML_PARSERS, extract_gdrive_contents from align_data.common.alignment_dataset import AlignmentDataset logger = logging.getLogger(__name__) diff --git a/align_data/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py similarity index 100% rename from align_data/articles/google_cloud.py rename to align_data/sources/articles/google_cloud.py diff --git a/align_data/articles/html.py b/align_data/sources/articles/html.py similarity index 100% rename from align_data/articles/html.py rename to align_data/sources/articles/html.py diff --git a/align_data/articles/indices.py b/align_data/sources/articles/indices.py similarity index 98% rename from align_data/articles/indices.py rename to align_data/sources/articles/indices.py index 5da56e31..6eb5761c 100644 --- a/align_data/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -1,7 +1,7 @@ import re from collections import defaultdict -from align_data.articles.html import fetch, fetch_element +from align_data.sources.articles.html import fetch, fetch_element from align_data.common.alignment_dataset import AlignmentDataset from dateutil.parser import ParserError, parse from markdownify import MarkdownConverter diff --git a/align_data/articles/parsers.py b/align_data/sources/articles/parsers.py similarity index 98% rename from align_data/articles/parsers.py rename to align_data/sources/articles/parsers.py index 4b8faaa7..62426335 100644 --- a/align_data/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -4,8 +4,8 @@ import grobid_tei_xml import regex as re -from align_data.articles.html import element_extractor, fetch, fetch_element -from align_data.articles.pdf import doi_getter, fetch_pdf, get_pdf_from_page, get_arxiv_pdf +from align_data.sources.articles.html import element_extractor, fetch, fetch_element +from align_data.sources.articles.pdf import doi_getter, fetch_pdf, get_pdf_from_page, get_arxiv_pdf from markdownify import MarkdownConverter from bs4 import BeautifulSoup from markdownify import MarkdownConverter diff --git a/align_data/articles/pdf.py b/align_data/sources/articles/pdf.py similarity index 98% rename from align_data/articles/pdf.py rename to align_data/sources/articles/pdf.py index 7be755f1..ae4492f6 100644 --- a/align_data/articles/pdf.py +++ b/align_data/sources/articles/pdf.py @@ -9,7 +9,7 @@ from PyPDF2 import PdfReader from PyPDF2.errors import PdfReadError -from align_data.articles.html import fetch, fetch_element +from align_data.sources.articles.html import fetch, fetch_element logger = logging.getLogger(__name__) diff --git a/align_data/arxiv_papers/__init__.py b/align_data/sources/arxiv_papers/__init__.py similarity index 100% rename from align_data/arxiv_papers/__init__.py rename to align_data/sources/arxiv_papers/__init__.py diff --git a/align_data/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py similarity index 100% rename from align_data/arxiv_papers/arxiv_papers.py rename to align_data/sources/arxiv_papers/arxiv_papers.py diff --git a/align_data/audio_transcripts/__init__.py b/align_data/sources/audio_transcripts/__init__.py similarity index 100% rename from align_data/audio_transcripts/__init__.py rename to align_data/sources/audio_transcripts/__init__.py diff --git a/align_data/audio_transcripts/audio_transcripts.py b/align_data/sources/audio_transcripts/audio_transcripts.py similarity index 100% rename from align_data/audio_transcripts/audio_transcripts.py rename to align_data/sources/audio_transcripts/audio_transcripts.py diff --git a/align_data/blogs/__init__.py b/align_data/sources/blogs/__init__.py similarity index 86% rename from align_data/blogs/__init__.py rename to align_data/sources/blogs/__init__.py index 3c2c7c0c..8f1d5fc1 100644 --- a/align_data/blogs/__init__.py +++ b/align_data/sources/blogs/__init__.py @@ -1,10 +1,10 @@ -from align_data.blogs.wp_blog import WordpressBlog -from align_data.blogs.medium_blog import MediumBlog -from align_data.blogs.gwern_blog import GwernBlog -from align_data.blogs.blogs import ( +from align_data.sources.blogs.wp_blog import WordpressBlog +from align_data.sources.blogs.medium_blog import MediumBlog +from align_data.sources.blogs.gwern_blog import GwernBlog +from align_data.sources.blogs.blogs import ( ColdTakes, GenerativeInk, CaradoMoe, EleutherAI, OpenAIResearch, DeepMindTechnicalBlog ) -from align_data.blogs.substack_blog import SubstackBlog +from align_data.sources.blogs.substack_blog import SubstackBlog BLOG_REGISTRY = [ diff --git a/align_data/blogs/blogs.py b/align_data/sources/blogs/blogs.py similarity index 98% rename from align_data/blogs/blogs.py rename to align_data/sources/blogs/blogs.py index 9c6b962b..65ad95e6 100644 --- a/align_data/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -1,7 +1,7 @@ import logging import requests -from align_data.articles.parsers import item_metadata +from align_data.sources.articles.parsers import item_metadata from align_data.common.html_dataset import HTMLDataset, RSSDataset from bs4 import BeautifulSoup from dateutil.parser import ParserError diff --git a/align_data/blogs/gwern_blog.py b/align_data/sources/blogs/gwern_blog.py similarity index 100% rename from align_data/blogs/gwern_blog.py rename to align_data/sources/blogs/gwern_blog.py diff --git a/align_data/blogs/medium_blog.py b/align_data/sources/blogs/medium_blog.py similarity index 100% rename from align_data/blogs/medium_blog.py rename to align_data/sources/blogs/medium_blog.py diff --git a/align_data/blogs/substack_blog.py b/align_data/sources/blogs/substack_blog.py similarity index 100% rename from align_data/blogs/substack_blog.py rename to align_data/sources/blogs/substack_blog.py diff --git a/align_data/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py similarity index 100% rename from align_data/blogs/wp_blog.py rename to align_data/sources/blogs/wp_blog.py diff --git a/align_data/distill/__init__.py b/align_data/sources/distill/__init__.py similarity index 100% rename from align_data/distill/__init__.py rename to align_data/sources/distill/__init__.py diff --git a/align_data/distill/distill.py b/align_data/sources/distill/distill.py similarity index 100% rename from align_data/distill/distill.py rename to align_data/sources/distill/distill.py diff --git a/align_data/ebooks/__init__.py b/align_data/sources/ebooks/__init__.py similarity index 100% rename from align_data/ebooks/__init__.py rename to align_data/sources/ebooks/__init__.py diff --git a/align_data/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py similarity index 100% rename from align_data/ebooks/agentmodels.py rename to align_data/sources/ebooks/agentmodels.py diff --git a/align_data/ebooks/gdrive_ebooks.py b/align_data/sources/ebooks/gdrive_ebooks.py similarity index 100% rename from align_data/ebooks/gdrive_ebooks.py rename to align_data/sources/ebooks/gdrive_ebooks.py diff --git a/align_data/ebooks/mdebooks.py b/align_data/sources/ebooks/mdebooks.py similarity index 100% rename from align_data/ebooks/mdebooks.py rename to align_data/sources/ebooks/mdebooks.py diff --git a/align_data/gdocs/__init__.py b/align_data/sources/gdocs/__init__.py similarity index 100% rename from align_data/gdocs/__init__.py rename to align_data/sources/gdocs/__init__.py diff --git a/align_data/gdocs/gdocs.py b/align_data/sources/gdocs/gdocs.py similarity index 100% rename from align_data/gdocs/gdocs.py rename to align_data/sources/gdocs/gdocs.py diff --git a/align_data/greaterwrong/__init__.py b/align_data/sources/greaterwrong/__init__.py similarity index 100% rename from align_data/greaterwrong/__init__.py rename to align_data/sources/greaterwrong/__init__.py diff --git a/align_data/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py similarity index 100% rename from align_data/greaterwrong/greaterwrong.py rename to align_data/sources/greaterwrong/greaterwrong.py diff --git a/align_data/reports/__init__.py b/align_data/sources/reports/__init__.py similarity index 100% rename from align_data/reports/__init__.py rename to align_data/sources/reports/__init__.py diff --git a/align_data/reports/reports.py b/align_data/sources/reports/reports.py similarity index 100% rename from align_data/reports/reports.py rename to align_data/sources/reports/reports.py diff --git a/align_data/stampy/__init__.py b/align_data/sources/stampy/__init__.py similarity index 100% rename from align_data/stampy/__init__.py rename to align_data/sources/stampy/__init__.py diff --git a/align_data/stampy/stampy.py b/align_data/sources/stampy/stampy.py similarity index 100% rename from align_data/stampy/stampy.py rename to align_data/sources/stampy/stampy.py diff --git a/main.py b/main.py index 4e3f35ac..e1bce889 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,8 @@ from align_data import ALL_DATASETS, DATASET_REGISTRY, get_dataset from align_data.analysis.count_tokens import count_token -from align_data.articles.articles import update_new_items, check_new_articles +from align_data.sources.articles.articles import update_new_items, check_new_articles +from align_data.pinecone.update_pinecone import ARDUpdater from align_data.settings import ( METADATA_OUTPUT_SPREADSHEET, METADATA_SOURCE_SHEET, METADATA_SOURCE_SPREADSHEET ) @@ -121,6 +122,12 @@ def fetch_new_articles(self, source_spreadsheet=METADATA_SOURCE_SPREADSHEET, sou """ return check_new_articles(source_spreadsheet, source_sheet) + def update_pinecone(self): + """ + This function updates the Pinecone vector DB. + """ + updater = ARDUpdater() + updater.update() if __name__ == "__main__": fire.Fire(AlignmentDataset) diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index d85c3d13..1c5f7a2d 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from align_data.articles.datasets import EbookArticles, HTMLArticles, PDFArticles, SpreadsheetDataset, XMLArticles +from align_data.sources.articles.datasets import EbookArticles, HTMLArticles, PDFArticles, SpreadsheetDataset, XMLArticles @pytest.fixture diff --git a/tests/align_data/articles/test_parsers.py b/tests/align_data/articles/test_parsers.py index 7e480b22..f062343a 100644 --- a/tests/align_data/articles/test_parsers.py +++ b/tests/align_data/articles/test_parsers.py @@ -4,7 +4,7 @@ import pytest from bs4 import BeautifulSoup -from align_data.articles.parsers import ( +from align_data.sources.articles.parsers import ( google_doc, medium_blog, parse_grobid, get_content_type, extract_gdrive_contents ) diff --git a/tests/align_data/test_alignment_newsletter.py b/tests/align_data/test_alignment_newsletter.py index a0f039d7..0e9db7a4 100644 --- a/tests/align_data/test_alignment_newsletter.py +++ b/tests/align_data/test_alignment_newsletter.py @@ -2,7 +2,7 @@ import pytest import pandas as pd -from align_data.alignment_newsletter import AlignmentNewsletter +from align_data.sources.alignment_newsletter import AlignmentNewsletter @pytest.fixture(scope="module") diff --git a/tests/align_data/test_arbital.py b/tests/align_data/test_arbital.py index f2f4cd98..d0c454dd 100644 --- a/tests/align_data/test_arbital.py +++ b/tests/align_data/test_arbital.py @@ -4,7 +4,7 @@ import pytest from dateutil.parser import parse -from align_data.arbital.arbital import Arbital, extract_text, flatten, parse_arbital_link +from align_data.sources.arbital.arbital import Arbital, extract_text, flatten, parse_arbital_link @pytest.mark.parametrize('contents, expected', ( diff --git a/tests/align_data/test_blogs.py b/tests/align_data/test_blogs.py index 62de36ae..f2dc7d21 100644 --- a/tests/align_data/test_blogs.py +++ b/tests/align_data/test_blogs.py @@ -4,11 +4,11 @@ from bs4 import BeautifulSoup from dateutil.parser import parse -from align_data.blogs import ( +from align_data.sources.blogs import ( CaradoMoe, ColdTakes, GenerativeInk, GwernBlog, MediumBlog, SubstackBlog, WordpressBlog, OpenAIResearch, DeepMindTechnicalBlog ) -from align_data.blogs.blogs import EleutherAI +from align_data.sources.blogs.blogs import EleutherAI SAMPLE_HTML = """ diff --git a/tests/align_data/test_distill.py b/tests/align_data/test_distill.py index e0c8943f..b94b5bda 100644 --- a/tests/align_data/test_distill.py +++ b/tests/align_data/test_distill.py @@ -3,7 +3,7 @@ import pytest from bs4 import BeautifulSoup -from align_data.distill import Distill +from align_data.sources.distill import Distill def test_extract_authors(): diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/test_greater_wrong.py index 72b32b83..f7c273c2 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/test_greater_wrong.py @@ -5,7 +5,7 @@ import pytest -from align_data.greaterwrong.greaterwrong import ( +from align_data.sources.greaterwrong.greaterwrong import ( fetch_LW_tags, fetch_ea_forum_topics, GreaterWrong ) diff --git a/tests/align_data/test_stampy.py b/tests/align_data/test_stampy.py index 1381941a..c3694086 100644 --- a/tests/align_data/test_stampy.py +++ b/tests/align_data/test_stampy.py @@ -1,7 +1,7 @@ from unittest.mock import patch from dateutil.parser import parse -from align_data.stampy import Stampy +from align_data.sources.stampy import Stampy def test_validate_coda_token():