diff --git a/.env.example b/.env.example index 057b5ba4..1c29520b 100644 --- a/.env.example +++ b/.env.example @@ -9,3 +9,4 @@ OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" PINECONE_INDEX_NAME="stampy-chat-ard" PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" PINECONE_ENVIRONMENT="xx-xxxxx-gcp" +YOUTUBE_API_KEY="" \ No newline at end of file diff --git a/.github/workflows/fetch-daily.yml b/.github/workflows/fetch-daily.yml index 9254f461..f4a07050 100644 --- a/.github/workflows/fetch-daily.yml +++ b/.github/workflows/fetch-daily.yml @@ -14,6 +14,7 @@ jobs: with: datasource: ${{ matrix.datasource }} coda_token: ${{ inputs.coda_token }} + airtable_api_key: ${{ inputs.airtable_api_key }} youtube_api_key: ${{ inputs.youtube_api_key }} db_user: ${{ inputs.db_user }} db_password: ${{ inputs.db_password }} diff --git a/.github/workflows/fetch-dataset.yml b/.github/workflows/fetch-dataset.yml index b65a91dd..39dc3faa 100644 --- a/.github/workflows/fetch-dataset.yml +++ b/.github/workflows/fetch-dataset.yml @@ -9,6 +9,9 @@ on: coda_token: type: string required: true + airtable_api_key: + type: string + required: true youtube_api_key: type: string required: true @@ -28,44 +31,19 @@ on: type: choice options: - agentmodels - - aiimpacts - - aisafety.camp + - agisf - aisafety.info - - ai_alignment_playlist - - ai_explained - - ai_safety_talks - - ai_safety_reading_group - - ai_tech_tu_delft - - alignmentforum - alignment_newsletter + - alignmentforum - arbital - arxiv - - carado.moe - - cold_takes - - deepmind_blog - - deepmind_technical_blog + - blogs - distill - eaforum - - ebooks - - eleuther.ai - - gdocs - - generative.ink - - gwern_blog - - html_articles - - importai - indices - - jsteinhardt_blog - lesswrong - - markdown - - miri - - ml_safety_newsletter - - openai.research - - pdfs - - rob_miles_ai_safety - special_docs - - vkrakovna_blog - - yudkowsky_blog - - xmls + - youtube jobs: build-dataset: @@ -93,6 +71,7 @@ jobs: - name: Process dataset env: CODA_TOKEN: ${{ secrets.CODA_TOKEN || inputs.coda_token }} + AIRTABLE_API_KEY: ${{ secrets.AIRTABLE_API_KEY || inputs.airtable_api_key }} YOUTUBE_API_KEY: ${{ secrets.YOUTUBE_API_KEY || inputs.youtube_api_key }} ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }} ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }} diff --git a/.github/workflows/fetch-weekly.yml b/.github/workflows/fetch-weekly.yml index e8850ec4..64d17c07 100644 --- a/.github/workflows/fetch-weekly.yml +++ b/.github/workflows/fetch-weekly.yml @@ -14,6 +14,7 @@ jobs: with: datasource: ${{ matrix.datasource }} coda_token: ${{ inputs.coda_token }} + airtable_api_key: ${{ inputs.airtable_api_key }} youtube_api_key: ${{ inputs.youtube_api_key }} db_user: ${{ inputs.db_user }} db_password: ${{ inputs.db_password }} diff --git a/.github/workflows/push-dataset.yml b/.github/workflows/push-dataset.yml index 2ecd539d..1b48a3b4 100644 --- a/.github/workflows/push-dataset.yml +++ b/.github/workflows/push-dataset.yml @@ -24,37 +24,17 @@ on: options: - all - agentmodels - - aiimpacts - - aisafety.camp + - agisf - aisafety.info - - ai_alignment_playlist - - ai_explained - - ai_safety_talks - - ai_safety_reading_group - - ai_tech_tu_delft - alignmentforum - arbital - arxiv - - carado.moe - - cold_takes - - deepmind_blog - - deepmind_technical_blog + - blogs - distill - eaforum - - eleuther.ai - - gdocs - - generative.ink - - gwern_blog - - importai - - jsteinhardt_blog - lesswrong - - miri - - ml_safety_newsletter - - openai.research - - rob_miles_ai_safety - special_docs - - vkrakovna_blog - - yudkowsky_blog + - youtube jobs: generate-dataset: diff --git a/.github/workflows/update-metadata.yml b/.github/workflows/update-metadata.yml index 70e2ddd5..fe95f002 100644 --- a/.github/workflows/update-metadata.yml +++ b/.github/workflows/update-metadata.yml @@ -31,4 +31,9 @@ jobs: run: curl -L "${{ inputs.csv_url }}" -o data.csv - name: Run Script + env: + ARD_DB_USER: ${{ secrets.ARD_DB_USER }} + ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD }} + ARD_DB_HOST: ${{ secrets.ARD_DB_HOST }} + ARD_DB_NAME: alignment_research_dataset run: python main.py update data.csv ${{ inputs.delimiter }} diff --git a/.github/workflows/update-pinecone.yml b/.github/workflows/update-pinecone.yml index 0cb7924b..b87d9ed7 100644 --- a/.github/workflows/update-pinecone.yml +++ b/.github/workflows/update-pinecone.yml @@ -32,43 +32,18 @@ on: options: - all - agentmodels - - aiimpacts - - aisafety.camp + - agisf - aisafety.info - - ai_alignment_playlist - - ai_explained - - ai_safety_talks - - ai_safety_reading_group - - ai_tech_tu_delft - alignmentforum - arbital - arxiv - - carado.moe - - cold_takes - - deepmind_blog - - deepmind_technical_blog + - blogs - distill - eaforum - - ebooks - - eleuther.ai - - gdocs - - generative.ink - - gwern_blog - - html_articles - - importai - indices - - jsteinhardt_blog - lesswrong - - markdown - - miri - - ml_safety_newsletter - - openai.research - - pdfs - - rob_miles_ai_safety - special_docs - - vkrakovna_blog - - yudkowsky_blog - - xmls + - youtube jobs: build-dataset: @@ -78,28 +53,28 @@ jobs: - name: Checkout repository uses: actions/checkout@v2 - - name: Setup Python environment - uses: actions/setup-python@v2 - with: - python-version: '3.x' + # - name: Setup Python environment + # uses: actions/setup-python@v2 + # with: + # python-version: '3.x' - - name: Install dependencies - run: | - pip install -r requirements.txt; - python -c 'import nltk; nltk.download("punkt")' + # - name: Install dependencies + # run: | + # pip install -r requirements.txt; + # python -c 'import nltk; nltk.download("punkt")' - - name: Process dataset - env: - ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }} - ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }} - ARD_DB_HOST: ${{ secrets.ARD_DB_HOST || inputs.db_host }} - ARD_DB_NAME: alignment_research_dataset - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || inputs.openai_api_key }} - PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY || inputs.pinecone_api_key }} - PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT || inputs.pinecone_environment }} - run: | - if [ "${{ inputs.datasource }}" = "all" ]; then - python main.py pinecone_update_all - else - python main.py pinecone_update ${{ inputs.datasource }} - fi + # - name: Process dataset + # env: + # ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }} + # ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }} + # ARD_DB_HOST: ${{ secrets.ARD_DB_HOST || inputs.db_host }} + # ARD_DB_NAME: alignment_research_dataset + # OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || inputs.openai_api_key }} + # PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY || inputs.pinecone_api_key }} + # PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT || inputs.pinecone_environment }} + # run: | + # if [ "${{ inputs.datasource }}" = "all" ]; then + # python main.py pinecone_update_all + # else + # python main.py pinecone_update ${{ inputs.datasource }} + # fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..61a81a7b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer +- repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + language_version: python3.11 diff --git a/README.md b/README.md index 5afbcf19..b01779df 100644 --- a/README.md +++ b/README.md @@ -4,63 +4,69 @@ The AI Alignment Research Dataset is a collection of documents related to AI Ali ## Sources -The following list of sources may change and items may be renamed: - -- [agentmodels](https://agentmodels.org/) -- [aiimpacts](https://aiimpacts.org/) -- [aisafety.camp](https://aisafety.camp/) -- [aisafety.info](https://aisafety.info/) -- [ai_alignment_playlist]() -- [ai_explained](https://www.youtube.com/@ai-explained-) -- [ai_safety_talks](https://www.youtube.com/@aisafetytalks) -- [ai_safety_reading_group](https://www.youtube.com/@aisafetyreadinggroup/videos) -- [ai_tech_tu_delft](https://www.youtube.com/@AiTechTUDelft/) +Here are the list of sources along with sample contents: + +- [agentmodel](https://agentmodels.org/) +- [agisf](https://course.aisafetyfundamentals.com/) - recommended readings from AGI Safety Fundamentals +- [aisafety.info](https://aisafety.info/) - Stampy's FAQ - [alignmentforum](https://www.alignmentforum.org) - [alignment_newsletter](https://rohinshah.com/alignment-newsletter/) - [arbital](https://arbital.com/) -- arxiv - alignment research papers from [arxiv](https://arxiv.org/) -- [carado.moe](https://carado.moe/) -- [cold_takes](https://www.cold-takes.com/) -- [deepmind_blog](https://deepmindsafetyresearch.medium.com/) -- [deepmind_technical_blog](https://www.deepmind.com/blog-categories/technical-blogs) +- [arxiv](https://arxiv.org/) - relevant research papers + +- blogs - entire websites automatically scraped + - [AI Impacts](https://aiimpacts.org/) + - [AI Safety Camp](https://aisafety.camp/) + - [carado.moe](https://carado.moe/) + - [Cold Takes](https://www.cold-takes.com/) + - [DeepMind technical blogs](https://www.deepmind.com/blog-categories/technical-blogs) + - [DeepMind AI Safety Research](https://deepmindsafetyresearch.medium.com/) + - [EleutherAI](https://blog.eleuther.ai/) + - [generative.ink](https://generative.ink/posts/) + - [Gwern Branwen's blog](https://gwern.net/) + - [Jack Clark's Import AI](https://importai.substack.com/) + - [MIRI](https://intelligence.org/) + - [Jacob Steinhardt's blog](https://jsteinhardt.wordpress.com/) + - [ML Safety Newsletter](https://newsletter.mlsafety.org/) + - [Transformer Circuits Thread](https://transformer-circuits.pub/) + - [Open AI Research](https://openai.com/research/) + - [Victoria Krakovna's blog](https://vkrakovna.wordpress.com/) + - [Eliezer Yudkowsky's blog](https://www.yudkowsky.net/) + - [distill](https://distill.pub/) - [eaforum](https://forum.effectivealtruism.org/) - selected posts -- [eleuther.ai](https://blog.eleuther.ai/) -- [generative.ink](https://generative.ink/posts/) -- [gwern_blog](https://gwern.net/) -- gdocs - various doc files stored on Google drive -- html_articles - various articles on websites -- [import.ai](https://importai.substack.com) -- [jsteinhardt_blog](https://jsteinhardt.wordpress.com/) - [lesswrong](https://www.lesswrong.com/) - selected posts -- markdown -- [miri](https://intelligence.org/) - MIRI -- [ml_safety_newsletter](https://newsletter.mlsafety.org) -- [openai.research](https://openai.com/research) -- pdfs - various pdfs from different places -- [rob_miles_ai_safety](https://www.youtube.com/@RobertMilesAI) -- [vkrakovna_blog](https://vkrakovna.wordpress.com) -- [waitbutwhy](https://waitbutwhy.com/) -- [yudkowsky_blog](https://www.yudkowsky.net/) -- xmls - various articles stored as XML files +- special_docs - individual documents curated from various resources + - [Make a suggestion](https://bit.ly/ard-suggestion) for sources not already in the dataset + +- youtube - playlists & channels + - [AI Alignment playlist](https://www.youtube.com/playlist?list=PLCRVRLd2RhZTpdUdEzJjo3qhmX3y3skWA) and other lists + - [AI Explained](https://www.youtube.com/@aiexplained-official) + - [Evan Hubinger's AI Safety Talks](https://www.youtube.com/@aisafetytalks) + - [AI Safety Reading Group](https://www.youtube.com/@aisafetyreadinggroup/videos) + - [AiTech - TU Delft](https://www.youtube.com/@AiTechTUDelft/) + - [Rob Miles AI](https://www.youtube.com/@RobertMilesAI) ## Keys -Not all of the entries contain the same keys, but they all have the following: +All entries contain the following keys: -- `id` - unique identifier -- `source` - based on the data source listed in the previous section -- `title` - title of document +- `id` - string of unique identifier +- `source` - string of data source listed above +- `title` - string of document title of document +- `authors` - list of strings - `text` - full text of document content -- `url` - some values may be `'n/a'`, still being updated -- `date_published` - some `'n/a'` +- `url` - string of valid link to text content +- `date_published` - in UTC format -The values of the keys are still being cleaned up for consistency. Additional keys are available depending on the source document. +Additional keys may be available depending on the source document. ## Development Environment -To set up the development environment, run the following steps. You'll have to also set up [mysqlclient](https://pypi.org/project/mysqlclient/): +Follow the [instructions to install **mysqlclient** on your operating system](https://pypi.org/project/mysqlclient/) toward the middle to bottom of the linked page. + +To set up the development environment, run the following steps: ```bash git clone https://github.com/StampyAI/alignment-research-dataset diff --git a/align_data/__init__.py b/align_data/__init__.py index 54041500..68b2ba2c 100644 --- a/align_data/__init__.py +++ b/align_data/__init__.py @@ -1,5 +1,6 @@ import align_data.sources.arbital as arbital import align_data.sources.articles as articles +import align_data.sources.agisf as agisf import align_data.sources.blogs as blogs import align_data.sources.ebooks as ebooks import align_data.sources.greaterwrong as greaterwrong @@ -11,6 +12,7 @@ DATASET_REGISTRY = ( arbital.ARBITAL_REGISTRY + articles.ARTICLES_REGISTRY + + agisf.AGISF_DATASETS + blogs.BLOG_REGISTRY + ebooks.EBOOK_REGISTRY + greaterwrong.GREATERWRONG_REGISTRY diff --git a/align_data/analysis/analyse_jsonl_data.py b/align_data/analysis/analyse_jsonl_data.py index 1fb3f526..9ef49649 100644 --- a/align_data/analysis/analyse_jsonl_data.py +++ b/align_data/analysis/analyse_jsonl_data.py @@ -69,9 +69,7 @@ def process_jsonl_files(data_dir): for id, duplicates in seen_urls.items(): if len(duplicates) > 1: - list_of_duplicates = "\n".join( - get_data_dict_str(duplicate) for duplicate in duplicates - ) + list_of_duplicates = "\n".join(get_data_dict_str(duplicate) for duplicate in duplicates) print( f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n" ) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 5e48b999..a1637a51 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -17,6 +17,8 @@ from align_data.db.models import Article, Summary from align_data.db.session import make_session from align_data.settings import ARTICLE_MAIN_KEYS +from align_data.sources.utils import merge_dicts + logger = logging.getLogger(__name__) @@ -69,12 +71,13 @@ def _add_authors(self, article: Article, authors: List[str]) -> Article: article.authors = ",".join(article.authors[:1024].split(",")[:-1]) return article - def make_data_entry(self, data: Dict[str, Any], **kwargs) -> Article: - data = dict(data, **kwargs) + def make_data_entry(self, data, **kwargs) -> Article: + data = merge_dicts(data, kwargs) summary = data.pop("summary", None) authors = data.pop("authors", []) article = Article( + pinecone_update_required=True, meta={k: v for k, v in data.items() if k not in ARTICLE_MAIN_KEYS and v is not None}, **{k: v for k, v in data.items() if k in ARTICLE_MAIN_KEYS}, ) @@ -158,14 +161,9 @@ def _load_outputted_items(self) -> Set[str]: # This doesn't filter by self.name. The good thing about that is that it should handle a lot more # duplicates. The bad thing is that this could potentially return a massive amount of data if there # are lots of items. - return set( - session.scalars(select(getattr(Article, self.done_key))).all() - ) - return { - meta[self.done_key] - for meta in session.scalars(select(Article.meta)).all() - if isinstance(meta, JSON) and meta.get(self.done_key) - } + return set(session.scalars(select(getattr(Article, self.done_key))).all()) + # TODO: Properly handle this - it should create a proper SQL JSON select + return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()} def not_processed(self, item): # NOTE: `self._outputted_items` reads in all items. Which could potentially be a lot. If this starts to @@ -209,7 +207,7 @@ def fetch_entries(self) -> Generator[Article, None, None]: if self.COOLDOWN: time.sleep(self.COOLDOWN) - def process_entry(self, entry) -> Optional[Article]: + def process_entry(self, entry) -> Article | None: """Process a single entry.""" raise NotImplementedError @@ -218,7 +216,7 @@ def _format_datetime(date: datetime) -> str: return date.strftime("%Y-%m-%dT%H:%M:%SZ") @staticmethod - def _get_published_date(date: str) -> Optional[datetime]: + def _get_published_date(date) -> datetime | None: try: # Totally ignore any timezone info, forcing everything to UTC return parse(str(date)).replace(tzinfo=pytz.UTC) @@ -234,7 +232,11 @@ def unprocessed_items(self, items=None) -> Iterable: urls = map(self.get_item_key, items) with make_session() as session: - articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls)) + articles = ( + session.query(Article) + .options(joinedload(Article.summaries)) + .filter(Article.url.in_(urls)) + ) self.articles = {a.url: a for a in articles if a.url} return items @@ -244,9 +246,7 @@ def _load_outputted_items(self) -> Set[str]: with make_session() as session: return set( session.scalars( - select(Article.url) - .join(Article.summaries) - .filter(Summary.source == self.name) + select(Article.url).join(Article.summaries).filter(Summary.source == self.name) ) ) @@ -257,3 +257,40 @@ def merge(item): return item session.add_all(map(merge, batch)) + + +@dataclass +class MultiDataset(AlignmentDataset): + + datasets: List[AlignmentDataset] + + @property + def names(self): + return [dataset.name for dataset in self.datasets] + + @property + def items_list(self) -> Iterable: + """Returns a collection of items to be processed.""" + return ((item, dataset) for dataset in self.datasets for item in dataset.items_list) + + def setup(self): + for dataset in self.datasets: + dataset.setup() + + def get_item_key(self, entry): + item, dataset = entry + return dataset.get_item_key(item) + + def process_entry(self, entry) -> Optional[Article]: + item, dataset = entry + article = dataset.process_entry(item) + article.add_meta('initial_source', article.source) + article.source = self.name + + def fetch_entries(self): + for dataset in self.datasets: + for article in dataset.fetch_entries(): + if article.source != self.name: + article.add_meta('initial_source', article.source) + article.source = self.name + yield article diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py index fc948687..e9c3aa2d 100644 --- a/align_data/common/html_dataset.py +++ b/align_data/common/html_dataset.py @@ -75,7 +75,7 @@ def get_contents(self, article_url: str) -> Dict[str, Any]: def process_entry(self, article: Tag) -> Article: article_url = self.get_item_key(article) contents = self.get_contents(article_url) - if not contents.get('text'): + if not contents.get("text"): return None return self.make_data_entry(contents) @@ -149,9 +149,12 @@ def fetch_contents(self, url: str): soup=soup, ) + def _extract_item_url(self, item) -> str | None: + return item.get('link') + @property def items_list(self): logger.info(f"Fetching entries from {self.feed_url}") feed = feedparser.parse(self.feed_url) - self.items = {item["link"]: item for item in feed["entries"]} + self.items = {url: item for item in feed["entries"] if (url := self._extract_item_url(item))} return list(self.items.keys()) diff --git a/align_data/db/models.py b/align_data/db/models.py index 43d23caa..0161ff50 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -17,7 +17,6 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.ext.hybrid import hybrid_property -from align_data.settings import PINECONE_METADATA_KEYS logger = logging.getLogger(__name__) @@ -71,35 +70,26 @@ class Article(Base): def __repr__(self) -> str: return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})" - def is_metadata_keys_equal(self, other): - if not isinstance(other, Article): - raise TypeError( - f"Expected an instance of Article, got {type(other).__name__}" - ) - return all( - getattr(self, key, None) == getattr(other, key, None) - for key in PINECONE_METADATA_KEYS - ) # entry_id is implicitly ignored - def generate_id_string(self) -> bytes: - return ( - "".join( - str(getattr(self, field)) - for field in self.__id_fields - ).encode("utf-8") - ) + return "".join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8") @property def __id_fields(self): - if self.source == 'aisafety.info': - return ['url'] - if self.source in ['importai', 'ml_safety_newsletter', 'alignment_newsletter']: - return ['url', 'title', 'source'] + if self.source == "aisafety.info": + return ["url"] + if self.source in ["importai", "ml_safety_newsletter", "alignment_newsletter"]: + return ["url", "title", "source"] return ["url", "title"] @property def missing_fields(self): - fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'} + fields = set(self.__id_fields) | { + "text", + "title", + "url", + "source", + "date_published", + } return sorted([field for field in fields if not getattr(self, field, None)]) def verify_id(self): @@ -142,13 +132,21 @@ def add_meta(self, key: str, val): # TODO: verify that this actually updates the meta column; # https://amercader.net/blog/beware-of-json-fields-in-sqlalchemy/ + def append_comment(self, comment: str): + """Appends a comment to the article.comments field. You must run session.commit() to save the comment to the database.""" + if self.comments is None: + self.comments = "" + self.comments = f"{self.comments}\n\n{comment}".strip() + @hybrid_property def is_valid(self): - return ( - self.text and self.text.strip() and - self.url and self.title and - self.authors is not None and - self.status == OK_STATUS + return bool( + self.text + and self.text.strip() + and self.url + and self.title + and self.authors is not None + and self.status == OK_STATUS ) @is_valid.expression @@ -166,9 +164,11 @@ def before_write(cls, _mapper, _connection, target: "Article"): target.verify_id_fields() if not target.status and target.missing_fields: - target.status = 'Missing fields' + target.status = "Missing fields" target.comments = f'missing fields: {", ".join(target.missing_fields)}' + target.pinecone_update_required = target.is_valid + if target.id: target.verify_id() else: diff --git a/align_data/db/session.py b/align_data/db/session.py index f388cf73..75040812 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -2,31 +2,45 @@ import logging from contextlib import contextmanager -from sqlalchemy import create_engine +from sqlalchemy import create_engine, or_ from sqlalchemy.orm import Session - -from align_data.settings import DB_CONNECTION_URI +from align_data.settings import DB_CONNECTION_URI, MIN_CONFIDENCE from align_data.db.models import Article logger = logging.getLogger(__name__) +# We create a single engine for the entire application +engine = create_engine(DB_CONNECTION_URI, echo=False) + @contextmanager -def make_session(auto_commit: bool = False) -> Generator[Session, None, None]: - engine = create_engine(DB_CONNECTION_URI, echo=False) +def make_session(auto_commit=False): with Session(engine, autoflush=False) as session: yield session if auto_commit: session.commit() -def stream_pinecone_updates(session: Session, custom_sources: List[str]) -> Generator[Article, None, None]: +def stream_pinecone_updates( + session: Session, custom_sources: List[str], force_update: bool = False +): """Yield Pinecone entries that require an update.""" yield from ( - session - .query(Article) - .filter(Article.pinecone_update_required.is_(True)) + session.query(Article) + .filter(or_(Article.pinecone_update_required.is_(True), force_update)) .filter(Article.is_valid) .filter(Article.source.in_(custom_sources)) + .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) .yield_per(1000) ) + + +def get_all_valid_article_ids(session: Session) -> List[str]: + """Return all valid article IDs.""" + query_result = ( + session.query(Article.id) + .filter(Article.is_valid) + .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) + .all() + ) + return [item[0] for item in query_result] diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py new file mode 100644 index 00000000..6ef57b88 --- /dev/null +++ b/align_data/embeddings/embedding_utils.py @@ -0,0 +1,199 @@ +import logging +from typing import List, Tuple, Dict, Any, Optional +from functools import wraps + +import openai +from langchain.embeddings import HuggingFaceEmbeddings +from openai.error import ( + OpenAIError, + RateLimitError, + APIError, +) +from tenacity import ( + retry, + stop_after_attempt, + wait_random_exponential, + retry_if_exception_type, + retry_if_exception, +) + +from align_data.embeddings.pinecone.pinecone_models import MissingEmbeddingModelError +from align_data.settings import ( + USE_OPENAI_EMBEDDINGS, + OPENAI_EMBEDDINGS_MODEL, + EMBEDDING_LENGTH_BIAS, + SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, + DEVICE, +) + + +# -------------------- +# CONSTANTS & CONFIGURATION +# -------------------- + +logger = logging.getLogger(__name__) + +hf_embedding_model = None +if not USE_OPENAI_EMBEDDINGS: + hf_embedding_model = HuggingFaceEmbeddings( + model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, + model_kwargs={"device": DEVICE}, + encode_kwargs={"show_progress_bar": False}, + ) + +ModerationInfoType = Dict[str, Any] + + +# -------------------- +# DECORATORS +# -------------------- + + +def handle_openai_errors(func): + """Decorator to handle OpenAI-specific exceptions with retries.""" + + @wraps(func) + @retry( + wait=wait_random_exponential(multiplier=1, min=2, max=30), + stop=stop_after_attempt(6), + retry=retry_if_exception_type(RateLimitError) + | retry_if_exception_type(APIError) + | retry_if_exception(lambda e: "502" in str(e)), + ) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except RateLimitError as e: + logger.warning(f"OpenAI Rate limit error. Trying again. Error: {e}") + raise + except APIError as e: + if "502" in str(e): + logger.warning(f"OpenAI 502 Bad Gateway error. Trying again. Error: {e}") + else: + logger.error(f"OpenAI API Error encountered: {e}") + raise + except OpenAIError as e: + logger.error(f"OpenAI Error encountered: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error encountered: {e}") + raise + + return wrapper + + +# -------------------- +# MAIN FUNCTIONS +# -------------------- + + +@handle_openai_errors +def moderation_check(texts: List[str]) -> List[ModerationInfoType]: + return openai.Moderation.create(input=texts)["results"] + + +@handle_openai_errors +def _compute_openai_embeddings(non_flagged_texts: List[str], **kwargs) -> List[List[float]]: + data = openai.Embedding.create(input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data + return [d["embedding"] for d in data] + + +def get_embeddings_without_moderation( + texts: List[str], + source: Optional[str] = None, + **kwargs, +) -> List[List[float]]: + """ + Obtain embeddings without moderation checks. + + Parameters: + - texts (List[str]): List of texts to be embedded. + - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None. + - **kwargs: Additional keyword arguments passed to the embedding function. + + Returns: + - List[List[float]]: List of embeddings for the provided texts. + """ + if not texts: + return [] + + texts = [text.replace("\n", " ") for text in texts] + if USE_OPENAI_EMBEDDINGS: + embeddings = _compute_openai_embeddings(texts, **kwargs) + elif hf_embedding_model: + embeddings = hf_embedding_model.embed_documents(texts) + else: + raise MissingEmbeddingModelError("No embedding model available.") + + # Bias adjustment + if source and (bias := EMBEDDING_LENGTH_BIAS.get(source, 1.0)): + embeddings = [[bias * e for e in embedding] for embedding in embeddings] + + return embeddings + + +def get_embeddings_or_none_if_flagged( + texts: List[str], + source: Optional[str] = None, + **kwargs, +) -> Tuple[List[List[float]] | None, List[ModerationInfoType]]: + """ + Obtain embeddings for the provided texts. If any text is flagged during moderation, + the function returns None for the embeddings while still providing the moderation results. + + Parameters: + - texts (List[str]): List of texts to be embedded. + - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None. + - **kwargs: Additional keyword arguments passed to the embedding function. + + Returns: + - Tuple[Optional[List[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (or None if any text is flagged) and the moderation results. + """ + moderation_results = moderation_check(texts) + if any(result["flagged"] for result in moderation_results): + return None, moderation_results + + embeddings = get_embeddings_without_moderation(texts, source, **kwargs) + return embeddings, moderation_results + + +def get_embeddings( + texts: List[str], + source: Optional[str] = None, + **kwargs, +) -> Tuple[List[List[float] | None], List[ModerationInfoType]]: + """ + Obtain embeddings for the provided texts, replacing the embeddings of flagged texts with `None`. + + Parameters: + - texts (List[str]): List of texts to be embedded. + - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None. + - **kwargs: Additional keyword arguments passed to the embedding function. + + Returns: + - Tuple[List[Optional[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results. + """ + assert len(texts) <= 2048, "The batch size should not be larger than 2048." + assert all(texts), "No empty strings allowed in the input list." + + # replace newlines, which can negatively affect performance + texts = [text.replace("\n", " ") for text in texts] + + # Check all texts for moderation flags + moderation_results = moderation_check(texts) + flags = [result["flagged"] for result in moderation_results] + + non_flagged_texts = [text for text, flag in zip(texts, flags) if not flag] + non_flagged_embeddings = get_embeddings_without_moderation( + non_flagged_texts, source, **kwargs + ) + embeddings = [None if flag else non_flagged_embeddings.pop(0) for flag in flags] + return embeddings, moderation_results + + +def get_embedding( + text: str, source: Optional[str] = None, **kwargs +) -> Tuple[List[float] | None, ModerationInfoType]: + """Obtain an embedding for a single text.""" + embedding, moderation_result = get_embeddings([text], source, **kwargs) + return embedding[0], moderation_result[0] diff --git a/align_data/embeddings/finetuning/data/best_finetuned_model.pth b/align_data/embeddings/finetuning/data/best_finetuned_model.pth new file mode 100644 index 00000000..e05a5d52 Binary files /dev/null and b/align_data/embeddings/finetuning/data/best_finetuned_model.pth differ diff --git a/align_data/embeddings/finetuning/data/finetuned_model.pth b/align_data/embeddings/finetuning/data/finetuned_model.pth new file mode 100644 index 00000000..7ba44ac6 Binary files /dev/null and b/align_data/embeddings/finetuning/data/finetuned_model.pth differ diff --git a/align_data/embeddings/finetuning/finetuning_dataset.py b/align_data/embeddings/finetuning/finetuning_dataset.py new file mode 100644 index 00000000..8c5eec04 --- /dev/null +++ b/align_data/embeddings/finetuning/finetuning_dataset.py @@ -0,0 +1,105 @@ +import math +import random +from typing import List, Tuple, Generator +from collections import deque + +import torch +from torch.utils.data import IterableDataset, get_worker_info +from sqlalchemy.exc import OperationalError +from sqlalchemy.sql import func +from sqlalchemy.orm import Session + +from align_data.db.session import make_session, get_all_valid_article_ids +from align_data.embeddings.embedding_utils import get_embedding +from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB +from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter +from align_data.embeddings.pinecone.update_pinecone import get_text_chunks +from align_data.db.models import Article + + +class FinetuningDataset(IterableDataset): + def __init__(self, num_batches_per_epoch: int, cache_size: int = 1280): + self.num_batches_per_epoch = num_batches_per_epoch + self.article_cache = deque(maxlen=cache_size) + + self.text_splitter = ParagraphSentenceUnitTextSplitter() + self.pinecone_db = PineconeDB() + + with make_session() as session: + self.all_article_ids = get_all_valid_article_ids(session) + self.total_articles = len(self.all_article_ids) + + def __len__(self): + return self.num_batches_per_epoch + + def __iter__(self): + start, end = 0, None + worker_info = get_worker_info() + if worker_info is not None: # Multi-process loading + per_worker = math.ceil(self.total_articles / worker_info.num_workers) + start = worker_info.id * per_worker + end = min(start + per_worker, self.total_articles) + + with make_session() as session: + return self._generate_pairs(session, start, end) + + def _fetch_random_articles(self, session: Session, batch_size: int = 1) -> List[Article]: + """Fetch a batch of random articles.""" + # If the list has fewer IDs than needed, raise an exception + random_selected_ids = random.sample(self.all_article_ids, batch_size) + return session.query(Article).filter(Article.id.in_(random_selected_ids)).all() + + def _get_random_chunks(self, article: Article, num_chunks: int = 2) -> List[Tuple[int, str]]: + chunked_text = get_text_chunks(article, self.text_splitter) + + chunks = list(enumerate(chunked_text)) + if len(chunks) < num_chunks: + return [] + + return random.sample(chunks, num_chunks) + + def _get_embeddings(self, article: Article, chunks: List[Tuple[int, str]]) -> List[List[float]]: + full_ids = [f"{article.id}_{str(idx).zfill(6)}" for idx, _ in chunks] + _embeddings = self.pinecone_db.get_embeddings_by_ids(full_ids) + + embeddings = [] + for (_, chunk), (_, embedding) in zip(chunks, _embeddings): + if embedding is None: + embedding, _ = get_embedding(chunk, article.source) + embeddings.append(torch.tensor(embedding)) + + return embeddings + + def _generate_pairs( + self, session, start=0, end=None, neg_pos_proportion=0.5 + ) -> Generator[Tuple[List[float], List[float], int], None, None]: + end = end or self.total_articles + + batches_yielded = 0 + while start < end: + start += 1 + if random.random() < neg_pos_proportion: + # Positive pairs + article = self._fetch_random_articles(session)[0] + chunks = self._get_random_chunks(article, 2) + if not chunks: + continue + embedding_1, embedding_2 = self._get_embeddings(article, chunks) + label = 1 + else: + # Negative pairs + article1, article2 = self._fetch_random_articles(session, batch_size=2) + chunk1 = self._get_random_chunks(article1, 1) + chunk2 = self._get_random_chunks(article2, 1) + embedding_1, embedding_2 = ( + self._get_embeddings(article1, chunk1)[0], + self._get_embeddings(article2, chunk2)[0], + ) + label = 0 + yield torch.tensor(embedding_1, dtype=torch.int64), torch.tensor( + embedding_2, dtype=torch.int64 + ), torch.tensor(label, dtype=torch.int64) + batches_yielded += 1 + + if self.num_batches_per_epoch and batches_yielded >= self.num_batches_per_epoch: + break diff --git a/align_data/embeddings/finetuning/training.py b/align_data/embeddings/finetuning/training.py new file mode 100644 index 00000000..cd1d7845 --- /dev/null +++ b/align_data/embeddings/finetuning/training.py @@ -0,0 +1,177 @@ +import os + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader + +from align_data.embeddings.finetuning.finetuning_dataset import FinetuningDataset +from align_data.settings import ( + PINECONE_VALUES_DIMS, + DEVICE, + OPENAI_FINETUNED_LAYER_PATH, + OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH, +) + + +class ContrastiveLoss(nn.Module): + def __init__(self, margin=2.0): + super(ContrastiveLoss, self).__init__() + self.margin = margin + + def forward(self, output1, output2, label): + euclidean_distance = nn.functional.pairwise_distance(output1, output2) + loss_contrastive = torch.mean( + (1 - label) * torch.pow(euclidean_distance, 2) + + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2) + ) + + return loss_contrastive + + +class NonLinearFineTuneModel(nn.Module): + def __init__(self, embedding_dim=PINECONE_VALUES_DIMS, hidden_dim=2000, dropout=0.5): + super(FineTuneModel, self).__init__() + + self.fc1 = nn.Linear(embedding_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, embedding_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = nn.functional.relu(self.fc1(x)) + x = self.dropout(x) + x = self.fc2(x) + return x + + +class FineTuneModel(nn.Module): + def __init__(self, embedding_dim=PINECONE_VALUES_DIMS): + super(FineTuneModel, self).__init__() + + self.fc = nn.Linear(embedding_dim, embedding_dim) + + def forward(self, x): + x = self.fc(x) + return x + + +def train(model, dataloader, optimizer, criterion): + model.train() + total_loss = 0.0 + + for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader): + text1_embedding = text1_embedding.to(DEVICE) + text2_embedding = text2_embedding.to(DEVICE) + target = target.float().to(DEVICE) + + optimizer.zero_grad() + + output1 = model(text1_embedding) + output2 = model(text2_embedding) + + loss = criterion(output1, output2, target) + loss.backward() + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + optimizer.step() + + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def validate(model, dataloader, criterion): + model.eval() + total_loss = 0.0 + + with torch.no_grad(): + for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader): + text1_embedding = text1_embedding.to(DEVICE) + text2_embedding = text2_embedding.to(DEVICE) + target = target.float().to(DEVICE) + + output1 = model(text1_embedding) + output2 = model(text2_embedding) + + loss = criterion(output1, output2, target) + total_loss += loss.item() + + return total_loss / len(dataloader) + + +def finetune_embeddings(): + # Hyperparameters & Configuration + EPOCHS = 100 + BATCH_PER_EPOCH = 20 + BATCH_SIZE = 64 + LEARNING_RATE = 5.0000e-02 + MARGIN = 2.0 + + dataset = FinetuningDataset(num_batches_per_epoch=BATCH_PER_EPOCH) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=5) + + model = FineTuneModel().to(DEVICE) + model = load_best_model_if_exists(model) + optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) + scheduler = ReduceLROnPlateau(optimizer, "min", patience=2, factor=0.5, verbose=True) + criterion = ContrastiveLoss(MARGIN) + + # Assuming you've split your data and have a separate validation set + validation_dataset = FinetuningDataset(num_batches_per_epoch=BATCH_PER_EPOCH) + validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, num_workers=5) + best_val_loss = validate(model, validation_dataloader, criterion) + print(f"Initial validation loss (from loaded model or new model): {best_val_loss:.4f}") + + epochs_without_improvement = 0 + max_epochs_without_improvement = 15 # stop after 5 epochs without improvement + + for epoch in range(EPOCHS): + train_loss = train(model, dataloader, optimizer, criterion) + validate_loss = validate(model, validation_dataloader, criterion) + + scheduler.step(validate_loss) + if validate_loss < best_val_loss: + best_val_loss = validate_loss + torch.save(model.state_dict(), OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH) + epochs_without_improvement = 0 + else: + epochs_without_improvement += 1 + + print( + f"Epoch: {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Validation Loss: {validate_loss:.4f}" + ) + + if epochs_without_improvement >= max_epochs_without_improvement: + print("Early stopping due to no improvement in validation loss.") + break + + torch.save(model.state_dict(), OPENAI_FINETUNED_LAYER_PATH) + + +### HELPER FUNCTIONS ### + + +def load_best_model_if_exists(model): + """ + Load the best saved model if it exists. + + Parameters: + - model (torch.nn.Module): The model architecture. + + Returns: + - model (torch.nn.Module): The loaded model. + """ + if os.path.exists(OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH): + model.load_state_dict( + torch.load(OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH, map_location=DEVICE) + ) + print(f"Loaded model from {OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH}.") + else: + print( + f"No saved model found at {OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH}. Starting from scratch." + ) + + return model diff --git a/align_data/pinecone/__init__.py b/align_data/embeddings/pinecone/__init__.py similarity index 100% rename from align_data/pinecone/__init__.py rename to align_data/embeddings/pinecone/__init__.py diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py new file mode 100644 index 00000000..dd2990d3 --- /dev/null +++ b/align_data/embeddings/pinecone/pinecone_db_handler.py @@ -0,0 +1,142 @@ +# dataset/pinecone_db_handler.py +import logging +from typing import List, Tuple + +import pinecone +from pinecone.core.client.models import ScoredVector + +from align_data.embeddings.embedding_utils import get_embedding +from align_data.embeddings.pinecone.pinecone_models import ( + PineconeEntry, + PineconeMetadata, +) +from align_data.settings import ( + PINECONE_INDEX_NAME, + PINECONE_VALUES_DIMS, + PINECONE_METRIC, + PINECONE_API_KEY, + PINECONE_ENVIRONMENT, + PINECONE_NAMESPACE, +) + + +logger = logging.getLogger(__name__) + + +class PineconeDB: + def __init__( + self, + index_name: str = PINECONE_INDEX_NAME, + values_dims: int = PINECONE_VALUES_DIMS, + metric: str = PINECONE_METRIC, + create_index: bool = False, + log_index_stats: bool = False, + ): + self.index_name = index_name + self.values_dims = values_dims + self.metric = metric + + pinecone.init( + api_key=PINECONE_API_KEY, + environment=PINECONE_ENVIRONMENT, + ) + + if create_index: + self.create_index() + + self.index = pinecone.Index(index_name=self.index_name) + + if log_index_stats: + index_stats_response = self.index.describe_index_stats() + logger.info(f"{self.index_name}:\n{index_stats_response}") + + def upsert_entry(self, pinecone_entry: PineconeEntry, upsert_size: int = 100): + vectors = pinecone_entry.create_pinecone_vectors() + self.index.upsert(vectors=vectors, batch_size=upsert_size, namespace=PINECONE_NAMESPACE) + + def query_vector( + self, + query: List[float], + top_k: int = 10, + include_values: bool = False, + include_metadata: bool = True, + **kwargs, + ) -> List[ScoredVector]: + assert not isinstance( + query, str + ), "query must be a list of floats. Use query_PineconeDB_text for text queries" + + query_response = self.index.query( + vector=query, + top_k=top_k, + include_values=include_values, + include_metadata=include_metadata, + **kwargs, + namespace=PINECONE_NAMESPACE, + ) + + return [ + ScoredVector( + id=match["id"], + score=match["score"], + metadata=PineconeMetadata(**match["metadata"]), + ) + for match in query_response["matches"] + ] + + def query_text( + self, + query: str, + top_k: int = 10, + include_values: bool = False, + include_metadata: bool = True, + **kwargs, + ) -> List[ScoredVector]: + query_vector = get_embedding(query)[0] + return self.query_vector( + query=query_vector, + top_k=top_k, + include_values=include_values, + include_metadata=include_metadata, + **kwargs, + ) + + def delete_entries(self, ids): + self.index.delete(filter={"hash_id": {"$in": ids}}) + + def create_index(self, replace_current_index: bool = True): + if replace_current_index: + self.delete_index() + + pinecone.create_index( + name=self.index_name, + dimension=self.values_dims, + metric=self.metric, + metadata_config={"indexed": list(PineconeMetadata.__annotations__.keys())}, + ) + + def delete_index(self): + if self.index_name in pinecone.list_indexes(): + logger.info(f"Deleting index '{self.index_name}'.") + pinecone.delete_index(self.index_name) + + def get_embeddings_by_ids(self, ids: List[str]) -> List[Tuple[str, List[float] | None]]: + """ + Fetch embeddings for given entry IDs from Pinecone. + + Args: + - ids (List[str]): List of entry IDs for which embeddings are to be fetched. + + Returns: + - List[Tuple[str, List[float] | None]]: List of tuples containing ID and its corresponding embedding. + """ + # TODO: check that this still works + vectors = self.index.fetch( + ids=ids, + namespace=PINECONE_NAMESPACE, + )["vectors"] + return [(id, vectors.get(id, {}).get("values", None)) for id in ids] + + +def strip_block(text: str) -> str: + return "\n".join(text.split("\n")[1:]) diff --git a/align_data/embeddings/pinecone/pinecone_models.py b/align_data/embeddings/pinecone/pinecone_models.py new file mode 100644 index 00000000..fd7b67eb --- /dev/null +++ b/align_data/embeddings/pinecone/pinecone_models.py @@ -0,0 +1,77 @@ +from typing import List, TypedDict + +from pydantic import BaseModel, validator +from pinecone.core.client.models import Vector + + +class MissingFieldsError(Exception): + pass + + +class MissingEmbeddingModelError(Exception): + pass + + +class PineconeMetadata(TypedDict): + hash_id: str + source: str + title: str + url: str + date_published: float + authors: List[str] + text: str + + +class PineconeEntry(BaseModel): + hash_id: str + source: str + title: str + url: str + date_published: float + authors: List[str] + text_chunks: List[str] + embeddings: List[List[float]] + + def __init__(self, **data): + """Check for missing (falsy) fields before initializing.""" + missing_fields = [field for field, value in data.items() if not str(value).strip()] + + if missing_fields: + raise MissingFieldsError(f"Missing fields: {missing_fields}") + + super().__init__(**data) + + def __repr__(self): + def make_small(chunk: str) -> str: + return (chunk[:45] + " [...] " + chunk[-45:]) if len(chunk) > 100 else chunk + + def display_chunks(chunks_lst: List[str]) -> str: + chunks = ", ".join(f'"{make_small(chunk)}"' for chunk in chunks_lst) + return ( + f"[{chunks[:450]} [...] {chunks[-450:]} ]" if len(chunks) > 1000 else f"[{chunks}]" + ) + + return f"PineconeEntry(hash_id={self.hash_id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={display_chunks(self.text_chunks)})" + + @property + def chunk_num(self) -> int: + return len(self.text_chunks) + + def create_pinecone_vectors(self) -> List[Vector]: + return [ + Vector( + id=f"{self.hash_id}_{str(i).zfill(6)}", + values=self.embeddings[i], + metadata=PineconeMetadata( + hash_id=self.hash_id, + source=self.source, + title=self.title, + authors=self.authors, + url=self.url, + date_published=self.date_published, + text=self.text_chunks[i], + ), + ) + for i in range(self.chunk_num) + if self.embeddings[i] # Skips flagged chunks + ] diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py new file mode 100644 index 00000000..b425ee9d --- /dev/null +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -0,0 +1,129 @@ +from datetime import datetime +import logging +from itertools import islice +from typing import Callable, List, Tuple, Generator, Iterator, Optional + +from sqlalchemy.orm import Session +from pydantic import ValidationError + +from align_data.embeddings.embedding_utils import get_embeddings +from align_data.db.models import Article +from align_data.db.session import make_session, stream_pinecone_updates +from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB +from align_data.embeddings.pinecone.pinecone_models import ( + PineconeEntry, MissingFieldsError, MissingEmbeddingModelError +) +from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter + + +logger = logging.getLogger(__name__) + + +# Define type aliases for the Callables +LengthFunctionType = Callable[[str], int] +TruncateFunctionType = Callable[[str, int], str] + + +class PineconeUpdater: + def __init__(self): + self.text_splitter = ParagraphSentenceUnitTextSplitter() + self.pinecone_db = PineconeDB() + + def update(self, custom_sources: List[str], force_update: bool = False): + """ + Update the given sources. If no sources are provided, updates all sources. + + :param custom_sources: List of sources to update. + """ + with make_session() as session: + articles_to_update_stream = stream_pinecone_updates( + session, custom_sources, force_update + ) + for batch in self.batch_entries(articles_to_update_stream): + self.save_batch(session, batch) + + def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]): + try: + for article, pinecone_entry in batch: + self.pinecone_db.upsert_entry(pinecone_entry) + + article.pinecone_update_required = False + session.add(article) + + session.commit() + + except Exception as e: + # Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking + logger.error(e) + session.rollback() + + def batch_entries( + self, article_stream: Generator[Article, None, None] + ) -> Iterator[List[Tuple[Article, PineconeEntry]]]: + while batch := tuple(islice(article_stream, 10)): + yield [ + (article, pinecone_entry) + for article in batch + if (pinecone_entry := self._make_pinecone_entry(article)) is not None + ] + + def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: + try: + text_chunks = get_text_chunks(article, self.text_splitter) + embeddings, moderation_results = get_embeddings(text_chunks, article.source) + + if any(result['flagged'] for result in moderation_results): + flagged_text_chunks = [f"Chunk {i}: \"{text}\"" for i, (text, result) in enumerate(zip(text_chunks, moderation_results)) if result["flagged"]] + logger.warning(f"OpenAI moderation flagged text chunks for the following article: {article.id}") + article.append_comment(f"OpenAI moderation flagged the following text chunks: {flagged_text_chunks}") + + return PineconeEntry( + hash_id=article.id, # the hash_id of the article + source=article.source, + title=article.title, + url=article.url, + date_published=article.date_published.timestamp(), + authors=[author.strip() for author in article.authors.split(",") if author.strip()], + text_chunks=text_chunks, + embeddings=embeddings, + ) + except (ValueError, TypeError, AttributeError, ValidationError, MissingFieldsError, MissingEmbeddingModelError) as e: + logger.warning(e) + article.append_comment(f"Error encountered while processing this article: {e}") + return None + + except Exception as e: + logger.error(e) + raise + + +def get_text_chunks( + article: Article, text_splitter: ParagraphSentenceUnitTextSplitter +) -> List[str]: + title = article.title.replace("\n", " ") + + authors_lst = [author.strip() for author in article.authors.split(",")] + authors = get_authors_str(authors_lst) + + signature = f"Title: {title}; Author(s): {authors}." + text_chunks = text_splitter.split_text(article.text) + return [f'###{signature}###\n"""{text_chunk}"""' for text_chunk in text_chunks] + + +def get_authors_str(authors_lst: List[str]) -> str: + if not authors_lst: + return "n/a" + + if len(authors_lst) == 1: + authors_str = authors_lst[0] + else: + authors_lst = authors_lst[:4] + authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}" + + authors_str = authors_str.replace("\n", " ") + + # Truncate if necessary + if len(authors_str) > 500: + authors_str = authors_str[:497] + "..." + + return authors_str diff --git a/align_data/pinecone/text_splitter.py b/align_data/embeddings/text_splitter.py similarity index 89% rename from align_data/pinecone/text_splitter.py rename to align_data/embeddings/text_splitter.py index b8af09a3..a364415e 100644 --- a/align_data/pinecone/text_splitter.py +++ b/align_data/embeddings/text_splitter.py @@ -1,5 +1,3 @@ -# dataset/text_splitter.py - from typing import List, Callable, Any from langchain.text_splitter import TextSplitter from nltk.tokenize import sent_tokenize @@ -11,8 +9,6 @@ StrToIntFunction = Callable[[str], int] StrIntBoolToStrFunction = Callable[[str, int, bool], str] -def default_truncate_function(string: str, length: int, from_end: bool = False) -> str: - return string[-length:] if from_end else string[:length] def default_truncate_function(string: str, length: int, from_end: bool = False) -> str: return string[-length:] if from_end else string[:length] @@ -26,7 +22,7 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter): @param length_function: A function that returns the length of a string in units. Defaults to len(). @param truncate_function: A function that truncates a string to a given unit length. """ - + DEFAULT_MIN_CHUNK_SIZE: int = 900 DEFAULT_MAX_CHUNK_SIZE: int = 1100 DEFAULT_LENGTH_FUNCTION: StrToIntFunction = len @@ -38,7 +34,7 @@ def __init__( max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, length_function: StrToIntFunction = DEFAULT_LENGTH_FUNCTION, truncate_function: StrIntBoolToStrFunction = DEFAULT_TRUNCATE_FUNCTION, - **kwargs: Any + **kwargs: Any, ): super().__init__(**kwargs) self.min_chunk_size = min_chunk_size @@ -49,6 +45,9 @@ def __init__( def split_text(self, text: str) -> List[str]: """Split text into chunks of length between min_chunk_size and max_chunk_size.""" + if not text: + return [] + blocks: List[str] = [] current_block: str = "" @@ -90,13 +89,11 @@ def _handle_large_paragraph(self, current_block: str, blocks: List[str], paragra def _truncate_large_block(self, current_block: str, blocks: List[str]) -> str: while self._length_function(current_block) > self.max_chunk_size: # Truncate current_block to max size, set remaining text as current_block - truncated_block = self._truncate_function( - current_block, self.max_chunk_size - ) + truncated_block = self._truncate_function(current_block, self.max_chunk_size, False) blocks.append(truncated_block) - current_block = current_block[len(truncated_block):].lstrip() - + current_block = current_block[len(truncated_block) :].lstrip() + return current_block def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str]: @@ -107,9 +104,7 @@ def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str if self.min_chunk_size - len_last_block > 0: # Add text from previous block to last block if last_block is too short part_prev_block = self._truncate_function( - string=blocks[-1], - length=self.min_chunk_size - len_last_block, - from_end=True + blocks[-1], self.min_chunk_size - len_last_block, True ) last_block = part_prev_block + last_block diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py deleted file mode 100644 index d8f565df..00000000 --- a/align_data/pinecone/pinecone_db_handler.py +++ /dev/null @@ -1,91 +0,0 @@ -# dataset/pinecone_db_handler.py - -import logging -from typing import Dict - -import pinecone - -from align_data.settings import ( - PINECONE_INDEX_NAME, - PINECONE_VALUES_DIMS, - PINECONE_METRIC, - PINECONE_METADATA_KEYS, - PINECONE_API_KEY, - PINECONE_ENVIRONMENT, -) - - -logger = logging.getLogger(__name__) - - -class PineconeDB: - def __init__( - self, - index_name: str = PINECONE_INDEX_NAME, - values_dims: int = PINECONE_VALUES_DIMS, - metric: str = PINECONE_METRIC, - metadata_keys: list = PINECONE_METADATA_KEYS, - create_index: bool = False, - log_index_stats: bool = True, - ): - self.index_name = index_name - self.values_dims = values_dims - self.metric = metric - self.metadata_keys = metadata_keys - - pinecone.init( - api_key=PINECONE_API_KEY, - environment=PINECONE_ENVIRONMENT, - ) - - if create_index: - self.create_index() - - self.index = pinecone.Index(index_name=self.index_name) - - if log_index_stats: - index_stats_response = self.index.describe_index_stats() - logger.info(f"{self.index_name}:\n{index_stats_response}") - - def upsert_entry(self, entry: Dict, upsert_size=100): - self.index.upsert( - vectors=list( - zip( - [ - f"{entry['id']}_{str(i).zfill(6)}" - for i in range(len(entry["text_chunks"])) - ], - entry["embeddings"].tolist(), - [ - { - "entry_id": entry["id"], - "source": entry["source"], - "title": entry["title"], - "authors": entry["authors"], - "text": text_chunk, - } - for text_chunk in entry["text_chunks"] - ], - ) - ), - batch_size=upsert_size, - ) - - def delete_entries(self, ids): - self.index.delete(filter={"entry_id": {"$in": ids}}) - - def create_index(self, replace_current_index: bool = True): - if replace_current_index: - self.delete_index() - - pinecone.create_index( - name=self.index_name, - dimension=self.values_dims, - metric=self.metric, - metadata_config={"indexed": self.metadata_keys}, - ) - - def delete_index(self): - if self.index_name in pinecone.list_indexes(): - logger.info(f"Deleting index '{self.index_name}'.") - pinecone.delete_index(self.index_name) diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py deleted file mode 100644 index 5821a276..00000000 --- a/align_data/pinecone/update_pinecone.py +++ /dev/null @@ -1,192 +0,0 @@ -from datetime import datetime -import logging -import numpy as np -import os -from itertools import islice -from typing import Callable, List, Tuple, Generator - -import openai -from pydantic import BaseModel, ValidationError, validator - -from align_data.db.models import Article -from align_data.db.session import make_session, stream_pinecone_updates -from align_data.pinecone.pinecone_db_handler import PineconeDB -from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter -from align_data.settings import ( - USE_OPENAI_EMBEDDINGS, - OPENAI_EMBEDDINGS_MODEL, - OPENAI_EMBEDDINGS_DIMS, - OPENAI_EMBEDDINGS_RATE_LIMIT, - SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, - SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, - CHUNK_SIZE, - MAX_NUM_AUTHORS_IN_SIGNATURE, - EMBEDDING_LENGTH_BIAS, -) - -logger = logging.getLogger(__name__) - - -# Define type aliases for the Callables -LengthFunctionType = Callable[[str], int] -TruncateFunctionType = Callable[[str, int], str] - - -class PineconeEntry(BaseModel): - id: str - source: str - title: str - url: str - date_published: datetime - authors: List[str] - text_chunks: List[str] - embeddings: np.ndarray - - class Config: - arbitrary_types_allowed = True - - def __repr__(self): - return f"PineconeEntry(id={self.id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={self.text_chunks[:5]!r})" - - @validator( - "id", - "source", - "title", - "url", - "date_published", - "authors", - "text_chunks", - pre=True, - always=True, - ) - def empty_strings_not_allowed(cls, value): - if not str(value).strip(): - raise ValueError("Attribute should not be empty.") - return value - - -class PineconeUpdater: - def __init__( - self, - min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE, - max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE, - length_function: LengthFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION, - truncate_function: TruncateFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION, - ): - self.min_chunk_size = min_chunk_size - self.max_chunk_size = max_chunk_size - self.length_function = length_function - self.truncate_function = truncate_function - - self.text_splitter = ParagraphSentenceUnitTextSplitter( - min_chunk_size=self.min_chunk_size, - max_chunk_size=self.max_chunk_size, - length_function=self.length_function, - truncate_function=self.truncate_function, - ) - self.pinecone_db = PineconeDB() - - if USE_OPENAI_EMBEDDINGS: - import openai - - openai.api_key = os.environ["OPENAI_API_KEY"] - else: - import torch - from langchain.embeddings import HuggingFaceEmbeddings - - self.hf_embeddings = HuggingFaceEmbeddings( - model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, - model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}, - encode_kwargs={"show_progress_bar": False}, - ) - - def save_batch(self, session, batch): - try: - for article, pinecone_entry in batch: - self.pinecone_db.upsert_entry(pinecone_entry.dict()) - article.pinecone_update_required = False - session.add(article) - session.commit() - except Exception as e: - # Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking - logger.error(e) - session.rollback() - - def update(self, custom_sources: List[str]): - """ - Update the given sources. If no sources are provided, updates all sources. - - :param custom_sources: List of sources to update. - """ - with make_session() as session: - entries_stream = stream_pinecone_updates(session, custom_sources) - for batch in self.batch_entries(entries_stream): - self.save_batch(session, batch) - - def _make_pinecone_update(self, article): - try: - text_chunks = self.get_text_chunks(article) - return article, PineconeEntry( - id=article.id, - source=article.source, - title=article.title, - url=article.url, - date_published=article.date_published, - authors=[ - author.strip() - for author in article.authors.split(",") - if author.strip() - ], - text_chunks=text_chunks, - embeddings=self.extract_embeddings( - text_chunks, [article.source] * len(text_chunks) - ), - ) - except (ValueError, ValidationError) as e: - logger.exception(e) - - def batch_entries( - self, article_stream: Generator[Article, None, None] - ) -> Generator[List[Tuple[Article, PineconeEntry]], None, None]: - items = iter(article_stream) - while batch := tuple(islice(items, 10)): - yield list(filter(None, map(self._make_pinecone_update, batch))) - - def get_text_chunks(self, article: Article) -> List[str]: - signature = f"Title: {article.title}, Author(s): {self.get_authors_str(article.authors)}" - text_chunks = self.text_splitter.split_text(article.text) - text_chunks = [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks] - return text_chunks - - def extract_embeddings(self, chunks_batch, sources_batch): - if USE_OPENAI_EMBEDDINGS: - return self.get_openai_embeddings(chunks_batch, sources_batch) - else: - return np.array( - self.hf_embeddings.embed_documents(chunks_batch, sources_batch) - ) - - @staticmethod - def get_openai_embeddings(chunks, sources=""): - embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS)) - - openai_output = openai.Embedding.create( - model=OPENAI_EMBEDDINGS_MODEL, input=chunks - )["data"] - - for i, (embedding, source) in enumerate(zip(openai_output, sources)): - bias = EMBEDDING_LENGTH_BIAS.get(source, 1.0) - embeddings[i] = bias * np.array(embedding["embedding"]) - - return embeddings - - @staticmethod - def get_authors_str(authors_lst: List[str]) -> str: - if authors_lst == []: - return "n/a" - if len(authors_lst) == 1: - return authors_lst[0] - else: - authors_lst = authors_lst[:MAX_NUM_AUTHORS_IN_SIGNATURE] - authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}" - return authors_str diff --git a/align_data/settings.py b/align_data/settings.py index d88bcca3..59304c08 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -1,5 +1,8 @@ import os import logging +from typing import Dict +import openai +import torch from dotenv import load_dotenv load_dotenv() @@ -24,9 +27,12 @@ "METADATA_OUTPUT_SPREADSHEET", "1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4" ) -### YouTube ### +### YOUTUBE ### YOUTUBE_API_KEY = os.environ.get("YOUTUBE_API_KEY") +### Airtable ### +AIRTABLE_API_KEY = os.environ.get("AIRTABLE_API_KEY") + ### MYSQL ### user = os.environ.get("ARD_DB_USER", "user") password = os.environ.get("ARD_DB_PASSWORD", "we all live in a yellow submarine") @@ -38,13 +44,16 @@ ### EMBEDDINGS ### USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used. -EMBEDDING_LENGTH_BIAS = { - "aisafety.info": 1.05, # In search, favor AISafety.info entries. +EMBEDDING_LENGTH_BIAS: Dict[str, float] = { + # TODO: Experiement with these values. For now, let's remove the bias. + # "aisafety.info": 1.05, # In search, favor AISafety.info entries. } OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002" OPENAI_EMBEDDINGS_DIMS = 1536 OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 +openai.api_key = os.environ.get("OPENAI_API_KEY", None) +openai.organization = os.environ.get("OPENAI_ORGANIZATION", None) SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768 @@ -54,13 +63,20 @@ PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) PINECONE_VALUES_DIMS = ( - OPENAI_EMBEDDINGS_DIMS - if USE_OPENAI_EMBEDDINGS - else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS + OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS ) PINECONE_METRIC = "dotproduct" -PINECONE_METADATA_KEYS = ["entry_id", "source", "title", "authors", "text", "url"] +PINECONE_NAMESPACE = os.environ.get("PINECONE_NAMESPACE", "normal") # "normal" or "finetuned" + +### FINE-TUNING ### +OPENAI_FINETUNED_LAYER_PATH = os.environ.get( + "OPENAI_FINETUNED_LAYER_PATH", "align_data/finetuning/data/finetuned_model.pth" +) +OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH = os.environ.get( + "OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH", + "align_data/finetuning/data/best_finetuned_model.pth", +) ### MISCELLANEOUS ### -CHUNK_SIZE = 1750 -MAX_NUM_AUTHORS_IN_SIGNATURE = 3 +MIN_CONFIDENCE = 50 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/align_data/sources/agisf/__init__.py b/align_data/sources/agisf/__init__.py new file mode 100644 index 00000000..0fbf757a --- /dev/null +++ b/align_data/sources/agisf/__init__.py @@ -0,0 +1,36 @@ +from align_data.common.alignment_dataset import MultiDataset +from align_data.sources.airtable import AirtableDataset +from align_data.sources.agisf.agisf import AGISFPodcastDataset + + +datasets = [ + AirtableDataset( + name='agisf_governance', + base_id='app9q0E0jlDWlsR0z', + table_id='tblgTb3kszvSbo2Mb', + mappings={ + 'title': '[>] Resource', + 'url': '[h] [>] Link', + 'source_type': '[h] [>] Type', + 'summary': '[h] Resource guide', + 'authors': 'Author(s) (from Resources)', + }, + processors = { + 'source_type': lambda val: val[0] if val else None, + 'authors': lambda val: val and [v.strip() for v in val.split(',')] + } + ), + AGISFPodcastDataset( + name='agisf_readings_alignment', + url='https://feeds.type3.audio/agi-safety-fundamentals--alignment.rss', + ), + AGISFPodcastDataset( + name='agisf_readings_governance', + url='https://feeds.type3.audio/agi-safety-fundamentals--governance.rss', + ), +] + + +AGISF_DATASETS = [ + MultiDataset(name='agisf', datasets=datasets), +] diff --git a/align_data/sources/agisf/agisf.py b/align_data/sources/agisf/agisf.py new file mode 100644 index 00000000..e56a60e4 --- /dev/null +++ b/align_data/sources/agisf/agisf.py @@ -0,0 +1,54 @@ +import re +from typing import Any, Dict +from bs4 import BeautifulSoup + +from align_data.common.html_dataset import RSSDataset +from align_data.sources.articles.parsers import item_metadata +from align_data.sources.utils import merge_dicts + + +class AGISFPodcastDataset(RSSDataset): + + regex = re.compile(r'^\[Week .*?\]\s+“(?P.*?)”\s+by\s+(?P<authors>.*?)$') + + @property + def feed_url(self): + return self.url + + def fetch_contents(self, url: str) -> Dict[str, Any]: + contents = super().fetch_contents(url) + if extracted := self.regex.search(contents.get('title')): + return merge_dicts(contents, extracted.groupdict()) + return contents + + def _get_text(self, item): + contents = item_metadata(item['link']) + # Replace any non empty values in item. `item.update()` will happily insert Nones + for k, v in contents.items(): + if v: + item[k] = v + return item.get('text') + + def extract_authors(self, item): + authors = item.get("authors") + if not authors: + return self.authors + if isinstance(authors, str): + return [a.strip() for a in authors.split(',')] + return authors + + def _extra_values(self, contents): + if summary := contents.get('summary'): + soup = BeautifulSoup(summary, "html.parser") + for el in soup.select('b'): + if el.next_sibling: + el.next_sibling.extract() + el.extract() + return {'summary': self._extract_markdown(soup)} + return {} + + def process_entry(self, article): + article_url = self.get_item_key(article) + contents = self.get_contents(article_url) + + return self.make_data_entry(contents) diff --git a/align_data/sources/airtable.py b/align_data/sources/airtable.py new file mode 100644 index 00000000..db24350e --- /dev/null +++ b/align_data/sources/airtable.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, Optional, Union + +from airtable import airtable +from align_data.db.models import Article + +from align_data.settings import AIRTABLE_API_KEY, ARTICLE_MAIN_KEYS +from align_data.common.alignment_dataset import AlignmentDataset +from align_data.sources.articles.parsers import item_metadata +from align_data.sources.utils import merge_dicts + + +@dataclass +class AirtableDataset(AlignmentDataset): + + base_id: str + table_id: str + mappings: Dict[str, str] + processors: Dict[str, Callable[[Any], str]] + done_key = 'url' + + def setup(self): + if not AIRTABLE_API_KEY: + raise ValueError('No AIRTABLE_API_KEY provided!') + super().setup() + self.at = airtable.Airtable(self.base_id, AIRTABLE_API_KEY) + + def map_cols(self, item: Dict[str, Dict[str, str]]) -> Optional[Dict[str, Optional[str]]]: + fields = item.get('fields', {}) + def map_col(k): + val = fields.get(self.mappings.get(k) or k) + if processor := self.processors.get(k): + val = processor(val) + return val + + mapped = {k: map_col(k) for k in ARTICLE_MAIN_KEYS + ['summary']} + if (mapped.get('url') or '').startswith('http'): + return mapped + + def get_item_key(self, item): + return item.get('url') + + @property + def items_list(self) -> Iterable[Dict[str, Union[str, None]]]: + return filter(None, map(self.map_cols, self.at.iterate(self.table_id))) + + def process_entry(self, entry) -> Optional[Article]: + contents = item_metadata(self.get_item_key(entry)) + if not contents: + return None + + entry['date_published'] = self._get_published_date(entry.get('date_published')) + return self.make_data_entry(merge_dicts(entry, contents), source=self.name) diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py index da7f3a6b..6fd45fbc 100644 --- a/align_data/sources/articles/__init__.py +++ b/align_data/sources/articles/__init__.py @@ -1,10 +1,18 @@ from align_data.sources.articles.datasets import ( - ArxivPapers, EbookArticles, DocArticles, HTMLArticles, - MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles + ArxivPapers, + EbookArticles, + DocArticles, + HTMLArticles, + MarkdownArticles, + PDFArticles, + SpecialDocs, + XMLArticles, ) from align_data.sources.articles.indices import IndicesDataset +from align_data.common.alignment_dataset import MultiDataset -ARTICLES_REGISTRY = [ + +ARTICLES_DATASETS = [ PDFArticles( name="pdfs", spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4", @@ -36,14 +44,19 @@ sheet_id="1293295703", ), SpecialDocs( - 'special_docs', - spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI', - sheet_id='980957638', + "special_docs", + spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", + sheet_id="980957638", ), +] + + +ARTICLES_REGISTRY = [ + MultiDataset(name='special_docs', datasets=ARTICLES_DATASETS), ArxivPapers( name="arxiv", spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", sheet_id="655836697", ), - IndicesDataset('indices'), + IndicesDataset("indices"), ] diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py index 3a210ead..a6a343f6 100644 --- a/align_data/sources/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -146,8 +146,7 @@ def check_new_articles(source_spreadsheet_id: str, source_sheet_name: str): missing = [ item for title, item in indices_items.items() - if title not in current - and not {item.get("url"), item.get("source_url")} & seen_urls + if title not in current and not {item.get("url"), item.get("source_url")} & seen_urls ] if not missing: logger.info("No new articles found") @@ -163,14 +162,12 @@ def check_new_articles(source_spreadsheet_id: str, source_sheet_name: str): "publication_title", "source_type", ] - res = source_sheet.append_rows( - [[item.get(col) for col in columns] for item in missing] - ) + res = source_sheet.append_rows([[item.get(col) for col in columns] for item in missing]) updated = res["updates"]["updatedRows"] logger.info("Added %s rows", updated) return updated def update_articles(csv_file, delimiter): - dataset = ReplacerDataset(name='updater', csv_path=csv_file, delimiter=delimiter) + dataset = ReplacerDataset(name="updater", csv_path=csv_file, delimiter=delimiter) dataset.add_entries(dataset.fetch_entries()) diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index 95d18e17..d0fae22a 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -15,10 +15,16 @@ from align_data.db.models import Article from align_data.sources.articles.google_cloud import fetch_file, fetch_markdown from align_data.sources.articles.parsers import ( - HTML_PARSERS, extract_gdrive_contents, item_metadata, parse_domain + HTML_PARSERS, + extract_gdrive_contents, + item_metadata, + parse_domain, ) from align_data.sources.articles.pdf import read_pdf -from align_data.sources.arxiv_papers import fetch_arxiv, canonical_url as arxiv_canonical_url +from align_data.sources.arxiv_papers import ( + fetch_arxiv, + canonical_url as arxiv_cannonical_url, +) logger = logging.getLogger(__name__) @@ -43,8 +49,8 @@ def get_item_key(self, item): @property def items_list(self): - url = f'https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}' - logger.info(f'Fetching {url}') + url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}" + logger.info(f"Fetching {url}") df = pd.read_csv(url) return ( item @@ -88,7 +94,6 @@ def process_entry(self, item): class SpecialDocs(SpreadsheetDataset): - @property def _query_items(self) -> Select[Tuple[Article]]: special_docs_types = ["pdf", "html", "xml", "markdown", "docx"] @@ -99,35 +104,39 @@ def get_contents(self, item) -> Dict: if url := self.maybe(item, "source_url") or self.maybe(item, "url"): contents = item_metadata(url) - return dict(contents, **{ - 'url': self.maybe(item, "url"), - 'title': self.maybe(item, "title") or contents.get('title'), - 'source': contents.get('source_type') or self.name, - 'source_url': self.maybe(item, "source_url"), - 'source_type': contents.get('source_type') or self.maybe(item, "source_type"), - 'date_published': self._get_published_date(self.maybe(item, 'date_published')) or contents.get('date_published'), - 'authors': self.extract_authors(item) or contents.get('authors', []), - 'text': contents.get('text'), - 'status': 'Invalid' if contents.get('error') else None, - 'comments': contents.get('error'), - }) + return dict( + contents, + **{ + "url": self.maybe(item, "url"), + "title": self.maybe(item, "title") or contents.get("title"), + "source": contents.get("source_type") or self.name, + "source_url": self.maybe(item, "source_url"), + "source_type": contents.get("source_type") or self.maybe(item, "source_type"), + "date_published": self._get_published_date(self.maybe(item, "date_published")) + or contents.get("date_published"), + "authors": self.extract_authors(item) or contents.get("authors", []), + "text": contents.get("text"), + "status": "Invalid" if contents.get("error") else None, + "comments": contents.get("error"), + }, + ) def not_processed(self, item): - url = self.maybe(item, 'url') - source_url = self.maybe(item, 'source_url') + url = self.maybe(item, "url") + source_url = self.maybe(item, "source_url") return ( - self.get_item_key(item) not in self._outputted_items and - url not in self._outputted_items and - source_url not in self._outputted_items and - (not url or arxiv_canonical_url(url) not in self._outputted_items) and - (not source_url or arxiv_canonical_url(source_url) not in self._outputted_items) + self.get_item_key(item) not in self._outputted_items + and url not in self._outputted_items + and source_url not in self._outputted_items + and (not url or arxiv_cannonical_url(url) not in self._outputted_items) + and (not source_url or arxiv_cannonical_url(source_url) not in self._outputted_items) ) def process_entry(self, item): if ArxivPapers.is_arxiv(item.url): contents = ArxivPapers.get_contents(item) - contents['source'] = 'arxiv' + contents["source"] = "arxiv" else: contents = self.get_contents(item) @@ -156,7 +165,7 @@ def _get_text(item): domain = parse_domain(item.source_url) if parser := HTML_PARSERS.get(domain): res = parser(item.source_url) - return res and res.get('text') + return res and res.get("text") class EbookArticles(SpreadsheetDataset): @@ -170,9 +179,7 @@ def setup(self): def _get_text(self, item): file_id = item.source_url.split("/")[-2] - filename = download( - output=str(self.files_path / f"{item.title}.epub"), id=file_id - ) + filename = download(output=str(self.files_path / f"{item.title}.epub"), id=file_id) return convert_file(filename, "plain", "epub", extra_args=["--wrap=none"]) @@ -181,7 +188,7 @@ class XMLArticles(SpreadsheetDataset): def _get_text(self, item): vals = extract_gdrive_contents(item.source_url) - return vals["text"] + return vals.get("text") class MarkdownArticles(SpreadsheetDataset): @@ -190,7 +197,7 @@ class MarkdownArticles(SpreadsheetDataset): def _get_text(self, item): file_id = item.source_url.split("/")[-2] vals = fetch_markdown(file_id) - return vals["text"] + return vals.get("text") class DocArticles(SpreadsheetDataset): @@ -223,12 +230,12 @@ def get_contents(cls, item) -> Dict: contents = fetch_arxiv(item.url or item.source_url) if cls.maybe(item, "authors") and item.authors.strip(): - contents['authors'] = [i.strip() for i in item.authors.split(',')] + contents["authors"] = [i.strip() for i in item.authors.split(",")] if cls.maybe(item, "title"): - contents['title'] = cls.maybe(item, "title") + contents["title"] = cls.maybe(item, "title") - contents['date_published'] = cls._get_published_date( - cls.maybe(item, "date_published") or contents.get('date_published') + contents["date_published"] = cls._get_published_date( + cls.maybe(item, "date_published") or contents.get("date_published") ) return contents diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py index 23385b7b..ee6f2ddc 100644 --- a/align_data/sources/articles/google_cloud.py +++ b/align_data/sources/articles/google_cloud.py @@ -158,19 +158,19 @@ def fetch_markdown(file_id: str) -> Dict[str, str]: "source_type": "markdown", } except Exception as e: - return {'error': str(e)} + return {"error": str(e)} def parse_grobid(contents: str | bytes) -> Dict[str, Any]: if isinstance(contents, bytes): contents = contents.decode('utf-8') doc_dict = grobid_tei_xml.parse_document_xml(contents).to_dict() - authors: List[str] = [author["full_name"].strip(' !') for author in doc_dict.get("header", {}).get("authors", [])] + authors: List[str] = [author["full_name"].strip(" !") for author in doc_dict.get("header", {}).get("authors", [])] - if not doc_dict.get('body'): + if not doc_dict.get("body"): return { - 'error': 'No contents in XML file', - 'source_type': 'xml', + "error": "No contents in XML file", + "source_type": "xml", } return { @@ -183,21 +183,21 @@ def parse_grobid(contents: str | bytes) -> Dict[str, Any]: def get_content_type(res: requests.Response) -> Set[str]: - header = res.headers.get('Content-Type') or '' - parts = [c_type.strip().lower() for c_type in header.split(';')] + header = res.headers.get("Content-Type") or "" + parts = [c_type.strip().lower() for c_type in header.split(";")] return set(filter(None, parts)) def extract_gdrive_contents(link: str) -> Dict[str, Any]: - file_id = link.split('/')[-2] - url = f'https://drive.google.com/uc?id={file_id}' - res = fetch(url, 'head') + file_id = link.split("/")[-2] + url = f"https://drive.google.com/uc?id={file_id}" + res = fetch(url, "head") if res.status_code == 403: - logger.error('Could not fetch the file at %s - 403 returned', link) - return {'error': 'Could not read file from google drive - forbidden'} + logger.error("Could not fetch the file at %s - 403 returned", link) + return {"error": "Could not read file from google drive - forbidden"} if res.status_code >= 400: - logger.error('Could not fetch the file at %s - are you sure that link is correct?', link) - return {'error': 'Could not read file from google drive'} + logger.error("Could not fetch the file at %s - are you sure that link is correct?", link) + return {"error": "Could not read file from google drive"} result: Dict[str, Any] = { 'source_url': link, @@ -206,16 +206,16 @@ def extract_gdrive_contents(link: str) -> Dict[str, Any]: content_type = get_content_type(res) if not content_type: - result['error'] = 'no content type' - elif content_type & {'application/octet-stream', 'application/pdf'}: + result["error"] = "no content type" + elif content_type & {"application/octet-stream", "application/pdf"}: result.update(fetch_pdf(url)) - elif content_type & {'text/markdown'}: + elif content_type & {"text/markdown"}: result.update(fetch_markdown(file_id)) - elif content_type & {'application/epub+zip', 'application/epub'}: - result['source_type'] = 'ebook' - elif content_type & {'text/html'}: + elif content_type & {"application/epub+zip", "application/epub"}: + result["source_type"] = "ebook" + elif content_type & {"text/html"}: res = fetch(url) - if 'Google Drive - Virus scan warning' in res.text: + if "Google Drive - Virus scan warning" in res.text: soup = BeautifulSoup(res.content, "html.parser") form_tag = soup.select_one('form') @@ -229,30 +229,35 @@ def extract_gdrive_contents(link: str) -> Dict[str, Any]: res = fetch(form_action_url) content_type = get_content_type(res) - if content_type & {'text/xml'}: + if content_type & {"text/xml"}: result.update(parse_grobid(res.content)) - elif content_type & {'text/html'}: + elif content_type & {"text/html"}: soup = BeautifulSoup(res.content, "html.parser") - result.update({ - 'text': MarkdownConverter().convert_soup(soup.select_one('body')).strip(), - 'source_type': 'html', - }) + result.update( + { + "text": MarkdownConverter().convert_soup(soup.select_one("body")).strip(), + "source_type": "html", + } + ) else: - result['error'] = f'unknown content type: {content_type}' + result["error"] = f"unknown content type: {content_type}" else: - result['error'] = f'unknown content type: {content_type}' + result["error"] = f"unknown content type: {content_type}" return result def google_doc(url: str) -> Dict[str, Any]: """Fetch the contents of the given gdoc url as markdown.""" - res = re.search(r'https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/', url) + res = re.search(r"https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/", url) if not res: return {} doc_id = res.group(1) - body = fetch_element(f'https://docs.google.com/document/d/{doc_id}/export?format=html', 'body') + body = fetch_element(f"https://docs.google.com/document/d/{doc_id}/export?format=html", "body") if body: - return {'text': MarkdownConverter().convert_soup(body).strip(), 'source_url': url} + return { + "text": MarkdownConverter().convert_soup(body).strip(), + "source_url": url, + } return {} diff --git a/align_data/sources/articles/html.py b/align_data/sources/articles/html.py index 9365bde5..9c1c15a1 100644 --- a/align_data/sources/articles/html.py +++ b/align_data/sources/articles/html.py @@ -75,9 +75,9 @@ def getter(url: str) -> Dict[str, Any]: for e in elem.select(sel): e.extract() return { - 'text': MarkdownConverter().convert_soup(elem).strip(), - 'source_url': url, - 'source_type': 'html', + "text": MarkdownConverter().convert_soup(elem).strip(), + "source_url": url, + "source_type": "html", } return getter diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py index a673257d..b9cc1621 100644 --- a/align_data/sources/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -85,16 +85,6 @@ def format_anthropic(post): } -def format_transformer_circuits(item): - if not item.get("href").startswith("http"): - url = f'https://transformer-circuits.pub/{item.get("href")}' - return { - "title": get_text(item, "h3"), - "url": url, - "source_url": url, - } - - def format_safe_ai(item): return { "title": get_text(item, "h4"), @@ -209,12 +199,6 @@ def fetch_all(): "a", format_anthropic, ), - indice_fetcher( - "https://transformer-circuits.pub/", - "div.toc", - "a", - format_transformer_circuits, - ), indice_fetcher( "https://www.safe.ai/research", "#guiding-principles", @@ -264,48 +248,58 @@ def fetch_all(): class IndicesDataset(AlignmentDataset): - - done_key = 'url' + done_key = "url" @property def items_list(self): return fetch_all().values() def get_item_key(self, item): - return item.get('url') + return item.get("url") @staticmethod def extract_authors(item): - if authors := (item.get('authors') or '').strip(): - return [author.strip() for author in authors.split(',') if author.strip()] + if authors := (item.get("authors") or "").strip(): + return [author.strip() for author in authors.split(",") if author.strip()] return [] def process_entry(self, item): contents = {} - if url := item.get('source_url') or item.get('url'): - contents= item_metadata(url) - - if not contents.get('text'): - logger.error('Could not get text for %s (%s) - %s - skipping for now', item.get('title'), url, contents.get('error')) + url = item.get("source_url") or item.get("url") + if url: + contents = item_metadata(url) + + if not contents.get("text"): + logger.error( + "Could not get text for %s (%s) - %s - skipping for now", + item.get("title"), + url, + contents.get("error"), + ) return None # If the article is not an arxiv paper, just mark it as ignored - if in the future editors # decide it's worth adding, it can be fetched then - if parse_domain(contents.get('source_url') or '') != 'arxiv.org': - return self.make_data_entry({ - 'source': self.name, - 'url': self.get_item_key(item), - 'title': item.get('title'), - 'date_published': self._get_published_date(item.get('date_published')), - 'authors': self.extract_authors(item), - 'status': 'Ignored', - 'comments': 'Added from indices', - }) - - return self.make_data_entry({ - 'source': self.name, - 'url': self.get_item_key(item), - 'title': item.get('title'), - 'date_published': self._get_published_date(item.get('date_published')), - 'authors': self.extract_authors(item), - }, **contents) + if parse_domain(url or "") != "arxiv.org": + return self.make_data_entry( + { + "source": self.name, + "url": self.get_item_key(item), + "title": item.get("title"), + "date_published": self._get_published_date(item.get("date_published")), + "authors": self.extract_authors(item), + "status": "Ignored", + "comments": "Added from indices", + } + ) + + return self.make_data_entry( + { + "source": "arxiv", + "url": contents.get("url") or self.get_item_key(item), + "title": item.get("title"), + "date_published": self._get_published_date(item.get("date_published")), + "authors": self.extract_authors(item), + }, + **contents, + ) diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py index c5fe325a..5dc1da3d 100644 --- a/align_data/sources/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -93,7 +93,7 @@ def error(error_msg: str): def func(url: str) -> Dict[str, Any]: if error_msg: logger.error(error_msg) - return {'error': error_msg, 'source_url': url} + return {"error": error_msg, "source_url": url} return func @@ -160,17 +160,17 @@ def getter(url: str) -> Dict[str, Any]: "mediangroup.org": element_extractor("div.entry-content"), "www.alexirpan.com": element_extractor("article"), "www.incompleteideas.net": element_extractor("body"), - "ai-alignment.com": MediumParser(name='html', url='ai-alignment.com'), + "ai-alignment.com": MediumParser(name="html", url="ai-alignment.com"), "aisrp.org": element_extractor("article"), "bounded-regret.ghost.io": element_extractor("div.post-content"), "carnegieendowment.org": element_extractor( "div.article-body", remove=[".no-print", ".related-pubs"] ), - "casparoesterheld.com": element_extractor( - ".entry-content", remove=["div.sharedaddy"] - ), + "casparoesterheld.com": element_extractor(".entry-content", remove=["div.sharedaddy"]), "cullenokeefe.com": element_extractor("div.sqs-block-content"), - "deepmindsafetyresearch.medium.com": MediumParser(name='html', url='deepmindsafetyresearch.medium.com'), + "deepmindsafetyresearch.medium.com": MediumParser( + name="html", url="deepmindsafetyresearch.medium.com" + ), "docs.google.com": google_doc, "docs.microsoft.com": element_extractor("div.content"), "digichina.stanford.edu": element_extractor("div.h_editor-content"), @@ -185,7 +185,7 @@ def getter(url: str) -> Dict[str, Any]: "link.springer.com": element_extractor("article.c-article-body"), "longtermrisk.org": element_extractor("div.entry-content"), "lukemuehlhauser.com": element_extractor("div.entry-content"), - "medium.com": MediumParser(name='html', url='medium.com'), + "medium.com": MediumParser(name="html", url="medium.com"), "openai.com": element_extractor("#content"), "ought.org": element_extractor("div.BlogPostBodyContainer"), "sideways-view.com": element_extractor("article", remove=["header"]), @@ -200,10 +200,8 @@ def getter(url: str) -> Dict[str, Any]: ), "theconversation.com": element_extractor("div.content-body"), "thegradient.pub": element_extractor("div.c-content"), - "towardsdatascience.com": MediumParser(name='html', url='towardsdatascience.com'), - "unstableontology.com": element_extractor( - ".entry-content", remove=["div.sharedaddy"] - ), + "towardsdatascience.com": MediumParser(name="html", url="towardsdatascience.com"), + "unstableontology.com": element_extractor(".entry-content", remove=["div.sharedaddy"]), "waitbutwhy.com": element_extractor("article", remove=[".entry-header"]), "weightagnostic.github.io": element_extractor( "dt-article", remove=["#authors_section", "dt-byline"] @@ -211,9 +209,7 @@ def getter(url: str) -> Dict[str, Any]: "cnas.org": element_extractor("#mainbar-toc"), "econlib.org": element_extractor("div.post-content"), "humanityplus.org": element_extractor("div.content"), - "gleech.org": element_extractor( - "article.post-content", remove=["center", "div.accordion"] - ), + "gleech.org": element_extractor("article.post-content", remove=["center", "div.accordion"]), "ibm.com": element_extractor("div:has(> p)"), # IBM's HTML is really ugly... "microsoft.com": element_extractor("div.content-container"), "mdpi.com": element_extractor( @@ -289,9 +285,7 @@ def getter(url: str) -> Dict[str, Any]: "jstor.org": doi_getter, "ri.cmu.edu": get_pdf_from_page("a.pub-link"), "risksciences.ucla.edu": get_pdf_from_page('a:-soup-contains("Download")'), - "ssrn.com": get_pdf_from_page( - '.abstract-buttons a.button-link:-soup-contains("Download")' - ), + "ssrn.com": get_pdf_from_page('.abstract-buttons a.button-link:-soup-contains("Download")'), "yjolt.org": get_pdf_from_page("span.file a"), } @@ -320,7 +314,7 @@ def item_metadata(url: str) -> Dict[str, Any]: # there is a link to a pdf on it if parser := HTML_PARSERS.get(domain): res = parser(url) - if res and 'error' not in res: + if res and "error" not in res: # Proper contents were found on the page, so use them return res @@ -338,7 +332,9 @@ def item_metadata(url: str) -> Dict[str, Any]: return {"error": f"No domain handler defined for {domain}"} return {"error": "could not parse url"} elif content_type & {"application/octet-stream", "application/pdf"}: - # this looks like it could be a pdf - try to download it as one + if domain == "arxiv.org": + return fetch_arxiv(url) + # just download it as a pdf return fetch_pdf(url) elif content_type & {"application/epub+zip", "application/epub"}: # it looks like an ebook. Assume it's fine. diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py index 93971a86..ca2669e5 100644 --- a/align_data/sources/articles/updater.py +++ b/align_data/sources/articles/updater.py @@ -37,8 +37,9 @@ def maybe(item, key): def items_list(self) -> List[Item]: df = pd.read_csv(self.csv_path, delimiter=self.delimiter) self.csv_items = [ - item for item in df.itertuples() - if self.maybe(item, 'id') or self.maybe(item, 'hash_id') + item + for item in df.itertuples() + if self.maybe(item, "id") or self.maybe(item, "hash_id") ] by_id = {id: item for item in self.csv_items if (id := self.maybe(item, 'id'))} by_hash_id = {hash_id: item for item in self.csv_items if (hash_id := self.maybe(item, 'hash_id'))} @@ -53,39 +54,39 @@ def items_list(self) -> List[Item]: @property def _query_items(self) -> Select[Tuple[Article]]: - ids = [i.id for i in self.csv_items if self.maybe(i, 'id')] - hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, 'hash_id')] + ids = [i.id for i in self.csv_items if self.maybe(i, "id")] + hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, "hash_id")] return select(Article).where(or_(Article.id.in_(hash_ids), Article._id.in_(ids))) def update_text(self, updates: NamedTuple, article: Article): # If the url is the same as it was before, and there isn't a source url provided, assume that the # previous text is still valid - if article.url == self.maybe(updates, 'url') and not self.maybe(updates, 'source_url'): + if article.url == self.maybe(updates, "url") and not self.maybe(updates, "source_url"): return # If no url found, then don't bother fetching the text - assume it was successfully fetched previously - url = self.maybe(updates, 'source_url') or self.maybe(updates, 'url') + url = self.maybe(updates, "source_url") or self.maybe(updates, "url") if not url: return if article.url != url: - article.add_meta('source_url', url) + article.add_meta("source_url", url) metadata = item_metadata(url) # Only change the text if it could be fetched - better to have outdated values than none - if metadata.get('text'): - article.text = metadata['text'] - article.status = metadata.get('error') + if metadata.get("text"): + article.text = metadata["text"] + article.status = metadata.get("error") def process_entry(self, item: Item) -> Article: updates, article = item - for key in ['url', 'title', 'source', 'authors', 'comment', 'confidence']: + for key in ["url", "title", "source", "authors", "comment", "confidence"]: value = self.maybe(updates, key) if value and getattr(article, key, None) != value: setattr(article, key, value) - if date := getattr(updates, 'date_published', None): + if date := getattr(updates, "date_published", None): article.date_published = self._get_published_date(date) self.update_text(updates, article) diff --git a/align_data/sources/arxiv_papers.py b/align_data/sources/arxiv_papers.py index 89739e11..9067152a 100644 --- a/align_data/sources/arxiv_papers.py +++ b/align_data/sources/arxiv_papers.py @@ -6,6 +6,7 @@ from align_data.sources.articles.pdf import fetch_pdf, parse_vanity from align_data.sources.articles.html import fetch_element +from align_data.sources.utils import merge_dicts logger = logging.getLogger(__name__) @@ -22,7 +23,7 @@ def get_arxiv_metadata(paper_id: str) -> arxiv.Result | None: return None -def get_id(url: str) -> Optional[str]: +def get_id(url: str) -> str | None: if res := re.search(r"https?://arxiv.org/(?:abs|pdf)/(.*?)(?:v\d+)?(?:/|\.pdf)?$", url): return res.group(1) return None @@ -30,25 +31,30 @@ def get_id(url: str) -> Optional[str]: def canonical_url(url: str) -> str: if paper_id := get_id(url): - return f'https://arxiv.org/abs/{paper_id}' + return f"https://arxiv.org/abs/{paper_id}" return url def get_contents(paper_id: str) -> Dict[str, Any]: arxiv_vanity = parse_vanity(f"https://www.arxiv-vanity.com/papers/{paper_id}") - if 'error' not in arxiv_vanity: + if "error" not in arxiv_vanity: return arxiv_vanity ar5iv = parse_vanity(f"https://ar5iv.org/abs/{paper_id}") - if 'error' not in ar5iv: + if "error" not in ar5iv: return ar5iv return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf") -def get_version(id: str) -> Optional[str]: - if res := re.search(r'.*v(\d+)$', id): +def get_version(id: str) -> str | None: + if res := re.search(r".*v(\d+)$", id): return res.group(1) + + +def is_withdrawn(url: str): + if elem := fetch_element(canonical_url(url), ".extra-services .full-text ul"): + return elem.text.strip().lower() == "withdrawn" return None @@ -62,7 +68,7 @@ def add_metadata(data: Dict[str, Any], paper_id: str) -> Dict[str, Any]: metadata = get_arxiv_metadata(paper_id) if not metadata: return {} - return dict({ + return merge_dicts({ "authors": metadata.authors, "title": metadata.title, "date_published": metadata.published, @@ -74,28 +80,28 @@ def add_metadata(data: Dict[str, Any], paper_id: str) -> Dict[str, Any]: "primary_category": metadata.primary_category, "categories": metadata.categories, "version": get_version(metadata.get_short_id()), - }, **data) + }, data) def fetch_arxiv(url: str) -> Dict[str, Any]: paper_id = get_id(url) if not paper_id: - return {'error': 'Could not extract arxiv id'} - - if is_withdrawn(url): paper = {'status': 'Withdrawn'} - else: paper = get_contents(paper_id) + return {"error": "Could not extract arxiv id"} - data = add_metadata({ - "url": canonical_url(url), - "source_type": paper.get('data_source'), - }, paper_id) - - authors = data.get('authors') or paper.get("authors") - if not authors: - data['authors'] = [] - elif not isinstance(authors, list): - data['authors'] = [str(authors).strip()] + if is_withdrawn(url): + paper = {"status": "Withdrawn"} else: - data['authors'] = [str(author).strip() for author in authors] - - return dict(data, **paper) + paper = get_contents(paper_id) + + data = add_metadata( + { + "url": canonical_url(url), + "source_type": paper.get("data_source"), + }, + paper_id, + ) + authors = data.get("authors") or paper.get("authors") or [] + data["authors"] = [str(a).strip() for a in authors] + data["source"] = "arxiv" + + return merge_dicts(data, paper) diff --git a/align_data/sources/blogs/__init__.py b/align_data/sources/blogs/__init__.py index 923c5257..64f27310 100644 --- a/align_data/sources/blogs/__init__.py +++ b/align_data/sources/blogs/__init__.py @@ -1,18 +1,22 @@ from align_data.sources.blogs.wp_blog import WordpressBlog from align_data.sources.blogs.gwern_blog import GwernBlog from align_data.sources.blogs.blogs import ( + AXRPDataset, ColdTakes, GenerativeInk, CaradoMoe, EleutherAI, OpenAIResearch, DeepMindTechnicalBlog, + TransformerCircuits, ) from align_data.sources.blogs.substack_blog import SubstackBlog from align_data.sources.articles.parsers import MediumParser +from align_data.common.alignment_dataset import MultiDataset -BLOG_REGISTRY = [ +BLOG_DATASETS = [ + AXRPDataset(name='axrp', url='https://axrp.net', authors=['AXRP']), WordpressBlog(name="aiimpacts", url="https://aiimpacts.org"), WordpressBlog(name="aisafety.camp", url="https://aisafety.camp"), WordpressBlog(name="miri", url="https://intelligence.org"), @@ -24,9 +28,7 @@ url="https://deepmindsafetyresearch.medium.com/", authors=["DeepMind Safety Research"], ), - GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ), + GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]), ColdTakes( name="cold_takes", url="https://www.cold-takes.com/", @@ -56,4 +58,10 @@ name="deepmind_technical_blog", url="https://www.deepmind.com/blog-categories/technical-blogs", ), + TransformerCircuits(name='transformer-circuits', url='https://transformer-circuits.pub/'), +] + + +BLOG_REGISTRY = [ + MultiDataset(name='blogs', datasets=BLOG_DATASETS), ] diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index 849a7889..08414199 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -1,4 +1,5 @@ import logging +from urllib.parse import urljoin import requests from bs4 import BeautifulSoup @@ -51,12 +52,7 @@ def _get_published_date(self, contents): return "" def extract_authors(self, article): - return ( - article.select_one("header .post-meta") - .text.split("·")[1] - .strip() - .split(", ") - ) + return article.select_one("header .post-meta").text.split("·")[1].strip().split(", ") class OpenAIResearch(HTMLDataset): @@ -116,9 +112,7 @@ def items_list(self): page += 1 # update the tqdm progress bar - pbar.set_postfix_str( - f"page {page}", refresh=True - ) # Set postfix to "page X" + pbar.set_postfix_str(f"page {page}", refresh=True) # Set postfix to "page X" pbar.update() # Here we increment the progress bar by 1 logger.info("Got %s pages", len(articles)) @@ -135,3 +129,64 @@ def extract_authors(self, article): ): return [author.strip() for author in div.text.split(",")] return [] + + +class TransformerCircuits(HTMLDataset): + + item_selector = "div.toc a" + text_selector = 'h3' + + def get_item_key(self, item): + article_url = item.get("href").split("?")[0] + return urljoin(self.url, article_url) + + @property + def items_list(self): + return [i for i in super().items_list if self.get_item_key(i).startswith(self.url)] + + def _metadata(self, contents, selector): + if meta := contents.select_one('div.d-byline'): + return meta.select(selector) + + def _get_title(self, contents): + title = contents.find("title") + return title and title.text.strip() + + def _get_published_date(self, contents): + if date := self._metadata(contents, 'div.published div'): + return super()._get_published_date(date[0].text) + + def _get_text(self, contents): + article = contents.find("d-article") or contents.find("dt-article") + return self._extract_markdown(article) + + def extract_authors(self, contents): + if authors := self._metadata(contents, 'span.author'): + for a in authors: + for sup in a.select('sup'): + sup.extract() + return [a.text.strip().strip(',*') for a in authors] + return [] + + +class AXRPDataset(RSSDataset): + + @property + def feed_url(self): + return f"{self.url}/feed.xml" + + def _extract_item_url(self, item) -> str | None: + if path := item.get('link'): + return self.url + path + return None + + def extract_authors(self, item): + if "authors" in item: + authors = [name for a in item["authors"] if (name := (a.get("name") or '').strip())] + if authors: + return authors + + bits = item.get('title', '').split(' with ') + if len(bits) > 1 and bits[-1].strip(): + return self.authors + [bits[-1].strip()] + return self.authors diff --git a/align_data/sources/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py index bdb45292..b0c9e9f1 100644 --- a/align_data/sources/blogs/wp_blog.py +++ b/align_data/sources/blogs/wp_blog.py @@ -41,9 +41,7 @@ def items_list(self): self.items[item["link"]] = item # update the tqdm progress bar - pbar.set_postfix_str( - f"page {page_number}", refresh=True - ) # Set postfix to "page X" + pbar.set_postfix_str(f"page {page_number}", refresh=True) # Set postfix to "page X" pbar.update() # Here we increment the progress bar by 1 logger.info(f"Got {len(self.items)} pages") diff --git a/align_data/sources/distill/distill.py b/align_data/sources/distill/distill.py index 2b1690b6..8709b154 100644 --- a/align_data/sources/distill/distill.py +++ b/align_data/sources/distill/distill.py @@ -1,15 +1,14 @@ -from dataclasses import dataclass -from markdownify import markdownify from align_data.common.html_dataset import RSSDataset -@dataclass class Distill(RSSDataset): source_type = "html" done_key = "url" def extract_authors(self, item): - return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or ["Distill"] + return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or [ + "Distill" + ] def _get_text(self, item): article = item["soup"].find("d-article") or item["soup"].find("dt-article") diff --git a/align_data/sources/ebooks/__init__.py b/align_data/sources/ebooks/__init__.py index 0055f5e0..7fdcd729 100644 --- a/align_data/sources/ebooks/__init__.py +++ b/align_data/sources/ebooks/__init__.py @@ -1,7 +1,5 @@ from .agentmodels import AgentModels EBOOK_REGISTRY = [ - AgentModels( - name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git" - ), + AgentModels(name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git"), ] diff --git a/align_data/sources/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py index 748ceebf..3756524a 100644 --- a/align_data/sources/ebooks/agentmodels.py +++ b/align_data/sources/ebooks/agentmodels.py @@ -29,9 +29,7 @@ def setup(self): self.files_path = self.base_dir / "chapters" def _get_published_date(self, filename): - last_commit = next( - self.repository.iter_commits(paths=f"chapters/{filename.name}") - ) + last_commit = next(self.repository.iter_commits(paths=f"chapters/{filename.name}")) return last_commit.committed_datetime.astimezone(timezone.utc) def process_entry(self, filename): diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py index bd906348..12dc9b80 100644 --- a/align_data/sources/greaterwrong/greaterwrong.py +++ b/align_data/sources/greaterwrong/greaterwrong.py @@ -76,10 +76,7 @@ def setup(self): self.ai_tags = get_allowed_tags(self.base_url, self.name) def tags_ok(self, post): - return ( - not self.ai_tags - or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags - ) + return not self.ai_tags or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags def get_item_key(self, item): return item["pageUrl"] @@ -176,7 +173,7 @@ def process_entry(self, item): authors = item["coauthors"] if item["user"]: authors = [item["user"]] + authors - authors = [a["displayName"] for a in authors] or ['anonymous'] + authors = [a["displayName"] for a in authors] or ["anonymous"] return self.make_data_entry( { "title": item["title"], diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py index c6f784c5..7aad532c 100644 --- a/align_data/sources/stampy/stampy.py +++ b/align_data/sources/stampy/stampy.py @@ -43,13 +43,9 @@ def _get_published_date(self, entry): def process_entry(self, entry): def clean_text(text): text = html.unescape(text) - return re.sub( - r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text - ) + return re.sub(r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text) - question = clean_text( - entry["Question"] - ) # raise an error if the entry has no question + question = clean_text(entry["Question"]) # raise an error if the entry has no question answer = clean_text(entry["Rich Text"]) url = "https://aisafety.info?state=" + entry["UI ID"] diff --git a/align_data/sources/utils.py b/align_data/sources/utils.py new file mode 100644 index 00000000..3ed9abfe --- /dev/null +++ b/align_data/sources/utils.py @@ -0,0 +1,5 @@ +def merge_dicts(*dicts): + final = {} + for d in dicts: + final = dict(final, **{k: v for k, v in d.items() if v is not None}) + return final diff --git a/align_data/sources/youtube/__init__.py b/align_data/sources/youtube/__init__.py index 06c8defe..ca0d9b33 100644 --- a/align_data/sources/youtube/__init__.py +++ b/align_data/sources/youtube/__init__.py @@ -1,9 +1,10 @@ +from align_data.common.alignment_dataset import MultiDataset from align_data.sources.youtube.youtube import ( YouTubeChannelDataset, YouTubePlaylistDataset, ) -YOUTUBE_REGISTRY = [ +YOUTUBE_DATASETS = [ YouTubeChannelDataset( name="rob_miles_ai_safety", channel_id="UCLB7AzTwc6VFZrBsO2ucBMg", @@ -40,3 +41,8 @@ ], ), ] + + +YOUTUBE_REGISTRY = [ + MultiDataset(name='youtube', datasets=YOUTUBE_DATASETS), +] diff --git a/align_data/sources/youtube/youtube.py b/align_data/sources/youtube/youtube.py index 759c9840..9bdad819 100644 --- a/align_data/sources/youtube/youtube.py +++ b/align_data/sources/youtube/youtube.py @@ -1,4 +1,3 @@ -import collections import logging from dataclasses import dataclass, field from typing import List, Optional, Iterable diff --git a/main.py b/main.py index 3c934374..2edeec9d 100644 --- a/main.py +++ b/main.py @@ -7,8 +7,13 @@ from align_data import ALL_DATASETS, get_dataset from align_data.analysis.count_tokens import count_token -from align_data.sources.articles.articles import update_new_items, check_new_articles, update_articles -from align_data.pinecone.update_pinecone import PineconeUpdater +from align_data.sources.articles.articles import ( + update_new_items, + check_new_articles, + update_articles, +) +from align_data.embeddings.pinecone.update_pinecone import PineconeUpdater +from align_data.embeddings.finetuning.training import finetune_embeddings from align_data.settings import ( METADATA_OUTPUT_SPREADSHEET, METADATA_SOURCE_SHEET, @@ -75,12 +80,10 @@ def count_tokens(self, merged_dataset_path: str) -> None: This function counts the number of tokens, words, and characters in the dataset :return: None """ - assert os.path.exists( - merged_dataset_path - ), "The path to the merged dataset does not exist" + assert os.path.exists(merged_dataset_path), "The path to the merged dataset does not exist" count_token(merged_dataset_path) - def update(self, csv_path, delimiter=','): + def update(self, csv_path, delimiter=","): """Update all articles in the provided csv files, overwriting the provided values and fetching new text if a different url provided. :param str csv_path: The path to the csv file to be processed @@ -114,7 +117,7 @@ def fetch_new_articles( """ return check_new_articles(source_spreadsheet, source_sheet) - def pinecone_update(self, *names) -> None: + def pinecone_update(self, *names, force_update=False) -> None: """ This function updates the Pinecone vector DB. @@ -124,14 +127,20 @@ def pinecone_update(self, *names) -> None: names = ALL_DATASETS missing = {name for name in names if name not in ALL_DATASETS} assert not missing, f"{missing} are not valid dataset names" - PineconeUpdater().update(names) + PineconeUpdater().update(names, force_update) - def pinecone_update_all(self, *skip) -> None: + def pinecone_update_all(self, *skip, force_update=False) -> None: """ This function updates the Pinecone vector DB. """ names = [name for name in ALL_DATASETS if name not in skip] - PineconeUpdater().update(names) + PineconeUpdater().update(names, force_update) + + def train_finetuning_layer(self) -> None: + """ + This function trains a finetuning layer on top of the OpenAI embeddings. + """ + finetune_embeddings() if __name__ == "__main__": diff --git a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py index 7a8485fe..e5b9a303 100644 --- a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py +++ b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py @@ -18,9 +18,7 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False) - ) + op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)) # ### end Alembic commands ### diff --git a/migrations/versions/f5a2bcfa6b2c_add_status_column.py b/migrations/versions/f5a2bcfa6b2c_add_status_column.py index 76c89ee0..d93a8c86 100644 --- a/migrations/versions/f5a2bcfa6b2c_add_status_column.py +++ b/migrations/versions/f5a2bcfa6b2c_add_status_column.py @@ -10,17 +10,17 @@ from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = 'f5a2bcfa6b2c' -down_revision = '59ac3cb671e3' +revision = "f5a2bcfa6b2c" +down_revision = "59ac3cb671e3" branch_labels = None depends_on = None def upgrade() -> None: - op.add_column('articles', sa.Column('status', sa.String(length=256), nullable=True)) - op.add_column('articles', sa.Column('comments', mysql.LONGTEXT(), nullable=True)) + op.add_column("articles", sa.Column("status", sa.String(length=256), nullable=True)) + op.add_column("articles", sa.Column("comments", mysql.LONGTEXT(), nullable=True)) def downgrade() -> None: - op.drop_column('articles', 'comments') - op.drop_column('articles', 'status') + op.drop_column("articles", "comments") + op.drop_column("articles", "status") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..aa4949aa --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 100 diff --git a/requirements.txt b/requirements.txt index b7985f3f..008d629d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,6 +28,7 @@ google-auth-httplib2 google-api-python-client gspread youtube-transcript-api +airtable alembic mysqlclient @@ -36,3 +37,5 @@ openai langchain nltk pinecone-client + +torch diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index 29b38c46..f7246525 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -48,7 +48,7 @@ def mock_arxiv(): journal_ref="sdf", primary_category="cat", ) - metadata.get_short_id.return_value = '2001.11038' + metadata.get_short_id.return_value = "2001.11038" arxiv = Mock() arxiv.Search.return_value.results.return_value = iter([metadata]) @@ -124,30 +124,24 @@ def test_pdf_articles_process_item(articles): "text": "pdf contents [bla](asd.com)", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } def test_html_articles_get_text(): def parser(url): assert url == "http://example.org/bla.bla" - return {'text': "html contents"} + return {"text": "html contents"} - with patch( - "align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser} - ): + with patch("align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser}): assert ( - HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) - == "html contents" + HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) == "html contents" ) def test_html_articles_get_text_no_parser(): with patch("align_data.sources.articles.datasets.HTML_PARSERS", {}): - assert ( - HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) - is None - ) + assert HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) is None def test_html_articles_process_entry(articles): @@ -156,7 +150,9 @@ def test_html_articles_process_entry(articles): item = list(dataset.items_list)[0] parsers = { - "example.com": lambda _: {'text': ' html contents with <a href="bla.com">proper elements</a> ble ble '} + "example.com": lambda _: { + "text": ' html contents with <a href="bla.com">proper elements</a> ble ble ' + } } with patch("align_data.sources.articles.datasets.HTML_PARSERS", parsers): assert dataset.process_entry(item).to_dict() == { @@ -170,7 +166,7 @@ def test_html_articles_process_entry(articles): "text": "html contents with [proper elements](bla.com) ble ble", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } @@ -201,9 +197,7 @@ def test_ebook_articles_process_entry(articles): contents = ' html contents with <a href="bla.com">proper elements</a> ble ble ' with patch("align_data.sources.articles.datasets.download"): - with patch( - "align_data.sources.articles.datasets.convert_file", return_value=contents - ): + with patch("align_data.sources.articles.datasets.convert_file", return_value=contents): assert dataset.process_entry(item).to_dict() == { "authors": ["John Snow", "mr Blobby"], "date_published": "2023-01-01T12:32:11Z", @@ -215,7 +209,7 @@ def test_ebook_articles_process_entry(articles): "text": "html contents with [proper elements](bla.com) ble ble", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } @@ -248,7 +242,7 @@ def test_xml_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } @@ -281,19 +275,15 @@ def test_markdown_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } def test_doc_articles_get_text(): dataset = DocArticles(name="bla", spreadsheet_id="123", sheet_id="456") with patch("align_data.sources.articles.datasets.fetch_file"): - with patch( - "align_data.sources.articles.datasets.convert_file", return_value="bla bla" - ): - assert ( - dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla" - ) + with patch("align_data.sources.articles.datasets.convert_file", return_value="bla bla"): + assert dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla" def test_doc_articles_process_entry(articles): @@ -302,9 +292,7 @@ def test_doc_articles_process_entry(articles): item = list(dataset.items_list)[0] with patch("align_data.sources.articles.datasets.fetch_file"): - with patch( - "align_data.sources.articles.datasets.convert_file", return_value="bla bla" - ): + with patch("align_data.sources.articles.datasets.convert_file", return_value="bla bla"): assert dataset.process_entry(item).to_dict() == { "authors": ["John Snow", "mr Blobby"], "date_published": "2023-01-01T12:32:11Z", @@ -316,11 +304,11 @@ def test_doc_articles_process_entry(articles): "text": "bla bla", "title": "article no 0", "url": "http://example.com/item/0", - 'source_url': 'http://example.com/source_url/0', + "source_url": "http://example.com/source_url/0", } -@patch('requests.get', return_value=Mock(content='')) +@patch("requests.get", return_value=Mock(content="")) def test_arxiv_process_entry(_, mock_arxiv): dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da") item = Mock( @@ -335,9 +323,7 @@ def test_arxiv_process_entry(_, mock_arxiv): "authors": ["mr blobby"], "source_type": "html", } - with patch( - "align_data.sources.arxiv_papers.parse_vanity", return_value=contents - ): + with patch("align_data.sources.arxiv_papers.parse_vanity", return_value=contents): assert dataset.process_entry(item).to_dict() == { "comment": "no comment", "authors": ["mr blobby"], @@ -377,12 +363,9 @@ def test_arxiv_process_entry_retracted(mock_arxiv): </div> """ - with patch('requests.get', return_value=Mock(content=response)): + with patch("requests.get", return_value=Mock(content=response)): article = dataset.process_entry(item) - print(article.to_dict()) - print(article.status) - print(article.__dir__()) - assert article.status == 'Withdrawn' + assert article.status == "Withdrawn" assert article.to_dict() == { "comment": "no comment", "authors": [], @@ -410,7 +393,7 @@ def test_special_docs_process_entry(): authors="mr. blobby", date_published="2023-10-02T01:23:45", source_type=None, - source_url="https://ble.ble.com" + source_url="https://ble.ble.com", ) contents = { "text": "this is the text", @@ -421,20 +404,20 @@ def test_special_docs_process_entry(): with patch("align_data.sources.articles.datasets.item_metadata", return_value=contents): assert dataset.process_entry(item).to_dict() == { - 'authors': ['mr. blobby'], - 'date_published': '2023-10-02T01:23:45Z', - 'id': None, - 'source': 'html', - 'source_url': "https://ble.ble.com", - 'source_type': 'html', - 'summaries': [], - 'text': 'this is the text', - 'title': 'this is the title', - 'url': 'https://bla.bla.bla', + "authors": ["mr. blobby"], + "date_published": "2023-10-02T01:23:45Z", + "id": None, + "source": "html", + "source_url": "https://ble.ble.com", + "source_type": "html", + "summaries": [], + "text": "this is the text", + "title": "this is the title", + "url": "https://bla.bla.bla", } -@patch('requests.get', return_value=Mock(content='')) +@patch("requests.get", return_value=Mock(content="")) def test_special_docs_process_entry_arxiv(_, mock_arxiv): dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") item = Mock( @@ -450,9 +433,7 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv): "source_type": "pdf", } - with patch( - "align_data.sources.arxiv_papers.parse_vanity", return_value=contents - ): + with patch("align_data.sources.arxiv_papers.parse_vanity", return_value=contents): assert dataset.process_entry(item).to_dict() == { "comment": "no comment", "authors": ["mr blobby"], @@ -472,16 +453,22 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv): } -@pytest.mark.parametrize('url, expected', ( - ("http://bla.bla", "http://bla.bla"), - ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/math/2001.11038", "https://arxiv.org/abs/math/2001.11038"), -)) +@pytest.mark.parametrize( + "url, expected", + ( + ("http://bla.bla", "http://bla.bla"), + ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"), + ( + "https://arxiv.org/abs/math/2001.11038", + "https://arxiv.org/abs/math/2001.11038", + ), + ), +) def test_special_docs_not_processed_true(url, expected): dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") dataset._outputted_items = [url, expected] @@ -489,13 +476,15 @@ def test_special_docs_not_processed_true(url, expected): assert not dataset.not_processed(Mock(url=None, source_url=url)) -@pytest.mark.parametrize('url', ( - "http://bla.bla" - "http://arxiv.org/abs/2001.11038", - "https://arxiv.org/abs/2001.11038", - "https://arxiv.org/abs/2001.11038/", - "https://arxiv.org/pdf/2001.11038", -)) +@pytest.mark.parametrize( + "url", + ( + "http://bla.bla" "http://arxiv.org/abs/2001.11038", + "https://arxiv.org/abs/2001.11038", + "https://arxiv.org/abs/2001.11038/", + "https://arxiv.org/pdf/2001.11038", + ), +) def test_special_docs_not_processed_false(url): dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") dataset._outputted_items = [] diff --git a/tests/align_data/articles/test_google_cloud.py b/tests/align_data/articles/test_google_cloud.py index 39cacce3..bc814fe1 100644 --- a/tests/align_data/articles/test_google_cloud.py +++ b/tests/align_data/articles/test_google_cloud.py @@ -1,7 +1,12 @@ from unittest.mock import Mock, patch import pytest -from align_data.sources.articles.google_cloud import extract_gdrive_contents, get_content_type, google_doc, parse_grobid +from align_data.sources.articles.google_cloud import ( + extract_gdrive_contents, + get_content_type, + google_doc, + parse_grobid, +) SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?> @@ -45,8 +50,12 @@ def test_google_doc(): def fetcher(url, *args, **kwargs): - assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html' - return Mock(content=""" + assert ( + url + == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html" + ) + return Mock( + content=""" <html> <header>bla bla bla</header> <body> @@ -54,35 +63,45 @@ def fetcher(url, *args, **kwargs): </body> </html> - """) + """ + ) - with patch('requests.get', fetcher): - url = 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit' + with patch("requests.get", fetcher): + url = "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit" assert google_doc(url) == { - 'text': 'ble ble [a link](bla.com)', - 'source_url': url + "text": "ble ble [a link](bla.com)", + "source_url": url, } def test_google_doc_no_body(): def fetcher(url, *args, **kwargs): - assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html' + assert ( + url + == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html" + ) return Mock(content="<html> <header>bla bla bla</header> </html>") - with patch('requests.get', fetcher): - assert google_doc('https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit') == {} + with patch("requests.get", fetcher): + assert ( + google_doc( + "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit" + ) + == {} + ) def test_google_doc_bad_url(): - assert google_doc('https://docs.google.com/bla/bla') == {} + assert google_doc("https://docs.google.com/bla/bla") == {} + def test_parse_grobid(): assert parse_grobid(SAMPLE_XML) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'text': 'This is the contents', - 'title': 'The title!!', - 'source_type': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "text": "This is the contents", + "title": "The title!!", + "source_type": "xml", } @@ -104,74 +123,94 @@ def test_parse_grobid_no_body(): </text> </TEI> """ - assert parse_grobid(xml) == {'error': 'No contents in XML file', 'source_type': 'xml'} - + assert parse_grobid(xml) == { + "error": "No contents in XML file", + "source_type": "xml", + } -@pytest.mark.parametrize('header, expected', ( - (None, set()), - ('', set()), - ('text/html', {'text/html'}), - ('text/html; bla=asdas; fewwe=fe', {'text/html', 'bla=asdas', 'fewwe=fe'}), -)) +@pytest.mark.parametrize( + "header, expected", + ( + (None, set()), + ("", set()), + ("text/html", {"text/html"}), + ("text/html; bla=asdas; fewwe=fe", {"text/html", "bla=asdas", "fewwe=fe"}), + ), +) def test_get_content_type(header, expected): - assert get_content_type(Mock(headers={'Content-Type': header})) == expected - - -@pytest.mark.parametrize('headers', ( - {}, - {'Content-Type': None}, - {'Content-Type': ''}, - {'Content-Type': ' '}, - {'Content-Type': ' ; ;; '}, -)) + assert get_content_type(Mock(headers={"Content-Type": header})) == expected + + +@pytest.mark.parametrize( + "headers", + ( + {}, + {"Content-Type": None}, + {"Content-Type": ""}, + {"Content-Type": " "}, + {"Content-Type": " ; ;; "}, + ), +) def test_extract_gdrive_contents_no_contents(headers): - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers=headers, status_code=200)): + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=Mock(headers=headers, status_code=200)): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'error': 'no content type' + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "error": "no content type", } -@pytest.mark.parametrize('header', ( - 'application/octet-stream', - 'application/pdf', - 'application/pdf; filename=bla.pdf' -)) +@pytest.mark.parametrize( + "header", + ( + "application/octet-stream", + "application/pdf", + "application/pdf; filename=bla.pdf", + ), +) def test_extract_gdrive_contents_pdf(header): - res = Mock(headers={'Content-Type': header}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): - with patch('align_data.sources.articles.google_cloud.fetch_pdf', return_value={'text': 'bla'}): + res = Mock(headers={"Content-Type": header}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): + with patch( + "align_data.sources.articles.google_cloud.fetch_pdf", + return_value={"text": "bla"}, + ): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'bla', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "bla", } -@pytest.mark.parametrize('header', ( - 'application/epub', - 'application/epub+zip', - 'application/epub; filename=bla.epub', -)) +@pytest.mark.parametrize( + "header", + ( + "application/epub", + "application/epub+zip", + "application/epub; filename=bla.epub", + ), +) def test_extract_gdrive_contents_ebook(header): - res = Mock(headers={'Content-Type': header}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): + res = Mock(headers={"Content-Type": header}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'source_type': 'ebook', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "source_type": "ebook", } def test_extract_gdrive_contents_html(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): html = """ <html> <header>bleee</header> @@ -179,45 +218,48 @@ def test_extract_gdrive_contents_html(): </html> """ res = Mock( - headers={'Content-Type': 'text/html'}, + headers={"Content-Type": "text/html"}, status_code=200, content=html, text=html, ) - with patch('requests.get', return_value=res): + with patch("requests.get", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'bla bla', - 'source_type': 'html', + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "bla bla", + "source_type": "html", } def test_extract_gdrive_contents_xml(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): res = Mock( - headers={'Content-Type': 'text/xml'}, + headers={"Content-Type": "text/xml"}, status_code=200, content=SAMPLE_XML, text=SAMPLE_XML, ) - with patch('requests.get', return_value=res): + with patch("requests.get", return_value=res): assert extract_gdrive_contents(url) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'This is the contents', - 'title': 'The title!!', - 'source_type': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "This is the contents", + "title": "The title!!", + "source_type": "xml", } def test_extract_gdrive_contents_xml_with_confirm(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" def fetcher(link, *args, **kwargs): # The first request should get the google drive warning page @@ -228,27 +270,35 @@ def fetcher(link, *args, **kwargs): <form action="fetch/xml/contents"></form> </body> """ - return Mock(headers={'Content-Type': 'text/html'}, status_code=200, text=html, content=html) + return Mock( + headers={"Content-Type": "text/html"}, + status_code=200, + text=html, + content=html, + ) # The second one returns the actual contents - return Mock(headers={'Content-Type': 'text/xml'}, status_code=200, content=SAMPLE_XML) + return Mock(headers={"Content-Type": "text/xml"}, status_code=200, content=SAMPLE_XML) - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): - with patch('requests.get', fetcher): + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): + with patch("requests.get", fetcher): assert extract_gdrive_contents(url) == { - 'abstract': 'this is the abstract', - 'authors': ['Cullen Oâ\x80\x99Keefe'], - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'text': 'This is the contents', - 'title': 'The title!!', - 'source_type': 'xml', + "abstract": "this is the abstract", + "authors": ["Cullen Oâ\x80\x99Keefe"], + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "text": "This is the contents", + "title": "The title!!", + "source_type": "xml", } def test_extract_gdrive_contents_warning_with_unknown(): - res = Mock(headers={'Content-Type': 'text/html'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' + res = Mock(headers={"Content-Type": "text/html"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" def fetcher(link, *args, **kwargs): # The first request should get the google drive warning page @@ -259,26 +309,34 @@ def fetcher(link, *args, **kwargs): <form action="fetch/xml/contents"></form> </body> """ - return Mock(headers={'Content-Type': 'text/html'}, status_code=200, text=html, content=html) + return Mock( + headers={"Content-Type": "text/html"}, + status_code=200, + text=html, + content=html, + ) # The second one returns the actual contents, with an unhandled content type - return Mock(headers={'Content-Type': 'text/bla bla'}, status_code=200) + return Mock(headers={"Content-Type": "text/bla bla"}, status_code=200) - with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)): - with patch('requests.get', fetcher): + with patch( + "requests.head", + return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200), + ): + with patch("requests.get", fetcher): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'error': "unknown content type: {'text/bla bla'}", - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', + "downloaded_from": "google drive", + "error": "unknown content type: {'text/bla bla'}", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", } def test_extract_gdrive_contents_unknown_content_type(): - res = Mock(headers={'Content-Type': 'bla bla'}, status_code=200) - url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing' - with patch('requests.head', return_value=res): + res = Mock(headers={"Content-Type": "bla bla"}, status_code=200) + url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing" + with patch("requests.head", return_value=res): assert extract_gdrive_contents(url) == { - 'downloaded_from': 'google drive', - 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing', - 'error': "unknown content type: {'bla bla'}", + "downloaded_from": "google drive", + "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing", + "error": "unknown content type: {'bla bla'}", } diff --git a/tests/align_data/articles/test_parsers.py b/tests/align_data/articles/test_parsers.py index 8bac313f..5a174e3f 100644 --- a/tests/align_data/articles/test_parsers.py +++ b/tests/align_data/articles/test_parsers.py @@ -42,6 +42,7 @@ </TEI> """ + def test_medium_blog(): html = """ <article> @@ -60,14 +61,14 @@ def test_medium_blog(): </article> """ with patch("requests.get", return_value=Mock(content=html)): - assert MediumParser('html', 'ble')("bla.com") == { - 'authors': [], - 'date_published': parse('Oct 7, 2023').replace(tzinfo=pytz.UTC), - 'source': 'html', - 'source_type': 'blog', - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': 'This is the title', - 'url': 'bla.com', + assert MediumParser("html", "ble")("bla.com") == { + "authors": [], + "date_published": parse("Oct 7, 2023").replace(tzinfo=pytz.UTC), + "source": "html", + "source_type": "blog", + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": "This is the title", + "url": "bla.com", } @@ -83,14 +84,14 @@ def test_medium_blog_no_title(): </article> """ with patch("requests.get", return_value=Mock(content=html)): - assert MediumParser(name='html', url='')("bla.com") == { - 'authors': [], - 'date_published': None, - 'source': 'html', - 'source_type': 'blog', - 'text': 'bla bla bla [a link](http://ble.com) bla bla', - 'title': None, - 'url': 'bla.com', + assert MediumParser(name="html", url="")("bla.com") == { + "authors": [], + "date_published": None, + "source": "html", + "source_type": "blog", + "text": "bla bla bla [a link](http://ble.com) bla bla", + "title": None, + "url": "bla.com", } @@ -105,13 +106,13 @@ def test_medium_blog_no_contents(): </div> </div> """ - with patch('requests.get', return_value=Mock(content=html)): - assert MediumParser(name='html', url='')('bla.com') == { - 'authors': [], - 'date_published': None, - 'source': 'html', - 'source_type': 'blog', - 'text': None, - 'title': None, - 'url': 'bla.com', + with patch("requests.get", return_value=Mock(content=html)): + assert MediumParser(name="html", url="")("bla.com") == { + "authors": [], + "date_published": None, + "source": "html", + "source_type": "blog", + "text": None, + "title": None, + "url": "bla.com", } diff --git a/tests/align_data/articles/test_updater.py b/tests/align_data/articles/test_updater.py index 7d11fbb7..f9e2aea2 100644 --- a/tests/align_data/articles/test_updater.py +++ b/tests/align_data/articles/test_updater.py @@ -9,39 +9,43 @@ SAMPLE_UPDATES = [ {}, - {'title': 'no id - should be ignored'}, - - {'id': '122', 'hash_id': 'deadbeef000'}, + {"title": "no id - should be ignored"}, + {"id": "122", "hash_id": "deadbeef000"}, + { + "id": "123", + "hash_id": "deadbeef001", + "title": "bla bla", + "url": "http://bla.com", + "source_url": "http://bla.bla.com", + "authors": "mr. blobby, johnny", + }, + { + "id": "124", + "title": "no hash id", + "url": "http://bla.com", + "source_url": "http://bla.bla.com", + "authors": "mr. blobby", + }, { - 'id': '123', 'hash_id': 'deadbeef001', - 'title': 'bla bla', - 'url': 'http://bla.com', - 'source_url': 'http://bla.bla.com', - 'authors': 'mr. blobby, johnny', - }, { - 'id': '124', - 'title': 'no hash id', - 'url': 'http://bla.com', - 'source_url': 'http://bla.bla.com', - 'authors': 'mr. blobby', - }, { - 'hash_id': 'deadbeef002', - 'title': 'no id', - 'url': 'http://bla.com', - 'source_url': 'http://bla.bla.com', - 'authors': 'mr. blobby', - }, { - 'id': '125', - 'title': 'no hash id, url or title', - 'authors': 'mr. blobby', - } + "hash_id": "deadbeef002", + "title": "no id", + "url": "http://bla.com", + "source_url": "http://bla.bla.com", + "authors": "mr. blobby", + }, + { + "id": "125", + "title": "no hash id, url or title", + "authors": "mr. blobby", + }, ] + @pytest.fixture def csv_file(tmp_path): - filename = tmp_path / 'data.csv' - with open(filename, 'w', newline='') as csvfile: - fieldnames = ['id', 'hash_id', 'title', 'url', 'source_url', 'authors'] + filename = tmp_path / "data.csv" + with open(filename, "w", newline="") as csvfile: + fieldnames = ["id", "hash_id", "title", "url", "source_url", "authors"] writer = DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() @@ -51,152 +55,195 @@ def csv_file(tmp_path): def test_items_list(csv_file): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") def mock_entries(): return [ Mock( - _id=dataset.maybe(v, 'id'), - id=dataset.maybe(v, 'hash_id'), - title=dataset.maybe(v, 'title'), - url=dataset.maybe(v, 'url'), - authors=dataset.maybe(v, 'authors') + _id=dataset.maybe(v, "id"), + id=dataset.maybe(v, "hash_id"), + title=dataset.maybe(v, "title"), + url=dataset.maybe(v, "url"), + authors=dataset.maybe(v, "authors"), ) for v in dataset.csv_items ] - with patch.object(dataset, 'read_entries', mock_entries): + with patch.object(dataset, "read_entries", mock_entries): items = dataset.items_list - assert len(items) == 5, "items_list should only contain items with valid ids - something is wrong" + assert ( + len(items) == 5 + ), "items_list should only contain items with valid ids - something is wrong" for item in items: - assert dataset.maybe(item.updates, 'id') == item.article._id - assert dataset.maybe(item.updates, 'hash_id') == item.article.id - assert dataset.maybe(item.updates, 'title') == item.article.title - assert dataset.maybe(item.updates, 'url') == item.article.url - assert dataset.maybe(item.updates, 'authors') == item.article.authors - - -@pytest.mark.parametrize('updates', ( - Mock(url='http://some.other.url'), - Mock(source_url='http://some.other.url'), - Mock(url='http://some.other.url', source_url='http://another.url'), -)) + assert dataset.maybe(item.updates, "id") == item.article._id + assert dataset.maybe(item.updates, "hash_id") == item.article.id + assert dataset.maybe(item.updates, "title") == item.article.title + assert dataset.maybe(item.updates, "url") == item.article.url + assert dataset.maybe(item.updates, "authors") == item.article.authors + + +@pytest.mark.parametrize( + "updates", + ( + Mock(url="http://some.other.url"), + Mock(source_url="http://some.other.url"), + Mock(url="http://some.other.url", source_url="http://another.url"), + ), +) def test_update_text(csv_file, updates): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") - article = Mock(text='this should be changed', status='as should this', url='http:/bla.bla.com') + article = Mock(text="this should be changed", status="as should this", url="http:/bla.bla.com") - with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"text": "bla bla bla"}, + ): dataset.update_text(updates, article) - assert article.text == 'bla bla bla' + assert article.text == "bla bla bla" assert article.status == None -@pytest.mark.parametrize('updates', ( - Mock(url='http://some.other.url'), - Mock(source_url='http://some.other.url'), - Mock(url='http://some.other.url', source_url='http://another.url'), -)) +@pytest.mark.parametrize( + "updates", + ( + Mock(url="http://some.other.url"), + Mock(source_url="http://some.other.url"), + Mock(url="http://some.other.url", source_url="http://another.url"), + ), +) def test_update_text_error(csv_file, updates): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") - article = Mock(text='this should not be changed', status='but this should be', url='http:/bla.bla.com') + article = Mock( + text="this should not be changed", + status="but this should be", + url="http:/bla.bla.com", + ) - with patch('align_data.sources.articles.updater.item_metadata', return_value={'error': 'oh noes!'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"error": "oh noes!"}, + ): dataset.update_text(updates, article) - assert article.text == 'this should not be changed' - assert article.status == 'oh noes!' - - -@pytest.mark.parametrize('updates', ( - Mock(url='http://bla.bla.com', source_url=None, comment='Same url as article, no source_url'), - Mock(url='http://bla.bla.com', source_url='', comment='Same url as article, empty source_url'), - Mock(url=None, source_url=None, comment='no urls provided'), -)) + assert article.text == "this should not be changed" + assert article.status == "oh noes!" + + +@pytest.mark.parametrize( + "updates", + ( + Mock( + url="http://bla.bla.com", + source_url=None, + comment="Same url as article, no source_url", + ), + Mock( + url="http://bla.bla.com", + source_url="", + comment="Same url as article, empty source_url", + ), + Mock(url=None, source_url=None, comment="no urls provided"), + ), +) def test_update_text_no_update(csv_file, updates): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") - article = Mock(text='this should not be changed', status='as should not this', url='http://bla.bla.com') + article = Mock( + text="this should not be changed", + status="as should not this", + url="http://bla.bla.com", + ) - with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"text": "bla bla bla"}, + ): dataset.update_text(updates, article) - assert article.text == 'this should not be changed' - assert article.status == 'as should not this' + assert article.text == "this should not be changed" + assert article.status == "as should not this" def test_process_entry(csv_file): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") article = Article( - _id=123, id='deadbeef0123', - title='this should be changed', - url='this should be changed', - text='this should be changed', - authors='this should be changed', - date_published='this should be changed', + _id=123, + id="deadbeef0123", + title="this should be changed", + url="this should be changed", + text="this should be changed", + authors="this should be changed", + date_published="this should be changed", ) updates = Mock( - id='123', - hash_id='deadbeef001', - title='bla bla', - url='http://bla.com', - source_url='http://bla.bla.com', - source='tests', - authors='mr. blobby, johnny', - date_published='2000-12-23T10:32:43Z', + id="123", + hash_id="deadbeef001", + title="bla bla", + url="http://bla.com", + source_url="http://bla.bla.com", + source="tests", + authors="mr. blobby, johnny", + date_published="2000-12-23T10:32:43Z", ) - with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"text": "bla bla bla"}, + ): assert dataset.process_entry(Item(updates, article)).to_dict() == { - 'authors': ['mr. blobby', 'johnny'], - 'date_published': '2000-12-23T10:32:43Z', - 'id': 'd8d8cad8d28739a0862654a0e6e8ce6e', - 'source': 'tests', - 'source_type': None, - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'bla bla', - 'url': 'http://bla.com', - 'source_url': 'http://bla.bla.com', + "authors": ["mr. blobby", "johnny"], + "date_published": "2000-12-23T10:32:43Z", + "id": "d8d8cad8d28739a0862654a0e6e8ce6e", + "source": "tests", + "source_type": None, + "summaries": [], + "text": "bla bla bla", + "title": "bla bla", + "url": "http://bla.com", + "source_url": "http://bla.bla.com", } def test_process_entry_empty(csv_file): - dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',') + dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",") article = Article( - _id=123, id='deadbeef0123', - title='this should not be changed', - url='this should not be changed', - source='this should not be changed', - authors='this should not be changed', - - text='this should be changed', - date_published='this should be changed', + _id=123, + id="deadbeef0123", + title="this should not be changed", + url="this should not be changed", + source="this should not be changed", + authors="this should not be changed", + text="this should be changed", + date_published="this should be changed", ) updates = Mock( - id='123', - hash_id='deadbeef001', + id="123", + hash_id="deadbeef001", title=None, - url='', - source_url='http://bla.bla.com', - source=' ', - authors=' \n \n \t \t ', - date_published='2000-12-23T10:32:43Z', + url="", + source_url="http://bla.bla.com", + source=" ", + authors=" \n \n \t \t ", + date_published="2000-12-23T10:32:43Z", ) - with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}): + with patch( + "align_data.sources.articles.updater.item_metadata", + return_value={"text": "bla bla bla"}, + ): assert dataset.process_entry(Item(updates, article)).to_dict() == { - 'authors': ['this should not be changed'], - 'date_published': '2000-12-23T10:32:43Z', - 'id': '606e9224254f508d297bcb17bcc6d104', - 'source': 'this should not be changed', - 'source_type': None, - 'summaries': [], - 'text': 'bla bla bla', - 'title': 'this should not be changed', - 'url': 'this should not be changed', - 'source_url': 'http://bla.bla.com', + "authors": ["this should not be changed"], + "date_published": "2000-12-23T10:32:43Z", + "id": "606e9224254f508d297bcb17bcc6d104", + "source": "this should not be changed", + "source_type": None, + "summaries": [], + "text": "bla bla bla", + "title": "this should not be changed", + "url": "this should not be changed", + "source_url": "http://bla.bla.com", } diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py index e45c62eb..d18aaf78 100644 --- a/tests/align_data/common/test_alignment_dataset.py +++ b/tests/align_data/common/test_alignment_dataset.py @@ -75,41 +75,68 @@ def test_data_entry_id_from_urls_and_title(): ) -@pytest.mark.parametrize('item, error', ( - ( - {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines", "url": "asd"}, - 'missing fields: date_published, source, text' - ), - ( - {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "text": "asdasd", "title": "asdasd"}, - 'missing fields: date_published, source' - ), - ( - { - "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", - "source": "dwe", "date_published": "dwe" - }, - 'missing fields: text' - ), - ( - { - "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", - "text": "asdasd", "date_published": "dwe" - }, - 'missing fields: source' - ), +@pytest.mark.parametrize( + "item, error", ( - { - "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", "text": "asdasd", "source": "dwe" - }, - 'missing fields: date_published' + ( + { + "key1": 12, + "key2": 312, + "title": "wikipedia goes to war on porcupines", + "url": "asd", + }, + "missing fields: date_published, source, text", + ), + ( + { + "key1": 12, + "key2": 312, + "url": "www.wikipedia.org", + "text": "asdasd", + "title": "asdasd", + }, + "missing fields: date_published, source", + ), + ( + { + "key1": 12, + "key2": 312, + "url": "www.wikipedia.org", + "title": "bla", + "source": "dwe", + "date_published": "dwe", + }, + "missing fields: text", + ), + ( + { + "key1": 12, + "key2": 312, + "url": "www.wikipedia.org", + "title": "bla", + "text": "asdasd", + "date_published": "dwe", + }, + "missing fields: source", + ), + ( + { + "key1": 12, + "key2": 312, + "url": "www.wikipedia.org", + "title": "bla", + "text": "asdasd", + "source": "dwe", + }, + "missing fields: date_published", + ), ), -)) +) def test_data_entry_missing(item, error): dataset = AlignmentDataset(name="blaa") entry = dataset.make_data_entry(item) Article.before_write(None, None, entry) - assert entry.status == 'Missing fields' + assert entry.status == "Missing fields" assert entry.comments == error @@ -136,7 +163,7 @@ def test_data_entry_verify_id_fails(): "id": "f2b4e02fc1dd8ae43845e4f930f2d84f", } ) - expected = 'Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97' + expected = "Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97" with pytest.raises(AssertionError, match=expected): entry.verify_id() @@ -172,7 +199,9 @@ def test_data_entry_verify_fields_fails(data, error): def test_data_entry_id_fields(): dataset = AlignmentDataset(name="blaa") - entry = dataset.make_data_entry({"url": "https://www.google.ca/once_upon_a_time", 'title': 'bla'}) + entry = dataset.make_data_entry( + {"url": "https://www.google.ca/once_upon_a_time", "title": "bla"} + ) Article.before_write(None, None, entry) assert entry.id @@ -246,16 +275,11 @@ def test_unprocessed_items_some_done(numbers_dataset): def test_fetch_entries(numbers_dataset): - assert [i.meta["value"] for i in numbers_dataset.fetch_entries()] == [ - i**2 for i in range(10) - ] + assert [i.meta["value"] for i in numbers_dataset.fetch_entries()] == [i**2 for i in range(10)] def test_format_datatime(dataset): - assert ( - dataset._format_datetime(datetime(2022, 1, 1, 12, 23, 43)) - == "2022-01-01T12:23:43Z" - ) + assert dataset._format_datetime(datetime(2022, 1, 1, 12, 23, 43)) == "2022-01-01T12:23:43Z" def test_format_datatime_ignore_timezone(dataset): diff --git a/tests/align_data/common/test_html_dataset.py b/tests/align_data/common/test_html_dataset.py index 25e84b8b..3efbddb0 100644 --- a/tests/align_data/common/test_html_dataset.py +++ b/tests/align_data/common/test_html_dataset.py @@ -91,16 +91,12 @@ def test_html_dataset_items_list(html_dataset): def test_html_datasetfetch_contents(html_dataset): with patch("requests.get", return_value=Mock(content=SAMPLE_HTML)): - assert html_dataset.fetch_contents("url") == BeautifulSoup( - SAMPLE_HTML, "html.parser" - ) + assert html_dataset.fetch_contents("url") == BeautifulSoup(SAMPLE_HTML, "html.parser") def test_html_dataset_get_text(html_dataset): soup = BeautifulSoup(f"<article>{SAMPLE_CONTENTS}</article>", "html.parser") - assert ( - html_dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" - ) + assert html_dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" def test_html_dataset_find_date(html_dataset): @@ -125,10 +121,7 @@ def test_html_dataset_find_date(html_dataset): ), ) def test_html_dataset_extract_metadata(html_dataset, text): - assert ( - html_dataset._extract_markdown(text) - == "bla bla bla [a link](http://ble.com) bla bla" - ) + assert html_dataset._extract_markdown(text) == "bla bla bla [a link](http://ble.com) bla bla" def test_html_dataset_process_entry(html_dataset): @@ -176,9 +169,7 @@ def test_html_dataset_process_entry_no_text(html_dataset): ), ) def test_rss_dataset_extract_authors(item, authors): - dataset = RSSDataset( - name="bla", url="http://example.org", authors=["default author"] - ) + dataset = RSSDataset(name="bla", url="http://example.org", authors=["default author"]) assert dataset.extract_authors(item) == authors @@ -202,9 +193,7 @@ def test_rss_dataset_get_title(): ), ) def test_rss_dataset_get_published_date(item, date): - dataset = RSSDataset( - name="bla", url="http://example.org", authors=["default author"] - ) + dataset = RSSDataset(name="bla", url="http://example.org", authors=["default author"]) assert dataset._get_published_date(item) == date @@ -263,6 +252,4 @@ def test_rss_dataset_items_list(): } with patch("feedparser.parse", return_value=contents): - assert dataset.items_list == [ - f"http://example.org/article-{i}" for i in range(5) - ] + assert dataset.items_list == [f"http://example.org/article-{i}" for i in range(5)] diff --git a/tests/align_data/embeddings/test_embedding_utils.py b/tests/align_data/embeddings/test_embedding_utils.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/align_data/embeddings/test_pinecone_db_handler.py b/tests/align_data/embeddings/test_pinecone_db_handler.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/align_data/test_alignment_newsletter.py b/tests/align_data/sources/test_alignment_newsletter.py similarity index 98% rename from tests/align_data/test_alignment_newsletter.py rename to tests/align_data/sources/test_alignment_newsletter.py index ffd5e31b..da8fa043 100644 --- a/tests/align_data/test_alignment_newsletter.py +++ b/tests/align_data/sources/test_alignment_newsletter.py @@ -43,9 +43,7 @@ def test_process_entry_no_summary(dataset): def test_format_datatime(dataset): - assert dataset._get_published_date(2022) == datetime( - 2022, 1, 1, tzinfo=timezone.utc - ) + assert dataset._get_published_date(2022) == datetime(2022, 1, 1, tzinfo=timezone.utc) def test_process_entry(dataset): diff --git a/tests/align_data/test_arbital.py b/tests/align_data/sources/test_arbital.py similarity index 96% rename from tests/align_data/test_arbital.py rename to tests/align_data/sources/test_arbital.py index 304c5398..af65ed05 100644 --- a/tests/align_data/test_arbital.py +++ b/tests/align_data/sources/test_arbital.py @@ -127,9 +127,7 @@ def post(url, *args, **kwargs): page = json.loads(kwargs.get("data", "{}")).get("pageAlias") if "json/explore" in url: - response.json.return_value = { - "pages": {f"{page}-{i}": i for i in range(10)} - } + response.json.return_value = {"pages": {f"{page}-{i}": i for i in range(10)}} elif "json/primaryPage" in url: response.json.return_value = { "pages": { @@ -201,9 +199,7 @@ def test_extract_authors_ignore_missing(dataset): page = {"changeLogs": [{"userId": author} for author in authors]} with patch.object(dataset, "get_title", lambda author: author): - assert sorted(dataset.extract_authors(page)) == sorted( - ["John Snow", "mr. blobby"] - ) + assert sorted(dataset.extract_authors(page)) == sorted(["John Snow", "mr. blobby"]) @pytest.mark.parametrize( diff --git a/tests/align_data/test_arxiv.py b/tests/align_data/sources/test_arxiv.py similarity index 100% rename from tests/align_data/test_arxiv.py rename to tests/align_data/sources/test_arxiv.py diff --git a/tests/align_data/test_blogs.py b/tests/align_data/sources/test_blogs.py similarity index 74% rename from tests/align_data/test_blogs.py rename to tests/align_data/sources/test_blogs.py index 9e648f14..1e5a01c7 100644 --- a/tests/align_data/test_blogs.py +++ b/tests/align_data/sources/test_blogs.py @@ -5,6 +5,7 @@ from dateutil.parser import parse from align_data.sources.blogs import ( + AXRPDataset, CaradoMoe, ColdTakes, GenerativeInk, @@ -14,6 +15,7 @@ WordpressBlog, OpenAIResearch, DeepMindTechnicalBlog, + TransformerCircuits, ) from align_data.sources.blogs.blogs import EleutherAI @@ -161,10 +163,7 @@ def test_caradomoe_text(): </div> """ soup = BeautifulSoup(contents, "html.parser") - assert ( - dataset._get_text({"soup": soup}) - == "bla bla bla [a link](http://ble.com) bla bla" - ) + assert dataset._get_text({"soup": soup}) == "bla bla bla [a link](http://ble.com) bla bla" def test_caradomoe_process_entry(): @@ -229,9 +228,7 @@ def test_caradomoe_process_entry(): def test_gwern_get_text(): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) soup = BeautifulSoup(GWERN_CONTENTS, "html.parser") assert dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla" @@ -251,17 +248,13 @@ def test_gwern_get_text(): ), ) def test_gwern_get_published_date(metadata, date): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) assert dataset._get_published_date(metadata) == date def test_gwern_get_article(): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) with patch("requests.get", return_value="article contents"): assert dataset._get_article("http://bla.com") == "article contents" @@ -302,13 +295,9 @@ def test_gwern_process_markdown(): ... {SAMPLE_HTML} """ - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) - assert dataset._process_markdown( - "http://article.url", Mock(text=text) - ).to_dict() == { + assert dataset._process_markdown("http://article.url", Mock(text=text)).to_dict() == { "authors": ["Gwern Branwen"], "date_published": "2020-05-28T00:00:00Z", "id": None, @@ -329,13 +318,9 @@ def test_gwern_process_entry_markdown(): ... {SAMPLE_HTML} """ - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) - with patch( - "requests.get", return_value=Mock(text=text, status_code=200, headers={}) - ): + with patch("requests.get", return_value=Mock(text=text, status_code=200, headers={})): assert dataset.process_entry("http://article.url").to_dict() == { "authors": ["Gwern Branwen"], "date_published": "2020-05-28T00:00:00Z", @@ -350,9 +335,7 @@ def test_gwern_process_entry_markdown(): def test_gwern_process_entry_html(): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) with patch( "requests.get", @@ -376,9 +359,7 @@ def test_gwern_process_entry_html(): def test_gwern_process_entry_erro(): - dataset = GwernBlog( - name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"] - ) + dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]) with patch("requests.get", return_value=Mock(status_code=404)): assert dataset.process_entry("http://article.url") is None @@ -490,9 +471,7 @@ def test_substack_blog_process_entry(): "title": "Eliezer S. Yudkowsky", "link": "https://www.yudkowsky.net", }, - "headers": { - "link": '<https://www.yudkowsky.net/wp-json/>; rel="https://api.w.org/"' - }, + "headers": {"link": '<https://www.yudkowsky.net/wp-json/>; rel="https://api.w.org/"'}, } @@ -508,9 +487,7 @@ def test_wordpress_blog_setup(): @patch("feedparser.parse", return_value=WORDPRESS_FEED) def test_wordpress_blog_items_list(feedparser_parse): blog = WordpressBlog(name="blog", url="https://www.bla.yudkowsky.net") - assert blog.items_list == [ - "https://www.yudkowsky.net/other/fiction/prospiracy-theory" - ] + assert blog.items_list == ["https://www.yudkowsky.net/other/fiction/prospiracy-theory"] def test_wordpress_blog_get_item_key(): @@ -527,9 +504,7 @@ def test_wordpress_blog_get_published_date(): name="blog", url="https://www.bla.yudkowsky.net", ) - date_published = blog._get_published_date( - {"published": "Mon, 26 Jun 2023 13:40:01 +0000"} - ) + date_published = blog._get_published_date({"published": "Mon, 26 Jun 2023 13:40:01 +0000"}) assert date_published == parse("2023-06-26T13:40:01Z") @@ -540,9 +515,7 @@ def test_wordpress_blog_process_entry(feedparser_parse): url="https://www.bla.yudkowsky.net", ) blog.items = {i["link"]: i for i in WORDPRESS_FEED["entries"]} - entry = blog.process_entry( - "https://www.yudkowsky.net/other/fiction/prospiracy-theory" - ) + entry = blog.process_entry("https://www.yudkowsky.net/other/fiction/prospiracy-theory") assert entry.to_dict() == { "authors": ["Eliezer S. Yudkowsky"], "date_published": "2020-09-04T04:11:23Z", @@ -642,10 +615,8 @@ def test_openai_research_get_text(): dataset = OpenAIResearch(name="openai", url="bla.bla") soup = BeautifulSoup(OPENAI_HTML, "html.parser") - parsers = {"arxiv.org": lambda _: {'text': 'bla bla bla'}} - with patch( - "requests.head", return_value=Mock(headers={"Content-Type": "text/html"}) - ): + parsers = {"arxiv.org": lambda _: {"text": "bla bla bla"}} + with patch("requests.head", return_value=Mock(headers={"Content-Type": "text/html"})): with patch("align_data.sources.articles.parsers.PDF_PARSERS", parsers): assert dataset._get_text(soup) == "bla bla bla" @@ -696,10 +667,8 @@ def test_openai_research_process_entry(): dataset = OpenAIResearch(name="openai", url="bla.bla") soup = BeautifulSoup(OPENAI_HTML, "html.parser") - parsers = {"arxiv.org": lambda _: {'text': 'bla bla bla'}} - with patch( - "requests.head", return_value=Mock(headers={"Content-Type": "text/html"}) - ): + parsers = {"arxiv.org": lambda _: {"text": "bla bla bla"}} + with patch("requests.head", return_value=Mock(headers={"Content-Type": "text/html"})): with patch("requests.get", return_value=Mock(content=OPENAI_HTML)): with patch("align_data.sources.articles.parsers.PDF_PARSERS", parsers): assert dataset.process_entry(soup).to_dict() == { @@ -781,3 +750,154 @@ def test_deepmind_technical_proces_entry(): "title": "title!", "url": "http://bla.bl", } + + +TRANSFORMER_CIRCUITS_HTML = """<html> + <head> + <title>This is the title + + + +

