From 5e630c8790646baf9948014541b6a898b0f68fe0 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Fri, 25 Aug 2023 10:33:00 +0200 Subject: [PATCH 01/25] handle link types in axrp (#161) --- align_data/sources/blogs/blogs.py | 5 +++-- tests/align_data/sources/test_blogs.py | 9 +++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index a645e5a2..6d7b144b 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -171,9 +171,10 @@ def feed_url(self): return f"{self.url}/feed.xml" def _extract_item_url(self, item) -> str | None: - if path := item.get('link'): + path = item.get('link') + if path and not path.startswith('http'): return self.url + path - return None + return path def extract_authors(self, item): if "authors" in item: diff --git a/tests/align_data/sources/test_blogs.py b/tests/align_data/sources/test_blogs.py index 1e5a01c7..154daa6b 100644 --- a/tests/align_data/sources/test_blogs.py +++ b/tests/align_data/sources/test_blogs.py @@ -853,9 +853,14 @@ def test_transformer_circuits_process_item(): } -def test_axrp_dataset_extract_item_url(): +@pytest.mark.parametrize('url, expected', ( + ('/a/path', 'https://ble.ble.com/a/path'), + ('http://ble.ble.com/bla', 'http://ble.ble.com/bla'), + ('https://ble.ble.com/bla', 'https://ble.ble.com/bla'), +)) +def test_axrp_dataset_extract_item_url(url, expected): dataset = AXRPDataset(name='bla', url='https://ble.ble.com') - assert dataset._extract_item_url({'link': '/a/path'}) == 'https://ble.ble.com/a/path' + assert dataset._extract_item_url({'link': url}) == expected @pytest.mark.parametrize('item, expected', ( From c995e4cb94c3e988bf67ab1f7e0f4a23ae539612 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Fri, 25 Aug 2023 20:49:47 +0200 Subject: [PATCH 02/25] Tidy up (#167) * change mysql connector library * unify docker DB name * mark indice from which articles came from * Skip greater wrong posts that have equivalent title,authors pairs in the db --- README.md | 2 - align_data/settings.py | 2 +- align_data/sources/articles/indices.py | 11 +++- .../sources/greaterwrong/greaterwrong.py | 42 +++++++++++--- local_db.sh | 4 +- requirements.txt | 2 +- .../align_data/sources/test_greater_wrong.py | 56 +++++++++++++++++-- 7 files changed, 101 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index b01779df..9a557547 100644 --- a/README.md +++ b/README.md @@ -64,8 +64,6 @@ Additional keys may be available depending on the source document. ## Development Environment -Follow the [instructions to install **mysqlclient** on your operating system](https://pypi.org/project/mysqlclient/) toward the middle to bottom of the linked page. - To set up the development environment, run the following steps: ```bash diff --git a/align_data/settings.py b/align_data/settings.py index 59304c08..73459b8f 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -39,7 +39,7 @@ host = os.environ.get("ARD_DB_HOST", "127.0.0.1") port = os.environ.get("ARD_DB_PORT", "3306") db_name = os.environ.get("ARD_DB_NAME", "alignment_research_dataset") -DB_CONNECTION_URI = f"mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}" +DB_CONNECTION_URI = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{db_name}" ARTICLE_MAIN_KEYS = ["id", "source", 'source_type', "title", "authors", "text", "url", "date_published", "status", "comments"] ### EMBEDDINGS ### diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py index 5e7d5a05..220a9005 100644 --- a/align_data/sources/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -60,7 +60,7 @@ def aisafetysupport(): def format_mlsafety_course(a): if (a.get("href") or "").startswith("http"): - return {"title": a.text, "url": a.get("href")} + return {"title": a.text, "url": a.get("href"), "initial_source": "mlsafety_course"} def format_anthropic(post): @@ -75,6 +75,7 @@ def format_anthropic(post): "title": get_text(post, "div.post-heading"), "url": url, "source_url": source_url, + "initial_source": "anthropic", "date_published": date_published, } @@ -84,6 +85,7 @@ def format_safe_ai(item): "title": get_text(item, "h4"), "url": item.find("a").get("href"), "source_url": item.find("a").get("href"), + "initial_source": "safe.ai", "authors": get_text(item, "h4 ~ p"), } @@ -94,6 +96,7 @@ def format_far_ai(item): "url": f'https://www.safe.ai/research{item.select_one(".article-title a").get("href")}', "source_url": item.select_one('div.btn-links a:-soup-contains("PDF")').get("href"), "authors": ", ".join(i.text for i in item.select(".article-metadata a")), + "initial_source": "far.ai", } @@ -114,6 +117,7 @@ def format_redwoodresearch(item): "source_url": url, "authors": authors, "date_published": date_published, + "initial_source": "redwood_research", } @@ -134,6 +138,7 @@ def format_chai_research(item): "source_url": url, "authors": ", ".join(authors), "date_published": date_published, + "initial_source": "chai_research", } @@ -151,6 +156,7 @@ def format_chai_newsletter(item): "title": item.text, "url": item.get("href"), "source_url": item.get("href"), + "initial_source": "chai_newsletter", } @@ -168,6 +174,7 @@ def format_neel_nanda_fav(item): "title": title.replace("\n", " "), "url": url, "summary": MarkdownConverter().convert_soup(item).strip(), + "initial_source": "neelnanda", } @@ -281,12 +288,14 @@ def process_entry(self, item): "authors": self.extract_authors(item), "status": "Ignored", "comments": "Added from indices", + "initial_source": item.get("initial_source"), } ) return self.make_data_entry( { "source": "arxiv", + "initial_source": item.get("initial_source"), "url": contents.get("url") or self.get_item_key(item), "title": item.get("title"), "date_published": self._get_published_date(item.get("date_published")), diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py index 5f32b68f..8925fc22 100644 --- a/align_data/sources/greaterwrong/greaterwrong.py +++ b/align_data/sources/greaterwrong/greaterwrong.py @@ -2,13 +2,16 @@ import logging import time from dataclasses import dataclass +from typing import Set, Tuple import requests import jsonlines from bs4 import BeautifulSoup from markdownify import markdownify +from sqlalchemy import select from align_data.common.alignment_dataset import AlignmentDataset +from align_data.db.session import make_session from align_data.db.models import Article logger = logging.getLogger(__name__) @@ -69,6 +72,8 @@ class GreaterWrong(AlignmentDataset): summary_key: str = "summary" done_key = "url" lazy_eval = True + source_type = 'GreaterWrong' + _outputted_items = (set(), set()) def setup(self): super().setup() @@ -79,8 +84,29 @@ def setup(self): def tags_ok(self, post): return not self.ai_tags or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags - def get_item_key(self, item): - return item["pageUrl"] + def _load_outputted_items(self) -> Tuple[Set[str], Set[Tuple[str, str]]]: + """Load the output file (if it exists) in order to know which items have already been output.""" + with make_session() as session: + articles = ( + session + .query(Article.url, Article.title, Article.authors) + .where(Article.source_type == self.source_type) + .all() + ) + return ( + {a.url for a in articles}, + {(a.title.replace('\n', '').strip(), a.authors) for a in articles}, + ) + + def not_processed(self, item): + title = item["title"] + url = item["pageUrl"] + authors = ','.join(self.extract_authors(item)) + + return ( + url not in self._outputted_items[0] + and (title, authors) not in self._outputted_items[1] + ) def _get_published_date(self, item): return super()._get_published_date(item.get("postedAt")) @@ -167,11 +193,14 @@ def items_list(self): next_date = posts["results"][-1]["postedAt"] time.sleep(self.COOLDOWN) - def process_entry(self, item): + def extract_authors(self, item): authors = item["coauthors"] if item["user"]: authors = [item["user"]] + authors - authors = [a["displayName"] for a in authors] or ["anonymous"] + # Some posts don't have authors, for some reaason + return [a["displayName"] for a in authors] or ["anonymous"] + + def process_entry(self, item): return self.make_data_entry( { "title": item["title"], @@ -180,13 +209,12 @@ def process_entry(self, item): "date_published": self._get_published_date(item), "modified_at": item["modifiedAt"], "source": self.name, - "source_type": "GreaterWrong", + "source_type": self.source_type, "votes": item["voteCount"], "karma": item["baseScore"], "tags": [t["name"] for t in item["tags"]], "words": item["wordCount"], "comment_count": item["commentCount"], - # Some posts don't have authors, for some reaason - "authors": authors, + "authors": self.extract_authors(item), } ) diff --git a/local_db.sh b/local_db.sh index 853dc7fc..97738a95 100755 --- a/local_db.sh +++ b/local_db.sh @@ -1,10 +1,10 @@ #!/usr/bin/env bash ROOT_PASSWORD=my-secret-pw -docker start alignment-research-dataset +docker start stampy-db if [ $? -ne 0 ]; then echo 'No docker container found - creating a new one' - docker run --name alignment-research-dataset -p 3306:3306 -e MYSQL_ROOT_PASSWORD=$ROOT_PASSWORD -d mysql:latest + docker run --name stampy-db -p 3306:3306 -e MYSQL_ROOT_PASSWORD=$ROOT_PASSWORD -d mysql:latest fi echo "Waiting till mysql is available..." diff --git a/requirements.txt b/requirements.txt index 008d629d..0068c61e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,7 +31,7 @@ youtube-transcript-api airtable alembic -mysqlclient +mysql-connector-python openai langchain diff --git a/tests/align_data/sources/test_greater_wrong.py b/tests/align_data/sources/test_greater_wrong.py index b8a9e73d..a1ba0eaa 100644 --- a/tests/align_data/sources/test_greater_wrong.py +++ b/tests/align_data/sources/test_greater_wrong.py @@ -79,10 +79,6 @@ def test_greaterwrong_tags_ok_missing(dataset, tags): assert not dataset.tags_ok({"tags": tags}) -def test_greaterwrong_get_item_key(dataset): - assert dataset.get_item_key({"pageUrl": "item key"}) == "item key" - - def test_greaterwrong_get_published_date(dataset): assert dataset._get_published_date({"postedAt": "2021/02/01"}) == parse("2021-02-01T00:00:00Z") @@ -230,3 +226,55 @@ def test_process_entry_no_authors(dataset): "votes": 12, "words": 123, } + + +@pytest.mark.parametrize('item', ( + { + # non seen url + 'pageUrl': 'http://bla.bla', + 'title': 'new item', 'coauthors': [{'displayName': 'your momma'}] + }, + { + # already seen title, but different authors + 'title': 'this has already been seen', + 'pageUrl': 'http://bla.bla', 'coauthors': [{'displayName': 'your momma'}] + }, + { + # new title, but same authors + 'coauthors': [{'displayName': 'johnny'}], + 'title': 'new item', 'pageUrl': 'http://bla.bla' + }, +)) +def test_not_processed_true(item, dataset): + dataset._outputted_items = ( + {'http://already.seen'}, + {('this has been seen', 'johnny')} + ) + item['user'] = None + assert dataset.not_processed(item) + + +@pytest.mark.parametrize('item', ( + { + # url seen + 'pageUrl': 'http://already.seen', + 'title': 'new item', 'coauthors': [{'displayName': 'your momma'}] + }, + { + # already seen title and authors pair, but different url + 'title': 'this has already been seen', 'coauthors': [{'displayName': 'johnny'}], + 'pageUrl': 'http://bla.bla', + }, + { + # already seen everything + 'pageUrl': 'http://already.seen', + 'title': 'this has already been seen', 'coauthors': [{'displayName': 'johnny'}], + } +)) +def test_not_processed_false(item, dataset): + dataset._outputted_items = ( + {'http://already.seen'}, + {('this has already been seen', 'johnny')} + ) + item['user'] = None + assert not dataset.not_processed(item) From 1fafcc7fe8438b68efbb1ca5853aa1825c06d507 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Wed, 23 Aug 2023 09:53:52 -0400 Subject: [PATCH 03/25] fixed moderation, removed engine --- align_data/embeddings/embedding_utils.py | 30 +++++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py index 6ef57b88..cf98228c 100644 --- a/align_data/embeddings/embedding_utils.py +++ b/align_data/embeddings/embedding_utils.py @@ -88,13 +88,33 @@ def wrapper(*args, **kwargs): @handle_openai_errors -def moderation_check(texts: List[str]) -> List[ModerationInfoType]: - return openai.Moderation.create(input=texts)["results"] +def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]: + """ + Check moderation on a list of texts. + + Parameters: + - texts (List[str]): List of texts to be checked for moderation. + - max_texts_num (int): Number of texts to check at once. Defaults to 32. + + Returns: + - List[ModerationInfoType]: List of moderation results for the provided texts. + """ + total_texts = len(texts) + results = [] + + for i in range(0, total_texts, max_texts_num): + batch_texts = texts[i : i + max_texts_num] + batch_results = openai.Moderation.create(input=batch_texts)["results"] + results.extend(batch_results) + + return results @handle_openai_errors def _compute_openai_embeddings(non_flagged_texts: List[str], **kwargs) -> List[List[float]]: - data = openai.Embedding.create(input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data + data = openai.Embedding.create( + input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs + ).data return [d["embedding"] for d in data] @@ -184,9 +204,7 @@ def get_embeddings( flags = [result["flagged"] for result in moderation_results] non_flagged_texts = [text for text, flag in zip(texts, flags) if not flag] - non_flagged_embeddings = get_embeddings_without_moderation( - non_flagged_texts, source, **kwargs - ) + non_flagged_embeddings = get_embeddings_without_moderation(non_flagged_texts, source, **kwargs) embeddings = [None if flag else non_flagged_embeddings.pop(0) for flag in flags] return embeddings, moderation_results From 9a4a859299751a85112b1c55674532c6c18cf080 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Wed, 23 Aug 2023 09:54:26 -0400 Subject: [PATCH 04/25] added update_articles_by_ids in update_pinecone --- align_data/db/session.py | 14 +++++++++++--- align_data/embeddings/pinecone/update_pinecone.py | 11 +++++++++++ main.py | 11 +++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/align_data/db/session.py b/align_data/db/session.py index 4aa23a87..3999546b 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -23,18 +23,26 @@ def make_session(auto_commit=False): def stream_pinecone_updates( - session: Session, custom_sources: List[str], force_update: bool = False + session: Session, + custom_sources: List[str], + force_update: bool = False, + article_ids: List[int] | None = None, ): """Yield Pinecone entries that require an update.""" - yield from ( + query = ( session.query(Article) .filter(or_(Article.pinecone_update_required.is_(True), force_update)) .filter(Article.is_valid) .filter(Article.source.in_(custom_sources)) .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) - .yield_per(1000) ) + # If article_ids are provided, filter based on those IDs + if article_ids: + query = query.filter(Article.id.in_(article_ids)) + + yield from query.yield_per(1000) + def get_all_valid_article_ids(session: Session) -> List[str]: """Return all valid article IDs.""" diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index b425ee9d..dc3f23e6 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -42,6 +42,17 @@ def update(self, custom_sources: List[str], force_update: bool = False): for batch in self.batch_entries(articles_to_update_stream): self.save_batch(session, batch) + def update_articles_by_ids( + self, custom_sources: List[str], article_ids: List[int], force_update: bool = False + ): + """Update the Pinecone entries of specific articles based on their IDs.""" + with make_session() as session: + articles_to_update_stream = stream_pinecone_updates( + session, custom_sources, force_update, article_ids + ) + for batch in self.batch_entries(articles_to_update_stream): + self.save_batch(session, batch) + def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]): try: for article, pinecone_entry in batch: diff --git a/main.py b/main.py index 2c1d8da2..feeece54 100644 --- a/main.py +++ b/main.py @@ -137,6 +137,17 @@ def pinecone_update_all(self, *skip, force_update=False) -> None: names = [name for name in ALL_DATASETS if name not in skip] PineconeUpdater().update(names, force_update) + def pinecone_update_individual_articles(self, hash_ids: str, force_update=False) -> None: + """ + Update the Pinecone entries of specific articles based on their IDs given as a comma-separated string. + + :param str ids: Comma-separated list of article IDs. + """ + names = ALL_DATASETS + + article_ids = [hash_id for hash_id in hash_ids.split(",")] + PineconeUpdater().update_articles_by_ids(names, article_ids, force_update) + def train_finetuning_layer(self) -> None: """ This function trains a finetuning layer on top of the OpenAI embeddings. From e0d2197aa3e8b9f9605879dee5c6313ae96d6fa9 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Wed, 23 Aug 2023 12:44:58 -0400 Subject: [PATCH 05/25] merging mostly? --- align_data/common/alignment_dataset.py | 5 +- align_data/db/session.py | 29 ++- .../embeddings/pinecone/pinecone_models.py | 2 +- .../embeddings/pinecone/update_pinecone.py | 60 ++++-- align_data/settings.py | 15 +- align_data/sources/agisf/__init__.py | 30 ++- align_data/sources/agisf/agisf.py | 17 +- align_data/sources/airtable.py | 16 +- align_data/sources/articles/__init__.py | 2 +- align_data/sources/arxiv_papers.py | 29 +-- align_data/sources/blogs/__init__.py | 4 +- align_data/sources/blogs/blogs.py | 13 +- align_data/sources/youtube/__init__.py | 2 +- main.py | 7 +- tests/align_data/sources/test_agisf.py | 105 ++++++++++ tests/align_data/sources/test_airtable.py | 189 ++++++++++++++++++ tests/align_data/sources/test_arxiv.py | 38 ++-- tests/align_data/sources/test_blogs.py | 45 +++-- tests/align_data/{ => sources}/test_utils.py | 29 +-- tests/align_data/test_agisf.py | 96 --------- tests/align_data/test_airtable.py | 144 ------------- 21 files changed, 493 insertions(+), 384 deletions(-) create mode 100644 tests/align_data/sources/test_agisf.py create mode 100644 tests/align_data/sources/test_airtable.py rename tests/align_data/{ => sources}/test_utils.py (58%) delete mode 100644 tests/align_data/test_agisf.py delete mode 100644 tests/align_data/test_airtable.py diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 377e3880..0d607b13 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -253,7 +253,6 @@ def merge(item): @dataclass class MultiDataset(AlignmentDataset): - datasets: List[AlignmentDataset] @property @@ -276,13 +275,13 @@ def get_item_key(self, entry): def process_entry(self, entry) -> Optional[Article]: item, dataset = entry article = dataset.process_entry(item) - article.add_meta('initial_source', article.source) + article.add_meta("initial_source", article.source) article.source = self.name def fetch_entries(self): for dataset in self.datasets: for article in dataset.fetch_entries(): if article.source != self.name: - article.add_meta('initial_source', article.source) + article.add_meta("initial_source", article.source) article.source = self.name yield article diff --git a/align_data/db/session.py b/align_data/db/session.py index 3999546b..cab798b5 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -22,26 +22,41 @@ def make_session(auto_commit=False): session.commit() -def stream_pinecone_updates( +def get_pinecone_articles_to_update( session: Session, custom_sources: List[str], force_update: bool = False, - article_ids: List[int] | None = None, ): """Yield Pinecone entries that require an update.""" - query = ( + yield from ( session.query(Article) .filter(or_(Article.pinecone_update_required.is_(True), force_update)) .filter(Article.is_valid) .filter(Article.source.in_(custom_sources)) .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) + # .yield_per(10) ) - # If article_ids are provided, filter based on those IDs - if article_ids: - query = query.filter(Article.id.in_(article_ids)) - yield from query.yield_per(1000) +def get_pinecone_articles_by_ids( + session: Session, + custom_sources: List[str], + force_update: bool = False, + hash_ids: List[int] | None = None, +): + """Yield Pinecone entries that require an update and match the given IDs.""" + if hash_ids is None: + hash_ids = [] + + yield from ( + session.query(Article) + .filter(or_(Article.pinecone_update_required.is_(True), force_update)) + .filter(Article.is_valid) + .filter(Article.source.in_(custom_sources)) + .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) + .filter(Article.id.in_(hash_ids)) + # .yield_per(10) + ) def get_all_valid_article_ids(session: Session) -> List[str]: diff --git a/align_data/embeddings/pinecone/pinecone_models.py b/align_data/embeddings/pinecone/pinecone_models.py index fd7b67eb..3c6af194 100644 --- a/align_data/embeddings/pinecone/pinecone_models.py +++ b/align_data/embeddings/pinecone/pinecone_models.py @@ -30,7 +30,7 @@ class PineconeEntry(BaseModel): date_published: float authors: List[str] text_chunks: List[str] - embeddings: List[List[float]] + embeddings: List[List[float] | None] def __init__(self, **data): """Check for missing (falsy) fields before initializing.""" diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index dc3f23e6..f0aa1b57 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -8,10 +8,16 @@ from align_data.embeddings.embedding_utils import get_embeddings from align_data.db.models import Article -from align_data.db.session import make_session, stream_pinecone_updates +from align_data.db.session import ( + make_session, + get_pinecone_articles_to_update, + get_pinecone_articles_by_ids, +) from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB from align_data.embeddings.pinecone.pinecone_models import ( - PineconeEntry, MissingFieldsError, MissingEmbeddingModelError + PineconeEntry, + MissingFieldsError, + MissingEmbeddingModelError, ) from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter @@ -36,27 +42,28 @@ def update(self, custom_sources: List[str], force_update: bool = False): :param custom_sources: List of sources to update. """ with make_session() as session: - articles_to_update_stream = stream_pinecone_updates( + articles_to_update_stream = get_pinecone_articles_to_update( session, custom_sources, force_update ) for batch in self.batch_entries(articles_to_update_stream): self.save_batch(session, batch) def update_articles_by_ids( - self, custom_sources: List[str], article_ids: List[int], force_update: bool = False + self, custom_sources: List[str], hash_ids: List[int], force_update: bool = False ): - """Update the Pinecone entries of specific articles based on their IDs.""" + """Update the Pinecone entries of specific articles based on their hash_ids.""" with make_session() as session: - articles_to_update_stream = stream_pinecone_updates( - session, custom_sources, force_update, article_ids + articles_to_update_stream = get_pinecone_articles_by_ids( + session, custom_sources, force_update, hash_ids ) for batch in self.batch_entries(articles_to_update_stream): self.save_batch(session, batch) - def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]): + def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry | None]]): try: for article, pinecone_entry in batch: - self.pinecone_db.upsert_entry(pinecone_entry) + if pinecone_entry: + self.pinecone_db.upsert_entry(pinecone_entry) article.pinecone_update_required = False session.add(article) @@ -70,23 +77,27 @@ def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry] def batch_entries( self, article_stream: Generator[Article, None, None] - ) -> Iterator[List[Tuple[Article, PineconeEntry]]]: + ) -> Iterator[List[Tuple[Article, PineconeEntry | None]]]: while batch := tuple(islice(article_stream, 10)): - yield [ - (article, pinecone_entry) - for article in batch - if (pinecone_entry := self._make_pinecone_entry(article)) is not None - ] + yield [(article, self._make_pinecone_entry(article)) for article in batch] def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: try: text_chunks = get_text_chunks(article, self.text_splitter) embeddings, moderation_results = get_embeddings(text_chunks, article.source) - if any(result['flagged'] for result in moderation_results): - flagged_text_chunks = [f"Chunk {i}: \"{text}\"" for i, (text, result) in enumerate(zip(text_chunks, moderation_results)) if result["flagged"]] - logger.warning(f"OpenAI moderation flagged text chunks for the following article: {article.id}") - article.append_comment(f"OpenAI moderation flagged the following text chunks: {flagged_text_chunks}") + if any(result["flagged"] for result in moderation_results): + flagged_text_chunks = [ + f'Chunk {i}: "{text}"' + for i, (text, result) in enumerate(zip(text_chunks, moderation_results)) + if result["flagged"] + ] + logger.warning( + f"OpenAI moderation flagged text chunks for the following article: {article.id}" + ) + article.append_comment( + f"OpenAI moderation flagged the following text chunks: {flagged_text_chunks}" + ) return PineconeEntry( hash_id=article.id, # the hash_id of the article @@ -98,11 +109,18 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: text_chunks=text_chunks, embeddings=embeddings, ) - except (ValueError, TypeError, AttributeError, ValidationError, MissingFieldsError, MissingEmbeddingModelError) as e: + except ( + ValueError, + TypeError, + AttributeError, + ValidationError, + MissingFieldsError, + MissingEmbeddingModelError, + ) as e: logger.warning(e) article.append_comment(f"Error encountered while processing this article: {e}") return None - + except Exception as e: logger.error(e) raise diff --git a/align_data/settings.py b/align_data/settings.py index 73459b8f..8e91403a 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -39,8 +39,19 @@ host = os.environ.get("ARD_DB_HOST", "127.0.0.1") port = os.environ.get("ARD_DB_PORT", "3306") db_name = os.environ.get("ARD_DB_NAME", "alignment_research_dataset") -DB_CONNECTION_URI = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{db_name}" -ARTICLE_MAIN_KEYS = ["id", "source", 'source_type', "title", "authors", "text", "url", "date_published", "status", "comments"] +DB_CONNECTION_URI = f"mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}" +ARTICLE_MAIN_KEYS = [ + "id", + "source", + "source_type", + "title", + "authors", + "text", + "url", + "date_published", + "status", + "comments", +] ### EMBEDDINGS ### USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used. diff --git a/align_data/sources/agisf/__init__.py b/align_data/sources/agisf/__init__.py index 0fbf757a..60517c4c 100644 --- a/align_data/sources/agisf/__init__.py +++ b/align_data/sources/agisf/__init__.py @@ -5,32 +5,28 @@ datasets = [ AirtableDataset( - name='agisf_governance', - base_id='app9q0E0jlDWlsR0z', - table_id='tblgTb3kszvSbo2Mb', + name="agisf_governance", + base_id="app9q0E0jlDWlsR0z", + table_id="tblgTb3kszvSbo2Mb", mappings={ - 'title': '[>] Resource', - 'url': '[h] [>] Link', - 'source_type': '[h] [>] Type', - 'summary': '[h] Resource guide', - 'authors': 'Author(s) (from Resources)', + "title": "[>] Resource", + "url": "[h] [>] Link", + "source_type": "[h] [>] Type", + "comments": "[h] Resource guide", + "authors": "Author(s) (from Resources)", }, - processors = { - 'source_type': lambda val: val[0] if val else None, - 'authors': lambda val: val and [v.strip() for v in val.split(',')] - } ), AGISFPodcastDataset( - name='agisf_readings_alignment', - url='https://feeds.type3.audio/agi-safety-fundamentals--alignment.rss', + name="agisf_readings_alignment", + url="https://feeds.type3.audio/agi-safety-fundamentals--alignment.rss", ), AGISFPodcastDataset( - name='agisf_readings_governance', - url='https://feeds.type3.audio/agi-safety-fundamentals--governance.rss', + name="agisf_readings_governance", + url="https://feeds.type3.audio/agi-safety-fundamentals--governance.rss", ), ] AGISF_DATASETS = [ - MultiDataset(name='agisf', datasets=datasets), + MultiDataset(name="agisf", datasets=datasets), ] diff --git a/align_data/sources/agisf/agisf.py b/align_data/sources/agisf/agisf.py index e56a60e4..73e9eef4 100644 --- a/align_data/sources/agisf/agisf.py +++ b/align_data/sources/agisf/agisf.py @@ -8,8 +8,7 @@ class AGISFPodcastDataset(RSSDataset): - - regex = re.compile(r'^\[Week .*?\]\s+“(?P.*?)”\s+by\s+(?P<authors>.*?)$') + regex = re.compile(r"^\[Week .*?\]\s+“(?P<title>.*?)”\s+by\s+(?P<authors>.*?)$") @property def feed_url(self): @@ -17,34 +16,34 @@ def feed_url(self): def fetch_contents(self, url: str) -> Dict[str, Any]: contents = super().fetch_contents(url) - if extracted := self.regex.search(contents.get('title')): + if extracted := self.regex.search(contents.get("title")): return merge_dicts(contents, extracted.groupdict()) return contents def _get_text(self, item): - contents = item_metadata(item['link']) + contents = item_metadata(item["link"]) # Replace any non empty values in item. `item.update()` will happily insert Nones for k, v in contents.items(): if v: item[k] = v - return item.get('text') + return item.get("text") def extract_authors(self, item): authors = item.get("authors") if not authors: return self.authors if isinstance(authors, str): - return [a.strip() for a in authors.split(',')] + return [a.strip() for a in authors.split(",")] return authors def _extra_values(self, contents): - if summary := contents.get('summary'): + if summary := contents.get("summary"): soup = BeautifulSoup(summary, "html.parser") - for el in soup.select('b'): + for el in soup.select("b"): if el.next_sibling: el.next_sibling.extract() el.extract() - return {'summary': self._extract_markdown(soup)} + return {"summary": self._extract_markdown(soup)} return {} def process_entry(self, article): diff --git a/align_data/sources/airtable.py b/align_data/sources/airtable.py index db24350e..bcb07055 100644 --- a/align_data/sources/airtable.py +++ b/align_data/sources/airtable.py @@ -12,33 +12,33 @@ @dataclass class AirtableDataset(AlignmentDataset): - base_id: str table_id: str mappings: Dict[str, str] processors: Dict[str, Callable[[Any], str]] - done_key = 'url' + done_key = "url" def setup(self): if not AIRTABLE_API_KEY: - raise ValueError('No AIRTABLE_API_KEY provided!') + raise ValueError("No AIRTABLE_API_KEY provided!") super().setup() self.at = airtable.Airtable(self.base_id, AIRTABLE_API_KEY) def map_cols(self, item: Dict[str, Dict[str, str]]) -> Optional[Dict[str, Optional[str]]]: - fields = item.get('fields', {}) + fields = item.get("fields", {}) + def map_col(k): val = fields.get(self.mappings.get(k) or k) if processor := self.processors.get(k): val = processor(val) return val - mapped = {k: map_col(k) for k in ARTICLE_MAIN_KEYS + ['summary']} - if (mapped.get('url') or '').startswith('http'): + mapped = {k: map_col(k) for k in ARTICLE_MAIN_KEYS + ["summary"]} + if (mapped.get("url") or "").startswith("http"): return mapped def get_item_key(self, item): - return item.get('url') + return item.get("url") @property def items_list(self) -> Iterable[Dict[str, Union[str, None]]]: @@ -49,5 +49,5 @@ def process_entry(self, entry) -> Optional[Article]: if not contents: return None - entry['date_published'] = self._get_published_date(entry.get('date_published')) + entry["date_published"] = self._get_published_date(entry.get("date_published")) return self.make_data_entry(merge_dicts(entry, contents), source=self.name) diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py index 6fd45fbc..430c6449 100644 --- a/align_data/sources/articles/__init__.py +++ b/align_data/sources/articles/__init__.py @@ -52,7 +52,7 @@ ARTICLES_REGISTRY = [ - MultiDataset(name='special_docs', datasets=ARTICLES_DATASETS), + MultiDataset(name="special_docs", datasets=ARTICLES_DATASETS), ArxivPapers( name="arxiv", spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI", diff --git a/align_data/sources/arxiv_papers.py b/align_data/sources/arxiv_papers.py index ea95d6be..30002a78 100644 --- a/align_data/sources/arxiv_papers.py +++ b/align_data/sources/arxiv_papers.py @@ -60,19 +60,22 @@ def add_metadata(data, paper_id): metadata = get_arxiv_metadata(paper_id) if not metadata: return {} - return merge_dicts({ - "authors": metadata.authors, - "title": metadata.title, - "date_published": metadata.published, - "data_last_modified": metadata.updated.isoformat(), - "summary": metadata.summary.replace("\n", " "), - "comment": metadata.comment, - "journal_ref": metadata.journal_ref, - "doi": metadata.doi, - "primary_category": metadata.primary_category, - "categories": metadata.categories, - "version": get_version(metadata.get_short_id()), - }, data) + return merge_dicts( + { + "authors": metadata.authors, + "title": metadata.title, + "date_published": metadata.published, + "data_last_modified": metadata.updated.isoformat(), + "summary": metadata.summary.replace("\n", " "), + "comment": metadata.comment, + "journal_ref": metadata.journal_ref, + "doi": metadata.doi, + "primary_category": metadata.primary_category, + "categories": metadata.categories, + "version": get_version(metadata.get_short_id()), + }, + data, + ) def fetch_arxiv(url) -> Dict: diff --git a/align_data/sources/blogs/__init__.py b/align_data/sources/blogs/__init__.py index 64f27310..c7fd34db 100644 --- a/align_data/sources/blogs/__init__.py +++ b/align_data/sources/blogs/__init__.py @@ -58,10 +58,10 @@ name="deepmind_technical_blog", url="https://www.deepmind.com/blog-categories/technical-blogs", ), - TransformerCircuits(name='transformer-circuits', url='https://transformer-circuits.pub/'), + TransformerCircuits(name="transformer-circuits", url="https://transformer-circuits.pub/"), ] BLOG_REGISTRY = [ - MultiDataset(name='blogs', datasets=BLOG_DATASETS), + MultiDataset(name="blogs", datasets=BLOG_DATASETS), ] diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index 6d7b144b..df0f1633 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -127,9 +127,8 @@ def extract_authors(self, article): class TransformerCircuits(HTMLDataset): - item_selector = "div.toc a" - text_selector = 'h3' + text_selector = "h3" def get_item_key(self, item): article_url = item.get("href").split("?")[0] @@ -140,7 +139,7 @@ def items_list(self): return [i for i in super().items_list if self.get_item_key(i).startswith(self.url)] def _metadata(self, contents, selector): - if meta := contents.select_one('div.d-byline'): + if meta := contents.select_one("div.d-byline"): return meta.select(selector) def _get_title(self, contents): @@ -148,7 +147,7 @@ def _get_title(self, contents): return title and title.text.strip() def _get_published_date(self, contents): - if date := self._metadata(contents, 'div.published div'): + if date := self._metadata(contents, "div.published div"): return super()._get_published_date(date[0].text) def _get_text(self, contents): @@ -156,11 +155,11 @@ def _get_text(self, contents): return self._extract_markdown(article) def extract_authors(self, contents): - if authors := self._metadata(contents, 'span.author'): + if authors := self._metadata(contents, "span.author"): for a in authors: - for sup in a.select('sup'): + for sup in a.select("sup"): sup.extract() - return [a.text.strip().strip(',*') for a in authors] + return [a.text.strip().strip(",*") for a in authors] return [] diff --git a/align_data/sources/youtube/__init__.py b/align_data/sources/youtube/__init__.py index ca0d9b33..8a617b52 100644 --- a/align_data/sources/youtube/__init__.py +++ b/align_data/sources/youtube/__init__.py @@ -44,5 +44,5 @@ YOUTUBE_REGISTRY = [ - MultiDataset(name='youtube', datasets=YOUTUBE_DATASETS), + MultiDataset(name="youtube", datasets=YOUTUBE_DATASETS), ] diff --git a/main.py b/main.py index feeece54..4e28cd08 100644 --- a/main.py +++ b/main.py @@ -137,16 +137,15 @@ def pinecone_update_all(self, *skip, force_update=False) -> None: names = [name for name in ALL_DATASETS if name not in skip] PineconeUpdater().update(names, force_update) - def pinecone_update_individual_articles(self, hash_ids: str, force_update=False) -> None: + def pinecone_update_individual_articles(self, *hash_ids: str, force_update=False) -> None: """ Update the Pinecone entries of specific articles based on their IDs given as a comma-separated string. - :param str ids: Comma-separated list of article IDs. + :param str hash_ids: space-separated list of article IDs. """ names = ALL_DATASETS - article_ids = [hash_id for hash_id in hash_ids.split(",")] - PineconeUpdater().update_articles_by_ids(names, article_ids, force_update) + PineconeUpdater().update_articles_by_ids(names, hash_ids, force_update) def train_finetuning_layer(self) -> None: """ diff --git a/tests/align_data/sources/test_agisf.py b/tests/align_data/sources/test_agisf.py new file mode 100644 index 00000000..c3d67be6 --- /dev/null +++ b/tests/align_data/sources/test_agisf.py @@ -0,0 +1,105 @@ +import pytest +from unittest.mock import patch + +from align_data.sources.agisf.agisf import AGISFPodcastDataset + + +SAMPLE_ITEM = { + "title": "[Week 0] “Machine Learning for Humans, Part 2.1: Supervised Learning” by Vishal Maini", + "content": "this is needed, but will mostly be ignored", + "summary": '<p>Bla bla bla</p><br /><br /><b>Original article:<br /></b><a href="https://medium.com/machine-learning-for-humans/supervised-learning-740383a2feab">https://medium.com/machine-learning-for-humans/supervised-learning-740383a2feab</a><br /><br /><b>Author:<br /></b>Vishal Maini</p>', + "link": "https://ble.ble.com", +} + + +def test_fetch_contents(): + dataset = AGISFPodcastDataset(name="bla", url="https://bla.bla.com") + url = "https://test.url" + dataset.items = {url: SAMPLE_ITEM} + assert dataset.fetch_contents(url) == dict( + SAMPLE_ITEM, + authors="Vishal Maini", + title="Machine Learning for Humans, Part 2.1: Supervised Learning", + ) + + +def test_fetch_contents_bad_title(): + dataset = AGISFPodcastDataset(name="bla", url="https://bla.bla.com") + url = "https://test.url" + dataset.items = {url: dict(SAMPLE_ITEM, title="asdasdasd")} + assert dataset.fetch_contents(url) == dict(SAMPLE_ITEM, title="asdasdasd") + + +def test_get_text(): + dataset = AGISFPodcastDataset(name="bla", url="https://bla.bla.com") + item = dict(SAMPLE_ITEM) + + with patch( + "align_data.sources.agisf.agisf.item_metadata", + return_value={ + "text": "bla bla bla", + "source_type": "some kind of thing", + "title": None, + "authors": [], + "content": "this should now change", + }, + ): + assert dataset._get_text(item) == "bla bla bla" + assert item == dict( + SAMPLE_ITEM, + content="this should now change", + text="bla bla bla", + source_type="some kind of thing", + ) + + +@pytest.mark.parametrize( + "authors, expected", + ( + (None, ["default"]), + ("", ["default"]), + ([], ["default"]), + ("bla", ["bla"]), + ( + "johnny bravo, mr. blobby\t\t\t, Hans Klos ", + ["johnny bravo", "mr. blobby", "Hans Klos"], + ), + (["mr. bean"], ["mr. bean"]), + (["johnny bravo", "mr. blobby", "Hans Klos"], ["johnny bravo", "mr. blobby", "Hans Klos"]), + ), +) +def test_extract_authors(authors, expected): + dataset = AGISFPodcastDataset(name="bla", url="https://bla.bla.com", authors=["default"]) + item = dict(SAMPLE_ITEM, authors=authors) + assert dataset.extract_authors(item) == expected + + +def test_extra_values(): + dataset = AGISFPodcastDataset(name="bla", url="https://bla.bla.com", authors=["default"]) + assert dataset._extra_values(SAMPLE_ITEM) == { + "summary": "Bla bla bla", + } + + +def test_extra_values_no_summary(): + dataset = AGISFPodcastDataset(name="bla", url="https://bla.bla.com", authors=["default"]) + assert dataset._extra_values({}) == {} + + +def test_process_entry(): + dataset = AGISFPodcastDataset(name="bla", url="https://bla.bla.com") + url = "https://test.url" + dataset.items = {url: SAMPLE_ITEM} + + with patch("align_data.sources.agisf.agisf.item_metadata", return_value={"text": "bla"}): + assert dataset.process_entry(url).to_dict() == { + "authors": ["Vishal Maini"], + "date_published": None, + "id": None, + "source": "bla", + "source_type": "blog", + "summaries": ["Bla bla bla"], + "text": "bla", + "title": "Machine Learning for Humans, Part 2.1: Supervised Learning", + "url": "https://test.url", + } diff --git a/tests/align_data/sources/test_airtable.py b/tests/align_data/sources/test_airtable.py new file mode 100644 index 00000000..ae36f731 --- /dev/null +++ b/tests/align_data/sources/test_airtable.py @@ -0,0 +1,189 @@ +import pytest +from unittest.mock import patch + +from align_data.sources.airtable import AirtableDataset + + +@pytest.mark.parametrize( + "item, overwrites", + ( + ({"url": "http://bla.vle"}, {}), + ({"url": "http://bla.vle", "source": "your momma"}, {"source": "your momma"}), + ({"url": "http://bla.vle", "source": "your momma", "bla": "ble"}, {"source": "your momma"}), + ( + {"url": "http://bla.vle", "status": "fine", "title": "Something or other"}, + {"status": "fine", "title": "Something or other"}, + ), + ( + {"url": "http://some.other.url", "source_type": "blog", "authors": "bla, bla, bla"}, + {"url": "http://some.other.url", "source_type": "blog", "authors": "bla, bla, bla"}, + ), + ), +) +def test_map_cols_no_mapping(item, overwrites): + dataset = AirtableDataset( + name="asd", base_id="ddwe", table_id="csdcsc", mappings={}, processors={} + ) + assert dataset.map_cols({"id": "123", "fields": item}) == dict( + { + "authors": None, + "comments": None, + "date_published": None, + "id": None, + "source": None, + "source_type": None, + "status": None, + "text": None, + "title": None, + "summary": None, + "url": "http://bla.vle", + }, + **overwrites + ) + + +@pytest.mark.parametrize( + "item, overwrites", + ( + ({"an url!": "http://bla.vle"}, {}), + ({"an url!": "http://bla.vle", "source": "your momma"}, {"source": "your momma"}), + ( + {"an url!": "http://bla.vle", "source": "your momma", "bla": "ble"}, + {"source": "your momma"}, + ), + ( + {"an url!": "http://bla.vle", "status": "fine", "title": "Something or other"}, + {"status": "fine", "title": "Something or other"}, + ), + ( + { + "an url!": "http://some.other.url", + "source_type": "blog", + "whodunnit": "bla, bla, bla", + }, + {"url": "http://some.other.url", "source_type": "blog", "authors": "bla, bla, bla"}, + ), + ), +) +def test_map_cols_with_mapping(item, overwrites): + dataset = AirtableDataset( + name="asd", + base_id="ddwe", + table_id="csdcsc", + mappings={ + "url": "an url!", + "authors": "whodunnit", + }, + processors={}, + ) + assert dataset.map_cols({"id": "123", "fields": item}) == dict( + { + "authors": None, + "comments": None, + "date_published": None, + "id": None, + "source": None, + "source_type": None, + "status": None, + "text": None, + "title": None, + "summary": None, + "url": "http://bla.vle", + }, + **overwrites + ) + + +@pytest.mark.parametrize( + "item, overwrites", + ( + ({"an url!": "http://bla.vle"}, {}), + ({"an url!": "http://bla.vle", "source": "your momma"}, {"source": "your momma"}), + ( + {"an url!": "http://bla.vle", "source": "your momma", "bla": "ble"}, + {"source": "your momma"}, + ), + ( + {"an url!": "http://bla.vle", "status": "fine", "title": "Something or other"}, + {"status": "fine", "title": "Something or other bla!"}, + ), + ( + { + "an url!": "http://some.other.url", + "source_type": "blog", + "whodunnit": "bla, bla, bla", + }, + {"url": "http://some.other.url", "source_type": "blog", "authors": "bla, bla, bla"}, + ), + ), +) +def test_map_cols_with_processing(item, overwrites): + dataset = AirtableDataset( + name="asd", + base_id="ddwe", + table_id="csdcsc", + mappings={ + "url": "an url!", + "authors": "whodunnit", + }, + processors={ + "title": lambda val: val and val + " bla!", + "id": lambda _: 123, + }, + ) + assert dataset.map_cols({"id": "123", "fields": item}) == dict( + { + "authors": None, + "comments": None, + "date_published": None, + "id": 123, + "source": None, + "source_type": None, + "status": None, + "text": None, + "title": None, + "summary": None, + "url": "http://bla.vle", + }, + **overwrites + ) + + +@pytest.mark.parametrize("url", (None, "", "asdasdsad")) +def test_map_cols_no_url(url): + dataset = AirtableDataset( + name="asd", base_id="ddwe", table_id="csdcsc", mappings={}, processors={} + ) + assert dataset.map_cols({"id": "123", "fields": {"url": url}}) is None + + +def test_process_entry(): + dataset = AirtableDataset( + name="asd", base_id="ddwe", table_id="csdcsc", mappings={}, processors={} + ) + entry = { + "url": "http://bla.cle", + "authors": ["johnny", "your momma", "mr. Blobby", "Łóżćś Jaś"], + "date_published": "2023-01-02", + "source": "some place", + "status": "fine", + "comments": "should be ok", + } + with patch( + "align_data.sources.airtable.item_metadata", + return_value={ + "text": "bla bla bla", + "source_type": "some kind of thing", + }, + ): + assert dataset.process_entry(entry).to_dict() == { + "authors": ["johnny", "your momma", "mr. Blobby", "Łóżćś Jaś"], + "date_published": "2023-01-02T00:00:00Z", + "id": None, + "source": "asd", + "source_type": "some kind of thing", + "summaries": [], + "text": "bla bla bla", + "title": None, + "url": "http://bla.cle", + } diff --git a/tests/align_data/sources/test_arxiv.py b/tests/align_data/sources/test_arxiv.py index d5bf1c8e..1f6ed039 100644 --- a/tests/align_data/sources/test_arxiv.py +++ b/tests/align_data/sources/test_arxiv.py @@ -14,25 +14,31 @@ def test_get_id(url, expected): assert get_id("https://arxiv.org/abs/2001.11038") == "2001.11038" -@pytest.mark.parametrize('url, expected', ( - ("http://bla.bla", "http://bla.bla"), - ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"), - ("https://arxiv.org/abs/math/2001.11038", "https://arxiv.org/abs/math/2001.11038"), -)) +@pytest.mark.parametrize( + "url, expected", + ( + ("http://bla.bla", "http://bla.bla"), + ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"), + ("https://arxiv.org/abs/math/2001.11038", "https://arxiv.org/abs/math/2001.11038"), + ), +) def test_canonical_url(url, expected): assert canonical_url(url) == expected -@pytest.mark.parametrize('id, version', ( - ('123.123', None), - ('math/312', None), - ('3123123v1', '1'), - ('3123123v123', '123'), -)) +@pytest.mark.parametrize( + "id, version", + ( + ("123.123", None), + ("math/312", None), + ("3123123v1", "1"), + ("3123123v123", "123"), + ), +) def test_get_version(id, version): assert get_version(id) == version diff --git a/tests/align_data/sources/test_blogs.py b/tests/align_data/sources/test_blogs.py index 154daa6b..9debd46b 100644 --- a/tests/align_data/sources/test_blogs.py +++ b/tests/align_data/sources/test_blogs.py @@ -788,8 +788,9 @@ def test_deepmind_technical_proces_entry(): </html> """ + def test_transformer_circuits_item_key(): - dataset = TransformerCircuits(url='http://bla.com', name='ble') + dataset = TransformerCircuits(url="http://bla.com", name="ble") html = """<div> <a class="paper" href="2023/july-update/index.html"> <h3>Circuits Updates — July 2023</h3> @@ -798,11 +799,14 @@ def test_transformer_circuits_item_key(): A collection of small updates from the Anthropic Interpretability Team. </div> </a></div>""" - assert dataset.get_item_key(BeautifulSoup(html, 'html.parser').find('a')) == 'http://bla.com/2023/july-update/index.html' + assert ( + dataset.get_item_key(BeautifulSoup(html, "html.parser").find("a")) + == "http://bla.com/2023/july-update/index.html" + ) def test_transformer_circuits_item_list(): - dataset = TransformerCircuits(url='http://bla.com', name='ble') + dataset = TransformerCircuits(url="http://bla.com", name="ble") html = """<div> <div class="toc"> <a href="item1.html"></a> @@ -813,43 +817,46 @@ def test_transformer_circuits_item_list(): <a href="http://this.will.be.skipped"></a> </div></div>""" with patch("requests.get", return_value=Mock(content=html)): - assert [i.get('href') for i in dataset.items_list] == [ - 'item1.html', 'item2.html', 'item3.html', 'http://bla.com/item4.html' + assert [i.get("href") for i in dataset.items_list] == [ + "item1.html", + "item2.html", + "item3.html", + "http://bla.com/item4.html", ] def test_transformer_circuits_get_title(): - dataset = TransformerCircuits(url='http://bla.com', name='ble') + dataset = TransformerCircuits(url="http://bla.com", name="ble") soup = BeautifulSoup(TRANSFORMER_CIRCUITS_HTML, "html.parser") assert dataset._get_title(soup) == "This is the title" def test_transformer_circuits_get_published_date(): - dataset = TransformerCircuits(url='http://bla.com', name='ble') + dataset = TransformerCircuits(url="http://bla.com", name="ble") soup = BeautifulSoup(TRANSFORMER_CIRCUITS_HTML, "html.parser") assert dataset._get_published_date(soup).isoformat() == "2023-03-16T00:00:00+00:00" def test_transformer_circuits_get_text(): - dataset = TransformerCircuits(url='http://bla.com', name='ble') + dataset = TransformerCircuits(url="http://bla.com", name="ble") soup = BeautifulSoup(TRANSFORMER_CIRCUITS_HTML, "html.parser") assert dataset._get_text(soup) == "This is where the text goes. With a [link](bla.com) to test" def test_transformer_circuits_process_item(): - dataset = TransformerCircuits(url='http://bla.com', name='ble') - item = BeautifulSoup('<a href="ble/bla"</a>', "html.parser").find('a') + dataset = TransformerCircuits(url="http://bla.com", name="ble") + item = BeautifulSoup('<a href="ble/bla"</a>', "html.parser").find("a") with patch("requests.get", return_value=Mock(content=TRANSFORMER_CIRCUITS_HTML)): assert dataset.process_entry(item).to_dict() == { - 'authors': ['Nelson Elhage', 'Robert Lasenby', 'Christopher Olah'], - 'date_published': '2023-03-16T00:00:00Z', - 'id': None, - 'source': 'ble', - 'source_type': 'blog', - 'summaries': [], - 'text': 'This is where the text goes. With a [link](bla.com) to test', - 'title': 'This is the title', - 'url': 'http://bla.com/ble/bla', + "authors": ["Nelson Elhage", "Robert Lasenby", "Christopher Olah"], + "date_published": "2023-03-16T00:00:00Z", + "id": None, + "source": "ble", + "source_type": "blog", + "summaries": [], + "text": "This is where the text goes. With a [link](bla.com) to test", + "title": "This is the title", + "url": "http://bla.com/ble/bla", } diff --git a/tests/align_data/test_utils.py b/tests/align_data/sources/test_utils.py similarity index 58% rename from tests/align_data/test_utils.py rename to tests/align_data/sources/test_utils.py index 8e32c612..42b6ac7e 100644 --- a/tests/align_data/test_utils.py +++ b/tests/align_data/sources/test_utils.py @@ -10,27 +10,30 @@ def test_merge_dicts_no_args(): def test_merge_dicts_single_dict(): """Test merge_dicts function with a single dictionary.""" - result = merge_dicts({'a': 1, 'b': 2}) - assert result == {'a': 1, 'b': 2} + result = merge_dicts({"a": 1, "b": 2}) + assert result == {"a": 1, "b": 2} def test_merge_dicts_dicts_with_no_overlap(): """Test merge_dicts function with multiple dictionaries with no overlapping keys.""" - result = merge_dicts({'a': 1}, {'b': 2}, {'c': 3}) - assert result == {'a': 1, 'b': 2, 'c': 3} + result = merge_dicts({"a": 1}, {"b": 2}, {"c": 3}) + assert result == {"a": 1, "b": 2, "c": 3} def test_merge_dicts_dicts_with_overlap(): """Test merge_dicts function with multiple dictionaries with overlapping keys.""" - result = merge_dicts({'a': 1, 'b': 2}, {'b': 3, 'c': 4}, {'c': 5, 'd': 6}) - assert result == {'a': 1, 'b': 3, 'c': 5, 'd': 6} - - -@pytest.mark.parametrize("input_dicts, expected", [ - ([{'a': 1, 'b': None}, {'b': 3}], {'a': 1, 'b': 3}), - ([{'a': 0, 'b': 2}, {'b': None}], {'a': 0, 'b': 2}), - ([{'a': None}, {'b': 'test'}], {'b': 'test'}), -]) + result = merge_dicts({"a": 1, "b": 2}, {"b": 3, "c": 4}, {"c": 5, "d": 6}) + assert result == {"a": 1, "b": 3, "c": 5, "d": 6} + + +@pytest.mark.parametrize( + "input_dicts, expected", + [ + ([{"a": 1, "b": None}, {"b": 3}], {"a": 1, "b": 3}), + ([{"a": 0, "b": 2}, {"b": None}], {"a": 0, "b": 2}), + ([{"a": None}, {"b": "test"}], {"b": "test"}), + ], +) def test_merge_dicts_with_none_values(input_dicts, expected): """Test merge_dicts function with dictionaries containing None or falsey values.""" result = merge_dicts(*input_dicts) diff --git a/tests/align_data/test_agisf.py b/tests/align_data/test_agisf.py deleted file mode 100644 index 380219da..00000000 --- a/tests/align_data/test_agisf.py +++ /dev/null @@ -1,96 +0,0 @@ -import pytest -from unittest.mock import patch - -from align_data.sources.agisf.agisf import AGISFPodcastDataset - - -SAMPLE_ITEM = { - 'title': '[Week 0] “Machine Learning for Humans, Part 2.1: Supervised Learning” by Vishal Maini', - 'content': 'this is needed, but will mostly be ignored', - 'summary': '<p>Bla bla bla</p><br /><br /><b>Original article:<br /></b><a href="https://medium.com/machine-learning-for-humans/supervised-learning-740383a2feab">https://medium.com/machine-learning-for-humans/supervised-learning-740383a2feab</a><br /><br /><b>Author:<br /></b>Vishal Maini</p>', - 'link': 'https://ble.ble.com', -} - - -def test_fetch_contents(): - dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com') - url = 'https://test.url' - dataset.items = {url: SAMPLE_ITEM} - assert dataset.fetch_contents(url) == dict( - SAMPLE_ITEM, authors='Vishal Maini', - title='Machine Learning for Humans, Part 2.1: Supervised Learning' - ) - - -def test_fetch_contents_bad_title(): - dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com') - url = 'https://test.url' - dataset.items = {url: dict(SAMPLE_ITEM, title='asdasdasd')} - assert dataset.fetch_contents(url) == dict(SAMPLE_ITEM, title='asdasdasd') - - -def test_get_text(): - dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com') - item = dict(SAMPLE_ITEM) - - with patch("align_data.sources.agisf.agisf.item_metadata", return_value={ - 'text': 'bla bla bla', - 'source_type': 'some kind of thing', - 'title': None, - 'authors': [], - 'content': 'this should now change', - }): - assert dataset._get_text(item) == 'bla bla bla' - assert item == dict( - SAMPLE_ITEM, - content='this should now change', - text='bla bla bla', - source_type='some kind of thing', - ) - - -@pytest.mark.parametrize('authors, expected', ( - (None, ['default']), - ('', ['default']), - ([], ['default']), - - ('bla', ['bla']), - ('johnny bravo, mr. blobby\t\t\t, Hans Klos ', ['johnny bravo', 'mr. blobby', 'Hans Klos']), - (['mr. bean'], ['mr. bean']), - (['johnny bravo', 'mr. blobby', 'Hans Klos'], ['johnny bravo', 'mr. blobby', 'Hans Klos']), -)) -def test_extract_authors(authors, expected): - dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com', authors=['default']) - item = dict(SAMPLE_ITEM, authors=authors) - assert dataset.extract_authors(item) == expected - - -def test_extra_values(): - dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com', authors=['default']) - assert dataset._extra_values(SAMPLE_ITEM) == { - 'summary': 'Bla bla bla', - } - - -def test_extra_values_no_summary(): - dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com', authors=['default']) - assert dataset._extra_values({}) == {} - - -def test_process_entry(): - dataset = AGISFPodcastDataset(name='bla', url='https://bla.bla.com') - url = 'https://test.url' - dataset.items = {url: SAMPLE_ITEM} - - with patch("align_data.sources.agisf.agisf.item_metadata", return_value={'text': 'bla'}): - assert dataset.process_entry(url).to_dict() == { - 'authors': ['Vishal Maini'], - 'date_published': None, - 'id': None, - 'source': 'bla', - 'source_type': 'blog', - 'summaries': ['Bla bla bla'], - 'text': 'bla', - 'title': 'Machine Learning for Humans, Part 2.1: Supervised Learning', - 'url': 'https://test.url', - } diff --git a/tests/align_data/test_airtable.py b/tests/align_data/test_airtable.py deleted file mode 100644 index 4de1fed0..00000000 --- a/tests/align_data/test_airtable.py +++ /dev/null @@ -1,144 +0,0 @@ -import pytest -from unittest.mock import patch - -from align_data.sources.airtable import AirtableDataset - - -@pytest.mark.parametrize('item, overwrites', ( - ({'url': 'http://bla.vle'}, {}), - ({'url': 'http://bla.vle', 'source': 'your momma'}, {'source': 'your momma'}), - ({'url': 'http://bla.vle', 'source': 'your momma', 'bla': 'ble'}, {'source': 'your momma'}), - ( - {'url': 'http://bla.vle', 'status': 'fine', 'title': 'Something or other'}, - {'status': 'fine', 'title': 'Something or other'} - ), - ( - {'url': 'http://some.other.url', 'source_type': 'blog', 'authors': 'bla, bla, bla'}, - {'url': 'http://some.other.url', 'source_type': 'blog', 'authors': 'bla, bla, bla'} - ), -)) -def test_map_cols_no_mapping(item, overwrites): - dataset = AirtableDataset(name='asd', base_id='ddwe', table_id='csdcsc', mappings={}, processors={}) - assert dataset.map_cols({'id': '123', 'fields': item}) == dict({ - 'authors': None, - 'comments': None, - 'date_published': None, - 'id': None, - 'source': None, - 'source_type': None, - 'status': None, - 'text': None, - 'title': None, - 'summary': None, - 'url': 'http://bla.vle' - }, **overwrites) - - -@pytest.mark.parametrize('item, overwrites', ( - ({'an url!': 'http://bla.vle'}, {}), - ({'an url!': 'http://bla.vle', 'source': 'your momma'}, {'source': 'your momma'}), - ({'an url!': 'http://bla.vle', 'source': 'your momma', 'bla': 'ble'}, {'source': 'your momma'}), - ( - {'an url!': 'http://bla.vle', 'status': 'fine', 'title': 'Something or other'}, - {'status': 'fine', 'title': 'Something or other'} - ), - ( - {'an url!': 'http://some.other.url', 'source_type': 'blog', 'whodunnit': 'bla, bla, bla'}, - {'url': 'http://some.other.url', 'source_type': 'blog', 'authors': 'bla, bla, bla'} - ), -)) -def test_map_cols_with_mapping(item, overwrites): - dataset = AirtableDataset( - name='asd', base_id='ddwe', table_id='csdcsc', - mappings={ - 'url': 'an url!', - 'authors': 'whodunnit', - }, - processors={} - ) - assert dataset.map_cols({'id': '123', 'fields': item}) == dict({ - 'authors': None, - 'comments': None, - 'date_published': None, - 'id': None, - 'source': None, - 'source_type': None, - 'status': None, - 'text': None, - 'title': None, - 'summary': None, - 'url': 'http://bla.vle' - }, **overwrites) - - -@pytest.mark.parametrize('item, overwrites', ( - ({'an url!': 'http://bla.vle'}, {}), - ({'an url!': 'http://bla.vle', 'source': 'your momma'}, {'source': 'your momma'}), - ({'an url!': 'http://bla.vle', 'source': 'your momma', 'bla': 'ble'}, {'source': 'your momma'}), - ( - {'an url!': 'http://bla.vle', 'status': 'fine', 'title': 'Something or other'}, - {'status': 'fine', 'title': 'Something or other bla!'} - ), - ( - {'an url!': 'http://some.other.url', 'source_type': 'blog', 'whodunnit': 'bla, bla, bla'}, - {'url': 'http://some.other.url', 'source_type': 'blog', 'authors': 'bla, bla, bla'} - ), -)) -def test_map_cols_with_processing(item, overwrites): - dataset = AirtableDataset( - name='asd', base_id='ddwe', table_id='csdcsc', - mappings={ - 'url': 'an url!', - 'authors': 'whodunnit', - }, - processors={ - 'title': lambda val: val and val + ' bla!', - 'id': lambda _: 123, - } - ) - assert dataset.map_cols({'id': '123', 'fields': item}) == dict({ - 'authors': None, - 'comments': None, - 'date_published': None, - 'id': 123, - 'source': None, - 'source_type': None, - 'status': None, - 'text': None, - 'title': None, - 'summary': None, - 'url': 'http://bla.vle' - }, **overwrites) - - -@pytest.mark.parametrize('url', (None, '', 'asdasdsad')) -def test_map_cols_no_url(url): - dataset = AirtableDataset(name='asd', base_id='ddwe', table_id='csdcsc', mappings={}, processors={}) - assert dataset.map_cols({'id': '123', 'fields': {'url': url}}) is None - - -def test_process_entry(): - dataset = AirtableDataset(name='asd', base_id='ddwe', table_id='csdcsc', mappings={}, processors={}) - entry = { - 'url': 'http://bla.cle', - 'authors': ['johnny', 'your momma', 'mr. Blobby', 'Łóżćś Jaś'], - 'date_published': '2023-01-02', - 'source': 'some place', - 'status': 'fine', - 'comments': 'should be ok', - } - with patch("align_data.sources.airtable.item_metadata", return_value={ - 'text': 'bla bla bla', - 'source_type': 'some kind of thing', - }): - assert dataset.process_entry(entry).to_dict() == { - 'authors': ['johnny', 'your momma', 'mr. Blobby', 'Łóżćś Jaś'], - 'date_published': '2023-01-02T00:00:00Z', - 'id': None, - 'source': 'asd', - 'source_type': 'some kind of thing', - 'summaries': [], - 'text': 'bla bla bla', - 'title': None, - 'url': 'http://bla.cle', - } From 88b280da57cc76b294057e0b561ffb6ec988a0d9 Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Fri, 25 Aug 2023 00:37:34 -0400 Subject: [PATCH 06/25] added check_for_changes to deal with pinecone_update_required --- align_data/db/models.py | 70 +++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/align_data/db/models.py b/align_data/db/models.py index 98716768..77a78c9d 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -1,9 +1,12 @@ +import re import json import logging import pytz import hashlib from datetime import datetime -from typing import List, Optional +from typing import Any, Dict, List, Optional +from urllib.parse import urlparse + from sqlalchemy import ( JSON, DateTime, @@ -15,9 +18,11 @@ event, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.orm.attributes import get_history from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.ext.hybrid import hybrid_property +from align_data.embeddings.pinecone.pinecone_models import PineconeMetadata logger = logging.getLogger(__name__) OK_STATUS = None @@ -74,7 +79,7 @@ def generate_id_string(self) -> bytes: return "".join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8") @property - def __id_fields(self): + def __id_fields(self) -> List[str]: if self.source == "aisafety.info": return ["url"] if self.source in ["importai", "ml_safety_newsletter", "alignment_newsletter"]: @@ -82,7 +87,7 @@ def __id_fields(self): return ["url", "title"] @property - def missing_fields(self): + def missing_fields(self) -> List[str]: fields = set(self.__id_fields) | { "text", "title", @@ -107,9 +112,14 @@ def verify_id_fields(self): def update(self, other): for field in self.__table__.columns.keys(): - if field not in ["id", "hash_id", "metadata"] and getattr(other, field): - setattr(self, field, getattr(other, field)) - self.meta = dict((self.meta or {}), **{k: v for k, v in other.meta.items() if k and v}) + if field not in ["id", "hash_id", "metadata"]: + new_value = getattr(other, field) + if new_value and getattr(self, field) != new_value: + setattr(self, field, new_value) + + updated_meta = dict((self.meta or {}), **{k: v for k, v in other.meta.items() if k and v}) + if self.meta != updated_meta: + self.meta = updated_meta if other._id: self._id = other._id @@ -120,30 +130,40 @@ def _set_id(self): id_string = self.generate_id_string() self.id = hashlib.md5(id_string).hexdigest() - def add_meta(self, key, val): + def add_meta(self, key: str, val: Any): if self.meta is None: self.meta = {} self.meta[key] = val def append_comment(self, comment: str): - """Appends a comment to the article.comments field. You must run session.commit() to save the comment to the database.""" if self.comments is None: self.comments = "" self.comments = f"{self.comments}\n\n{comment}".strip() @hybrid_property - def is_valid(self): - return bool( - self.text - and self.text.strip() - and self.url - and self.title - and self.authors is not None - and self.status == OK_STATUS + def is_valid(self) -> bool: + # Check if the basic attributes are present and non-empty + basic_check = all( + [ + self.text and self.text.strip(), + self.url and self.url.strip(), + self.title and self.title.strip(), + self.authors, + self.status == OK_STATUS, + ] ) + # URL validation + try: + result = urlparse(self.url) + url_check = all([result.scheme in ["http", "https"], result.netloc]) + except: + url_check = False + + return basic_check and url_check + @is_valid.expression - def is_valid(cls): + def is_valid(cls) -> bool: return ( (cls.status == OK_STATUS) & (cls.text != None) @@ -160,14 +180,23 @@ def before_write(cls, _mapper, _connection, target): target.status = "Missing fields" target.comments = f'missing fields: {", ".join(target.missing_fields)}' - target.pinecone_update_required = target.is_valid - if target.id: target.verify_id() else: target._set_id() - def to_dict(self): + @classmethod + def check_for_changes(cls, mapper, connection, target): + # Attributes we want to monitor for changes + monitored_attributes = list(PineconeMetadata.__annotations__.keys()) + monitored_attributes.remove("hash_id") + + changed = any(get_history(target, attr).has_changes() for attr in monitored_attributes) + + if changed and target.is_valid: + target.pinecone_update_required = True + + def to_dict(self) -> Dict[str, Any]: if date := self.date_published: date = date.replace(tzinfo=pytz.UTC).strftime("%Y-%m-%dT%H:%M:%SZ") meta = json.loads(self.meta) if isinstance(self.meta, str) else self.meta @@ -192,3 +221,4 @@ def to_dict(self): event.listen(Article, "before_insert", Article.before_write) event.listen(Article, "before_update", Article.before_write) +event.listen(Article, "before_update", Article.check_for_changes) From 1505ba6ffc8e7a3caf85ff9618fdfeab4836357e Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Fri, 25 Aug 2023 00:43:42 -0400 Subject: [PATCH 07/25] minor bug fixes --- align_data/db/session.py | 2 -- align_data/embeddings/pinecone/pinecone_db_handler.py | 11 +++++++++-- align_data/embeddings/pinecone/update_pinecone.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/align_data/db/session.py b/align_data/db/session.py index cab798b5..d2557827 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -34,7 +34,6 @@ def get_pinecone_articles_to_update( .filter(Article.is_valid) .filter(Article.source.in_(custom_sources)) .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) - # .yield_per(10) ) @@ -55,7 +54,6 @@ def get_pinecone_articles_by_ids( .filter(Article.source.in_(custom_sources)) .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) .filter(Article.id.in_(hash_ids)) - # .yield_per(10) ) diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py index dd2990d3..8222ba83 100644 --- a/align_data/embeddings/pinecone/pinecone_db_handler.py +++ b/align_data/embeddings/pinecone/pinecone_db_handler.py @@ -50,9 +50,16 @@ def __init__( index_stats_response = self.index.describe_index_stats() logger.info(f"{self.index_name}:\n{index_stats_response}") - def upsert_entry(self, pinecone_entry: PineconeEntry, upsert_size: int = 100): + def upsert_entry( + self, pinecone_entry: PineconeEntry, upsert_size: int = 100, show_progress: bool = True + ): vectors = pinecone_entry.create_pinecone_vectors() - self.index.upsert(vectors=vectors, batch_size=upsert_size, namespace=PINECONE_NAMESPACE) + self.index.upsert( + vectors=vectors, + batch_size=upsert_size, + namespace=PINECONE_NAMESPACE, + show_progress=show_progress, + ) def query_vector( self, diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index f0aa1b57..cf1a937d 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -59,7 +59,7 @@ def update_articles_by_ids( for batch in self.batch_entries(articles_to_update_stream): self.save_batch(session, batch) - def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry | None]]): + def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]): try: for article, pinecone_entry in batch: if pinecone_entry: From 96ab92e46dded1cfc9c9afab0b3415b85e58ea66 Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Fri, 25 Aug 2023 02:16:19 -0400 Subject: [PATCH 08/25] pinecone_update_required bug fix --- align_data/common/alignment_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 0d607b13..e1eab49f 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -74,7 +74,6 @@ def make_data_entry(self, data, **kwargs) -> Article: authors = data.pop("authors", []) article = Article( - pinecone_update_required=True, meta={k: v for k, v in data.items() if k not in ARTICLE_MAIN_KEYS and v is not None}, **{k: v for k, v in data.items() if k in ARTICLE_MAIN_KEYS}, ) From 113d34f84931c8ffdfa24b80486e21f4732bd5e4 Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Wed, 23 Aug 2023 12:44:58 -0400 Subject: [PATCH 09/25] merging mostly? --- align_data/common/html_dataset.py | 6 ++- align_data/sources/agisf/__init__.py | 4 ++ align_data/sources/blogs/__init__.py | 2 +- align_data/sources/blogs/blogs.py | 5 +- tests/align_data/sources/test_blogs.py | 64 +++++++++++++------------- 5 files changed, 43 insertions(+), 38 deletions(-) diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py index c7e6c469..95459312 100644 --- a/align_data/common/html_dataset.py +++ b/align_data/common/html_dataset.py @@ -150,11 +150,13 @@ def fetch_contents(self, url): ) def _extract_item_url(self, item) -> str | None: - return item.get('link') + return item.get("link") @property def items_list(self): logger.info(f"Fetching entries from {self.feed_url}") feed = feedparser.parse(self.feed_url) - self.items = {url: item for item in feed["entries"] if (url := self._extract_item_url(item))} + self.items = { + url: item for item in feed["entries"] if (url := self._extract_item_url(item)) + } return list(self.items.keys()) diff --git a/align_data/sources/agisf/__init__.py b/align_data/sources/agisf/__init__.py index 60517c4c..bf484913 100644 --- a/align_data/sources/agisf/__init__.py +++ b/align_data/sources/agisf/__init__.py @@ -15,6 +15,10 @@ "comments": "[h] Resource guide", "authors": "Author(s) (from Resources)", }, + processors={ + "source_type": lambda val: val[0] if val else None, + "authors": lambda val: val and [v.strip() for v in val.split(",")], + }, ), AGISFPodcastDataset( name="agisf_readings_alignment", diff --git a/align_data/sources/blogs/__init__.py b/align_data/sources/blogs/__init__.py index c7fd34db..fbc576b1 100644 --- a/align_data/sources/blogs/__init__.py +++ b/align_data/sources/blogs/__init__.py @@ -16,7 +16,7 @@ BLOG_DATASETS = [ - AXRPDataset(name='axrp', url='https://axrp.net', authors=['AXRP']), + AXRPDataset(name="axrp", url="https://axrp.net", authors=["AXRP"]), WordpressBlog(name="aiimpacts", url="https://aiimpacts.org"), WordpressBlog(name="aisafety.camp", url="https://aisafety.camp"), WordpressBlog(name="miri", url="https://intelligence.org"), diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index df0f1633..5a1f1874 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -164,7 +164,6 @@ def extract_authors(self, contents): class AXRPDataset(RSSDataset): - @property def feed_url(self): return f"{self.url}/feed.xml" @@ -177,11 +176,11 @@ def _extract_item_url(self, item) -> str | None: def extract_authors(self, item): if "authors" in item: - authors = [name for a in item["authors"] if (name := (a.get("name") or '').strip())] + authors = [name for a in item["authors"] if (name := (a.get("name") or "").strip())] if authors: return authors - bits = item.get('title', '').split(' with ') + bits = item.get("title", "").split(" with ") if len(bits) > 1 and bits[-1].strip(): return self.authors + [bits[-1].strip()] return self.authors diff --git a/tests/align_data/sources/test_blogs.py b/tests/align_data/sources/test_blogs.py index 9debd46b..7e889935 100644 --- a/tests/align_data/sources/test_blogs.py +++ b/tests/align_data/sources/test_blogs.py @@ -870,46 +870,46 @@ def test_axrp_dataset_extract_item_url(url, expected): assert dataset._extract_item_url({'link': url}) == expected -@pytest.mark.parametrize('item, expected', ( - ({}, ['default authors']), - ({'authors': []}, ['default authors']), - ({'authors': [{'bla': 'bla'}]}, ['default authors']), - ({'authors': [{'name': ''}]}, ['default authors']), - ({'authors': [{'name': ' \t \n'}]}, ['default authors']), - - ({'title': 'bla bla bla'}, ['default authors']), - ({'title': 'bla bla bla with'}, ['default authors']), - ({'title': 'bla bla bla with \t \n'}, ['default authors']), - - ({'authors': [{'name': 'mr. blobby'}]}, ['mr. blobby']), - ({'authors': [{'name': 'mr. blobby'}, {'name': 'janek'}]}, ['mr. blobby', 'janek']), - - ({'title': 'bla bla bla with your momma'}, ['default authors', 'your momma']), -)) +@pytest.mark.parametrize( + "item, expected", + ( + ({}, ["default authors"]), + ({"authors": []}, ["default authors"]), + ({"authors": [{"bla": "bla"}]}, ["default authors"]), + ({"authors": [{"name": ""}]}, ["default authors"]), + ({"authors": [{"name": " \t \n"}]}, ["default authors"]), + ({"title": "bla bla bla"}, ["default authors"]), + ({"title": "bla bla bla with"}, ["default authors"]), + ({"title": "bla bla bla with \t \n"}, ["default authors"]), + ({"authors": [{"name": "mr. blobby"}]}, ["mr. blobby"]), + ({"authors": [{"name": "mr. blobby"}, {"name": "janek"}]}, ["mr. blobby", "janek"]), + ({"title": "bla bla bla with your momma"}, ["default authors", "your momma"]), + ), +) def test_axrp_dataset_extract_authors(item, expected): - dataset = AXRPDataset(name='bla', url='https://ble.ble.com', authors=['default authors']) + dataset = AXRPDataset(name="bla", url="https://ble.ble.com", authors=["default authors"]) assert dataset.extract_authors(item) == expected def test_axrp_dataset_process_entry(): - dataset = AXRPDataset(name='bla', url='https://ble.ble.com', authors=['default authors']) - url = 'https://ble.ble.com/ble/ble' + dataset = AXRPDataset(name="bla", url="https://ble.ble.com", authors=["default authors"]) + url = "https://ble.ble.com/ble/ble" dataset.items = { url: { - 'content': [{'value': 'bla bla'}], - 'link': '/ble/ble', - 'published': '2023-07-27T03:50:00+00:00', - 'title': 'Something or other with your momma', + "content": [{"value": "bla bla"}], + "link": "/ble/ble", + "published": "2023-07-27T03:50:00+00:00", + "title": "Something or other with your momma", } } assert dataset.process_entry(url).to_dict() == { - 'authors': ['default authors', 'your momma'], - 'date_published': '2023-07-27T03:50:00Z', - 'id': None, - 'source': 'bla', - 'source_type': 'blog', - 'summaries': [], - 'text': 'bla bla', - 'title': 'Something or other with your momma', - 'url': 'https://ble.ble.com/ble/ble', + "authors": ["default authors", "your momma"], + "date_published": "2023-07-27T03:50:00Z", + "id": None, + "source": "bla", + "source_type": "blog", + "summaries": [], + "text": "bla bla", + "title": "Something or other with your momma", + "url": "https://ble.ble.com/ble/ble", } From 55a0c8175bc23014e84f4ab1f09c5e4def772e3b Mon Sep 17 00:00:00 2001 From: Thomas Lemoine <43831409+Thomas-Lemoine@users.noreply.github.com> Date: Sat, 26 Aug 2023 10:44:04 -0400 Subject: [PATCH 10/25] fix titles (#159) * fix titles * tests * removed title from id_fields * removed useless test * get_item_key returns a string * default title is None instead of '' * normalize urls when comparing processed and processing items * special docs edge case and tests * simplify youtube link conversion * fix typo * type annotation fixes --------- Co-authored-by: Daniel O'Connell <github@ahiru.pl> --- align_data/common/alignment_dataset.py | 48 ++++++++++++++++--- align_data/common/html_dataset.py | 8 ++-- align_data/db/models.py | 26 +++++----- align_data/sources/airtable.py | 2 +- .../alignment_newsletter.py | 2 +- align_data/sources/arbital/arbital.py | 4 +- align_data/sources/articles/datasets.py | 35 +++++++++----- align_data/sources/articles/indices.py | 2 +- align_data/sources/articles/parsers.py | 2 +- align_data/sources/articles/updater.py | 2 +- align_data/sources/blogs/blogs.py | 2 +- align_data/sources/blogs/gwern_blog.py | 2 +- align_data/sources/stampy/stampy.py | 4 +- align_data/sources/youtube/youtube.py | 6 +-- tests/align_data/articles/test_datasets.py | 32 ++++++++++++- tests/align_data/articles/test_updater.py | 4 +- .../common/test_alignment_dataset.py | 39 ++++++++------- tests/align_data/common/test_html_dataset.py | 31 +++++++++++- tests/align_data/sources/test_stampy.py | 2 +- 19 files changed, 179 insertions(+), 74 deletions(-) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 377e3880..5298a944 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -1,3 +1,4 @@ +import re from datetime import datetime from itertools import islice import logging @@ -72,6 +73,7 @@ def make_data_entry(self, data, **kwargs) -> Article: data = merge_dicts(data, kwargs) summary = data.pop("summary", None) authors = data.pop("authors", []) + data['title'] = (data.get('title') or '').replace('\n', ' ').replace('\r', '') or None article = Article( pinecone_update_required=True, @@ -140,30 +142,64 @@ def items_list(self) -> Iterable: """Returns a collection of items to be processed.""" return [] - def get_item_key(self, item): + def get_item_key(self, item) -> str: """Get the identifier of the given `item` so it can be checked to see whether it's been output. The default assumption is that the `item` is a Path to a file. """ return item.name + @staticmethod + def _normalize_url(url: str | None) -> str | None: + if not url: + return url + + # ending '/' + url = url.rstrip("/") + + # Remove http and use https consistently + url = url.replace("http://", "https://") + + # Remove www + url = url.replace("https://www.", "https://") + + # Remove index.html or index.htm + url = re.sub(r'/index\.html?$', '', url) + + # Convert youtu.be links to youtube.com + url = url.replace("https://youtu.be/", "https://youtube.com/watch?v=") + + # Additional rules for mirror domains can be added here + + # agisafetyfundamentals.com -> aisafetyfundamentals.com + url = url.replace("https://agisafetyfundamentals.com", "https://aisafetyfundamentals.com") + + return url + + def _normalize_urls(self, urls: Iterable[str]) -> Set[str]: + return {self._normalize_url(url) for url in urls} + + def _load_outputted_items(self) -> Set[str]: """Load the output file (if it exists) in order to know which items have already been output.""" with make_session() as session: + items = set() if hasattr(Article, self.done_key): # This doesn't filter by self.name. The good thing about that is that it should handle a lot more # duplicates. The bad thing is that this could potentially return a massive amount of data if there # are lots of items. - return set(session.scalars(select(getattr(Article, self.done_key))).all()) + items = set(session.scalars(select(getattr(Article, self.done_key))).all()) # TODO: Properly handle this - it should create a proper SQL JSON select - return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()} + else: + items = {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()} + return self._normalize_urls(items) - def not_processed(self, item): + def not_processed(self, item) -> bool: # NOTE: `self._outputted_items` reads in all items. Which could potentially be a lot. If this starts to # cause problems (e.g. massive RAM usage, big slow downs) then it will have to be switched around, so that # this function runs a query to check if the item is in the database rather than first getting all done_keys. # If it get's to that level, consider batching it somehow - return self.get_item_key(item) not in self._outputted_items + return self._normalize_url(self.get_item_key(item)) not in self._outputted_items def unprocessed_items(self, items=None) -> Iterable: """Return a list of all items to be processed. @@ -269,7 +305,7 @@ def setup(self): for dataset in self.datasets: dataset.setup() - def get_item_key(self, entry): + def get_item_key(self, entry) -> str | None: item, dataset = entry return dataset.get_item_key(item) diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py index c7e6c469..ed70f561 100644 --- a/align_data/common/html_dataset.py +++ b/align_data/common/html_dataset.py @@ -39,7 +39,7 @@ class HTMLDataset(AlignmentDataset): def extract_authors(self, article): return self.authors - def get_item_key(self, item): + def get_item_key(self, item) -> str: article_url = item.find_all("a")[0]["href"].split("?")[0] return urljoin(self.url, article_url) @@ -55,7 +55,7 @@ def items_list(self): def _extra_values(self, contents): return {} - def get_contents(self, article_url): + def get_contents(self, article_url: str): contents = self.fetch_contents(article_url) title = self._get_title(contents) @@ -101,7 +101,7 @@ def _get_text(self, contents): def _find_date(self, items): for i in items: - if re.match("\w+ \d{1,2}, \d{4}", i.text): + if re.match(r"\w+ \d{1,2}, \d{4}", i.text): return datetime.strptime(i.text, "%b %d, %Y").replace(tzinfo=pytz.UTC) def _extract_markdown(self, element): @@ -112,7 +112,7 @@ def _extract_markdown(self, element): class RSSDataset(HTMLDataset): date_format = "%a, %d %b %Y %H:%M:%S %z" - def get_item_key(self, item): + def get_item_key(self, item: str) -> str: return item @property diff --git a/align_data/db/models.py b/align_data/db/models.py index 98716768..30dedff8 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -1,9 +1,10 @@ import json +import re import logging import pytz import hashlib from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Dict, Any from sqlalchemy import ( JSON, DateTime, @@ -71,18 +72,19 @@ def __repr__(self) -> str: return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})" def generate_id_string(self) -> bytes: - return "".join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8") + return "".join( + re.sub(r'[^a-zA-Z0-9\s]', '', str(getattr(self, field))).strip().lower() + for field in self.__id_fields + ).encode("utf-8") @property - def __id_fields(self): - if self.source == "aisafety.info": - return ["url"] + def __id_fields(self) -> List[str]: if self.source in ["importai", "ml_safety_newsletter", "alignment_newsletter"]: - return ["url", "title", "source"] - return ["url", "title"] + return ["url", "source"] + return ["url"] @property - def missing_fields(self): + def missing_fields(self) -> List[str]: fields = set(self.__id_fields) | { "text", "title", @@ -105,7 +107,7 @@ def verify_id_fields(self): missing = [field for field in self.__id_fields if not getattr(self, field)] assert not missing, f"Entry is missing the following fields: {missing}" - def update(self, other): + def update(self, other: "Article") -> "Article": for field in self.__table__.columns.keys(): if field not in ["id", "hash_id", "metadata"] and getattr(other, field): setattr(self, field, getattr(other, field)) @@ -120,7 +122,7 @@ def _set_id(self): id_string = self.generate_id_string() self.id = hashlib.md5(id_string).hexdigest() - def add_meta(self, key, val): + def add_meta(self, key: str, val): if self.meta is None: self.meta = {} self.meta[key] = val @@ -153,7 +155,7 @@ def is_valid(cls): ) @classmethod - def before_write(cls, _mapper, _connection, target): + def before_write(cls, _mapper, _connection, target: "Article"): target.verify_id_fields() if not target.status and target.missing_fields: @@ -167,7 +169,7 @@ def before_write(cls, _mapper, _connection, target): else: target._set_id() - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: if date := self.date_published: date = date.replace(tzinfo=pytz.UTC).strftime("%Y-%m-%dT%H:%M:%SZ") meta = json.loads(self.meta) if isinstance(self.meta, str) else self.meta diff --git a/align_data/sources/airtable.py b/align_data/sources/airtable.py index db24350e..6ee1337a 100644 --- a/align_data/sources/airtable.py +++ b/align_data/sources/airtable.py @@ -37,7 +37,7 @@ def map_col(k): if (mapped.get('url') or '').startswith('http'): return mapped - def get_item_key(self, item): + def get_item_key(self, item) -> str | None: return item.get('url') @property diff --git a/align_data/sources/alignment_newsletter/alignment_newsletter.py b/align_data/sources/alignment_newsletter/alignment_newsletter.py index 2b68e32f..87541d4a 100644 --- a/align_data/sources/alignment_newsletter/alignment_newsletter.py +++ b/align_data/sources/alignment_newsletter/alignment_newsletter.py @@ -30,7 +30,7 @@ def maybe(val): return None return val - def get_item_key(self, row): + def get_item_key(self, row) -> str | None: return self.maybe(row.URL) def _get_published_date(self, year): diff --git a/align_data/sources/arbital/arbital.py b/align_data/sources/arbital/arbital.py index 40ce16c3..ab19aab7 100644 --- a/align_data/sources/arbital/arbital.py +++ b/align_data/sources/arbital/arbital.py @@ -85,7 +85,7 @@ def markdownify_text(current, view): def extract_text(text): - parts = [i for i in re.split("([\[\]()])", text) if i] + parts = [i for i in re.split(r"([\[\]()])", text) if i] return markdownify_text([], zip(parts, parts[1:] + [None])) @@ -119,7 +119,7 @@ def items_list(self): logger.info("Got %s page aliases", len(items)) return items - def get_item_key(self, item): + def get_item_key(self, item: str) -> str: return item def process_entry(self, alias): diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py index 5545dfa7..cbf7f9d9 100644 --- a/align_data/sources/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Dict +from typing import Dict, Iterable import pandas as pd from gdown.download import download @@ -22,7 +22,7 @@ from align_data.sources.articles.pdf import read_pdf from align_data.sources.arxiv_papers import ( fetch_arxiv, - canonical_url as arxiv_cannonical_url, + canonical_url as arxiv_canonical_url, ) logger = logging.getLogger(__name__) @@ -43,11 +43,11 @@ def maybe(item, key: str): return None return val - def get_item_key(self, item): + def get_item_key(self, item) -> str | None: return self.maybe(item, self.done_key) @property - def items_list(self): + def items_list(self) -> Iterable[tuple]: url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}" logger.info(f"Fetching {url}") df = pd.read_csv(url) @@ -63,7 +63,7 @@ def extract_authors(item): return [] return [author.strip() for author in item.authors.split(",") if author.strip()] - def process_entry(self, item): + def process_entry(self, item: tuple): text = self._get_text(item) if not text: logger.error("Could not get text for %s - skipping for now", item.title) @@ -116,17 +116,26 @@ def get_contents(self, item) -> Dict: }, ) - def not_processed(self, item): + + def not_processed(self, item: tuple) -> bool: + item_key = self.get_item_key(item) url = self.maybe(item, "url") source_url = self.maybe(item, "source_url") - return ( - self.get_item_key(item) not in self._outputted_items - and url not in self._outputted_items - and source_url not in self._outputted_items - and (not url or arxiv_cannonical_url(url) not in self._outputted_items) - and (not source_url or arxiv_cannonical_url(source_url) not in self._outputted_items) - ) + if item_key and self._normalize_url(item_key) in self._outputted_items: + return False + + for given_url in [url, source_url]: + if given_url: + norm_url = self._normalize_url(given_url) + if norm_url in self._outputted_items: + return False + + norm_canonical_url = self._normalize_url(arxiv_canonical_url(given_url)) + if norm_canonical_url in self._outputted_items: + return False + + return True def process_entry(self, item): if ArxivPapers.is_arxiv(item.url): diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py index 220a9005..6deb02d8 100644 --- a/align_data/sources/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -252,7 +252,7 @@ class IndicesDataset(AlignmentDataset): def items_list(self): return fetch_all().values() - def get_item_key(self, item): + def get_item_key(self, item) -> str | None: return item.get("url") @staticmethod diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py index c210e565..4bd493be 100644 --- a/align_data/sources/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -267,7 +267,7 @@ def parse_domain(url: str) -> str: return url and urlparse(url).netloc.lstrip("www.") -def item_metadata(url) -> Dict[str, any]: +def item_metadata(url: str) -> Dict[str, any]: domain = parse_domain(url) try: res = fetch(url, "head") diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py index f453b5ae..c2a1d29e 100644 --- a/align_data/sources/articles/updater.py +++ b/align_data/sources/articles/updater.py @@ -19,7 +19,7 @@ class ReplacerDataset(AlignmentDataset): delimiter: str done_key = "url" - def get_item_key(self, item): + def get_item_key(self, item) -> None: return None @staticmethod diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py index 6d7b144b..7b0f8918 100644 --- a/align_data/sources/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -131,7 +131,7 @@ class TransformerCircuits(HTMLDataset): item_selector = "div.toc a" text_selector = 'h3' - def get_item_key(self, item): + def get_item_key(self, item) -> str: article_url = item.get("href").split("?")[0] return urljoin(self.url, article_url) diff --git a/align_data/sources/blogs/gwern_blog.py b/align_data/sources/blogs/gwern_blog.py index 9328d874..1d573a8e 100644 --- a/align_data/sources/blogs/gwern_blog.py +++ b/align_data/sources/blogs/gwern_blog.py @@ -16,7 +16,7 @@ class GwernBlog(HTMLDataset): COOLDOWN: int = 1 done_key = "url" - def get_item_key(self, item): + def get_item_key(self, item: str) -> str: return item @property diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py index 49f57620..95319820 100644 --- a/align_data/sources/stampy/stampy.py +++ b/align_data/sources/stampy/stampy.py @@ -3,8 +3,6 @@ import logging from dataclasses import dataclass from codaio import Coda, Document -from datetime import timezone -from dateutil.parser import parse from align_data.common.alignment_dataset import AlignmentDataset from align_data.settings import CODA_TOKEN, CODA_DOC_ID, ON_SITE_TABLE @@ -35,7 +33,7 @@ def items_list(self): table = doc.get_table(ON_SITE_TABLE) return table.to_dict() # a list of dicts - def get_item_key(self, entry): + def get_item_key(self, entry) -> str: return html.unescape(entry["Question"]) def _get_published_date(self, entry): diff --git a/align_data/sources/youtube/youtube.py b/align_data/sources/youtube/youtube.py index 876dc09e..740597d0 100644 --- a/align_data/sources/youtube/youtube.py +++ b/align_data/sources/youtube/youtube.py @@ -34,7 +34,7 @@ def next_page(self, collection_id, next_page_token): return {"items": []} @staticmethod - def _get_id(item): + def _get_id(item) -> str | None: if item.get("kind") == "youtube#searchResult": resource = item["id"] elif item.get("kind") == "youtube#playlistItem": @@ -66,9 +66,9 @@ def items_list(self): for video in self.fetch_videos(collection_id) ) - def get_item_key(self, item): + def get_item_key(self, item) -> str | None: video_id = self._get_id(item) - return f"https://www.youtube.com/watch?v={video_id}" + return video_id and f"https://www.youtube.com/watch?v={video_id}" def _get_contents(self, video): video_id = self._get_id(video) diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index 5b73e05b..7f911208 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -471,7 +471,7 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv): ) def test_special_docs_not_processed_true(url, expected): dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") - dataset._outputted_items = [url, expected] + dataset._outputted_items = dataset._normalize_urls({url, expected}) assert not dataset.not_processed(Mock(url=url, source_url=None)) assert not dataset.not_processed(Mock(url=None, source_url=url)) @@ -490,3 +490,33 @@ def test_special_docs_not_processed_false(url): dataset._outputted_items = [] assert dataset.not_processed(Mock(url=url, source_url=None)) assert dataset.not_processed(Mock(url=None, source_url=url)) + + +@pytest.fixture +def _outputted_items(): + return [ + "http://bla.bla", + "http://ble.ble", + "https://arxiv.org/abs/2001.11038/", + "https://arxiv.org/pdf/2001.00038", + "https://www.arxiv.org/pdf/2001.11038.pdf", + "https://arxiv.org/pdf/2002.11038v3.pdf", + ] + +def test_special_docs_check_not_processed(_outputted_items): + dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da") + dataset._outputted_items = dataset._normalize_urls(_outputted_items) + + # existing tests + assert dataset.not_processed(Mock(url="http://google.com", source_url=None)) + assert dataset.not_processed(Mock(url=None, source_url="http://google.com")) + assert dataset.not_processed(Mock(url=None, source_url=None)) + assert dataset.not_processed(Mock(url="http://ble.ble.com", source_url=None)) + + assert not dataset.not_processed(Mock(url="http://bla.bla", source_url=None)) + assert not dataset.not_processed(Mock(url="https://ble.ble/index.htm", source_url=None)) + assert not dataset.not_processed(Mock(url="https://arxiv.org/abs/2001.11038", source_url=None)) + assert not dataset.not_processed(Mock(url="https://www.arxiv.org/abs/2001.11038", source_url=None)) + + assert not dataset.not_processed(Mock(url=None, source_url="https://arxiv.org/pdf/2001.11038v3.pdf")) + assert not dataset.not_processed(Mock(url="dummy_url", source_url="https://arxiv.org/pdf/2001.11038v3.pdf")) diff --git a/tests/align_data/articles/test_updater.py b/tests/align_data/articles/test_updater.py index f9e2aea2..b5b5dff8 100644 --- a/tests/align_data/articles/test_updater.py +++ b/tests/align_data/articles/test_updater.py @@ -195,7 +195,7 @@ def test_process_entry(csv_file): assert dataset.process_entry(Item(updates, article)).to_dict() == { "authors": ["mr. blobby", "johnny"], "date_published": "2000-12-23T10:32:43Z", - "id": "d8d8cad8d28739a0862654a0e6e8ce6e", + "id": "3073112dd44a96a7efdf0253f8575e56", # id str is 'httpblacom' "source": "tests", "source_type": None, "summaries": [], @@ -238,7 +238,7 @@ def test_process_entry_empty(csv_file): assert dataset.process_entry(Item(updates, article)).to_dict() == { "authors": ["this should not be changed"], "date_published": "2000-12-23T10:32:43Z", - "id": "606e9224254f508d297bcb17bcc6d104", + "id": "283f362287e87a2d4a036d69c04b436b", "source": "this should not be changed", "source_type": None, "summaries": [], diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py index d18aaf78..4cad29f9 100644 --- a/tests/align_data/common/test_alignment_dataset.py +++ b/tests/align_data/common/test_alignment_dataset.py @@ -64,7 +64,7 @@ def test_data_entry_id_from_urls_and_title(): assert entry.to_dict() == dict( { "date_published": None, - "id": "770fe57c8c2130eda08dc392b8696f97", + "id": "761edb1a245f56b2ece52d652658b8eb", "source": None, "source_type": None, "text": None, @@ -148,7 +148,7 @@ def test_data_entry_verify_id_passes(): "text": "once upon a time", "url": "www.arbital.org", "title": "once upon a time", - "id": "770fe57c8c2130eda08dc392b8696f97", + "id": "761edb1a245f56b2ece52d652658b8eb", } ) entry.verify_id() @@ -163,30 +163,33 @@ def test_data_entry_verify_id_fails(): "id": "f2b4e02fc1dd8ae43845e4f930f2d84f", } ) - expected = "Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97" + expected = "Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 761edb1a245f56b2ece52d652658b8eb" with pytest.raises(AssertionError, match=expected): entry.verify_id() +def test_generate_id_string(): + dataset = AlignmentDataset(name="blaa") + entry = dataset.make_data_entry( + { + "url": "www.arbital.org", + "title": "once upon a time", + "id": "f2b4e02fc1dd8ae43845e4f930f2d84f", + } + ) + assert entry.generate_id_string() == b"wwwarbitalorg" + @pytest.mark.parametrize( "data, error", ( - ({"id": "123"}, "Entry is missing the following fields: \\['url', 'title'\\]"), + ({"id": "123"}, "Entry is missing the following fields: \\['url'\\]"), ( {"id": "123", "url": None}, - "Entry is missing the following fields: \\['url', 'title'\\]", - ), - ( - {"id": "123", "url": "www.google.com/"}, - "Entry is missing the following fields: \\['title'\\]", - ), - ( - {"id": "123", "url": "google", "text": None}, - "Entry is missing the following fields: \\['title'\\]", + "Entry is missing the following fields: \\['url'\\]", ), ( {"id": "123", "url": "", "title": ""}, - "Entry is missing the following fields: \\['url', 'title'\\]", + "Entry is missing the following fields: \\['url'\\]", ), ), ) @@ -233,8 +236,8 @@ class NumbersDataset(AlignmentDataset): def items_list(self): return self.nums - def get_item_key(self, item): - return item + def get_item_key(self, item) -> str: + return str(item) def process_entry(self, item): return self.make_data_entry( @@ -262,14 +265,14 @@ def test_unprocessed_items_fresh(numbers_dataset): def test_unprocessed_items_all_done(numbers_dataset): """Getting the unprocessed items from a dataset that has already processed everything should return an empty list.""" - seen = set(range(0, 10)) + seen = set(str(i) for i in range(0, 10)) with patch.object(numbers_dataset, "_load_outputted_items", return_value=seen): assert list(numbers_dataset.unprocessed_items()) == [] def test_unprocessed_items_some_done(numbers_dataset): """Getting the uprocessed items from a dataset that has partially completed should return the items that haven't been processed.""" - seen = set(range(0, 10, 2)) + seen = set(str(i) for i in range(0, 10, 2)) with patch.object(numbers_dataset, "_load_outputted_items", return_value=seen): assert list(numbers_dataset.unprocessed_items()) == list(range(1, 10, 2)) diff --git a/tests/align_data/common/test_html_dataset.py b/tests/align_data/common/test_html_dataset.py index 3efbddb0..6f3f1bec 100644 --- a/tests/align_data/common/test_html_dataset.py +++ b/tests/align_data/common/test_html_dataset.py @@ -34,14 +34,14 @@ def html_dataset(): """ -def test_html_dataset_extract_authors(html_dataset): +def test_html_dataset_extract_authors(html_dataset: HTMLDataset): assert html_dataset.extract_authors("dummy variable") == [ "John Smith", "Your momma", ] -def test_html_dataset_get_title(html_dataset): +def test_html_dataset_get_title(html_dataset: HTMLDataset): item = f""" <article> <h1> This is the title @@ -155,6 +155,33 @@ def test_html_dataset_process_entry_no_text(html_dataset): assert html_dataset.process_entry(article) is None +def test_html_dataset_newline_in_title(html_dataset: HTMLDataset): + html_with_newline_title = f""" + <article> + <h1> + the title\nwith a newline + </h1> + <div> + bla bla bla <a href="{html_dataset.url}/path/to/article">click to read more</a> bla bla + </div> + </article> + """ + article = BeautifulSoup(html_with_newline_title, "html.parser") + + with patch("requests.get", return_value=Mock(content=html_with_newline_title)): + assert html_dataset.process_entry(article).to_dict() == { + "authors": ["John Smith", "Your momma"], + "date_published": None, + "id": None, + "source": "bla", + "source_type": "blog", + "summaries": [], + "text": "bla bla bla [click to read more](http://example.com/path/to/article) bla bla", + "title": "the title with a newline", + "url": "http://example.com/path/to/article", + } + + @pytest.mark.parametrize( "item, authors", ( diff --git a/tests/align_data/sources/test_stampy.py b/tests/align_data/sources/test_stampy.py index 9b40a2c3..06041316 100644 --- a/tests/align_data/sources/test_stampy.py +++ b/tests/align_data/sources/test_stampy.py @@ -45,6 +45,6 @@ def test_process_entry(): "source_type": "markdown", "summaries": [], "text": "bla bla bla", - "title": "Why\nnot just?", + "title": "Why not just?", "url": "https://aisafety.info?state=1234", } From 3bf8901db2a131f8484c41facea9018e1a16cf2d Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Sun, 27 Aug 2023 03:26:02 -0400 Subject: [PATCH 11/25] minor refactor and bug fixes --- align_data/db/models.py | 17 +++-------- align_data/db/session.py | 28 ++++++++----------- .../embeddings/pinecone/update_pinecone.py | 2 +- align_data/settings.py | 2 +- 4 files changed, 17 insertions(+), 32 deletions(-) diff --git a/align_data/db/models.py b/align_data/db/models.py index 77a78c9d..574360eb 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -153,14 +153,7 @@ def is_valid(self) -> bool: ] ) - # URL validation - try: - result = urlparse(self.url) - url_check = all([result.scheme in ["http", "https"], result.netloc]) - except: - url_check = False - - return basic_check and url_check + return basic_check @is_valid.expression def is_valid(cls) -> bool: @@ -187,14 +180,12 @@ def before_write(cls, _mapper, _connection, target): @classmethod def check_for_changes(cls, mapper, connection, target): - # Attributes we want to monitor for changes monitored_attributes = list(PineconeMetadata.__annotations__.keys()) monitored_attributes.remove("hash_id") - changed = any(get_history(target, attr).has_changes() for attr in monitored_attributes) - - if changed and target.is_valid: - target.pinecone_update_required = True + if target.is_valid: + changed = any(get_history(target, attr).has_changes() for attr in monitored_attributes) + target.pinecone_update_required = changed def to_dict(self) -> Dict[str, Any]: if date := self.date_published: diff --git a/align_data/db/session.py b/align_data/db/session.py index d2557827..30796a3f 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -22,39 +22,33 @@ def make_session(auto_commit=False): session.commit() -def get_pinecone_articles_to_update( +def get_pinecone( session: Session, - custom_sources: List[str], force_update: bool = False, ): - """Yield Pinecone entries that require an update.""" yield from ( session.query(Article) .filter(or_(Article.pinecone_update_required.is_(True), force_update)) .filter(Article.is_valid) - .filter(Article.source.in_(custom_sources)) .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) ) -def get_pinecone_articles_by_ids( +def get_pinecone_from_sources( session: Session, custom_sources: List[str], force_update: bool = False, - hash_ids: List[int] | None = None, ): - """Yield Pinecone entries that require an update and match the given IDs.""" - if hash_ids is None: - hash_ids = [] + yield from get_pinecone(session, force_update).filter(Article.source.in_(custom_sources)) - yield from ( - session.query(Article) - .filter(or_(Article.pinecone_update_required.is_(True), force_update)) - .filter(Article.is_valid) - .filter(Article.source.in_(custom_sources)) - .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) - .filter(Article.id.in_(hash_ids)) - ) + +def get_pinecone_articles_by_ids( + session: Session, + hash_ids: List[int], + force_update: bool = False, +): + """Yield Pinecone entries that require an update and match the given IDs.""" + yield from get_pinecone_from_sources(session, force_update).filter(Article.id.in_(hash_ids)) def get_all_valid_article_ids(session: Session) -> List[str]: diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index cf1a937d..9864d00a 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -54,7 +54,7 @@ def update_articles_by_ids( """Update the Pinecone entries of specific articles based on their hash_ids.""" with make_session() as session: articles_to_update_stream = get_pinecone_articles_by_ids( - session, custom_sources, force_update, hash_ids + session, hash_ids, custom_sources, force_update ) for batch in self.batch_entries(articles_to_update_stream): self.save_batch(session, batch) diff --git a/align_data/settings.py b/align_data/settings.py index 8e91403a..2f99d285 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -39,7 +39,7 @@ host = os.environ.get("ARD_DB_HOST", "127.0.0.1") port = os.environ.get("ARD_DB_PORT", "3306") db_name = os.environ.get("ARD_DB_NAME", "alignment_research_dataset") -DB_CONNECTION_URI = f"mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}" +DB_CONNECTION_URI = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{db_name}" ARTICLE_MAIN_KEYS = [ "id", "source", From 847ecc3fcbb0eb36af8e207a0ebc119ebe376796 Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Sun, 27 Aug 2023 04:51:14 -0400 Subject: [PATCH 12/25] refactored check_for_changes --- align_data/db/models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/align_data/db/models.py b/align_data/db/models.py index 88d0e33f..415162ec 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -77,7 +77,7 @@ def __repr__(self) -> str: def generate_id_string(self) -> bytes: return "".join( - re.sub(r'[^a-zA-Z0-9\s]', '', str(getattr(self, field))).strip().lower() + re.sub(r"[^a-zA-Z0-9\s]", "", str(getattr(self, field))).strip().lower() for field in self.__id_fields ).encode("utf-8") @@ -181,12 +181,13 @@ def before_write(cls, _mapper, _connection, target: "Article"): @classmethod def check_for_changes(cls, mapper, connection, target): + if not target.is_valid: + return monitored_attributes = list(PineconeMetadata.__annotations__.keys()) monitored_attributes.remove("hash_id") - if target.is_valid: - changed = any(get_history(target, attr).has_changes() for attr in monitored_attributes) - target.pinecone_update_required = changed + changed = any(get_history(target, attr).has_changes() for attr in monitored_attributes) + target.pinecone_update_required = changed def to_dict(self) -> Dict[str, Any]: if date := self.date_published: From 6fa3584dd555a6904978e7acc1e83d1f3a5cb2ba Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Mon, 28 Aug 2023 22:04:02 -0400 Subject: [PATCH 13/25] oops nemo was right --- align_data/db/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/align_data/db/models.py b/align_data/db/models.py index 415162ec..ae24650d 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -214,4 +214,5 @@ def to_dict(self) -> Dict[str, Any]: event.listen(Article, "before_insert", Article.before_write) event.listen(Article, "before_update", Article.before_write) +event.listen(Article, "before_insert", Article.check_for_changes) event.listen(Article, "before_update", Article.check_for_changes) From f92d6335c5724331d7917282bd1db37574c1959a Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Mon, 28 Aug 2023 22:07:26 -0400 Subject: [PATCH 14/25] added confidence, added summaries to pinecone db, fixed session --- align_data/db/session.py | 13 +++++----- .../embeddings/pinecone/pinecone_models.py | 25 ++++++++++++------- .../embeddings/pinecone/update_pinecone.py | 17 ++++++++----- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/align_data/db/session.py b/align_data/db/session.py index 30796a3f..ead0a029 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -22,11 +22,11 @@ def make_session(auto_commit=False): session.commit() -def get_pinecone( +def get_pinecone_query( session: Session, force_update: bool = False, ): - yield from ( + return ( session.query(Article) .filter(or_(Article.pinecone_update_required.is_(True), force_update)) .filter(Article.is_valid) @@ -34,21 +34,20 @@ def get_pinecone( ) -def get_pinecone_from_sources( +def get_pinecone_from_sources_query( session: Session, custom_sources: List[str], force_update: bool = False, ): - yield from get_pinecone(session, force_update).filter(Article.source.in_(custom_sources)) + return get_pinecone_query(session, force_update).filter(Article.source.in_(custom_sources)) -def get_pinecone_articles_by_ids( +def get_pinecone_articles_by_ids_query( session: Session, hash_ids: List[int], force_update: bool = False, ): - """Yield Pinecone entries that require an update and match the given IDs.""" - yield from get_pinecone_from_sources(session, force_update).filter(Article.id.in_(hash_ids)) + return get_pinecone_from_sources_query(session, force_update).filter(Article.id.in_(hash_ids)) def get_all_valid_article_ids(session: Session) -> List[str]: diff --git a/align_data/embeddings/pinecone/pinecone_models.py b/align_data/embeddings/pinecone/pinecone_models.py index 3c6af194..e813551d 100644 --- a/align_data/embeddings/pinecone/pinecone_models.py +++ b/align_data/embeddings/pinecone/pinecone_models.py @@ -20,6 +20,7 @@ class PineconeMetadata(TypedDict): date_published: float authors: List[str] text: str + confidence: float | None class PineconeEntry(BaseModel): @@ -30,6 +31,7 @@ class PineconeEntry(BaseModel): date_published: float authors: List[str] text_chunks: List[str] + confidence: float | None embeddings: List[List[float] | None] def __init__(self, **data): @@ -62,15 +64,20 @@ def create_pinecone_vectors(self) -> List[Vector]: Vector( id=f"{self.hash_id}_{str(i).zfill(6)}", values=self.embeddings[i], - metadata=PineconeMetadata( - hash_id=self.hash_id, - source=self.source, - title=self.title, - authors=self.authors, - url=self.url, - date_published=self.date_published, - text=self.text_chunks[i], - ), + metadata={ + key: value + for key, value in PineconeMetadata( + hash_id=self.hash_id, + source=self.source, + title=self.title, + authors=self.authors, + url=self.url, + date_published=self.date_published, + text=self.text_chunks[i], + confidence=self.confidence, + ).items() + if value is not None # Filter out keys with None values + }, ) for i in range(self.chunk_num) if self.embeddings[i] # Skips flagged chunks diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index 9864d00a..a6dcf3d0 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -10,8 +10,8 @@ from align_data.db.models import Article from align_data.db.session import ( make_session, - get_pinecone_articles_to_update, - get_pinecone_articles_by_ids, + get_pinecone_from_sources_query, + get_pinecone_articles_by_ids_query, ) from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB from align_data.embeddings.pinecone.pinecone_models import ( @@ -42,10 +42,10 @@ def update(self, custom_sources: List[str], force_update: bool = False): :param custom_sources: List of sources to update. """ with make_session() as session: - articles_to_update_stream = get_pinecone_articles_to_update( + articles_to_update_query = get_pinecone_from_sources_query( session, custom_sources, force_update ) - for batch in self.batch_entries(articles_to_update_stream): + for batch in self.batch_entries(articles_to_update_query): self.save_batch(session, batch) def update_articles_by_ids( @@ -53,10 +53,10 @@ def update_articles_by_ids( ): """Update the Pinecone entries of specific articles based on their hash_ids.""" with make_session() as session: - articles_to_update_stream = get_pinecone_articles_by_ids( + articles_to_update_query = get_pinecone_articles_by_ids_query( session, hash_ids, custom_sources, force_update ) - for batch in self.batch_entries(articles_to_update_stream): + for batch in self.batch_entries(articles_to_update_query): self.save_batch(session, batch) def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]): @@ -108,6 +108,7 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: authors=[author.strip() for author in article.authors.split(",") if author.strip()], text_chunks=text_chunks, embeddings=embeddings, + confidence=article.confidence, ) except ( ValueError, @@ -135,7 +136,11 @@ def get_text_chunks( authors = get_authors_str(authors_lst) signature = f"Title: {title}; Author(s): {authors}." + text_chunks = text_splitter.split_text(article.text) + for summary in article.summaries: + text_chunks += text_splitter.split_text(summary.text) + return [f'###{signature}###\n"""{text_chunk}"""' for text_chunk in text_chunks] From ccefe52986b2645deb1e9e87049a30d6ed449947 Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Mon, 28 Aug 2023 22:27:30 -0400 Subject: [PATCH 15/25] fix session function names --- align_data/db/session.py | 10 +++++----- align_data/embeddings/pinecone/update_pinecone.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/align_data/db/session.py b/align_data/db/session.py index ead0a029..d76fc796 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -22,7 +22,7 @@ def make_session(auto_commit=False): session.commit() -def get_pinecone_query( +def get_pinecone_articles( session: Session, force_update: bool = False, ): @@ -34,20 +34,20 @@ def get_pinecone_query( ) -def get_pinecone_from_sources_query( +def get_pinecone_articles_by_sources( session: Session, custom_sources: List[str], force_update: bool = False, ): - return get_pinecone_query(session, force_update).filter(Article.source.in_(custom_sources)) + return get_pinecone_articles(session, force_update).filter(Article.source.in_(custom_sources)) -def get_pinecone_articles_by_ids_query( +def get_pinecone_articles_by_ids( session: Session, hash_ids: List[int], force_update: bool = False, ): - return get_pinecone_from_sources_query(session, force_update).filter(Article.id.in_(hash_ids)) + return get_pinecone_articles(session, force_update).filter(Article.id.in_(hash_ids)) def get_all_valid_article_ids(session: Session) -> List[str]: diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index a6dcf3d0..21b23f7e 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -10,8 +10,8 @@ from align_data.db.models import Article from align_data.db.session import ( make_session, - get_pinecone_from_sources_query, - get_pinecone_articles_by_ids_query, + get_pinecone_articles_by_sources, + get_pinecone_articles_by_ids, ) from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB from align_data.embeddings.pinecone.pinecone_models import ( @@ -42,7 +42,7 @@ def update(self, custom_sources: List[str], force_update: bool = False): :param custom_sources: List of sources to update. """ with make_session() as session: - articles_to_update_query = get_pinecone_from_sources_query( + articles_to_update_query = get_pinecone_articles_by_sources( session, custom_sources, force_update ) for batch in self.batch_entries(articles_to_update_query): @@ -53,7 +53,7 @@ def update_articles_by_ids( ): """Update the Pinecone entries of specific articles based on their hash_ids.""" with make_session() as session: - articles_to_update_query = get_pinecone_articles_by_ids_query( + articles_to_update_query = get_pinecone_articles_by_ids( session, hash_ids, custom_sources, force_update ) for batch in self.batch_entries(articles_to_update_query): From a40207da514c3feac3ab2abf063f8ecc5cc970fa Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Mon, 28 Aug 2023 22:32:01 -0400 Subject: [PATCH 16/25] changed articles_to_update_query name to articles_to_update --- align_data/embeddings/pinecone/update_pinecone.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index 21b23f7e..64850433 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -42,10 +42,10 @@ def update(self, custom_sources: List[str], force_update: bool = False): :param custom_sources: List of sources to update. """ with make_session() as session: - articles_to_update_query = get_pinecone_articles_by_sources( + articles_to_update = get_pinecone_articles_by_sources( session, custom_sources, force_update ) - for batch in self.batch_entries(articles_to_update_query): + for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) def update_articles_by_ids( @@ -53,10 +53,10 @@ def update_articles_by_ids( ): """Update the Pinecone entries of specific articles based on their hash_ids.""" with make_session() as session: - articles_to_update_query = get_pinecone_articles_by_ids( + articles_to_update = get_pinecone_articles_by_ids( session, hash_ids, custom_sources, force_update ) - for batch in self.batch_entries(articles_to_update_query): + for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]): From a457fa89918ae746540f994edbf1c5b76e3c1782 Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Mon, 28 Aug 2023 22:37:36 -0400 Subject: [PATCH 17/25] bug fix --- align_data/embeddings/pinecone/update_pinecone.py | 8 ++------ main.py | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index 64850433..7f662991 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -48,14 +48,10 @@ def update(self, custom_sources: List[str], force_update: bool = False): for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) - def update_articles_by_ids( - self, custom_sources: List[str], hash_ids: List[int], force_update: bool = False - ): + def update_articles_by_ids(self, hash_ids: List[int], force_update: bool = False): """Update the Pinecone entries of specific articles based on their hash_ids.""" with make_session() as session: - articles_to_update = get_pinecone_articles_by_ids( - session, hash_ids, custom_sources, force_update - ) + articles_to_update = get_pinecone_articles_by_ids(session, hash_ids, force_update) for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) diff --git a/main.py b/main.py index 4e28cd08..82c30f07 100644 --- a/main.py +++ b/main.py @@ -143,9 +143,7 @@ def pinecone_update_individual_articles(self, *hash_ids: str, force_update=False :param str hash_ids: space-separated list of article IDs. """ - names = ALL_DATASETS - - PineconeUpdater().update_articles_by_ids(names, hash_ids, force_update) + PineconeUpdater().update_articles_by_ids(hash_ids, force_update) def train_finetuning_layer(self) -> None: """ From 1238ac8853bbbe26721be6274c7881f9b6cfa981 Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Mon, 28 Aug 2023 22:43:54 -0400 Subject: [PATCH 18/25] fixed typing issues --- align_data/embeddings/finetuning/finetuning_dataset.py | 2 +- align_data/embeddings/pinecone/pinecone_db_handler.py | 4 ++++ align_data/embeddings/pinecone/update_pinecone.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/align_data/embeddings/finetuning/finetuning_dataset.py b/align_data/embeddings/finetuning/finetuning_dataset.py index 8c5eec04..1b588cb9 100644 --- a/align_data/embeddings/finetuning/finetuning_dataset.py +++ b/align_data/embeddings/finetuning/finetuning_dataset.py @@ -20,7 +20,7 @@ class FinetuningDataset(IterableDataset): def __init__(self, num_batches_per_epoch: int, cache_size: int = 1280): self.num_batches_per_epoch = num_batches_per_epoch - self.article_cache = deque(maxlen=cache_size) + self.article_cache: deque = deque(maxlen=cache_size) self.text_splitter = ParagraphSentenceUnitTextSplitter() self.pinecone_db = PineconeDB() diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py index 8222ba83..b0b09b9f 100644 --- a/align_data/embeddings/pinecone/pinecone_db_handler.py +++ b/align_data/embeddings/pinecone/pinecone_db_handler.py @@ -100,6 +100,10 @@ def query_text( **kwargs, ) -> List[ScoredVector]: query_vector = get_embedding(query)[0] + if query_vector is None: + print("The query is invalid.") + return [] + return self.query_vector( query=query_vector, top_k=top_k, diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index 7f662991..301b67eb 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -55,7 +55,7 @@ def update_articles_by_ids(self, hash_ids: List[int], force_update: bool = False for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) - def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]): + def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry | None]]): try: for article, pinecone_entry in batch: if pinecone_entry: From f4575c886124b99520927aae4a960cbe0925ed56 Mon Sep 17 00:00:00 2001 From: Daniel O'Connell <github@ahiru.pl> Date: Fri, 1 Sep 2023 15:53:03 +0200 Subject: [PATCH 19/25] renable pinecone updates (#175) --- .github/workflows/update-pinecone.yml | 46 +++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/.github/workflows/update-pinecone.yml b/.github/workflows/update-pinecone.yml index b87d9ed7..4640e4cb 100644 --- a/.github/workflows/update-pinecone.yml +++ b/.github/workflows/update-pinecone.yml @@ -53,28 +53,28 @@ jobs: - name: Checkout repository uses: actions/checkout@v2 - # - name: Setup Python environment - # uses: actions/setup-python@v2 - # with: - # python-version: '3.x' + - name: Setup Python environment + uses: actions/setup-python@v2 + with: + python-version: '3.x' - # - name: Install dependencies - # run: | - # pip install -r requirements.txt; - # python -c 'import nltk; nltk.download("punkt")' + - name: Install dependencies + run: | + pip install -r requirements.txt; + python -c 'import nltk; nltk.download("punkt")' - # - name: Process dataset - # env: - # ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }} - # ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }} - # ARD_DB_HOST: ${{ secrets.ARD_DB_HOST || inputs.db_host }} - # ARD_DB_NAME: alignment_research_dataset - # OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || inputs.openai_api_key }} - # PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY || inputs.pinecone_api_key }} - # PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT || inputs.pinecone_environment }} - # run: | - # if [ "${{ inputs.datasource }}" = "all" ]; then - # python main.py pinecone_update_all - # else - # python main.py pinecone_update ${{ inputs.datasource }} - # fi + - name: Process dataset + env: + ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }} + ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }} + ARD_DB_HOST: ${{ secrets.ARD_DB_HOST || inputs.db_host }} + ARD_DB_NAME: alignment_research_dataset + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || inputs.openai_api_key }} + PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY || inputs.pinecone_api_key }} + PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT || inputs.pinecone_environment }} + run: | + if [ "${{ inputs.datasource }}" = "all" ]; then + python main.py pinecone_update_all + else + python main.py pinecone_update ${{ inputs.datasource }} + fi From fd4a5870e4dda41e978726e8e2a9bfc156e4b09e Mon Sep 17 00:00:00 2001 From: Thomas Lemoine <43831409+Thomas-Lemoine@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:39:14 -0400 Subject: [PATCH 20/25] Arbital refactor (#174) * first commit * refactor markdownify_text with summaries * added test for new arbital summary behaviour * minor refactor of parse_arbital_link * added edge cases to parse_arbital_link --- align_data/sources/arbital/arbital.py | 195 +++++++++++++++-------- tests/align_data/sources/test_arbital.py | 29 ++-- 2 files changed, 143 insertions(+), 81 deletions(-) diff --git a/align_data/sources/arbital/arbital.py b/align_data/sources/arbital/arbital.py index ab19aab7..f6087937 100644 --- a/align_data/sources/arbital/arbital.py +++ b/align_data/sources/arbital/arbital.py @@ -1,5 +1,9 @@ import re +from dataclasses import dataclass, field +from datetime import datetime, timezone import logging +from typing import List, Tuple, Iterator, Dict, Union, Any, TypedDict + import requests from datetime import datetime, timezone from dateutil.parser import parse @@ -10,81 +14,132 @@ logger = logging.getLogger(__name__) -def parse_arbital_link(contents): - text = contents[1].split(" ") - url = f"https://arbital.com/p/{text[0]}" - if len(text) > 1: - title = " ".join(text[1:]) - else: - title = url - return f"[{title}]({url})" +class Page(TypedDict, total=False): + text: str + likeableId: str + likeableType: str + title: str + editCreatedAt: str + pageCreatedAt: str + alias: str + userId: str + tagIds: str + changeLogs: List[Dict[str, Any]] -def flatten(val): - if isinstance(val, (list, tuple)): - return [item for i in val for item in flatten(i)] - return [val] +def parse_arbital_link(internal_link: str) -> str: + """ + Parses the Arbital internal link. + :param str internal_link: The internal link to parse. + :return: The parsed link. + :rtype: str -def markdownify_text(current, view): - """Recursively parse the text parts in `view` to create a markdown AST from them. + Typical format: `123 Some title` -> `[Some title](https://arbital.com/p/123)` + Special cases: + `toc:` -> `toc:` + `https://www.gwern.net/ Gwern Branwen` -> `[Gwern Branwen](https://www.gwern.net/)` + """ + page_id, *title_parts = internal_link.split(" ") + if not page_id or page_id.startswith("toc:"): + # could be a regular text bracket, ignore it + return internal_link + if page_id.startswith("http"): + # could be a regular link, ignore it + return f"[{' '.join(title_parts)}]({page_id})" + url = f"https://arbital.com/p/{page_id}" + title = " ".join(title_parts) if title_parts else url + return f"[{title}]({url})" - Arbital adds some funky extra stuff to markdown. The known things are: - * "[summary: <contents>]" blocks to add summaries - * "[123 <title>]" are internal links to `<123>` - The `view` parameter should be a generator, so recursive calls can iterate over it without needing - to mess about with indexes etc. +def flatten(val: Union[List[str], Tuple[str], str]) -> List[str]: + """Flattens a nested list.""" + if isinstance(val, (list, tuple)): + return [item for sublist in val for item in flatten(sublist)] + return [val] - :param List[str] current: the list of parsed items. Should generally be passed in as `[]` - :param generator(str, str) view: a generator that returns `part` and `next_part`, where `part` is the current item - and `next_part` is a lookahead - :returns: a tuple of `(<summary string>, <markdown contents>)` +def markdownify_text(current: List[str], view: Iterator[Tuple[str, str]]) -> Tuple[str, str]: + """ + Recursively parse text segments from `view` to generate a markdown Abstract Syntax Tree (AST). + + This function helps in transitioning from Arbital's specific markdown extensions to standard markdown. It specifically + handles two main features: + - "[summary: <contents>]" blocks, which are used in Arbital to add summaries. + - "[123 <title>]" which are Arbital's internal links pointing to https://arbital.com/p/123, with link title <title>. + + Args: + :param List[str] current: A list of parsed items. Should generally be initialized as an empty list. + :param Iterator[Tuple[str, str]] view: An iterator that returns pairs of `part` and `next_part`, where `part` is the + current segment and `next_part` provides a lookahead. + + :return: <summary>, <text>, where <summary> is the summary extracted from the text, and <text> is the text with all + Arbital-specific markdown extensions replaced with standard markdown. + :rtype: Tuple[str, str] + + Example: + From the text: "[summary: A behaviorist [6w genie]]" + We get the input: + current = [] + view = iter([('[', 'summary: A behaviorist '), ('summary: A behaviorist ', '['), ('[', '6w genie'), ('6w genie', ']'), (']', ']'), (']', None)]) + The function should return: + `('A behaviorist [genie](https://arbital.com/p/6w)', '')` + + Note: + This function assumes that `view` provides a valid Arbital markdown sequence. Malformed sequences might lead to + unexpected results. """ in_link = False + summary = "" for part, next_part in view: if part == "[": # Recursively try to parse this new section - it's probably a link, but can be something else - current.append(markdownify_text([part], view)) - elif part == "]" and next_part == "(": - # mark that it's now in the url part of a markdown link - current.append("]") - in_link = True + sub_summary, text = markdownify_text([part], view) + summary += sub_summary + "\n\n" + current.append(text) + elif part == "]": - # this is the arbital summary - just join it for now, but it'll have to be handled later - if current[1].startswith("summary"): - return "".join(current[1:]) - # if this was a TODO section, then ignore it - if current[1].startswith("todo"): - return "" - # Otherwise it's an arbital link - return parse_arbital_link(current) + if next_part == "(": + # Indicate that it's in the URL part of a markdown link. + current.append(part) + in_link = True + else: + # Extract the descriptor, which might be a summary tag, TODO tag, or an Arbital internal link's "<page_id> <title>". + descriptor = current[1] + + # Handle Arbital summary. + if descriptor.startswith("summary"): + summary_tag, summary_content = "".join(current[1:]).split(":", 1) + return f"{summary_tag}: {summary_content.strip()}", "" + + # Handle TODO section (ignore it). + if descriptor.startswith("todo"): + return "", "" + + # Handle Arbital link (e.g., "6w genie" -> "[6w genie](https://arbital.com/p/6w)"). + return "", parse_arbital_link(descriptor) + elif in_link and part == ")": # this is the end of a markdown link - just join the contents, as they're already correct - return "".join(current + [part]) + return "", "".join(current + [part]) + elif in_link and current[-1] == "(" and next_part != ")": # This link is strange... looks like it could be malformed? # Assuming that it's malformed and missing a closing `)` # This will remove any additional info in the link, but that seems a reasonable price? words = part.split(" ") - return "".join(current + [words[0], ") ", " ".join(words[1:])]) + return "", "".join(current + [words[0], ") ", " ".join(words[1:])]) + else: # Just your basic text - add it to the processed parts and go on your merry way current.append(part) - # Check if the first item is the summary - if so, extract it - summary = "" - if current[0].startswith("summary"): - _, summary = re.split(r"summary[()\w]*:", current[0], 1) - current = current[1:] - # Otherwise just join all the parts back together return summary.strip(), "".join(flatten(current)).strip() -def extract_text(text): +def extract_text(text: str) -> Tuple[str, str]: parts = [i for i in re.split(r"([\[\]()])", text) if i] return markdownify_text([], zip(parts, parts[1:] + [None])) @@ -106,10 +161,10 @@ class Arbital(AlignmentDataset): "sec-fetch-dest": "empty", "accept-language": "en-US,en;q=0.9", } - titles_map = {} + titles_map: Dict[str, str] = field(default_factory=dict) @property - def items_list(self): + def items_list(self) -> List[str]: logger.info("Getting page aliases") items = [ alias @@ -122,7 +177,7 @@ def items_list(self): def get_item_key(self, item: str) -> str: return item - def process_entry(self, alias): + def process_entry(self, alias: str): try: page = self.get_page(alias) summary, text = extract_text(page["text"]) @@ -144,33 +199,37 @@ def process_entry(self, alias): except Exception as e: logger.error(f"Error getting page {alias}: {e}") return None - - def get_arbital_page_aliases(self, subspace): + + def send_post_request(self, url: str, page_alias: str, referer_base: str) -> requests.Response: headers = self.headers.copy() - headers["referer"] = f"https://arbital.com/explore/{subspace}/" - data = f'{{"pageAlias":"{subspace}"}}' - response = requests.post( - "https://arbital.com/json/explore/", headers=headers, data=data - ).json() - return list(response["pages"].keys()) + headers['referer'] = f"{referer_base}{page_alias}/" + data = f'{{"pageAlias":"{page_alias}"}}' + return requests.post(url, headers=headers, data=data) + + def get_arbital_page_aliases(self, subspace: str) -> List[str]: + response = self.send_post_request( + url='https://arbital.com/json/explore/', + page_alias=subspace, + referer_base='https://arbital.com/explore/' + ) + return list(response.json()['pages'].keys()) + + def get_page(self, alias: str) -> Page: + response = self.send_post_request( + url='https://arbital.com/json/primaryPage/', + page_alias=alias, + referer_base='https://arbital.com/p/' + ) + return response.json()['pages'][alias] @staticmethod - def _get_published_date(page): + def _get_published_date(page: Page) -> datetime | None: date_published = page.get("editCreatedAt") or page.get("pageCreatedAt") if date_published: return parse(date_published).astimezone(timezone.utc) return None - def get_page(self, alias): - headers = self.headers.copy() - headers["referer"] = "https://arbital.com/" - data = f'{{"pageAlias":"{alias}"}}' - response = requests.post( - "https://arbital.com/json/primaryPage/", headers=headers, data=data - ) - return response.json()["pages"][alias] - - def get_title(self, itemId): + def get_title(self, itemId: str) -> str | None: if title := self.titles_map.get(itemId): return title @@ -186,7 +245,7 @@ def get_title(self, itemId): return title return None - def extract_authors(self, page): + def extract_authors(self, page: Page) -> List[str]: """Get all authors of this page. This will work faster the more its used, as it only fetches info for authors it hasn't yet seen. diff --git a/tests/align_data/sources/test_arbital.py b/tests/align_data/sources/test_arbital.py index af65ed05..19ad8e97 100644 --- a/tests/align_data/sources/test_arbital.py +++ b/tests/align_data/sources/test_arbital.py @@ -15,12 +15,14 @@ @pytest.mark.parametrize( "contents, expected", ( - (["[", "123"], "[https://arbital.com/p/123](https://arbital.com/p/123)"), - (["[", "123 Some title"], "[Some title](https://arbital.com/p/123)"), + ("123", "[https://arbital.com/p/123](https://arbital.com/p/123)"), + ("123 Some title", "[Some title](https://arbital.com/p/123)"), ( - ["[", "123 Some title with multiple words"], + "123 Some title with multiple words", "[Some title with multiple words](https://arbital.com/p/123)", ), + ("https://www.gwern.net/ Gwern Branwen", "[Gwern Branwen](https://www.gwern.net/)"), + ("toc:", "toc:"), # `toc:` is a mysterious thing ), ) def test_parse_arbital_link(contents, expected): @@ -84,37 +86,38 @@ def test_markdownify_text_contents_arbital_markdown(text, expected): ( ( "[summary: summaries should be extracted] bla bla bla", - "summaries should be extracted", + ("summary: summaries should be extracted", "bla bla bla"), ), ( "[summary: \n whitespace should be stripped \n] bla bla bla", - "whitespace should be stripped", + ("summary: whitespace should be stripped", "bla bla bla"), ), ( "[summary(Bold): special summaries should be extracted] bla bla bla", - "special summaries should be extracted", + ("summary(Bold): special summaries should be extracted", "bla bla bla"), ), ( "[summary(Markdown): special summaries should be extracted] bla bla bla", - "special summaries should be extracted", + ("summary(Markdown): special summaries should be extracted", "bla bla bla"), ), ( "[summary(BLEEEE): special summaries should be extracted] bla bla bla", - "special summaries should be extracted", + ("summary(BLEEEE): special summaries should be extracted", "bla bla bla"), ), ( "[summary: markdown is handled: [bla](https://bla.bla)] bla bla bla", - "markdown is handled: [bla](https://bla.bla)", + ("summary: markdown is handled: [bla](https://bla.bla)", "bla bla bla"), ), ( "[summary: markdown is handled: [123 ble ble]] bla bla bla", - "markdown is handled: [ble ble](https://arbital.com/p/123)", + ("summary: markdown is handled: [ble ble](https://arbital.com/p/123)", "bla bla bla"), ), ), ) -def test_markdownify_text_summary(text, expected): - summary, _ = extract_text(text) - assert summary == expected +def test_markdownify_text_summary_and_content(text, expected): + summary, text = extract_text(text) + assert summary == expected[0] + assert text == expected[1] @pytest.fixture From 68e32b8a07146d5c8ff4869cf39e59bcc2e716ab Mon Sep 17 00:00:00 2001 From: Henri Lemoine <henri123lemoine@gmail.com> Date: Fri, 1 Sep 2023 13:51:00 -0400 Subject: [PATCH 21/25] fix 2048+ batch embedding --- align_data/embeddings/embedding_utils.py | 50 +++++++++++++----------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py index cf98228c..ff386325 100644 --- a/align_data/embeddings/embedding_utils.py +++ b/align_data/embeddings/embedding_utils.py @@ -88,34 +88,39 @@ def wrapper(*args, **kwargs): @handle_openai_errors -def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]: - """ - Check moderation on a list of texts. +def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]: + """Process a batch for moderation checks.""" + return openai.Moderation.create(input=batch)["results"] - Parameters: - - texts (List[str]): List of texts to be checked for moderation. - - max_texts_num (int): Number of texts to check at once. Defaults to 32. - Returns: - - List[ModerationInfoType]: List of moderation results for the provided texts. - """ - total_texts = len(texts) - results = [] +def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]: + """Batch moderation checks on list of texts.""" + return [ + result + for batch in (texts[i : i + max_texts_num] for i in range(0, len(texts), max_texts_num)) + for result in _single_batch_moderation_check(batch) + ] - for i in range(0, total_texts, max_texts_num): - batch_texts = texts[i : i + max_texts_num] - batch_results = openai.Moderation.create(input=batch_texts)["results"] - results.extend(batch_results) - return results +@handle_openai_errors +def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]: + """Compute embeddings for a batch.""" + batch_data = openai.Embedding.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data + return [d["embedding"] for d in batch_data] -@handle_openai_errors -def _compute_openai_embeddings(non_flagged_texts: List[str], **kwargs) -> List[List[float]]: - data = openai.Embedding.create( - input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs - ).data - return [d["embedding"] for d in data] +def _compute_openai_embeddings( + non_flagged_texts: List[str], max_texts_num: int = 2048, **kwargs +) -> List[List[float]]: + """Batch computation of embeddings for non-flagged texts.""" + return [ + embedding + for batch in ( + non_flagged_texts[i : i + max_texts_num] + for i in range(0, len(non_flagged_texts), max_texts_num) + ) + for embedding in _single_batch_compute_openai_embeddings(batch, **kwargs) + ] def get_embeddings_without_moderation( @@ -193,7 +198,6 @@ def get_embeddings( Returns: - Tuple[List[Optional[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results. """ - assert len(texts) <= 2048, "The batch size should not be larger than 2048." assert all(texts), "No empty strings allowed in the input list." # replace newlines, which can negatively affect performance From 0bcf0d903c8e6c399a2ed679c089bdc565dcc5cc Mon Sep 17 00:00:00 2001 From: Thomas Lemoine <43831409+Thomas-Lemoine@users.noreply.github.com> Date: Sat, 2 Sep 2023 17:55:49 -0400 Subject: [PATCH 22/25] Arbital many summaries (#179) * first commit * refactor markdownify_text with summaries * added test for new arbital summary behaviour * minor refactor of parse_arbital_link * added edge cases to parse_arbital_link * summaries optional key in data_entry added * arbital now uses a list of summaries instead of appending many summaries together --- align_data/common/alignment_dataset.py | 6 ++++- align_data/sources/arbital/arbital.py | 32 ++++++++++++------------ tests/align_data/sources/test_arbital.py | 26 ++++++++++++------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 7e461662..67b4bfe5 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -71,7 +71,11 @@ def _add_authors(self, article: Article, authors: List[str]) -> Article: def make_data_entry(self, data, **kwargs) -> Article: data = merge_dicts(data, kwargs) + + summaries = data.pop("summaries", []) summary = data.pop("summary", None) + summaries += [summary] if summary else [] + authors = data.pop("authors", []) data['title'] = (data.get('title') or '').replace('\n', ' ').replace('\r', '') or None @@ -80,7 +84,7 @@ def make_data_entry(self, data, **kwargs) -> Article: **{k: v for k, v in data.items() if k in ARTICLE_MAIN_KEYS}, ) self._add_authors(article, authors) - if summary: + for summary in summaries: # Note: This will be skipped if summaries is empty article.summaries.append(Summary(text=summary, source=self.name)) return article diff --git a/align_data/sources/arbital/arbital.py b/align_data/sources/arbital/arbital.py index f6087937..b08393c4 100644 --- a/align_data/sources/arbital/arbital.py +++ b/align_data/sources/arbital/arbital.py @@ -59,7 +59,7 @@ def flatten(val: Union[List[str], Tuple[str], str]) -> List[str]: return [val] -def markdownify_text(current: List[str], view: Iterator[Tuple[str, str]]) -> Tuple[str, str]: +def markdownify_text(current: List[str], view: Iterator[Tuple[str, str]]) -> Tuple[List[str], str]: """ Recursively parse text segments from `view` to generate a markdown Abstract Syntax Tree (AST). @@ -73,9 +73,9 @@ def markdownify_text(current: List[str], view: Iterator[Tuple[str, str]]) -> Tup :param Iterator[Tuple[str, str]] view: An iterator that returns pairs of `part` and `next_part`, where `part` is the current segment and `next_part` provides a lookahead. - :return: <summary>, <text>, where <summary> is the summary extracted from the text, and <text> is the text with all + :return: <summaries>, <text>, where <summaries> are the summaries extracted from the text, and <text> is the text with all Arbital-specific markdown extensions replaced with standard markdown. - :rtype: Tuple[str, str] + :rtype: Tuple[List[str], str] Example: From the text: "[summary: A behaviorist [6w genie]]" @@ -83,20 +83,20 @@ def markdownify_text(current: List[str], view: Iterator[Tuple[str, str]]) -> Tup current = [] view = iter([('[', 'summary: A behaviorist '), ('summary: A behaviorist ', '['), ('[', '6w genie'), ('6w genie', ']'), (']', ']'), (']', None)]) The function should return: - `('A behaviorist [genie](https://arbital.com/p/6w)', '')` + `(['A behaviorist [genie](https://arbital.com/p/6w)'], '')` Note: This function assumes that `view` provides a valid Arbital markdown sequence. Malformed sequences might lead to unexpected results. """ in_link = False - summary = "" + summaries = [] for part, next_part in view: if part == "[": # Recursively try to parse this new section - it's probably a link, but can be something else - sub_summary, text = markdownify_text([part], view) - summary += sub_summary + "\n\n" + sub_summaries, text = markdownify_text([part], view) + summaries.extend(sub_summaries) current.append(text) elif part == "]": @@ -110,33 +110,34 @@ def markdownify_text(current: List[str], view: Iterator[Tuple[str, str]]) -> Tup # Handle Arbital summary. if descriptor.startswith("summary"): + # descriptor will be something like "summary(Technical): <contents>", so we split by `:` summary_tag, summary_content = "".join(current[1:]).split(":", 1) - return f"{summary_tag}: {summary_content.strip()}", "" + return [f"{summary_tag}: {summary_content.strip()}"], "" # Handle TODO section (ignore it). if descriptor.startswith("todo"): - return "", "" + return [], "" # Handle Arbital link (e.g., "6w genie" -> "[6w genie](https://arbital.com/p/6w)"). - return "", parse_arbital_link(descriptor) + return [], parse_arbital_link(descriptor) elif in_link and part == ")": # this is the end of a markdown link - just join the contents, as they're already correct - return "", "".join(current + [part]) + return [], "".join(current + [part]) elif in_link and current[-1] == "(" and next_part != ")": # This link is strange... looks like it could be malformed? # Assuming that it's malformed and missing a closing `)` # This will remove any additional info in the link, but that seems a reasonable price? words = part.split(" ") - return "", "".join(current + [words[0], ") ", " ".join(words[1:])]) + return [], "".join(current + [words[0], ") ", " ".join(words[1:])]) else: # Just your basic text - add it to the processed parts and go on your merry way current.append(part) # Otherwise just join all the parts back together - return summary.strip(), "".join(flatten(current)).strip() + return summaries, "".join(flatten(current)).strip() def extract_text(text: str) -> Tuple[str, str]: @@ -146,7 +147,6 @@ def extract_text(text: str) -> Tuple[str, str]: @dataclass class Arbital(AlignmentDataset): - summary_key: str = "summary" ARBITAL_SUBSPACES = ["ai_alignment", "math", "rationality"] done_key = "alias" @@ -180,7 +180,7 @@ def get_item_key(self, item: str) -> str: def process_entry(self, alias: str): try: page = self.get_page(alias) - summary, text = extract_text(page["text"]) + summaries, text = extract_text(page["text"]) return self.make_data_entry( { @@ -193,7 +193,7 @@ def process_entry(self, alias: str): "authors": self.extract_authors(page), "alias": alias, "tags": list(filter(None, map(self.get_title, page["tagIds"]))), - "summary": summary, + "summaries": summaries, } ) except Exception as e: diff --git a/tests/align_data/sources/test_arbital.py b/tests/align_data/sources/test_arbital.py index 19ad8e97..87fed8cf 100644 --- a/tests/align_data/sources/test_arbital.py +++ b/tests/align_data/sources/test_arbital.py @@ -86,37 +86,45 @@ def test_markdownify_text_contents_arbital_markdown(text, expected): ( ( "[summary: summaries should be extracted] bla bla bla", - ("summary: summaries should be extracted", "bla bla bla"), + (["summary: summaries should be extracted"], "bla bla bla"), + ), + ( + "[summary: summaries should be extracted] [summary(Technical): technical summary should be handled separately] bla bla bla", + (["summary: summaries should be extracted", "summary(Technical): technical summary should be handled separately"], "bla bla bla"), + ), + ( + "[summary: summaries should be extracted] bla bla bla [summary(Technical): summaries should work in the middle too] bla bla bla", + (["summary: summaries should be extracted", "summary(Technical): summaries should work in the middle too"], "bla bla bla bla bla bla"), ), ( "[summary: \n whitespace should be stripped \n] bla bla bla", - ("summary: whitespace should be stripped", "bla bla bla"), + (["summary: whitespace should be stripped"], "bla bla bla"), ), ( "[summary(Bold): special summaries should be extracted] bla bla bla", - ("summary(Bold): special summaries should be extracted", "bla bla bla"), + (["summary(Bold): special summaries should be extracted"], "bla bla bla"), ), ( "[summary(Markdown): special summaries should be extracted] bla bla bla", - ("summary(Markdown): special summaries should be extracted", "bla bla bla"), + (["summary(Markdown): special summaries should be extracted"], "bla bla bla"), ), ( "[summary(BLEEEE): special summaries should be extracted] bla bla bla", - ("summary(BLEEEE): special summaries should be extracted", "bla bla bla"), + (["summary(BLEEEE): special summaries should be extracted"], "bla bla bla"), ), ( "[summary: markdown is handled: [bla](https://bla.bla)] bla bla bla", - ("summary: markdown is handled: [bla](https://bla.bla)", "bla bla bla"), + (["summary: markdown is handled: [bla](https://bla.bla)"], "bla bla bla"), ), ( "[summary: markdown is handled: [123 ble ble]] bla bla bla", - ("summary: markdown is handled: [ble ble](https://arbital.com/p/123)", "bla bla bla"), + (["summary: markdown is handled: [ble ble](https://arbital.com/p/123)"], "bla bla bla"), ), ), ) def test_markdownify_text_summary_and_content(text, expected): - summary, text = extract_text(text) - assert summary == expected[0] + summaries, text = extract_text(text) + assert summaries == expected[0] assert text == expected[1] From d9232a478a0af60f4bf57b347b2be0a36654f5fc Mon Sep 17 00:00:00 2001 From: Thomas Lemoine <43831409+Thomas-Lemoine@users.noreply.github.com> Date: Sun, 3 Sep 2023 16:15:12 -0400 Subject: [PATCH 23/25] Adding back summaries entry key (#178) * first commit * refactor markdownify_text with summaries * added test for new arbital summary behaviour * minor refactor of parse_arbital_link * added edge cases to parse_arbital_link * summaries optional key in data_entry added * nit --- align_data/common/alignment_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 67b4bfe5..344ee89d 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -84,8 +84,7 @@ def make_data_entry(self, data, **kwargs) -> Article: **{k: v for k, v in data.items() if k in ARTICLE_MAIN_KEYS}, ) self._add_authors(article, authors) - for summary in summaries: # Note: This will be skipped if summaries is empty - article.summaries.append(Summary(text=summary, source=self.name)) + article.summaries += [Summary(text=summary, source=self.name) for summary in summaries] return article def to_jsonl(self, out_path=None, filename=None) -> Path: From bbd6d81cc75372548629b1cf45f54ca34923ac2c Mon Sep 17 00:00:00 2001 From: Daniel O'Connell <github@ahiru.pl> Date: Mon, 4 Sep 2023 15:08:35 +0200 Subject: [PATCH 24/25] Update pinecone (#181) * More pinecone statuses * Add last checked column * Remove invalid items from pinecone --- align_data/db/models.py | 15 ++- align_data/db/session.py | 40 +++++- .../embeddings/pinecone/update_pinecone.py | 114 +++++++++++++++--- .../versions/1866340e456a_pinecone_status.py | 54 +++++++++ ...71e3_added_pinecone_update_required_to_.py | 4 - 5 files changed, 194 insertions(+), 33 deletions(-) create mode 100644 migrations/versions/1866340e456a_pinecone_status.py diff --git a/align_data/db/models.py b/align_data/db/models.py index ae24650d..e79da232 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -1,3 +1,4 @@ +import enum import re import json import re @@ -10,6 +11,7 @@ from sqlalchemy import ( JSON, DateTime, + Enum, ForeignKey, String, Boolean, @@ -43,6 +45,13 @@ class Summary(Base): article: Mapped["Article"] = relationship(back_populates="summaries") +class PineconeStatus(enum.Enum): + absent = 1 + pending_removal = 2 + pending_addition = 3 + added = 4 + + class Article(Base): __tablename__ = "articles" @@ -66,7 +75,7 @@ class Article(Base): status: Mapped[Optional[str]] = mapped_column(String(256)) comments: Mapped[Optional[str]] = mapped_column(LONGTEXT) # Editor comments. Can be anything - pinecone_update_required: Mapped[bool] = mapped_column(Boolean, default=False) + pinecone_status: Mapped[PineconeStatus] = mapped_column(Enum(PineconeStatus), default=PineconeStatus.absent) summaries: Mapped[List["Summary"]] = relationship( back_populates="article", cascade="all, delete-orphan" @@ -186,8 +195,8 @@ def check_for_changes(cls, mapper, connection, target): monitored_attributes = list(PineconeMetadata.__annotations__.keys()) monitored_attributes.remove("hash_id") - changed = any(get_history(target, attr).has_changes() for attr in monitored_attributes) - target.pinecone_update_required = changed + if any(get_history(target, attr).has_changes() for attr in monitored_attributes): + target.pinecone_status = PineconeStatus.pending_addition def to_dict(self) -> Dict[str, Any]: if date := self.date_published: diff --git a/align_data/db/session.py b/align_data/db/session.py index d76fc796..331de9b7 100644 --- a/align_data/db/session.py +++ b/align_data/db/session.py @@ -5,7 +5,7 @@ from sqlalchemy import create_engine, or_ from sqlalchemy.orm import Session from align_data.settings import DB_CONNECTION_URI, MIN_CONFIDENCE -from align_data.db.models import Article +from align_data.db.models import Article, PineconeStatus logger = logging.getLogger(__name__) @@ -25,12 +25,24 @@ def make_session(auto_commit=False): def get_pinecone_articles( session: Session, force_update: bool = False, + statuses: List[PineconeStatus] = [PineconeStatus.pending_addition], ): return ( session.query(Article) - .filter(or_(Article.pinecone_update_required.is_(True), force_update)) + .filter(or_(Article.pinecone_status.in_(statuses), force_update)) .filter(Article.is_valid) - .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE)) + .filter(or_(Article.confidence == None, Article.confidence >= MIN_CONFIDENCE)) + ) + + +def get_pinecone_articles_to_remove(session: Session): + return ( + session.query(Article) + .filter(or_( + Article.pinecone_status == PineconeStatus.pending_removal, + Article.is_valid == False, + Article.confidence < MIN_CONFIDENCE + )) ) @@ -38,16 +50,18 @@ def get_pinecone_articles_by_sources( session: Session, custom_sources: List[str], force_update: bool = False, + statuses: List[PineconeStatus] = [PineconeStatus.pending_addition], ): - return get_pinecone_articles(session, force_update).filter(Article.source.in_(custom_sources)) + return get_pinecone_articles(session, force_update, statuses).filter(Article.source.in_(custom_sources)) def get_pinecone_articles_by_ids( session: Session, - hash_ids: List[int], + hash_ids: List[str], force_update: bool = False, + statuses: List[PineconeStatus] = [PineconeStatus.pending_addition], ): - return get_pinecone_articles(session, force_update).filter(Article.id.in_(hash_ids)) + return get_pinecone_articles(session, force_update, statuses).filter(Article.id.in_(hash_ids)) def get_all_valid_article_ids(session: Session) -> List[str]: @@ -59,3 +73,17 @@ def get_all_valid_article_ids(session: Session) -> List[str]: .all() ) return [item[0] for item in query_result] + + +def get_pinecone_to_delete_by_sources( + session: Session, + custom_sources: List[str], +): + return get_pinecone_articles_to_remove(session).filter(Article.source.in_(custom_sources)) + + +def get_pinecone_to_delete_by_ids( + session: Session, + hash_ids: List[str], +): + return get_pinecone_articles_to_remove(session).filter(Article.id.in_(hash_ids)) diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index 301b67eb..321fe6b1 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -1,17 +1,18 @@ -from datetime import datetime import logging from itertools import islice -from typing import Callable, List, Tuple, Generator, Iterator, Optional +from typing import Any, Callable, Iterable, List, Tuple, Generator, Iterator from sqlalchemy.orm import Session from pydantic import ValidationError from align_data.embeddings.embedding_utils import get_embeddings -from align_data.db.models import Article +from align_data.db.models import Article, PineconeStatus from align_data.db.session import ( make_session, get_pinecone_articles_by_sources, get_pinecone_articles_by_ids, + get_pinecone_to_delete_by_sources, + get_pinecone_to_delete_by_ids, ) from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB from align_data.embeddings.pinecone.pinecone_models import ( @@ -30,10 +31,17 @@ TruncateFunctionType = Callable[[str, int], str] -class PineconeUpdater: - def __init__(self): - self.text_splitter = ParagraphSentenceUnitTextSplitter() - self.pinecone_db = PineconeDB() +class PineconeAction: + batch_size = 10 + + def __init__(self, pinecone=None): + self.pinecone_db = pinecone or PineconeDB() + + def _articles_by_source(self, session: Session, sources: List[str], force_update: bool) -> Iterable[Article]: + raise NotImplementedError + + def _articles_by_id(self, session: Session, ids: List[str], force_update: bool) -> Iterable[Article]: + raise NotImplementedError def update(self, custom_sources: List[str], force_update: bool = False): """ @@ -42,28 +50,24 @@ def update(self, custom_sources: List[str], force_update: bool = False): :param custom_sources: List of sources to update. """ with make_session() as session: - articles_to_update = get_pinecone_articles_by_sources( - session, custom_sources, force_update - ) + articles_to_update = self._articles_by_source(session, custom_sources, force_update) + logger.info('Processing %s items', articles_to_update.count()) for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) def update_articles_by_ids(self, hash_ids: List[int], force_update: bool = False): """Update the Pinecone entries of specific articles based on their hash_ids.""" with make_session() as session: - articles_to_update = get_pinecone_articles_by_ids(session, hash_ids, force_update) + articles_to_update = self._articles_by_id(session, hash_ids, force_update) for batch in self.batch_entries(articles_to_update): self.save_batch(session, batch) - def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry | None]]): - try: - for article, pinecone_entry in batch: - if pinecone_entry: - self.pinecone_db.upsert_entry(pinecone_entry) - - article.pinecone_update_required = False - session.add(article) + def process_batch(self, batch: List[Tuple[Article, PineconeEntry | None]]) -> List[Article]: + raise NotImplementedError + def save_batch(self, session: Session, batch: List[Any]): + try: + session.add_all(self.process_batch(batch)) session.commit() except Exception as e: @@ -71,10 +75,36 @@ def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry logger.error(e) session.rollback() + def batch_entries(self, article_stream: Generator[Article, None, None]) -> Iterator[List[Article]]: + while batch := tuple(islice(article_stream, self.batch_size)): + yield list(batch) + + +class PineconeAdder(PineconeAction): + batch_size = 10 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.text_splitter = ParagraphSentenceUnitTextSplitter() + + def _articles_by_source(self, session, sources: List[str], force_update: bool): + return get_pinecone_articles_by_sources(session, sources, force_update) + + def _articles_by_id(self, session, ids: List[str], force_update: bool): + return get_pinecone_articles_by_ids(session, ids, force_update) + + def process_batch(self, batch: List[Tuple[Article, PineconeEntry | None]]): + for article, pinecone_entry in batch: + if pinecone_entry: + self.pinecone_db.upsert_entry(pinecone_entry) + + article.pinecone_status = PineconeStatus.added + return [a for a, _ in batch] + def batch_entries( self, article_stream: Generator[Article, None, None] ) -> Iterator[List[Tuple[Article, PineconeEntry | None]]]: - while batch := tuple(islice(article_stream, 10)): + while batch := tuple(islice(article_stream, self.batch_size)): yield [(article, self._make_pinecone_entry(article)) for article in batch] def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: @@ -123,6 +153,50 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: raise +class PineconeDeleter(PineconeAction): + batch_size = 100 + pinecone_statuses = [PineconeStatus.pending_removal] + + def _articles_by_source(self, session, sources: List[str], _force_update: bool): + return get_pinecone_to_delete_by_sources(session, sources) + + def _articles_by_id(self, session, ids: List[str], _force_update: bool): + return get_pinecone_to_delete_by_ids(session, ids) + + def process_batch(self, batch: List[Article]): + self.pinecone_db.delete_entries([a.id for a in batch]) + logger.info('removing batch %s', len(batch)) + for article in batch: + article.pinecone_status = PineconeStatus.removed + return batch + + +class PineconeUpdater(PineconeAction): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.adder = PineconeAdder(*args, pinecone=self.pinecone_db, **kwargs) + self.remover = PineconeDeleter(*args, pinecone=self.pinecone_db, **kwargs) + + def update(self, custom_sources: List[str], force_update: bool = False): + """ + Update the given sources. If no sources are provided, updates all sources. + + :param custom_sources: List of sources to update. + """ + logger.info('Adding pinecone entries for %s', custom_sources) + self.adder.update(custom_sources, force_update) + + logger.info('Removing pinecone entries for %s', custom_sources) + self.remover.update(custom_sources, force_update) + + def update_articles_by_ids(self, hash_ids: List[int], force_update: bool = False): + """Update the Pinecone entries of specific articles based on their hash_ids.""" + logger.info('Adding pinecone entries by hash_id') + self.adder.update_articles_by_ids(hash_ids, force_update) + logger.info('Removing pinecone entries by hash_id') + self.remover.update_articles_by_ids(hash_ids, force_update) + + def get_text_chunks( article: Article, text_splitter: ParagraphSentenceUnitTextSplitter ) -> List[str]: diff --git a/migrations/versions/1866340e456a_pinecone_status.py b/migrations/versions/1866340e456a_pinecone_status.py new file mode 100644 index 00000000..8a73964a --- /dev/null +++ b/migrations/versions/1866340e456a_pinecone_status.py @@ -0,0 +1,54 @@ +"""pinecone status + +Revision ID: 1866340e456a +Revises: f5a2bcfa6b2c +Create Date: 2023-09-03 15:34:02.755588 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + + +# revision identifiers, used by Alembic. +revision = '1866340e456a' +down_revision = 'f5a2bcfa6b2c' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + ## Set the pinecone status + op.add_column('articles', sa.Column('pinecone_status', sa.String(length=32), nullable=False)) + + IS_VALID = """( + articles.status IS NULL AND + articles.text IS NOT NULL AND + articles.url IS NOT NULL AND + articles.title IS NOT NULL AND + articles.authors IS NOT NULL + )""" + op.execute(f""" + UPDATE articles SET pinecone_status = 'absent' + WHERE NOT articles.pinecone_update_required AND NOT {IS_VALID} + """) + op.execute(f""" + UPDATE articles SET pinecone_status = 'pending_removal' + WHERE articles.pinecone_update_required AND NOT {IS_VALID} + """) + op.execute(f""" + UPDATE articles SET pinecone_status = 'pending_addition' + WHERE articles.pinecone_update_required AND {IS_VALID} + """) + op.execute(f""" + UPDATE articles SET pinecone_status = 'added' + WHERE NOT articles.pinecone_update_required AND {IS_VALID} + """) + + op.drop_column('articles', 'pinecone_update_required') + + +def downgrade() -> None: + op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)) + op.execute("UPDATE articles SET articles.pinecone_update_required = (pinecone_status = 'pending_addition')") + op.drop_column('articles', 'pinecone_status') diff --git a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py index e5b9a303..92a63bb6 100644 --- a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py +++ b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py @@ -17,12 +17,8 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("articles", "pinecone_update_required") - # ### end Alembic commands ### From af6f5777be031eb9009bbcc5fe2de87aedb6f8df Mon Sep 17 00:00:00 2001 From: Daniel O'Connell <github@ahiru.pl> Date: Thu, 7 Sep 2023 00:23:10 +0200 Subject: [PATCH 25/25] Validate MIN_CONFIDENCE (#180) --- .env.example | 7 ++++++- align_data/settings.py | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index 1c29520b..759b915c 100644 --- a/.env.example +++ b/.env.example @@ -1,12 +1,17 @@ LOG_LEVEL="INFO" +MIN_CONFIDENCE="0.5" + CODA_TOKEN="" +YOUTUBE_API_KEY="" + ARD_DB_USER="user" ARD_DB_PASSWORD="we all live in a yellow submarine" ARD_DB_HOST="127.0.0.1" ARD_DB_PORT="3306" ARD_DB_NAME="alignment_research_dataset" + OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + PINECONE_INDEX_NAME="stampy-chat-ard" PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" PINECONE_ENVIRONMENT="xx-xxxxx-gcp" -YOUTUBE_API_KEY="" \ No newline at end of file diff --git a/align_data/settings.py b/align_data/settings.py index 2f99d285..6f747d8f 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -89,5 +89,8 @@ ) ### MISCELLANEOUS ### -MIN_CONFIDENCE = 50 +MIN_CONFIDENCE = float(os.environ.get('MIN_CONFIDENCE') or '0.5') +if MIN_CONFIDENCE < 0 or MIN_CONFIDENCE > 1: + raise ValueError(f'MIN_CONFIDENCE must be between 0 and 1 - got {MIN_CONFIDENCE}') + DEVICE = "cuda" if torch.cuda.is_available() else "cpu"