This is also the title

+
+
+
+
+

Authors

+
+ Nelson Elhage, + Robert Lasenby, + Christopher Olah +
+
+
+

Affiliation

+
Anthropic
+
+
+

Published

+
March 16, 2023
+
+
+
+ + + This is where the text goes. With a link to test + + + +""" + +def test_transformer_circuits_item_key(): + dataset = TransformerCircuits(url='http://bla.com', name='ble') + html = """
+ +

Circuits Updates — July 2023

+ +
+ A collection of small updates from the Anthropic Interpretability Team. +
+
""" + assert dataset.get_item_key(BeautifulSoup(html, 'html.parser').find('a')) == 'http://bla.com/2023/july-update/index.html' + + +def test_transformer_circuits_item_list(): + dataset = TransformerCircuits(url='http://bla.com', name='ble') + html = """
+
+ + + + + + +
""" + with patch("requests.get", return_value=Mock(content=html)): + assert [i.get('href') for i in dataset.items_list] == [ + 'item1.html', 'item2.html', 'item3.html', 'http://bla.com/item4.html' + ] + + +def test_transformer_circuits_get_title(): + dataset = TransformerCircuits(url='http://bla.com', name='ble') + soup = BeautifulSoup(TRANSFORMER_CIRCUITS_HTML, "html.parser") + assert dataset._get_title(soup) == "This is the title" + + +def test_transformer_circuits_get_published_date(): + dataset = TransformerCircuits(url='http://bla.com', name='ble') + soup = BeautifulSoup(TRANSFORMER_CIRCUITS_HTML, "html.parser") + assert dataset._get_published_date(soup).isoformat() == "2023-03-16T00:00:00+00:00" + + +def test_transformer_circuits_get_text(): + dataset = TransformerCircuits(url='http://bla.com', name='ble') + soup = BeautifulSoup(TRANSFORMER_CIRCUITS_HTML, "html.parser") + assert dataset._get_text(soup) == "This is where the text goes. With a [link](bla.com) to test" + + +def test_transformer_circuits_process_item(): + dataset = TransformerCircuits(url='http://bla.com', name='ble') + item = BeautifulSoup('', "html.parser").find('a') + with patch("requests.get", return_value=Mock(content=TRANSFORMER_CIRCUITS_HTML)): + assert dataset.process_entry(item).to_dict() == { + 'authors': ['Nelson Elhage', 'Robert Lasenby', 'Christopher Olah'], + 'date_published': '2023-03-16T00:00:00Z', + 'id': None, + 'source': 'ble', + 'source_type': 'blog', + 'summaries': [], + 'text': 'This is where the text goes. With a [link](bla.com) to test', + 'title': 'This is the title', + 'url': 'http://bla.com/ble/bla', + } + + +def test_axrp_dataset_extract_item_url(): + dataset = AXRPDataset(name='bla', url='https://ble.ble.com') + assert dataset._extract_item_url({'link': '/a/path'}) == 'https://ble.ble.com/a/path' + + +@pytest.mark.parametrize('item, expected', ( + ({}, ['default authors']), + ({'authors': []}, ['default authors']), + ({'authors': [{'bla': 'bla'}]}, ['default authors']), + ({'authors': [{'name': ''}]}, ['default authors']), + ({'authors': [{'name': ' \t \n'}]}, ['default authors']), + + ({'title': 'bla bla bla'}, ['default authors']), + ({'title': 'bla bla bla with'}, ['default authors']), + ({'title': 'bla bla bla with \t \n'}, ['default authors']), + + ({'authors': [{'name': 'mr. blobby'}]}, ['mr. blobby']), + ({'authors': [{'name': 'mr. blobby'}, {'name': 'janek'}]}, ['mr. blobby', 'janek']), + + ({'title': 'bla bla bla with your momma'}, ['default authors', 'your momma']), +)) +def test_axrp_dataset_extract_authors(item, expected): + dataset = AXRPDataset(name='bla', url='https://ble.ble.com', authors=['default authors']) + assert dataset.extract_authors(item) == expected + + +def test_axrp_dataset_process_entry(): + dataset = AXRPDataset(name='bla', url='https://ble.ble.com', authors=['default authors']) + url = 'https://ble.ble.com/ble/ble' + dataset.items = { + url: { + 'content': [{'value': 'bla bla'}], + 'link': '/ble/ble', + 'published': '2023-07-27T03:50:00+00:00', + 'title': 'Something or other with your momma', + } + } + assert dataset.process_entry(url).to_dict() == { + 'authors': ['default authors', 'your momma'], + 'date_published': '2023-07-27T03:50:00Z', + 'id': None, + 'source': 'bla', + 'source_type': 'blog', + 'summaries': [], + 'text': 'bla bla', + 'title': 'Something or other with your momma', + 'url': 'https://ble.ble.com/ble/ble', + } diff --git a/tests/align_data/test_distill.py b/tests/align_data/sources/test_distill.py similarity index 98% rename from tests/align_data/test_distill.py rename to tests/align_data/sources/test_distill.py index 6ced02df..ac1b7f34 100644 --- a/tests/align_data/test_distill.py +++ b/tests/align_data/sources/test_distill.py @@ -76,9 +76,7 @@ def test_extra_values(): """ soup = BeautifulSoup(contents, "html.parser") - assert dataset._extra_values( - {"soup": soup, "summary": "A wild summary has appeared!"} - ) == { + assert dataset._extra_values({"soup": soup, "summary": "A wild summary has appeared!"}) == { "bibliography": [ { "link": "https://doi.org/10.23915/distill.00033", diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/sources/test_greater_wrong.py similarity index 97% rename from tests/align_data/test_greater_wrong.py rename to tests/align_data/sources/test_greater_wrong.py index 29140794..b8a9e73d 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/sources/test_greater_wrong.py @@ -84,9 +84,7 @@ def test_greaterwrong_get_item_key(dataset): def test_greaterwrong_get_published_date(dataset): - assert dataset._get_published_date({"postedAt": "2021/02/01"}) == parse( - "2021-02-01T00:00:00Z" - ) + assert dataset._get_published_date({"postedAt": "2021/02/01"}) == parse("2021-02-01T00:00:00Z") def test_greaterwrong_get_published_date_missing(dataset): @@ -152,9 +150,7 @@ def fetcher(next_date): ] return {"results": results} - mock_items = ( - i for i in [Mock(date_published=datetime.fromisoformat("2014-12-12T01:23:45"))] - ) + mock_items = (i for i in [Mock(date_published=datetime.fromisoformat("2014-12-12T01:23:45"))]) with patch.object(dataset, "fetch_posts", fetcher): with patch.object(dataset, "make_query", lambda next_date: next_date): with patch.object(dataset, "read_entries", return_value=mock_items): diff --git a/tests/align_data/test_stampy.py b/tests/align_data/sources/test_stampy.py similarity index 83% rename from tests/align_data/test_stampy.py rename to tests/align_data/sources/test_stampy.py index 5d4500b5..9b40a2c3 100644 --- a/tests/align_data/test_stampy.py +++ b/tests/align_data/sources/test_stampy.py @@ -14,17 +14,14 @@ def test_validate_coda_token(): def test_get_item_key(): dataset = Stampy(name="bla") - assert ( - dataset.get_item_key({"Question": "Why not just?"}) - == "Why\nnot just?" - ) + assert dataset.get_item_key({"Question": "Why not just?"}) == "Why\nnot just?" def test_get_published_date(): dataset = Stampy(name="bla") - assert dataset._get_published_date( - {"Doc Last Edited": "2012/01/03 12:23:32"} - ) == parse("2012-01-03T12:23:32Z") + assert dataset._get_published_date({"Doc Last Edited": "2012/01/03 12:23:32"}) == parse( + "2012-01-03T12:23:32Z" + ) def test_get_published_date_missing(): diff --git a/tests/align_data/test_youtube.py b/tests/align_data/sources/test_youtube.py similarity index 93% rename from tests/align_data/test_youtube.py rename to tests/align_data/sources/test_youtube.py index bcb720e8..70bbed94 100644 --- a/tests/align_data/test_youtube.py +++ b/tests/align_data/sources/test_youtube.py @@ -46,9 +46,7 @@ def test_next_page_empty_by_default(): }, { "kind": "youtube#playlistItem", - "snippet": { - "resourceId": {"kind": "youtube#video", "videoId": "your_video_id"} - }, + "snippet": {"resourceId": {"kind": "youtube#video", "videoId": "your_video_id"}}, }, ), ) @@ -72,9 +70,7 @@ def test_get_id_with_id(item): }, { "kind": "youtube#playlistItem", - "snippet": { - "resourceId": {"kind": "invalid_kind", "videoId": "your_video_id"} - }, + "snippet": {"resourceId": {"kind": "invalid_kind", "videoId": "your_video_id"}}, }, ), ) @@ -187,8 +183,7 @@ def test_items_list(): def fetcher(collection_id): return [ - {"id": {"kind": "youtube#video", "videoId": f"{collection_id}_{i}"}} - for i in range(3) + {"id": {"kind": "youtube#video", "videoId": f"{collection_id}_{i}"}} for i in range(3) ] with patch.object(dataset, "fetch_videos", fetcher): @@ -208,9 +203,7 @@ def test_get_item_key(): "id": {"kind": "youtube#video", "videoId": "your_video_id"}, "kind": "youtube#searchResult", } - assert ( - dataset.get_item_key(video) == "https://www.youtube.com/watch?v=your_video_id" - ) + assert dataset.get_item_key(video) == "https://www.youtube.com/watch?v=your_video_id" @pytest.mark.parametrize( @@ -229,9 +222,7 @@ def test_get_contents_with_no_transcript_found(error): } transcriber = Mock() - transcriber.list_transcripts.return_value.find_transcript.return_value.fetch.side_effect = ( - error - ) + transcriber.list_transcripts.return_value.find_transcript.return_value.fetch.side_effect = error with patch("align_data.sources.youtube.youtube.YouTubeTranscriptApi", transcriber): assert dataset._get_contents(video) is None @@ -336,16 +327,12 @@ def test_channel_process_item(transcriber): def test_playlist_collection_ids(): - dataset = YouTubePlaylistDataset( - name="bla", playlist_ids=["a list id", "another id"] - ) + dataset = YouTubePlaylistDataset(name="bla", playlist_ids=["a list id", "another id"]) assert dataset.collection_ids == ["a list id", "another id"] def test_playlist_published_date(): - dataset = YouTubePlaylistDataset( - name="bla", playlist_ids=["a list id", "another id"] - ) + dataset = YouTubePlaylistDataset(name="bla", playlist_ids=["a list id", "another id"]) video = { "kind": "youtube#playlistItem", "snippet": { @@ -359,9 +346,7 @@ def test_playlist_published_date(): def test_channel_process_item(transcriber): - dataset = YouTubePlaylistDataset( - name="bla", playlist_ids=["a list id", "another id"] - ) + dataset = YouTubePlaylistDataset(name="bla", playlist_ids=["a list id", "another id"]) video = { "kind": "youtube#playlistItem", "snippet": { diff --git a/tests/align_data/test_agisf.py b/tests/align_data/test_agisf.py new file mode 100644 index 00000000..380219da --- /dev/null +++ b/tests/align_data/test_agisf.py @@ -0,0 +1,96 @@ +import pytest +from unittest.mock import patch + +from align_data.sources.agisf.agisf import AGISFPodcastDataset + + +SAMPLE_ITEM = { + 'title': '[Week 0] “Machine Learning for Humans, Part 2.1: Supervised Learning” by Vishal Maini', + 'content': 'this is needed, but will mostly be ignored', + 'summary': '

Bla bla bla



Original article:
https://medium.com/machine-learning-for-humans/supervised-learning-740383a2feab

Author:
Vishal Maini

', + 'link': 'https://ble.ble.com', +} + + +def test_fetch_contents(): + dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com') + url = 'https://test.url' + dataset.items = {url: SAMPLE_ITEM} + assert dataset.fetch_contents(url) == dict( + SAMPLE_ITEM, authors='Vishal Maini', + title='Machine Learning for Humans, Part 2.1: Supervised Learning' + ) + + +def test_fetch_contents_bad_title(): + dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com') + url = 'https://test.url' + dataset.items = {url: dict(SAMPLE_ITEM, title='asdasdasd')} + assert dataset.fetch_contents(url) == dict(SAMPLE_ITEM, title='asdasdasd') + + +def test_get_text(): + dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com') + item = dict(SAMPLE_ITEM) + + with patch("align_data.sources.agisf.agisf.item_metadata", return_value={ + 'text': 'bla bla bla', + 'source_type': 'some kind of thing', + 'title': None, + 'authors': [], + 'content': 'this should now change', + }): + assert dataset._get_text(item) == 'bla bla bla' + assert item == dict( + SAMPLE_ITEM, + content='this should now change', + text='bla bla bla', + source_type='some kind of thing', + ) + + +@pytest.mark.parametrize('authors, expected', ( + (None, ['default']), + ('', ['default']), + ([], ['default']), + + ('bla', ['bla']), + ('johnny bravo, mr. blobby\t\t\t, Hans Klos ', ['johnny bravo', 'mr. blobby', 'Hans Klos']), + (['mr. bean'], ['mr. bean']), + (['johnny bravo', 'mr. blobby', 'Hans Klos'], ['johnny bravo', 'mr. blobby', 'Hans Klos']), +)) +def test_extract_authors(authors, expected): + dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com', authors=['default']) + item = dict(SAMPLE_ITEM, authors=authors) + assert dataset.extract_authors(item) == expected + + +def test_extra_values(): + dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com', authors=['default']) + assert dataset._extra_values(SAMPLE_ITEM) == { + 'summary': 'Bla bla bla', + } + + +def test_extra_values_no_summary(): + dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com', authors=['default']) + assert dataset._extra_values({}) == {} + + +def test_process_entry(): + dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com') + url = 'https://test.url' + dataset.items = {url: SAMPLE_ITEM} + + with patch("align_data.sources.agisf.agisf.item_metadata", return_value={'text': 'bla'}): + assert dataset.process_entry(url).to_dict() == { + 'authors': ['Vishal Maini'], + 'date_published': None, + 'id': None, + 'source': 'bla', + 'source_type': 'blog', + 'summaries': ['Bla bla bla'], + 'text': 'bla', + 'title': 'Machine Learning for Humans, Part 2.1: Supervised Learning', + 'url': 'https://test.url', + } diff --git a/tests/align_data/test_airtable.py b/tests/align_data/test_airtable.py new file mode 100644 index 00000000..4de1fed0 --- /dev/null +++ b/tests/align_data/test_airtable.py @@ -0,0 +1,144 @@ +import pytest +from unittest.mock import patch + +from align_data.sources.airtable import AirtableDataset + + +@pytest.mark.parametrize('item, overwrites', ( + ({'url': 'http://bla.vle'}, {}), + ({'url': 'http://bla.vle', 'source': 'your momma'}, {'source': 'your momma'}), + ({'url': 'http://bla.vle', 'source': 'your momma', 'bla': 'ble'}, {'source': 'your momma'}), + ( + {'url': 'http://bla.vle', 'status': 'fine', 'title': 'Something or other'}, + {'status': 'fine', 'title': 'Something or other'} + ), + ( + {'url': 'http://some.other.url', 'source_type': 'blog', 'authors': 'bla, bla, bla'}, + {'url': 'http://some.other.url', 'source_type': 'blog', 'authors': 'bla, bla, bla'} + ), +)) +def test_map_cols_no_mapping(item, overwrites): + dataset = AirtableDataset(name='asd', base_id='ddwe', table_id='csdcsc', mappings={}, processors={}) + assert dataset.map_cols({'id': '123', 'fields': item}) == dict({ + 'authors': None, + 'comments': None, + 'date_published': None, + 'id': None, + 'source': None, + 'source_type': None, + 'status': None, + 'text': None, + 'title': None, + 'summary': None, + 'url': 'http://bla.vle' + }, **overwrites) + + +@pytest.mark.parametrize('item, overwrites', ( + ({'an url!': 'http://bla.vle'}, {}), + ({'an url!': 'http://bla.vle', 'source': 'your momma'}, {'source': 'your momma'}), + ({'an url!': 'http://bla.vle', 'source': 'your momma', 'bla': 'ble'}, {'source': 'your momma'}), + ( + {'an url!': 'http://bla.vle', 'status': 'fine', 'title': 'Something or other'}, + {'status': 'fine', 'title': 'Something or other'} + ), + ( + {'an url!': 'http://some.other.url', 'source_type': 'blog', 'whodunnit': 'bla, bla, bla'}, + {'url': 'http://some.other.url', 'source_type': 'blog', 'authors': 'bla, bla, bla'} + ), +)) +def test_map_cols_with_mapping(item, overwrites): + dataset = AirtableDataset( + name='asd', base_id='ddwe', table_id='csdcsc', + mappings={ + 'url': 'an url!', + 'authors': 'whodunnit', + }, + processors={} + ) + assert dataset.map_cols({'id': '123', 'fields': item}) == dict({ + 'authors': None, + 'comments': None, + 'date_published': None, + 'id': None, + 'source': None, + 'source_type': None, + 'status': None, + 'text': None, + 'title': None, + 'summary': None, + 'url': 'http://bla.vle' + }, **overwrites) + + +@pytest.mark.parametrize('item, overwrites', ( + ({'an url!': 'http://bla.vle'}, {}), + ({'an url!': 'http://bla.vle', 'source': 'your momma'}, {'source': 'your momma'}), + ({'an url!': 'http://bla.vle', 'source': 'your momma', 'bla': 'ble'}, {'source': 'your momma'}), + ( + {'an url!': 'http://bla.vle', 'status': 'fine', 'title': 'Something or other'}, + {'status': 'fine', 'title': 'Something or other bla!'} + ), + ( + {'an url!': 'http://some.other.url', 'source_type': 'blog', 'whodunnit': 'bla, bla, bla'}, + {'url': 'http://some.other.url', 'source_type': 'blog', 'authors': 'bla, bla, bla'} + ), +)) +def test_map_cols_with_processing(item, overwrites): + dataset = AirtableDataset( + name='asd', base_id='ddwe', table_id='csdcsc', + mappings={ + 'url': 'an url!', + 'authors': 'whodunnit', + }, + processors={ + 'title': lambda val: val and val + ' bla!', + 'id': lambda _: 123, + } + ) + assert dataset.map_cols({'id': '123', 'fields': item}) == dict({ + 'authors': None, + 'comments': None, + 'date_published': None, + 'id': 123, + 'source': None, + 'source_type': None, + 'status': None, + 'text': None, + 'title': None, + 'summary': None, + 'url': 'http://bla.vle' + }, **overwrites) + + +@pytest.mark.parametrize('url', (None, '', 'asdasdsad')) +def test_map_cols_no_url(url): + dataset = AirtableDataset(name='asd', base_id='ddwe', table_id='csdcsc', mappings={}, processors={}) + assert dataset.map_cols({'id': '123', 'fields': {'url': url}}) is None + + +def test_process_entry(): + dataset = AirtableDataset(name='asd', base_id='ddwe', table_id='csdcsc', mappings={}, processors={}) + entry = { + 'url': 'http://bla.cle', + 'authors': ['johnny', 'your momma', 'mr. Blobby', 'Łóżćś Jaś'], + 'date_published': '2023-01-02', + 'source': 'some place', + 'status': 'fine', + 'comments': 'should be ok', + } + with patch("align_data.sources.airtable.item_metadata", return_value={ + 'text': 'bla bla bla', + 'source_type': 'some kind of thing', + }): + assert dataset.process_entry(entry).to_dict() == { + 'authors': ['johnny', 'your momma', 'mr. Blobby', 'Łóżćś Jaś'], + 'date_published': '2023-01-02T00:00:00Z', + 'id': None, + 'source': 'asd', + 'source_type': 'some kind of thing', + 'summaries': [], + 'text': 'bla bla bla', + 'title': None, + 'url': 'http://bla.cle', + } diff --git a/tests/align_data/test_utils.py b/tests/align_data/test_utils.py new file mode 100644 index 00000000..8e32c612 --- /dev/null +++ b/tests/align_data/test_utils.py @@ -0,0 +1,37 @@ +import pytest +from align_data.sources.utils import merge_dicts + + +def test_merge_dicts_no_args(): + """Test merge_dicts function with no arguments.""" + result = merge_dicts() + assert result == {} + + +def test_merge_dicts_single_dict(): + """Test merge_dicts function with a single dictionary.""" + result = merge_dicts({'a': 1, 'b': 2}) + assert result == {'a': 1, 'b': 2} + + +def test_merge_dicts_dicts_with_no_overlap(): + """Test merge_dicts function with multiple dictionaries with no overlapping keys.""" + result = merge_dicts({'a': 1}, {'b': 2}, {'c': 3}) + assert result == {'a': 1, 'b': 2, 'c': 3} + + +def test_merge_dicts_dicts_with_overlap(): + """Test merge_dicts function with multiple dictionaries with overlapping keys.""" + result = merge_dicts({'a': 1, 'b': 2}, {'b': 3, 'c': 4}, {'c': 5, 'd': 6}) + assert result == {'a': 1, 'b': 3, 'c': 5, 'd': 6} + + +@pytest.mark.parametrize("input_dicts, expected", [ + ([{'a': 1, 'b': None}, {'b': 3}], {'a': 1, 'b': 3}), + ([{'a': 0, 'b': 2}, {'b': None}], {'a': 0, 'b': 2}), + ([{'a': None}, {'b': 'test'}], {'b': 'test'}), +]) +def test_merge_dicts_with_none_values(input_dicts, expected): + """Test merge_dicts function with dictionaries containing None or falsey values.""" + result = merge_dicts(*input_dicts) + assert result == expected diff --git a/tests/print_date_published.py b/tests/print_date_published.py index d3ba46fd..32bda616 100644 --- a/tests/print_date_published.py +++ b/tests/print_date_published.py @@ -20,9 +20,7 @@ def validate_date_format(file_path, keys_to_print): # Try to parse the date_published string into a datetime object parse(date_published) except ValueError: - print( - f"Row {i}: date_published is NOT in a valid format: {date_published}" - ) + print(f"Row {i}: date_published is NOT in a valid format: {date_published}") for key in keys_to_print: print(f" {key}: {entry.get(key)}") diff --git a/upload_to_huggingface.py b/upload_to_huggingface.py index 6368d14e..43956b29 100644 --- a/upload_to_huggingface.py +++ b/upload_to_huggingface.py @@ -10,41 +10,39 @@ from huggingface_hub import HfApi -GDOCS_FOLDER = ( - "https://drive.google.com/drive/folders/1n4i0J4CuSfNmrUkKPyTFKJU0XWYLtRF8" -) +GDOCS_FOLDER = "https://drive.google.com/drive/folders/1n4i0J4CuSfNmrUkKPyTFKJU0XWYLtRF8" DATASOURCES = [ - 'agentmodels', - 'aiimpacts', - 'aisafety.camp', - 'aisafety.info', - 'ai_alignment_playlist', - 'ai_explained', - 'ai_safety_talks', - 'ai_safety_reading_group', - 'ai_tech_tu_delft', - 'alignmentforum', - 'arbital', - 'arxiv', - 'carado.moe', - 'cold_takes', - 'deepmind_blog', - 'deepmind_technical_blog', - 'distill', - 'eaforum', - 'eleuther.ai', - 'generative.ink', - 'gwern_blog', - 'importai', - 'jsteinhardt_blog', - 'lesswrong', - 'miri', - 'ml_safety_newsletter', - 'openai.research', - 'rob_miles_ai_safety', - 'special_docs', - 'vkrakovna_blog', - 'yudkowsky_blog' + "agentmodels", + "aiimpacts", + "aisafety.camp", + "aisafety.info", + "ai_alignment_playlist", + "ai_explained", + "ai_safety_talks", + "ai_safety_reading_group", + "ai_tech_tu_delft", + "alignmentforum", + "arbital", + "arxiv", + "carado.moe", + "cold_takes", + "deepmind_blog", + "deepmind_technical_blog", + "distill", + "eaforum", + "eleuther.ai", + "generative.ink", + "gwern_blog", + "importai", + "jsteinhardt_blog", + "lesswrong", + "miri", + "ml_safety_newsletter", + "openai.research", + "rob_miles_ai_safety", + "special_docs", + "vkrakovna_blog", + "yudkowsky_blog", ] @@ -70,11 +68,7 @@ def get_gdoc_names(url): return None _, id_name_type_iter = _parse_google_drive_file(url=url, content=res.text) - return [ - (id, name) - for id, name, filetype in id_name_type_iter - if name.endswith(".jsonl") - ] + return [(id, name) for id, name, filetype in id_name_type_iter if name.endswith(".jsonl")] def upload_data_file(api, name, repo_name): @@ -84,7 +78,7 @@ def upload_data_file(api, name, repo_name): # Don't download it if it exists locally if not filename.exists(): - print(f'{filename} not found!') + print(f"{filename} not found!") return try: @@ -99,9 +93,7 @@ def upload_data_file(api, name, repo_name): def download_file(repo_name, filename, api): headers = {"Authorization": f"Bearer {api.token}"} - url = ( - f"https://huggingface.co/datasets/StampyAI/{repo_name}/raw/main/{filename.name}" - ) + url = f"https://huggingface.co/datasets/StampyAI/{repo_name}/raw/main/{filename.name}" response = requests.get(url, headers=headers) if response.status_code == 200: