diff --git a/.env.example b/.env.example
index 057b5ba4..1c29520b 100644
--- a/.env.example
+++ b/.env.example
@@ -9,3 +9,4 @@ OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
PINECONE_INDEX_NAME="stampy-chat-ard"
PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
PINECONE_ENVIRONMENT="xx-xxxxx-gcp"
+YOUTUBE_API_KEY=""
\ No newline at end of file
diff --git a/.github/workflows/fetch-daily.yml b/.github/workflows/fetch-daily.yml
index 9254f461..f4a07050 100644
--- a/.github/workflows/fetch-daily.yml
+++ b/.github/workflows/fetch-daily.yml
@@ -14,6 +14,7 @@ jobs:
with:
datasource: ${{ matrix.datasource }}
coda_token: ${{ inputs.coda_token }}
+ airtable_api_key: ${{ inputs.airtable_api_key }}
youtube_api_key: ${{ inputs.youtube_api_key }}
db_user: ${{ inputs.db_user }}
db_password: ${{ inputs.db_password }}
diff --git a/.github/workflows/fetch-dataset.yml b/.github/workflows/fetch-dataset.yml
index b65a91dd..39dc3faa 100644
--- a/.github/workflows/fetch-dataset.yml
+++ b/.github/workflows/fetch-dataset.yml
@@ -9,6 +9,9 @@ on:
coda_token:
type: string
required: true
+ airtable_api_key:
+ type: string
+ required: true
youtube_api_key:
type: string
required: true
@@ -28,44 +31,19 @@ on:
type: choice
options:
- agentmodels
- - aiimpacts
- - aisafety.camp
+ - agisf
- aisafety.info
- - ai_alignment_playlist
- - ai_explained
- - ai_safety_talks
- - ai_safety_reading_group
- - ai_tech_tu_delft
- - alignmentforum
- alignment_newsletter
+ - alignmentforum
- arbital
- arxiv
- - carado.moe
- - cold_takes
- - deepmind_blog
- - deepmind_technical_blog
+ - blogs
- distill
- eaforum
- - ebooks
- - eleuther.ai
- - gdocs
- - generative.ink
- - gwern_blog
- - html_articles
- - importai
- indices
- - jsteinhardt_blog
- lesswrong
- - markdown
- - miri
- - ml_safety_newsletter
- - openai.research
- - pdfs
- - rob_miles_ai_safety
- special_docs
- - vkrakovna_blog
- - yudkowsky_blog
- - xmls
+ - youtube
jobs:
build-dataset:
@@ -93,6 +71,7 @@ jobs:
- name: Process dataset
env:
CODA_TOKEN: ${{ secrets.CODA_TOKEN || inputs.coda_token }}
+ AIRTABLE_API_KEY: ${{ secrets.AIRTABLE_API_KEY || inputs.airtable_api_key }}
YOUTUBE_API_KEY: ${{ secrets.YOUTUBE_API_KEY || inputs.youtube_api_key }}
ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }}
ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }}
diff --git a/.github/workflows/fetch-weekly.yml b/.github/workflows/fetch-weekly.yml
index e8850ec4..64d17c07 100644
--- a/.github/workflows/fetch-weekly.yml
+++ b/.github/workflows/fetch-weekly.yml
@@ -14,6 +14,7 @@ jobs:
with:
datasource: ${{ matrix.datasource }}
coda_token: ${{ inputs.coda_token }}
+ airtable_api_key: ${{ inputs.airtable_api_key }}
youtube_api_key: ${{ inputs.youtube_api_key }}
db_user: ${{ inputs.db_user }}
db_password: ${{ inputs.db_password }}
diff --git a/.github/workflows/push-dataset.yml b/.github/workflows/push-dataset.yml
index 2ecd539d..1b48a3b4 100644
--- a/.github/workflows/push-dataset.yml
+++ b/.github/workflows/push-dataset.yml
@@ -24,37 +24,17 @@ on:
options:
- all
- agentmodels
- - aiimpacts
- - aisafety.camp
+ - agisf
- aisafety.info
- - ai_alignment_playlist
- - ai_explained
- - ai_safety_talks
- - ai_safety_reading_group
- - ai_tech_tu_delft
- alignmentforum
- arbital
- arxiv
- - carado.moe
- - cold_takes
- - deepmind_blog
- - deepmind_technical_blog
+ - blogs
- distill
- eaforum
- - eleuther.ai
- - gdocs
- - generative.ink
- - gwern_blog
- - importai
- - jsteinhardt_blog
- lesswrong
- - miri
- - ml_safety_newsletter
- - openai.research
- - rob_miles_ai_safety
- special_docs
- - vkrakovna_blog
- - yudkowsky_blog
+ - youtube
jobs:
generate-dataset:
diff --git a/.github/workflows/update-metadata.yml b/.github/workflows/update-metadata.yml
index 70e2ddd5..fe95f002 100644
--- a/.github/workflows/update-metadata.yml
+++ b/.github/workflows/update-metadata.yml
@@ -31,4 +31,9 @@ jobs:
run: curl -L "${{ inputs.csv_url }}" -o data.csv
- name: Run Script
+ env:
+ ARD_DB_USER: ${{ secrets.ARD_DB_USER }}
+ ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD }}
+ ARD_DB_HOST: ${{ secrets.ARD_DB_HOST }}
+ ARD_DB_NAME: alignment_research_dataset
run: python main.py update data.csv ${{ inputs.delimiter }}
diff --git a/.github/workflows/update-pinecone.yml b/.github/workflows/update-pinecone.yml
index 0cb7924b..b87d9ed7 100644
--- a/.github/workflows/update-pinecone.yml
+++ b/.github/workflows/update-pinecone.yml
@@ -32,43 +32,18 @@ on:
options:
- all
- agentmodels
- - aiimpacts
- - aisafety.camp
+ - agisf
- aisafety.info
- - ai_alignment_playlist
- - ai_explained
- - ai_safety_talks
- - ai_safety_reading_group
- - ai_tech_tu_delft
- alignmentforum
- arbital
- arxiv
- - carado.moe
- - cold_takes
- - deepmind_blog
- - deepmind_technical_blog
+ - blogs
- distill
- eaforum
- - ebooks
- - eleuther.ai
- - gdocs
- - generative.ink
- - gwern_blog
- - html_articles
- - importai
- indices
- - jsteinhardt_blog
- lesswrong
- - markdown
- - miri
- - ml_safety_newsletter
- - openai.research
- - pdfs
- - rob_miles_ai_safety
- special_docs
- - vkrakovna_blog
- - yudkowsky_blog
- - xmls
+ - youtube
jobs:
build-dataset:
@@ -78,28 +53,28 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v2
- - name: Setup Python environment
- uses: actions/setup-python@v2
- with:
- python-version: '3.x'
+ # - name: Setup Python environment
+ # uses: actions/setup-python@v2
+ # with:
+ # python-version: '3.x'
- - name: Install dependencies
- run: |
- pip install -r requirements.txt;
- python -c 'import nltk; nltk.download("punkt")'
+ # - name: Install dependencies
+ # run: |
+ # pip install -r requirements.txt;
+ # python -c 'import nltk; nltk.download("punkt")'
- - name: Process dataset
- env:
- ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }}
- ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }}
- ARD_DB_HOST: ${{ secrets.ARD_DB_HOST || inputs.db_host }}
- ARD_DB_NAME: alignment_research_dataset
- OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || inputs.openai_api_key }}
- PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY || inputs.pinecone_api_key }}
- PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT || inputs.pinecone_environment }}
- run: |
- if [ "${{ inputs.datasource }}" = "all" ]; then
- python main.py pinecone_update_all
- else
- python main.py pinecone_update ${{ inputs.datasource }}
- fi
+ # - name: Process dataset
+ # env:
+ # ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }}
+ # ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }}
+ # ARD_DB_HOST: ${{ secrets.ARD_DB_HOST || inputs.db_host }}
+ # ARD_DB_NAME: alignment_research_dataset
+ # OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || inputs.openai_api_key }}
+ # PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY || inputs.pinecone_api_key }}
+ # PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT || inputs.pinecone_environment }}
+ # run: |
+ # if [ "${{ inputs.datasource }}" = "all" ]; then
+ # python main.py pinecone_update_all
+ # else
+ # python main.py pinecone_update ${{ inputs.datasource }}
+ # fi
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..61a81a7b
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,11 @@
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+- repo: https://github.com/psf/black
+ rev: 23.7.0
+ hooks:
+ - id: black
+ language_version: python3.11
diff --git a/README.md b/README.md
index 5afbcf19..b01779df 100644
--- a/README.md
+++ b/README.md
@@ -4,63 +4,69 @@ The AI Alignment Research Dataset is a collection of documents related to AI Ali
## Sources
-The following list of sources may change and items may be renamed:
-
-- [agentmodels](https://agentmodels.org/)
-- [aiimpacts](https://aiimpacts.org/)
-- [aisafety.camp](https://aisafety.camp/)
-- [aisafety.info](https://aisafety.info/)
-- [ai_alignment_playlist]()
-- [ai_explained](https://www.youtube.com/@ai-explained-)
-- [ai_safety_talks](https://www.youtube.com/@aisafetytalks)
-- [ai_safety_reading_group](https://www.youtube.com/@aisafetyreadinggroup/videos)
-- [ai_tech_tu_delft](https://www.youtube.com/@AiTechTUDelft/)
+Here are the list of sources along with sample contents:
+
+- [agentmodel](https://agentmodels.org/)
+- [agisf](https://course.aisafetyfundamentals.com/) - recommended readings from AGI Safety Fundamentals
+- [aisafety.info](https://aisafety.info/) - Stampy's FAQ
- [alignmentforum](https://www.alignmentforum.org)
- [alignment_newsletter](https://rohinshah.com/alignment-newsletter/)
- [arbital](https://arbital.com/)
-- arxiv - alignment research papers from [arxiv](https://arxiv.org/)
-- [carado.moe](https://carado.moe/)
-- [cold_takes](https://www.cold-takes.com/)
-- [deepmind_blog](https://deepmindsafetyresearch.medium.com/)
-- [deepmind_technical_blog](https://www.deepmind.com/blog-categories/technical-blogs)
+- [arxiv](https://arxiv.org/) - relevant research papers
+
+- blogs - entire websites automatically scraped
+ - [AI Impacts](https://aiimpacts.org/)
+ - [AI Safety Camp](https://aisafety.camp/)
+ - [carado.moe](https://carado.moe/)
+ - [Cold Takes](https://www.cold-takes.com/)
+ - [DeepMind technical blogs](https://www.deepmind.com/blog-categories/technical-blogs)
+ - [DeepMind AI Safety Research](https://deepmindsafetyresearch.medium.com/)
+ - [EleutherAI](https://blog.eleuther.ai/)
+ - [generative.ink](https://generative.ink/posts/)
+ - [Gwern Branwen's blog](https://gwern.net/)
+ - [Jack Clark's Import AI](https://importai.substack.com/)
+ - [MIRI](https://intelligence.org/)
+ - [Jacob Steinhardt's blog](https://jsteinhardt.wordpress.com/)
+ - [ML Safety Newsletter](https://newsletter.mlsafety.org/)
+ - [Transformer Circuits Thread](https://transformer-circuits.pub/)
+ - [Open AI Research](https://openai.com/research/)
+ - [Victoria Krakovna's blog](https://vkrakovna.wordpress.com/)
+ - [Eliezer Yudkowsky's blog](https://www.yudkowsky.net/)
+
- [distill](https://distill.pub/)
- [eaforum](https://forum.effectivealtruism.org/) - selected posts
-- [eleuther.ai](https://blog.eleuther.ai/)
-- [generative.ink](https://generative.ink/posts/)
-- [gwern_blog](https://gwern.net/)
-- gdocs - various doc files stored on Google drive
-- html_articles - various articles on websites
-- [import.ai](https://importai.substack.com)
-- [jsteinhardt_blog](https://jsteinhardt.wordpress.com/)
- [lesswrong](https://www.lesswrong.com/) - selected posts
-- markdown
-- [miri](https://intelligence.org/) - MIRI
-- [ml_safety_newsletter](https://newsletter.mlsafety.org)
-- [openai.research](https://openai.com/research)
-- pdfs - various pdfs from different places
-- [rob_miles_ai_safety](https://www.youtube.com/@RobertMilesAI)
-- [vkrakovna_blog](https://vkrakovna.wordpress.com)
-- [waitbutwhy](https://waitbutwhy.com/)
-- [yudkowsky_blog](https://www.yudkowsky.net/)
-- xmls - various articles stored as XML files
+- special_docs - individual documents curated from various resources
+ - [Make a suggestion](https://bit.ly/ard-suggestion) for sources not already in the dataset
+
+- youtube - playlists & channels
+ - [AI Alignment playlist](https://www.youtube.com/playlist?list=PLCRVRLd2RhZTpdUdEzJjo3qhmX3y3skWA) and other lists
+ - [AI Explained](https://www.youtube.com/@aiexplained-official)
+ - [Evan Hubinger's AI Safety Talks](https://www.youtube.com/@aisafetytalks)
+ - [AI Safety Reading Group](https://www.youtube.com/@aisafetyreadinggroup/videos)
+ - [AiTech - TU Delft](https://www.youtube.com/@AiTechTUDelft/)
+ - [Rob Miles AI](https://www.youtube.com/@RobertMilesAI)
## Keys
-Not all of the entries contain the same keys, but they all have the following:
+All entries contain the following keys:
-- `id` - unique identifier
-- `source` - based on the data source listed in the previous section
-- `title` - title of document
+- `id` - string of unique identifier
+- `source` - string of data source listed above
+- `title` - string of document title of document
+- `authors` - list of strings
- `text` - full text of document content
-- `url` - some values may be `'n/a'`, still being updated
-- `date_published` - some `'n/a'`
+- `url` - string of valid link to text content
+- `date_published` - in UTC format
-The values of the keys are still being cleaned up for consistency. Additional keys are available depending on the source document.
+Additional keys may be available depending on the source document.
## Development Environment
-To set up the development environment, run the following steps. You'll have to also set up [mysqlclient](https://pypi.org/project/mysqlclient/):
+Follow the [instructions to install **mysqlclient** on your operating system](https://pypi.org/project/mysqlclient/) toward the middle to bottom of the linked page.
+
+To set up the development environment, run the following steps:
```bash
git clone https://github.com/StampyAI/alignment-research-dataset
diff --git a/align_data/__init__.py b/align_data/__init__.py
index 54041500..68b2ba2c 100644
--- a/align_data/__init__.py
+++ b/align_data/__init__.py
@@ -1,5 +1,6 @@
import align_data.sources.arbital as arbital
import align_data.sources.articles as articles
+import align_data.sources.agisf as agisf
import align_data.sources.blogs as blogs
import align_data.sources.ebooks as ebooks
import align_data.sources.greaterwrong as greaterwrong
@@ -11,6 +12,7 @@
DATASET_REGISTRY = (
arbital.ARBITAL_REGISTRY
+ articles.ARTICLES_REGISTRY
+ + agisf.AGISF_DATASETS
+ blogs.BLOG_REGISTRY
+ ebooks.EBOOK_REGISTRY
+ greaterwrong.GREATERWRONG_REGISTRY
diff --git a/align_data/analysis/analyse_jsonl_data.py b/align_data/analysis/analyse_jsonl_data.py
index 1fb3f526..9ef49649 100644
--- a/align_data/analysis/analyse_jsonl_data.py
+++ b/align_data/analysis/analyse_jsonl_data.py
@@ -69,9 +69,7 @@ def process_jsonl_files(data_dir):
for id, duplicates in seen_urls.items():
if len(duplicates) > 1:
- list_of_duplicates = "\n".join(
- get_data_dict_str(duplicate) for duplicate in duplicates
- )
+ list_of_duplicates = "\n".join(get_data_dict_str(duplicate) for duplicate in duplicates)
print(
f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n"
)
diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py
index 5e48b999..a1637a51 100644
--- a/align_data/common/alignment_dataset.py
+++ b/align_data/common/alignment_dataset.py
@@ -17,6 +17,8 @@
from align_data.db.models import Article, Summary
from align_data.db.session import make_session
from align_data.settings import ARTICLE_MAIN_KEYS
+from align_data.sources.utils import merge_dicts
+
logger = logging.getLogger(__name__)
@@ -69,12 +71,13 @@ def _add_authors(self, article: Article, authors: List[str]) -> Article:
article.authors = ",".join(article.authors[:1024].split(",")[:-1])
return article
- def make_data_entry(self, data: Dict[str, Any], **kwargs) -> Article:
- data = dict(data, **kwargs)
+ def make_data_entry(self, data, **kwargs) -> Article:
+ data = merge_dicts(data, kwargs)
summary = data.pop("summary", None)
authors = data.pop("authors", [])
article = Article(
+ pinecone_update_required=True,
meta={k: v for k, v in data.items() if k not in ARTICLE_MAIN_KEYS and v is not None},
**{k: v for k, v in data.items() if k in ARTICLE_MAIN_KEYS},
)
@@ -158,14 +161,9 @@ def _load_outputted_items(self) -> Set[str]:
# This doesn't filter by self.name. The good thing about that is that it should handle a lot more
# duplicates. The bad thing is that this could potentially return a massive amount of data if there
# are lots of items.
- return set(
- session.scalars(select(getattr(Article, self.done_key))).all()
- )
- return {
- meta[self.done_key]
- for meta in session.scalars(select(Article.meta)).all()
- if isinstance(meta, JSON) and meta.get(self.done_key)
- }
+ return set(session.scalars(select(getattr(Article, self.done_key))).all())
+ # TODO: Properly handle this - it should create a proper SQL JSON select
+ return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()}
def not_processed(self, item):
# NOTE: `self._outputted_items` reads in all items. Which could potentially be a lot. If this starts to
@@ -209,7 +207,7 @@ def fetch_entries(self) -> Generator[Article, None, None]:
if self.COOLDOWN:
time.sleep(self.COOLDOWN)
- def process_entry(self, entry) -> Optional[Article]:
+ def process_entry(self, entry) -> Article | None:
"""Process a single entry."""
raise NotImplementedError
@@ -218,7 +216,7 @@ def _format_datetime(date: datetime) -> str:
return date.strftime("%Y-%m-%dT%H:%M:%SZ")
@staticmethod
- def _get_published_date(date: str) -> Optional[datetime]:
+ def _get_published_date(date) -> datetime | None:
try:
# Totally ignore any timezone info, forcing everything to UTC
return parse(str(date)).replace(tzinfo=pytz.UTC)
@@ -234,7 +232,11 @@ def unprocessed_items(self, items=None) -> Iterable:
urls = map(self.get_item_key, items)
with make_session() as session:
- articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls))
+ articles = (
+ session.query(Article)
+ .options(joinedload(Article.summaries))
+ .filter(Article.url.in_(urls))
+ )
self.articles = {a.url: a for a in articles if a.url}
return items
@@ -244,9 +246,7 @@ def _load_outputted_items(self) -> Set[str]:
with make_session() as session:
return set(
session.scalars(
- select(Article.url)
- .join(Article.summaries)
- .filter(Summary.source == self.name)
+ select(Article.url).join(Article.summaries).filter(Summary.source == self.name)
)
)
@@ -257,3 +257,40 @@ def merge(item):
return item
session.add_all(map(merge, batch))
+
+
+@dataclass
+class MultiDataset(AlignmentDataset):
+
+ datasets: List[AlignmentDataset]
+
+ @property
+ def names(self):
+ return [dataset.name for dataset in self.datasets]
+
+ @property
+ def items_list(self) -> Iterable:
+ """Returns a collection of items to be processed."""
+ return ((item, dataset) for dataset in self.datasets for item in dataset.items_list)
+
+ def setup(self):
+ for dataset in self.datasets:
+ dataset.setup()
+
+ def get_item_key(self, entry):
+ item, dataset = entry
+ return dataset.get_item_key(item)
+
+ def process_entry(self, entry) -> Optional[Article]:
+ item, dataset = entry
+ article = dataset.process_entry(item)
+ article.add_meta('initial_source', article.source)
+ article.source = self.name
+
+ def fetch_entries(self):
+ for dataset in self.datasets:
+ for article in dataset.fetch_entries():
+ if article.source != self.name:
+ article.add_meta('initial_source', article.source)
+ article.source = self.name
+ yield article
diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py
index fc948687..e9c3aa2d 100644
--- a/align_data/common/html_dataset.py
+++ b/align_data/common/html_dataset.py
@@ -75,7 +75,7 @@ def get_contents(self, article_url: str) -> Dict[str, Any]:
def process_entry(self, article: Tag) -> Article:
article_url = self.get_item_key(article)
contents = self.get_contents(article_url)
- if not contents.get('text'):
+ if not contents.get("text"):
return None
return self.make_data_entry(contents)
@@ -149,9 +149,12 @@ def fetch_contents(self, url: str):
soup=soup,
)
+ def _extract_item_url(self, item) -> str | None:
+ return item.get('link')
+
@property
def items_list(self):
logger.info(f"Fetching entries from {self.feed_url}")
feed = feedparser.parse(self.feed_url)
- self.items = {item["link"]: item for item in feed["entries"]}
+ self.items = {url: item for item in feed["entries"] if (url := self._extract_item_url(item))}
return list(self.items.keys())
diff --git a/align_data/db/models.py b/align_data/db/models.py
index 43d23caa..0161ff50 100644
--- a/align_data/db/models.py
+++ b/align_data/db/models.py
@@ -17,7 +17,6 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.hybrid import hybrid_property
-from align_data.settings import PINECONE_METADATA_KEYS
logger = logging.getLogger(__name__)
@@ -71,35 +70,26 @@ class Article(Base):
def __repr__(self) -> str:
return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})"
- def is_metadata_keys_equal(self, other):
- if not isinstance(other, Article):
- raise TypeError(
- f"Expected an instance of Article, got {type(other).__name__}"
- )
- return all(
- getattr(self, key, None) == getattr(other, key, None)
- for key in PINECONE_METADATA_KEYS
- ) # entry_id is implicitly ignored
-
def generate_id_string(self) -> bytes:
- return (
- "".join(
- str(getattr(self, field))
- for field in self.__id_fields
- ).encode("utf-8")
- )
+ return "".join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8")
@property
def __id_fields(self):
- if self.source == 'aisafety.info':
- return ['url']
- if self.source in ['importai', 'ml_safety_newsletter', 'alignment_newsletter']:
- return ['url', 'title', 'source']
+ if self.source == "aisafety.info":
+ return ["url"]
+ if self.source in ["importai", "ml_safety_newsletter", "alignment_newsletter"]:
+ return ["url", "title", "source"]
return ["url", "title"]
@property
def missing_fields(self):
- fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'}
+ fields = set(self.__id_fields) | {
+ "text",
+ "title",
+ "url",
+ "source",
+ "date_published",
+ }
return sorted([field for field in fields if not getattr(self, field, None)])
def verify_id(self):
@@ -142,13 +132,21 @@ def add_meta(self, key: str, val):
# TODO: verify that this actually updates the meta column;
# https://amercader.net/blog/beware-of-json-fields-in-sqlalchemy/
+ def append_comment(self, comment: str):
+ """Appends a comment to the article.comments field. You must run session.commit() to save the comment to the database."""
+ if self.comments is None:
+ self.comments = ""
+ self.comments = f"{self.comments}\n\n{comment}".strip()
+
@hybrid_property
def is_valid(self):
- return (
- self.text and self.text.strip() and
- self.url and self.title and
- self.authors is not None and
- self.status == OK_STATUS
+ return bool(
+ self.text
+ and self.text.strip()
+ and self.url
+ and self.title
+ and self.authors is not None
+ and self.status == OK_STATUS
)
@is_valid.expression
@@ -166,9 +164,11 @@ def before_write(cls, _mapper, _connection, target: "Article"):
target.verify_id_fields()
if not target.status and target.missing_fields:
- target.status = 'Missing fields'
+ target.status = "Missing fields"
target.comments = f'missing fields: {", ".join(target.missing_fields)}'
+ target.pinecone_update_required = target.is_valid
+
if target.id:
target.verify_id()
else:
diff --git a/align_data/db/session.py b/align_data/db/session.py
index f388cf73..75040812 100644
--- a/align_data/db/session.py
+++ b/align_data/db/session.py
@@ -2,31 +2,45 @@
import logging
from contextlib import contextmanager
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine, or_
from sqlalchemy.orm import Session
-
-from align_data.settings import DB_CONNECTION_URI
+from align_data.settings import DB_CONNECTION_URI, MIN_CONFIDENCE
from align_data.db.models import Article
logger = logging.getLogger(__name__)
+# We create a single engine for the entire application
+engine = create_engine(DB_CONNECTION_URI, echo=False)
+
@contextmanager
-def make_session(auto_commit: bool = False) -> Generator[Session, None, None]:
- engine = create_engine(DB_CONNECTION_URI, echo=False)
+def make_session(auto_commit=False):
with Session(engine, autoflush=False) as session:
yield session
if auto_commit:
session.commit()
-def stream_pinecone_updates(session: Session, custom_sources: List[str]) -> Generator[Article, None, None]:
+def stream_pinecone_updates(
+ session: Session, custom_sources: List[str], force_update: bool = False
+):
"""Yield Pinecone entries that require an update."""
yield from (
- session
- .query(Article)
- .filter(Article.pinecone_update_required.is_(True))
+ session.query(Article)
+ .filter(or_(Article.pinecone_update_required.is_(True), force_update))
.filter(Article.is_valid)
.filter(Article.source.in_(custom_sources))
+ .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
.yield_per(1000)
)
+
+
+def get_all_valid_article_ids(session: Session) -> List[str]:
+ """Return all valid article IDs."""
+ query_result = (
+ session.query(Article.id)
+ .filter(Article.is_valid)
+ .filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
+ .all()
+ )
+ return [item[0] for item in query_result]
diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py
new file mode 100644
index 00000000..6ef57b88
--- /dev/null
+++ b/align_data/embeddings/embedding_utils.py
@@ -0,0 +1,199 @@
+import logging
+from typing import List, Tuple, Dict, Any, Optional
+from functools import wraps
+
+import openai
+from langchain.embeddings import HuggingFaceEmbeddings
+from openai.error import (
+ OpenAIError,
+ RateLimitError,
+ APIError,
+)
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_random_exponential,
+ retry_if_exception_type,
+ retry_if_exception,
+)
+
+from align_data.embeddings.pinecone.pinecone_models import MissingEmbeddingModelError
+from align_data.settings import (
+ USE_OPENAI_EMBEDDINGS,
+ OPENAI_EMBEDDINGS_MODEL,
+ EMBEDDING_LENGTH_BIAS,
+ SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
+ DEVICE,
+)
+
+
+# --------------------
+# CONSTANTS & CONFIGURATION
+# --------------------
+
+logger = logging.getLogger(__name__)
+
+hf_embedding_model = None
+if not USE_OPENAI_EMBEDDINGS:
+ hf_embedding_model = HuggingFaceEmbeddings(
+ model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
+ model_kwargs={"device": DEVICE},
+ encode_kwargs={"show_progress_bar": False},
+ )
+
+ModerationInfoType = Dict[str, Any]
+
+
+# --------------------
+# DECORATORS
+# --------------------
+
+
+def handle_openai_errors(func):
+ """Decorator to handle OpenAI-specific exceptions with retries."""
+
+ @wraps(func)
+ @retry(
+ wait=wait_random_exponential(multiplier=1, min=2, max=30),
+ stop=stop_after_attempt(6),
+ retry=retry_if_exception_type(RateLimitError)
+ | retry_if_exception_type(APIError)
+ | retry_if_exception(lambda e: "502" in str(e)),
+ )
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except RateLimitError as e:
+ logger.warning(f"OpenAI Rate limit error. Trying again. Error: {e}")
+ raise
+ except APIError as e:
+ if "502" in str(e):
+ logger.warning(f"OpenAI 502 Bad Gateway error. Trying again. Error: {e}")
+ else:
+ logger.error(f"OpenAI API Error encountered: {e}")
+ raise
+ except OpenAIError as e:
+ logger.error(f"OpenAI Error encountered: {e}")
+ raise
+ except Exception as e:
+ logger.error(f"Unexpected error encountered: {e}")
+ raise
+
+ return wrapper
+
+
+# --------------------
+# MAIN FUNCTIONS
+# --------------------
+
+
+@handle_openai_errors
+def moderation_check(texts: List[str]) -> List[ModerationInfoType]:
+ return openai.Moderation.create(input=texts)["results"]
+
+
+@handle_openai_errors
+def _compute_openai_embeddings(non_flagged_texts: List[str], **kwargs) -> List[List[float]]:
+ data = openai.Embedding.create(input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
+ return [d["embedding"] for d in data]
+
+
+def get_embeddings_without_moderation(
+ texts: List[str],
+ source: Optional[str] = None,
+ **kwargs,
+) -> List[List[float]]:
+ """
+ Obtain embeddings without moderation checks.
+
+ Parameters:
+ - texts (List[str]): List of texts to be embedded.
+ - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None.
+ - **kwargs: Additional keyword arguments passed to the embedding function.
+
+ Returns:
+ - List[List[float]]: List of embeddings for the provided texts.
+ """
+ if not texts:
+ return []
+
+ texts = [text.replace("\n", " ") for text in texts]
+ if USE_OPENAI_EMBEDDINGS:
+ embeddings = _compute_openai_embeddings(texts, **kwargs)
+ elif hf_embedding_model:
+ embeddings = hf_embedding_model.embed_documents(texts)
+ else:
+ raise MissingEmbeddingModelError("No embedding model available.")
+
+ # Bias adjustment
+ if source and (bias := EMBEDDING_LENGTH_BIAS.get(source, 1.0)):
+ embeddings = [[bias * e for e in embedding] for embedding in embeddings]
+
+ return embeddings
+
+
+def get_embeddings_or_none_if_flagged(
+ texts: List[str],
+ source: Optional[str] = None,
+ **kwargs,
+) -> Tuple[List[List[float]] | None, List[ModerationInfoType]]:
+ """
+ Obtain embeddings for the provided texts. If any text is flagged during moderation,
+ the function returns None for the embeddings while still providing the moderation results.
+
+ Parameters:
+ - texts (List[str]): List of texts to be embedded.
+ - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None.
+ - **kwargs: Additional keyword arguments passed to the embedding function.
+
+ Returns:
+ - Tuple[Optional[List[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (or None if any text is flagged) and the moderation results.
+ """
+ moderation_results = moderation_check(texts)
+ if any(result["flagged"] for result in moderation_results):
+ return None, moderation_results
+
+ embeddings = get_embeddings_without_moderation(texts, source, **kwargs)
+ return embeddings, moderation_results
+
+
+def get_embeddings(
+ texts: List[str],
+ source: Optional[str] = None,
+ **kwargs,
+) -> Tuple[List[List[float] | None], List[ModerationInfoType]]:
+ """
+ Obtain embeddings for the provided texts, replacing the embeddings of flagged texts with `None`.
+
+ Parameters:
+ - texts (List[str]): List of texts to be embedded.
+ - source (Optional[str], optional): Source identifier to potentially adjust embedding bias. Defaults to None.
+ - **kwargs: Additional keyword arguments passed to the embedding function.
+
+ Returns:
+ - Tuple[List[Optional[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results.
+ """
+ assert len(texts) <= 2048, "The batch size should not be larger than 2048."
+ assert all(texts), "No empty strings allowed in the input list."
+
+ # replace newlines, which can negatively affect performance
+ texts = [text.replace("\n", " ") for text in texts]
+
+ # Check all texts for moderation flags
+ moderation_results = moderation_check(texts)
+ flags = [result["flagged"] for result in moderation_results]
+
+ non_flagged_texts = [text for text, flag in zip(texts, flags) if not flag]
+ non_flagged_embeddings = get_embeddings_without_moderation(
+ non_flagged_texts, source, **kwargs
+ )
+ embeddings = [None if flag else non_flagged_embeddings.pop(0) for flag in flags]
+ return embeddings, moderation_results
+
+
+def get_embedding(
+ text: str, source: Optional[str] = None, **kwargs
+) -> Tuple[List[float] | None, ModerationInfoType]:
+ """Obtain an embedding for a single text."""
+ embedding, moderation_result = get_embeddings([text], source, **kwargs)
+ return embedding[0], moderation_result[0]
diff --git a/align_data/embeddings/finetuning/data/best_finetuned_model.pth b/align_data/embeddings/finetuning/data/best_finetuned_model.pth
new file mode 100644
index 00000000..e05a5d52
Binary files /dev/null and b/align_data/embeddings/finetuning/data/best_finetuned_model.pth differ
diff --git a/align_data/embeddings/finetuning/data/finetuned_model.pth b/align_data/embeddings/finetuning/data/finetuned_model.pth
new file mode 100644
index 00000000..7ba44ac6
Binary files /dev/null and b/align_data/embeddings/finetuning/data/finetuned_model.pth differ
diff --git a/align_data/embeddings/finetuning/finetuning_dataset.py b/align_data/embeddings/finetuning/finetuning_dataset.py
new file mode 100644
index 00000000..8c5eec04
--- /dev/null
+++ b/align_data/embeddings/finetuning/finetuning_dataset.py
@@ -0,0 +1,105 @@
+import math
+import random
+from typing import List, Tuple, Generator
+from collections import deque
+
+import torch
+from torch.utils.data import IterableDataset, get_worker_info
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.sql import func
+from sqlalchemy.orm import Session
+
+from align_data.db.session import make_session, get_all_valid_article_ids
+from align_data.embeddings.embedding_utils import get_embedding
+from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB
+from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter
+from align_data.embeddings.pinecone.update_pinecone import get_text_chunks
+from align_data.db.models import Article
+
+
+class FinetuningDataset(IterableDataset):
+ def __init__(self, num_batches_per_epoch: int, cache_size: int = 1280):
+ self.num_batches_per_epoch = num_batches_per_epoch
+ self.article_cache = deque(maxlen=cache_size)
+
+ self.text_splitter = ParagraphSentenceUnitTextSplitter()
+ self.pinecone_db = PineconeDB()
+
+ with make_session() as session:
+ self.all_article_ids = get_all_valid_article_ids(session)
+ self.total_articles = len(self.all_article_ids)
+
+ def __len__(self):
+ return self.num_batches_per_epoch
+
+ def __iter__(self):
+ start, end = 0, None
+ worker_info = get_worker_info()
+ if worker_info is not None: # Multi-process loading
+ per_worker = math.ceil(self.total_articles / worker_info.num_workers)
+ start = worker_info.id * per_worker
+ end = min(start + per_worker, self.total_articles)
+
+ with make_session() as session:
+ return self._generate_pairs(session, start, end)
+
+ def _fetch_random_articles(self, session: Session, batch_size: int = 1) -> List[Article]:
+ """Fetch a batch of random articles."""
+ # If the list has fewer IDs than needed, raise an exception
+ random_selected_ids = random.sample(self.all_article_ids, batch_size)
+ return session.query(Article).filter(Article.id.in_(random_selected_ids)).all()
+
+ def _get_random_chunks(self, article: Article, num_chunks: int = 2) -> List[Tuple[int, str]]:
+ chunked_text = get_text_chunks(article, self.text_splitter)
+
+ chunks = list(enumerate(chunked_text))
+ if len(chunks) < num_chunks:
+ return []
+
+ return random.sample(chunks, num_chunks)
+
+ def _get_embeddings(self, article: Article, chunks: List[Tuple[int, str]]) -> List[List[float]]:
+ full_ids = [f"{article.id}_{str(idx).zfill(6)}" for idx, _ in chunks]
+ _embeddings = self.pinecone_db.get_embeddings_by_ids(full_ids)
+
+ embeddings = []
+ for (_, chunk), (_, embedding) in zip(chunks, _embeddings):
+ if embedding is None:
+ embedding, _ = get_embedding(chunk, article.source)
+ embeddings.append(torch.tensor(embedding))
+
+ return embeddings
+
+ def _generate_pairs(
+ self, session, start=0, end=None, neg_pos_proportion=0.5
+ ) -> Generator[Tuple[List[float], List[float], int], None, None]:
+ end = end or self.total_articles
+
+ batches_yielded = 0
+ while start < end:
+ start += 1
+ if random.random() < neg_pos_proportion:
+ # Positive pairs
+ article = self._fetch_random_articles(session)[0]
+ chunks = self._get_random_chunks(article, 2)
+ if not chunks:
+ continue
+ embedding_1, embedding_2 = self._get_embeddings(article, chunks)
+ label = 1
+ else:
+ # Negative pairs
+ article1, article2 = self._fetch_random_articles(session, batch_size=2)
+ chunk1 = self._get_random_chunks(article1, 1)
+ chunk2 = self._get_random_chunks(article2, 1)
+ embedding_1, embedding_2 = (
+ self._get_embeddings(article1, chunk1)[0],
+ self._get_embeddings(article2, chunk2)[0],
+ )
+ label = 0
+ yield torch.tensor(embedding_1, dtype=torch.int64), torch.tensor(
+ embedding_2, dtype=torch.int64
+ ), torch.tensor(label, dtype=torch.int64)
+ batches_yielded += 1
+
+ if self.num_batches_per_epoch and batches_yielded >= self.num_batches_per_epoch:
+ break
diff --git a/align_data/embeddings/finetuning/training.py b/align_data/embeddings/finetuning/training.py
new file mode 100644
index 00000000..cd1d7845
--- /dev/null
+++ b/align_data/embeddings/finetuning/training.py
@@ -0,0 +1,177 @@
+import os
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch.utils.data import DataLoader
+
+from align_data.embeddings.finetuning.finetuning_dataset import FinetuningDataset
+from align_data.settings import (
+ PINECONE_VALUES_DIMS,
+ DEVICE,
+ OPENAI_FINETUNED_LAYER_PATH,
+ OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH,
+)
+
+
+class ContrastiveLoss(nn.Module):
+ def __init__(self, margin=2.0):
+ super(ContrastiveLoss, self).__init__()
+ self.margin = margin
+
+ def forward(self, output1, output2, label):
+ euclidean_distance = nn.functional.pairwise_distance(output1, output2)
+ loss_contrastive = torch.mean(
+ (1 - label) * torch.pow(euclidean_distance, 2)
+ + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
+ )
+
+ return loss_contrastive
+
+
+class NonLinearFineTuneModel(nn.Module):
+ def __init__(self, embedding_dim=PINECONE_VALUES_DIMS, hidden_dim=2000, dropout=0.5):
+ super(FineTuneModel, self).__init__()
+
+ self.fc1 = nn.Linear(embedding_dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, embedding_dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = nn.functional.relu(self.fc1(x))
+ x = self.dropout(x)
+ x = self.fc2(x)
+ return x
+
+
+class FineTuneModel(nn.Module):
+ def __init__(self, embedding_dim=PINECONE_VALUES_DIMS):
+ super(FineTuneModel, self).__init__()
+
+ self.fc = nn.Linear(embedding_dim, embedding_dim)
+
+ def forward(self, x):
+ x = self.fc(x)
+ return x
+
+
+def train(model, dataloader, optimizer, criterion):
+ model.train()
+ total_loss = 0.0
+
+ for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader):
+ text1_embedding = text1_embedding.to(DEVICE)
+ text2_embedding = text2_embedding.to(DEVICE)
+ target = target.float().to(DEVICE)
+
+ optimizer.zero_grad()
+
+ output1 = model(text1_embedding)
+ output2 = model(text2_embedding)
+
+ loss = criterion(output1, output2, target)
+ loss.backward()
+
+ # Gradient clipping
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+
+ optimizer.step()
+
+ total_loss += loss.item()
+
+ return total_loss / len(dataloader)
+
+
+def validate(model, dataloader, criterion):
+ model.eval()
+ total_loss = 0.0
+
+ with torch.no_grad():
+ for batch_idx, (text1_embedding, text2_embedding, target) in enumerate(dataloader):
+ text1_embedding = text1_embedding.to(DEVICE)
+ text2_embedding = text2_embedding.to(DEVICE)
+ target = target.float().to(DEVICE)
+
+ output1 = model(text1_embedding)
+ output2 = model(text2_embedding)
+
+ loss = criterion(output1, output2, target)
+ total_loss += loss.item()
+
+ return total_loss / len(dataloader)
+
+
+def finetune_embeddings():
+ # Hyperparameters & Configuration
+ EPOCHS = 100
+ BATCH_PER_EPOCH = 20
+ BATCH_SIZE = 64
+ LEARNING_RATE = 5.0000e-02
+ MARGIN = 2.0
+
+ dataset = FinetuningDataset(num_batches_per_epoch=BATCH_PER_EPOCH)
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=5)
+
+ model = FineTuneModel().to(DEVICE)
+ model = load_best_model_if_exists(model)
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
+ scheduler = ReduceLROnPlateau(optimizer, "min", patience=2, factor=0.5, verbose=True)
+ criterion = ContrastiveLoss(MARGIN)
+
+ # Assuming you've split your data and have a separate validation set
+ validation_dataset = FinetuningDataset(num_batches_per_epoch=BATCH_PER_EPOCH)
+ validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, num_workers=5)
+ best_val_loss = validate(model, validation_dataloader, criterion)
+ print(f"Initial validation loss (from loaded model or new model): {best_val_loss:.4f}")
+
+ epochs_without_improvement = 0
+ max_epochs_without_improvement = 15 # stop after 5 epochs without improvement
+
+ for epoch in range(EPOCHS):
+ train_loss = train(model, dataloader, optimizer, criterion)
+ validate_loss = validate(model, validation_dataloader, criterion)
+
+ scheduler.step(validate_loss)
+ if validate_loss < best_val_loss:
+ best_val_loss = validate_loss
+ torch.save(model.state_dict(), OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH)
+ epochs_without_improvement = 0
+ else:
+ epochs_without_improvement += 1
+
+ print(
+ f"Epoch: {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Validation Loss: {validate_loss:.4f}"
+ )
+
+ if epochs_without_improvement >= max_epochs_without_improvement:
+ print("Early stopping due to no improvement in validation loss.")
+ break
+
+ torch.save(model.state_dict(), OPENAI_FINETUNED_LAYER_PATH)
+
+
+### HELPER FUNCTIONS ###
+
+
+def load_best_model_if_exists(model):
+ """
+ Load the best saved model if it exists.
+
+ Parameters:
+ - model (torch.nn.Module): The model architecture.
+
+ Returns:
+ - model (torch.nn.Module): The loaded model.
+ """
+ if os.path.exists(OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH):
+ model.load_state_dict(
+ torch.load(OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH, map_location=DEVICE)
+ )
+ print(f"Loaded model from {OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH}.")
+ else:
+ print(
+ f"No saved model found at {OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH}. Starting from scratch."
+ )
+
+ return model
diff --git a/align_data/pinecone/__init__.py b/align_data/embeddings/pinecone/__init__.py
similarity index 100%
rename from align_data/pinecone/__init__.py
rename to align_data/embeddings/pinecone/__init__.py
diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py
new file mode 100644
index 00000000..dd2990d3
--- /dev/null
+++ b/align_data/embeddings/pinecone/pinecone_db_handler.py
@@ -0,0 +1,142 @@
+# dataset/pinecone_db_handler.py
+import logging
+from typing import List, Tuple
+
+import pinecone
+from pinecone.core.client.models import ScoredVector
+
+from align_data.embeddings.embedding_utils import get_embedding
+from align_data.embeddings.pinecone.pinecone_models import (
+ PineconeEntry,
+ PineconeMetadata,
+)
+from align_data.settings import (
+ PINECONE_INDEX_NAME,
+ PINECONE_VALUES_DIMS,
+ PINECONE_METRIC,
+ PINECONE_API_KEY,
+ PINECONE_ENVIRONMENT,
+ PINECONE_NAMESPACE,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class PineconeDB:
+ def __init__(
+ self,
+ index_name: str = PINECONE_INDEX_NAME,
+ values_dims: int = PINECONE_VALUES_DIMS,
+ metric: str = PINECONE_METRIC,
+ create_index: bool = False,
+ log_index_stats: bool = False,
+ ):
+ self.index_name = index_name
+ self.values_dims = values_dims
+ self.metric = metric
+
+ pinecone.init(
+ api_key=PINECONE_API_KEY,
+ environment=PINECONE_ENVIRONMENT,
+ )
+
+ if create_index:
+ self.create_index()
+
+ self.index = pinecone.Index(index_name=self.index_name)
+
+ if log_index_stats:
+ index_stats_response = self.index.describe_index_stats()
+ logger.info(f"{self.index_name}:\n{index_stats_response}")
+
+ def upsert_entry(self, pinecone_entry: PineconeEntry, upsert_size: int = 100):
+ vectors = pinecone_entry.create_pinecone_vectors()
+ self.index.upsert(vectors=vectors, batch_size=upsert_size, namespace=PINECONE_NAMESPACE)
+
+ def query_vector(
+ self,
+ query: List[float],
+ top_k: int = 10,
+ include_values: bool = False,
+ include_metadata: bool = True,
+ **kwargs,
+ ) -> List[ScoredVector]:
+ assert not isinstance(
+ query, str
+ ), "query must be a list of floats. Use query_PineconeDB_text for text queries"
+
+ query_response = self.index.query(
+ vector=query,
+ top_k=top_k,
+ include_values=include_values,
+ include_metadata=include_metadata,
+ **kwargs,
+ namespace=PINECONE_NAMESPACE,
+ )
+
+ return [
+ ScoredVector(
+ id=match["id"],
+ score=match["score"],
+ metadata=PineconeMetadata(**match["metadata"]),
+ )
+ for match in query_response["matches"]
+ ]
+
+ def query_text(
+ self,
+ query: str,
+ top_k: int = 10,
+ include_values: bool = False,
+ include_metadata: bool = True,
+ **kwargs,
+ ) -> List[ScoredVector]:
+ query_vector = get_embedding(query)[0]
+ return self.query_vector(
+ query=query_vector,
+ top_k=top_k,
+ include_values=include_values,
+ include_metadata=include_metadata,
+ **kwargs,
+ )
+
+ def delete_entries(self, ids):
+ self.index.delete(filter={"hash_id": {"$in": ids}})
+
+ def create_index(self, replace_current_index: bool = True):
+ if replace_current_index:
+ self.delete_index()
+
+ pinecone.create_index(
+ name=self.index_name,
+ dimension=self.values_dims,
+ metric=self.metric,
+ metadata_config={"indexed": list(PineconeMetadata.__annotations__.keys())},
+ )
+
+ def delete_index(self):
+ if self.index_name in pinecone.list_indexes():
+ logger.info(f"Deleting index '{self.index_name}'.")
+ pinecone.delete_index(self.index_name)
+
+ def get_embeddings_by_ids(self, ids: List[str]) -> List[Tuple[str, List[float] | None]]:
+ """
+ Fetch embeddings for given entry IDs from Pinecone.
+
+ Args:
+ - ids (List[str]): List of entry IDs for which embeddings are to be fetched.
+
+ Returns:
+ - List[Tuple[str, List[float] | None]]: List of tuples containing ID and its corresponding embedding.
+ """
+ # TODO: check that this still works
+ vectors = self.index.fetch(
+ ids=ids,
+ namespace=PINECONE_NAMESPACE,
+ )["vectors"]
+ return [(id, vectors.get(id, {}).get("values", None)) for id in ids]
+
+
+def strip_block(text: str) -> str:
+ return "\n".join(text.split("\n")[1:])
diff --git a/align_data/embeddings/pinecone/pinecone_models.py b/align_data/embeddings/pinecone/pinecone_models.py
new file mode 100644
index 00000000..fd7b67eb
--- /dev/null
+++ b/align_data/embeddings/pinecone/pinecone_models.py
@@ -0,0 +1,77 @@
+from typing import List, TypedDict
+
+from pydantic import BaseModel, validator
+from pinecone.core.client.models import Vector
+
+
+class MissingFieldsError(Exception):
+ pass
+
+
+class MissingEmbeddingModelError(Exception):
+ pass
+
+
+class PineconeMetadata(TypedDict):
+ hash_id: str
+ source: str
+ title: str
+ url: str
+ date_published: float
+ authors: List[str]
+ text: str
+
+
+class PineconeEntry(BaseModel):
+ hash_id: str
+ source: str
+ title: str
+ url: str
+ date_published: float
+ authors: List[str]
+ text_chunks: List[str]
+ embeddings: List[List[float]]
+
+ def __init__(self, **data):
+ """Check for missing (falsy) fields before initializing."""
+ missing_fields = [field for field, value in data.items() if not str(value).strip()]
+
+ if missing_fields:
+ raise MissingFieldsError(f"Missing fields: {missing_fields}")
+
+ super().__init__(**data)
+
+ def __repr__(self):
+ def make_small(chunk: str) -> str:
+ return (chunk[:45] + " [...] " + chunk[-45:]) if len(chunk) > 100 else chunk
+
+ def display_chunks(chunks_lst: List[str]) -> str:
+ chunks = ", ".join(f'"{make_small(chunk)}"' for chunk in chunks_lst)
+ return (
+ f"[{chunks[:450]} [...] {chunks[-450:]} ]" if len(chunks) > 1000 else f"[{chunks}]"
+ )
+
+ return f"PineconeEntry(hash_id={self.hash_id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={display_chunks(self.text_chunks)})"
+
+ @property
+ def chunk_num(self) -> int:
+ return len(self.text_chunks)
+
+ def create_pinecone_vectors(self) -> List[Vector]:
+ return [
+ Vector(
+ id=f"{self.hash_id}_{str(i).zfill(6)}",
+ values=self.embeddings[i],
+ metadata=PineconeMetadata(
+ hash_id=self.hash_id,
+ source=self.source,
+ title=self.title,
+ authors=self.authors,
+ url=self.url,
+ date_published=self.date_published,
+ text=self.text_chunks[i],
+ ),
+ )
+ for i in range(self.chunk_num)
+ if self.embeddings[i] # Skips flagged chunks
+ ]
diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py
new file mode 100644
index 00000000..b425ee9d
--- /dev/null
+++ b/align_data/embeddings/pinecone/update_pinecone.py
@@ -0,0 +1,129 @@
+from datetime import datetime
+import logging
+from itertools import islice
+from typing import Callable, List, Tuple, Generator, Iterator, Optional
+
+from sqlalchemy.orm import Session
+from pydantic import ValidationError
+
+from align_data.embeddings.embedding_utils import get_embeddings
+from align_data.db.models import Article
+from align_data.db.session import make_session, stream_pinecone_updates
+from align_data.embeddings.pinecone.pinecone_db_handler import PineconeDB
+from align_data.embeddings.pinecone.pinecone_models import (
+ PineconeEntry, MissingFieldsError, MissingEmbeddingModelError
+)
+from align_data.embeddings.text_splitter import ParagraphSentenceUnitTextSplitter
+
+
+logger = logging.getLogger(__name__)
+
+
+# Define type aliases for the Callables
+LengthFunctionType = Callable[[str], int]
+TruncateFunctionType = Callable[[str, int], str]
+
+
+class PineconeUpdater:
+ def __init__(self):
+ self.text_splitter = ParagraphSentenceUnitTextSplitter()
+ self.pinecone_db = PineconeDB()
+
+ def update(self, custom_sources: List[str], force_update: bool = False):
+ """
+ Update the given sources. If no sources are provided, updates all sources.
+
+ :param custom_sources: List of sources to update.
+ """
+ with make_session() as session:
+ articles_to_update_stream = stream_pinecone_updates(
+ session, custom_sources, force_update
+ )
+ for batch in self.batch_entries(articles_to_update_stream):
+ self.save_batch(session, batch)
+
+ def save_batch(self, session: Session, batch: List[Tuple[Article, PineconeEntry]]):
+ try:
+ for article, pinecone_entry in batch:
+ self.pinecone_db.upsert_entry(pinecone_entry)
+
+ article.pinecone_update_required = False
+ session.add(article)
+
+ session.commit()
+
+ except Exception as e:
+ # Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking
+ logger.error(e)
+ session.rollback()
+
+ def batch_entries(
+ self, article_stream: Generator[Article, None, None]
+ ) -> Iterator[List[Tuple[Article, PineconeEntry]]]:
+ while batch := tuple(islice(article_stream, 10)):
+ yield [
+ (article, pinecone_entry)
+ for article in batch
+ if (pinecone_entry := self._make_pinecone_entry(article)) is not None
+ ]
+
+ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None:
+ try:
+ text_chunks = get_text_chunks(article, self.text_splitter)
+ embeddings, moderation_results = get_embeddings(text_chunks, article.source)
+
+ if any(result['flagged'] for result in moderation_results):
+ flagged_text_chunks = [f"Chunk {i}: \"{text}\"" for i, (text, result) in enumerate(zip(text_chunks, moderation_results)) if result["flagged"]]
+ logger.warning(f"OpenAI moderation flagged text chunks for the following article: {article.id}")
+ article.append_comment(f"OpenAI moderation flagged the following text chunks: {flagged_text_chunks}")
+
+ return PineconeEntry(
+ hash_id=article.id, # the hash_id of the article
+ source=article.source,
+ title=article.title,
+ url=article.url,
+ date_published=article.date_published.timestamp(),
+ authors=[author.strip() for author in article.authors.split(",") if author.strip()],
+ text_chunks=text_chunks,
+ embeddings=embeddings,
+ )
+ except (ValueError, TypeError, AttributeError, ValidationError, MissingFieldsError, MissingEmbeddingModelError) as e:
+ logger.warning(e)
+ article.append_comment(f"Error encountered while processing this article: {e}")
+ return None
+
+ except Exception as e:
+ logger.error(e)
+ raise
+
+
+def get_text_chunks(
+ article: Article, text_splitter: ParagraphSentenceUnitTextSplitter
+) -> List[str]:
+ title = article.title.replace("\n", " ")
+
+ authors_lst = [author.strip() for author in article.authors.split(",")]
+ authors = get_authors_str(authors_lst)
+
+ signature = f"Title: {title}; Author(s): {authors}."
+ text_chunks = text_splitter.split_text(article.text)
+ return [f'###{signature}###\n"""{text_chunk}"""' for text_chunk in text_chunks]
+
+
+def get_authors_str(authors_lst: List[str]) -> str:
+ if not authors_lst:
+ return "n/a"
+
+ if len(authors_lst) == 1:
+ authors_str = authors_lst[0]
+ else:
+ authors_lst = authors_lst[:4]
+ authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}"
+
+ authors_str = authors_str.replace("\n", " ")
+
+ # Truncate if necessary
+ if len(authors_str) > 500:
+ authors_str = authors_str[:497] + "..."
+
+ return authors_str
diff --git a/align_data/pinecone/text_splitter.py b/align_data/embeddings/text_splitter.py
similarity index 89%
rename from align_data/pinecone/text_splitter.py
rename to align_data/embeddings/text_splitter.py
index b8af09a3..a364415e 100644
--- a/align_data/pinecone/text_splitter.py
+++ b/align_data/embeddings/text_splitter.py
@@ -1,5 +1,3 @@
-# dataset/text_splitter.py
-
from typing import List, Callable, Any
from langchain.text_splitter import TextSplitter
from nltk.tokenize import sent_tokenize
@@ -11,8 +9,6 @@
StrToIntFunction = Callable[[str], int]
StrIntBoolToStrFunction = Callable[[str, int, bool], str]
-def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
- return string[-length:] if from_end else string[:length]
def default_truncate_function(string: str, length: int, from_end: bool = False) -> str:
return string[-length:] if from_end else string[:length]
@@ -26,7 +22,7 @@ class ParagraphSentenceUnitTextSplitter(TextSplitter):
@param length_function: A function that returns the length of a string in units. Defaults to len().
@param truncate_function: A function that truncates a string to a given unit length.
"""
-
+
DEFAULT_MIN_CHUNK_SIZE: int = 900
DEFAULT_MAX_CHUNK_SIZE: int = 1100
DEFAULT_LENGTH_FUNCTION: StrToIntFunction = len
@@ -38,7 +34,7 @@ def __init__(
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE,
length_function: StrToIntFunction = DEFAULT_LENGTH_FUNCTION,
truncate_function: StrIntBoolToStrFunction = DEFAULT_TRUNCATE_FUNCTION,
- **kwargs: Any
+ **kwargs: Any,
):
super().__init__(**kwargs)
self.min_chunk_size = min_chunk_size
@@ -49,6 +45,9 @@ def __init__(
def split_text(self, text: str) -> List[str]:
"""Split text into chunks of length between min_chunk_size and max_chunk_size."""
+ if not text:
+ return []
+
blocks: List[str] = []
current_block: str = ""
@@ -90,13 +89,11 @@ def _handle_large_paragraph(self, current_block: str, blocks: List[str], paragra
def _truncate_large_block(self, current_block: str, blocks: List[str]) -> str:
while self._length_function(current_block) > self.max_chunk_size:
# Truncate current_block to max size, set remaining text as current_block
- truncated_block = self._truncate_function(
- current_block, self.max_chunk_size
- )
+ truncated_block = self._truncate_function(current_block, self.max_chunk_size, False)
blocks.append(truncated_block)
- current_block = current_block[len(truncated_block):].lstrip()
-
+ current_block = current_block[len(truncated_block) :].lstrip()
+
return current_block
def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str]:
@@ -107,9 +104,7 @@ def _handle_remaining_text(self, last_block: str, blocks: List[str]) -> List[str
if self.min_chunk_size - len_last_block > 0:
# Add text from previous block to last block if last_block is too short
part_prev_block = self._truncate_function(
- string=blocks[-1],
- length=self.min_chunk_size - len_last_block,
- from_end=True
+ blocks[-1], self.min_chunk_size - len_last_block, True
)
last_block = part_prev_block + last_block
diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py
deleted file mode 100644
index d8f565df..00000000
--- a/align_data/pinecone/pinecone_db_handler.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# dataset/pinecone_db_handler.py
-
-import logging
-from typing import Dict
-
-import pinecone
-
-from align_data.settings import (
- PINECONE_INDEX_NAME,
- PINECONE_VALUES_DIMS,
- PINECONE_METRIC,
- PINECONE_METADATA_KEYS,
- PINECONE_API_KEY,
- PINECONE_ENVIRONMENT,
-)
-
-
-logger = logging.getLogger(__name__)
-
-
-class PineconeDB:
- def __init__(
- self,
- index_name: str = PINECONE_INDEX_NAME,
- values_dims: int = PINECONE_VALUES_DIMS,
- metric: str = PINECONE_METRIC,
- metadata_keys: list = PINECONE_METADATA_KEYS,
- create_index: bool = False,
- log_index_stats: bool = True,
- ):
- self.index_name = index_name
- self.values_dims = values_dims
- self.metric = metric
- self.metadata_keys = metadata_keys
-
- pinecone.init(
- api_key=PINECONE_API_KEY,
- environment=PINECONE_ENVIRONMENT,
- )
-
- if create_index:
- self.create_index()
-
- self.index = pinecone.Index(index_name=self.index_name)
-
- if log_index_stats:
- index_stats_response = self.index.describe_index_stats()
- logger.info(f"{self.index_name}:\n{index_stats_response}")
-
- def upsert_entry(self, entry: Dict, upsert_size=100):
- self.index.upsert(
- vectors=list(
- zip(
- [
- f"{entry['id']}_{str(i).zfill(6)}"
- for i in range(len(entry["text_chunks"]))
- ],
- entry["embeddings"].tolist(),
- [
- {
- "entry_id": entry["id"],
- "source": entry["source"],
- "title": entry["title"],
- "authors": entry["authors"],
- "text": text_chunk,
- }
- for text_chunk in entry["text_chunks"]
- ],
- )
- ),
- batch_size=upsert_size,
- )
-
- def delete_entries(self, ids):
- self.index.delete(filter={"entry_id": {"$in": ids}})
-
- def create_index(self, replace_current_index: bool = True):
- if replace_current_index:
- self.delete_index()
-
- pinecone.create_index(
- name=self.index_name,
- dimension=self.values_dims,
- metric=self.metric,
- metadata_config={"indexed": self.metadata_keys},
- )
-
- def delete_index(self):
- if self.index_name in pinecone.list_indexes():
- logger.info(f"Deleting index '{self.index_name}'.")
- pinecone.delete_index(self.index_name)
diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py
deleted file mode 100644
index 5821a276..00000000
--- a/align_data/pinecone/update_pinecone.py
+++ /dev/null
@@ -1,192 +0,0 @@
-from datetime import datetime
-import logging
-import numpy as np
-import os
-from itertools import islice
-from typing import Callable, List, Tuple, Generator
-
-import openai
-from pydantic import BaseModel, ValidationError, validator
-
-from align_data.db.models import Article
-from align_data.db.session import make_session, stream_pinecone_updates
-from align_data.pinecone.pinecone_db_handler import PineconeDB
-from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter
-from align_data.settings import (
- USE_OPENAI_EMBEDDINGS,
- OPENAI_EMBEDDINGS_MODEL,
- OPENAI_EMBEDDINGS_DIMS,
- OPENAI_EMBEDDINGS_RATE_LIMIT,
- SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
- SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS,
- CHUNK_SIZE,
- MAX_NUM_AUTHORS_IN_SIGNATURE,
- EMBEDDING_LENGTH_BIAS,
-)
-
-logger = logging.getLogger(__name__)
-
-
-# Define type aliases for the Callables
-LengthFunctionType = Callable[[str], int]
-TruncateFunctionType = Callable[[str, int], str]
-
-
-class PineconeEntry(BaseModel):
- id: str
- source: str
- title: str
- url: str
- date_published: datetime
- authors: List[str]
- text_chunks: List[str]
- embeddings: np.ndarray
-
- class Config:
- arbitrary_types_allowed = True
-
- def __repr__(self):
- return f"PineconeEntry(id={self.id!r}, source={self.source!r}, title={self.title!r}, url={self.url!r}, date_published={self.date_published!r}, authors={self.authors!r}, text_chunks={self.text_chunks[:5]!r})"
-
- @validator(
- "id",
- "source",
- "title",
- "url",
- "date_published",
- "authors",
- "text_chunks",
- pre=True,
- always=True,
- )
- def empty_strings_not_allowed(cls, value):
- if not str(value).strip():
- raise ValueError("Attribute should not be empty.")
- return value
-
-
-class PineconeUpdater:
- def __init__(
- self,
- min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE,
- max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE,
- length_function: LengthFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_LENGTH_FUNCTION,
- truncate_function: TruncateFunctionType = ParagraphSentenceUnitTextSplitter.DEFAULT_TRUNCATE_FUNCTION,
- ):
- self.min_chunk_size = min_chunk_size
- self.max_chunk_size = max_chunk_size
- self.length_function = length_function
- self.truncate_function = truncate_function
-
- self.text_splitter = ParagraphSentenceUnitTextSplitter(
- min_chunk_size=self.min_chunk_size,
- max_chunk_size=self.max_chunk_size,
- length_function=self.length_function,
- truncate_function=self.truncate_function,
- )
- self.pinecone_db = PineconeDB()
-
- if USE_OPENAI_EMBEDDINGS:
- import openai
-
- openai.api_key = os.environ["OPENAI_API_KEY"]
- else:
- import torch
- from langchain.embeddings import HuggingFaceEmbeddings
-
- self.hf_embeddings = HuggingFaceEmbeddings(
- model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
- model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
- encode_kwargs={"show_progress_bar": False},
- )
-
- def save_batch(self, session, batch):
- try:
- for article, pinecone_entry in batch:
- self.pinecone_db.upsert_entry(pinecone_entry.dict())
- article.pinecone_update_required = False
- session.add(article)
- session.commit()
- except Exception as e:
- # Rollback on any kind of error. The next run will redo this batch, but in the meantime keep trucking
- logger.error(e)
- session.rollback()
-
- def update(self, custom_sources: List[str]):
- """
- Update the given sources. If no sources are provided, updates all sources.
-
- :param custom_sources: List of sources to update.
- """
- with make_session() as session:
- entries_stream = stream_pinecone_updates(session, custom_sources)
- for batch in self.batch_entries(entries_stream):
- self.save_batch(session, batch)
-
- def _make_pinecone_update(self, article):
- try:
- text_chunks = self.get_text_chunks(article)
- return article, PineconeEntry(
- id=article.id,
- source=article.source,
- title=article.title,
- url=article.url,
- date_published=article.date_published,
- authors=[
- author.strip()
- for author in article.authors.split(",")
- if author.strip()
- ],
- text_chunks=text_chunks,
- embeddings=self.extract_embeddings(
- text_chunks, [article.source] * len(text_chunks)
- ),
- )
- except (ValueError, ValidationError) as e:
- logger.exception(e)
-
- def batch_entries(
- self, article_stream: Generator[Article, None, None]
- ) -> Generator[List[Tuple[Article, PineconeEntry]], None, None]:
- items = iter(article_stream)
- while batch := tuple(islice(items, 10)):
- yield list(filter(None, map(self._make_pinecone_update, batch)))
-
- def get_text_chunks(self, article: Article) -> List[str]:
- signature = f"Title: {article.title}, Author(s): {self.get_authors_str(article.authors)}"
- text_chunks = self.text_splitter.split_text(article.text)
- text_chunks = [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks]
- return text_chunks
-
- def extract_embeddings(self, chunks_batch, sources_batch):
- if USE_OPENAI_EMBEDDINGS:
- return self.get_openai_embeddings(chunks_batch, sources_batch)
- else:
- return np.array(
- self.hf_embeddings.embed_documents(chunks_batch, sources_batch)
- )
-
- @staticmethod
- def get_openai_embeddings(chunks, sources=""):
- embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS))
-
- openai_output = openai.Embedding.create(
- model=OPENAI_EMBEDDINGS_MODEL, input=chunks
- )["data"]
-
- for i, (embedding, source) in enumerate(zip(openai_output, sources)):
- bias = EMBEDDING_LENGTH_BIAS.get(source, 1.0)
- embeddings[i] = bias * np.array(embedding["embedding"])
-
- return embeddings
-
- @staticmethod
- def get_authors_str(authors_lst: List[str]) -> str:
- if authors_lst == []:
- return "n/a"
- if len(authors_lst) == 1:
- return authors_lst[0]
- else:
- authors_lst = authors_lst[:MAX_NUM_AUTHORS_IN_SIGNATURE]
- authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}"
- return authors_str
diff --git a/align_data/settings.py b/align_data/settings.py
index d88bcca3..59304c08 100644
--- a/align_data/settings.py
+++ b/align_data/settings.py
@@ -1,5 +1,8 @@
import os
import logging
+from typing import Dict
+import openai
+import torch
from dotenv import load_dotenv
load_dotenv()
@@ -24,9 +27,12 @@
"METADATA_OUTPUT_SPREADSHEET", "1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4"
)
-### YouTube ###
+### YOUTUBE ###
YOUTUBE_API_KEY = os.environ.get("YOUTUBE_API_KEY")
+### Airtable ###
+AIRTABLE_API_KEY = os.environ.get("AIRTABLE_API_KEY")
+
### MYSQL ###
user = os.environ.get("ARD_DB_USER", "user")
password = os.environ.get("ARD_DB_PASSWORD", "we all live in a yellow submarine")
@@ -38,13 +44,16 @@
### EMBEDDINGS ###
USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used.
-EMBEDDING_LENGTH_BIAS = {
- "aisafety.info": 1.05, # In search, favor AISafety.info entries.
+EMBEDDING_LENGTH_BIAS: Dict[str, float] = {
+ # TODO: Experiement with these values. For now, let's remove the bias.
+ # "aisafety.info": 1.05, # In search, favor AISafety.info entries.
}
OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002"
OPENAI_EMBEDDINGS_DIMS = 1536
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500
+openai.api_key = os.environ.get("OPENAI_API_KEY", None)
+openai.organization = os.environ.get("OPENAI_ORGANIZATION", None)
SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768
@@ -54,13 +63,20 @@
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
PINECONE_VALUES_DIMS = (
- OPENAI_EMBEDDINGS_DIMS
- if USE_OPENAI_EMBEDDINGS
- else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS
+ OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS
)
PINECONE_METRIC = "dotproduct"
-PINECONE_METADATA_KEYS = ["entry_id", "source", "title", "authors", "text", "url"]
+PINECONE_NAMESPACE = os.environ.get("PINECONE_NAMESPACE", "normal") # "normal" or "finetuned"
+
+### FINE-TUNING ###
+OPENAI_FINETUNED_LAYER_PATH = os.environ.get(
+ "OPENAI_FINETUNED_LAYER_PATH", "align_data/finetuning/data/finetuned_model.pth"
+)
+OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH = os.environ.get(
+ "OPENAI_CURRENT_BEST_FINETUNED_LAYER_PATH",
+ "align_data/finetuning/data/best_finetuned_model.pth",
+)
### MISCELLANEOUS ###
-CHUNK_SIZE = 1750
-MAX_NUM_AUTHORS_IN_SIGNATURE = 3
+MIN_CONFIDENCE = 50
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
diff --git a/align_data/sources/agisf/__init__.py b/align_data/sources/agisf/__init__.py
new file mode 100644
index 00000000..0fbf757a
--- /dev/null
+++ b/align_data/sources/agisf/__init__.py
@@ -0,0 +1,36 @@
+from align_data.common.alignment_dataset import MultiDataset
+from align_data.sources.airtable import AirtableDataset
+from align_data.sources.agisf.agisf import AGISFPodcastDataset
+
+
+datasets = [
+ AirtableDataset(
+ name='agisf_governance',
+ base_id='app9q0E0jlDWlsR0z',
+ table_id='tblgTb3kszvSbo2Mb',
+ mappings={
+ 'title': '[>] Resource',
+ 'url': '[h] [>] Link',
+ 'source_type': '[h] [>] Type',
+ 'summary': '[h] Resource guide',
+ 'authors': 'Author(s) (from Resources)',
+ },
+ processors = {
+ 'source_type': lambda val: val[0] if val else None,
+ 'authors': lambda val: val and [v.strip() for v in val.split(',')]
+ }
+ ),
+ AGISFPodcastDataset(
+ name='agisf_readings_alignment',
+ url='https://feeds.type3.audio/agi-safety-fundamentals--alignment.rss',
+ ),
+ AGISFPodcastDataset(
+ name='agisf_readings_governance',
+ url='https://feeds.type3.audio/agi-safety-fundamentals--governance.rss',
+ ),
+]
+
+
+AGISF_DATASETS = [
+ MultiDataset(name='agisf', datasets=datasets),
+]
diff --git a/align_data/sources/agisf/agisf.py b/align_data/sources/agisf/agisf.py
new file mode 100644
index 00000000..e56a60e4
--- /dev/null
+++ b/align_data/sources/agisf/agisf.py
@@ -0,0 +1,54 @@
+import re
+from typing import Any, Dict
+from bs4 import BeautifulSoup
+
+from align_data.common.html_dataset import RSSDataset
+from align_data.sources.articles.parsers import item_metadata
+from align_data.sources.utils import merge_dicts
+
+
+class AGISFPodcastDataset(RSSDataset):
+
+ regex = re.compile(r'^\[Week .*?\]\s+“(?P
.*?)”\s+by\s+(?P.*?)$')
+
+ @property
+ def feed_url(self):
+ return self.url
+
+ def fetch_contents(self, url: str) -> Dict[str, Any]:
+ contents = super().fetch_contents(url)
+ if extracted := self.regex.search(contents.get('title')):
+ return merge_dicts(contents, extracted.groupdict())
+ return contents
+
+ def _get_text(self, item):
+ contents = item_metadata(item['link'])
+ # Replace any non empty values in item. `item.update()` will happily insert Nones
+ for k, v in contents.items():
+ if v:
+ item[k] = v
+ return item.get('text')
+
+ def extract_authors(self, item):
+ authors = item.get("authors")
+ if not authors:
+ return self.authors
+ if isinstance(authors, str):
+ return [a.strip() for a in authors.split(',')]
+ return authors
+
+ def _extra_values(self, contents):
+ if summary := contents.get('summary'):
+ soup = BeautifulSoup(summary, "html.parser")
+ for el in soup.select('b'):
+ if el.next_sibling:
+ el.next_sibling.extract()
+ el.extract()
+ return {'summary': self._extract_markdown(soup)}
+ return {}
+
+ def process_entry(self, article):
+ article_url = self.get_item_key(article)
+ contents = self.get_contents(article_url)
+
+ return self.make_data_entry(contents)
diff --git a/align_data/sources/airtable.py b/align_data/sources/airtable.py
new file mode 100644
index 00000000..db24350e
--- /dev/null
+++ b/align_data/sources/airtable.py
@@ -0,0 +1,53 @@
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Iterable, Optional, Union
+
+from airtable import airtable
+from align_data.db.models import Article
+
+from align_data.settings import AIRTABLE_API_KEY, ARTICLE_MAIN_KEYS
+from align_data.common.alignment_dataset import AlignmentDataset
+from align_data.sources.articles.parsers import item_metadata
+from align_data.sources.utils import merge_dicts
+
+
+@dataclass
+class AirtableDataset(AlignmentDataset):
+
+ base_id: str
+ table_id: str
+ mappings: Dict[str, str]
+ processors: Dict[str, Callable[[Any], str]]
+ done_key = 'url'
+
+ def setup(self):
+ if not AIRTABLE_API_KEY:
+ raise ValueError('No AIRTABLE_API_KEY provided!')
+ super().setup()
+ self.at = airtable.Airtable(self.base_id, AIRTABLE_API_KEY)
+
+ def map_cols(self, item: Dict[str, Dict[str, str]]) -> Optional[Dict[str, Optional[str]]]:
+ fields = item.get('fields', {})
+ def map_col(k):
+ val = fields.get(self.mappings.get(k) or k)
+ if processor := self.processors.get(k):
+ val = processor(val)
+ return val
+
+ mapped = {k: map_col(k) for k in ARTICLE_MAIN_KEYS + ['summary']}
+ if (mapped.get('url') or '').startswith('http'):
+ return mapped
+
+ def get_item_key(self, item):
+ return item.get('url')
+
+ @property
+ def items_list(self) -> Iterable[Dict[str, Union[str, None]]]:
+ return filter(None, map(self.map_cols, self.at.iterate(self.table_id)))
+
+ def process_entry(self, entry) -> Optional[Article]:
+ contents = item_metadata(self.get_item_key(entry))
+ if not contents:
+ return None
+
+ entry['date_published'] = self._get_published_date(entry.get('date_published'))
+ return self.make_data_entry(merge_dicts(entry, contents), source=self.name)
diff --git a/align_data/sources/articles/__init__.py b/align_data/sources/articles/__init__.py
index da7f3a6b..6fd45fbc 100644
--- a/align_data/sources/articles/__init__.py
+++ b/align_data/sources/articles/__init__.py
@@ -1,10 +1,18 @@
from align_data.sources.articles.datasets import (
- ArxivPapers, EbookArticles, DocArticles, HTMLArticles,
- MarkdownArticles, PDFArticles, SpecialDocs, XMLArticles
+ ArxivPapers,
+ EbookArticles,
+ DocArticles,
+ HTMLArticles,
+ MarkdownArticles,
+ PDFArticles,
+ SpecialDocs,
+ XMLArticles,
)
from align_data.sources.articles.indices import IndicesDataset
+from align_data.common.alignment_dataset import MultiDataset
-ARTICLES_REGISTRY = [
+
+ARTICLES_DATASETS = [
PDFArticles(
name="pdfs",
spreadsheet_id="1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4",
@@ -36,14 +44,19 @@
sheet_id="1293295703",
),
SpecialDocs(
- 'special_docs',
- spreadsheet_id='1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI',
- sheet_id='980957638',
+ "special_docs",
+ spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI",
+ sheet_id="980957638",
),
+]
+
+
+ARTICLES_REGISTRY = [
+ MultiDataset(name='special_docs', datasets=ARTICLES_DATASETS),
ArxivPapers(
name="arxiv",
spreadsheet_id="1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI",
sheet_id="655836697",
),
- IndicesDataset('indices'),
+ IndicesDataset("indices"),
]
diff --git a/align_data/sources/articles/articles.py b/align_data/sources/articles/articles.py
index 3a210ead..a6a343f6 100644
--- a/align_data/sources/articles/articles.py
+++ b/align_data/sources/articles/articles.py
@@ -146,8 +146,7 @@ def check_new_articles(source_spreadsheet_id: str, source_sheet_name: str):
missing = [
item
for title, item in indices_items.items()
- if title not in current
- and not {item.get("url"), item.get("source_url")} & seen_urls
+ if title not in current and not {item.get("url"), item.get("source_url")} & seen_urls
]
if not missing:
logger.info("No new articles found")
@@ -163,14 +162,12 @@ def check_new_articles(source_spreadsheet_id: str, source_sheet_name: str):
"publication_title",
"source_type",
]
- res = source_sheet.append_rows(
- [[item.get(col) for col in columns] for item in missing]
- )
+ res = source_sheet.append_rows([[item.get(col) for col in columns] for item in missing])
updated = res["updates"]["updatedRows"]
logger.info("Added %s rows", updated)
return updated
def update_articles(csv_file, delimiter):
- dataset = ReplacerDataset(name='updater', csv_path=csv_file, delimiter=delimiter)
+ dataset = ReplacerDataset(name="updater", csv_path=csv_file, delimiter=delimiter)
dataset.add_entries(dataset.fetch_entries())
diff --git a/align_data/sources/articles/datasets.py b/align_data/sources/articles/datasets.py
index 95d18e17..d0fae22a 100644
--- a/align_data/sources/articles/datasets.py
+++ b/align_data/sources/articles/datasets.py
@@ -15,10 +15,16 @@
from align_data.db.models import Article
from align_data.sources.articles.google_cloud import fetch_file, fetch_markdown
from align_data.sources.articles.parsers import (
- HTML_PARSERS, extract_gdrive_contents, item_metadata, parse_domain
+ HTML_PARSERS,
+ extract_gdrive_contents,
+ item_metadata,
+ parse_domain,
)
from align_data.sources.articles.pdf import read_pdf
-from align_data.sources.arxiv_papers import fetch_arxiv, canonical_url as arxiv_canonical_url
+from align_data.sources.arxiv_papers import (
+ fetch_arxiv,
+ canonical_url as arxiv_cannonical_url,
+)
logger = logging.getLogger(__name__)
@@ -43,8 +49,8 @@ def get_item_key(self, item):
@property
def items_list(self):
- url = f'https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}'
- logger.info(f'Fetching {url}')
+ url = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/export?format=csv&gid={self.sheet_id}"
+ logger.info(f"Fetching {url}")
df = pd.read_csv(url)
return (
item
@@ -88,7 +94,6 @@ def process_entry(self, item):
class SpecialDocs(SpreadsheetDataset):
-
@property
def _query_items(self) -> Select[Tuple[Article]]:
special_docs_types = ["pdf", "html", "xml", "markdown", "docx"]
@@ -99,35 +104,39 @@ def get_contents(self, item) -> Dict:
if url := self.maybe(item, "source_url") or self.maybe(item, "url"):
contents = item_metadata(url)
- return dict(contents, **{
- 'url': self.maybe(item, "url"),
- 'title': self.maybe(item, "title") or contents.get('title'),
- 'source': contents.get('source_type') or self.name,
- 'source_url': self.maybe(item, "source_url"),
- 'source_type': contents.get('source_type') or self.maybe(item, "source_type"),
- 'date_published': self._get_published_date(self.maybe(item, 'date_published')) or contents.get('date_published'),
- 'authors': self.extract_authors(item) or contents.get('authors', []),
- 'text': contents.get('text'),
- 'status': 'Invalid' if contents.get('error') else None,
- 'comments': contents.get('error'),
- })
+ return dict(
+ contents,
+ **{
+ "url": self.maybe(item, "url"),
+ "title": self.maybe(item, "title") or contents.get("title"),
+ "source": contents.get("source_type") or self.name,
+ "source_url": self.maybe(item, "source_url"),
+ "source_type": contents.get("source_type") or self.maybe(item, "source_type"),
+ "date_published": self._get_published_date(self.maybe(item, "date_published"))
+ or contents.get("date_published"),
+ "authors": self.extract_authors(item) or contents.get("authors", []),
+ "text": contents.get("text"),
+ "status": "Invalid" if contents.get("error") else None,
+ "comments": contents.get("error"),
+ },
+ )
def not_processed(self, item):
- url = self.maybe(item, 'url')
- source_url = self.maybe(item, 'source_url')
+ url = self.maybe(item, "url")
+ source_url = self.maybe(item, "source_url")
return (
- self.get_item_key(item) not in self._outputted_items and
- url not in self._outputted_items and
- source_url not in self._outputted_items and
- (not url or arxiv_canonical_url(url) not in self._outputted_items) and
- (not source_url or arxiv_canonical_url(source_url) not in self._outputted_items)
+ self.get_item_key(item) not in self._outputted_items
+ and url not in self._outputted_items
+ and source_url not in self._outputted_items
+ and (not url or arxiv_cannonical_url(url) not in self._outputted_items)
+ and (not source_url or arxiv_cannonical_url(source_url) not in self._outputted_items)
)
def process_entry(self, item):
if ArxivPapers.is_arxiv(item.url):
contents = ArxivPapers.get_contents(item)
- contents['source'] = 'arxiv'
+ contents["source"] = "arxiv"
else:
contents = self.get_contents(item)
@@ -156,7 +165,7 @@ def _get_text(item):
domain = parse_domain(item.source_url)
if parser := HTML_PARSERS.get(domain):
res = parser(item.source_url)
- return res and res.get('text')
+ return res and res.get("text")
class EbookArticles(SpreadsheetDataset):
@@ -170,9 +179,7 @@ def setup(self):
def _get_text(self, item):
file_id = item.source_url.split("/")[-2]
- filename = download(
- output=str(self.files_path / f"{item.title}.epub"), id=file_id
- )
+ filename = download(output=str(self.files_path / f"{item.title}.epub"), id=file_id)
return convert_file(filename, "plain", "epub", extra_args=["--wrap=none"])
@@ -181,7 +188,7 @@ class XMLArticles(SpreadsheetDataset):
def _get_text(self, item):
vals = extract_gdrive_contents(item.source_url)
- return vals["text"]
+ return vals.get("text")
class MarkdownArticles(SpreadsheetDataset):
@@ -190,7 +197,7 @@ class MarkdownArticles(SpreadsheetDataset):
def _get_text(self, item):
file_id = item.source_url.split("/")[-2]
vals = fetch_markdown(file_id)
- return vals["text"]
+ return vals.get("text")
class DocArticles(SpreadsheetDataset):
@@ -223,12 +230,12 @@ def get_contents(cls, item) -> Dict:
contents = fetch_arxiv(item.url or item.source_url)
if cls.maybe(item, "authors") and item.authors.strip():
- contents['authors'] = [i.strip() for i in item.authors.split(',')]
+ contents["authors"] = [i.strip() for i in item.authors.split(",")]
if cls.maybe(item, "title"):
- contents['title'] = cls.maybe(item, "title")
+ contents["title"] = cls.maybe(item, "title")
- contents['date_published'] = cls._get_published_date(
- cls.maybe(item, "date_published") or contents.get('date_published')
+ contents["date_published"] = cls._get_published_date(
+ cls.maybe(item, "date_published") or contents.get("date_published")
)
return contents
diff --git a/align_data/sources/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py
index 23385b7b..ee6f2ddc 100644
--- a/align_data/sources/articles/google_cloud.py
+++ b/align_data/sources/articles/google_cloud.py
@@ -158,19 +158,19 @@ def fetch_markdown(file_id: str) -> Dict[str, str]:
"source_type": "markdown",
}
except Exception as e:
- return {'error': str(e)}
+ return {"error": str(e)}
def parse_grobid(contents: str | bytes) -> Dict[str, Any]:
if isinstance(contents, bytes):
contents = contents.decode('utf-8')
doc_dict = grobid_tei_xml.parse_document_xml(contents).to_dict()
- authors: List[str] = [author["full_name"].strip(' !') for author in doc_dict.get("header", {}).get("authors", [])]
+ authors: List[str] = [author["full_name"].strip(" !") for author in doc_dict.get("header", {}).get("authors", [])]
- if not doc_dict.get('body'):
+ if not doc_dict.get("body"):
return {
- 'error': 'No contents in XML file',
- 'source_type': 'xml',
+ "error": "No contents in XML file",
+ "source_type": "xml",
}
return {
@@ -183,21 +183,21 @@ def parse_grobid(contents: str | bytes) -> Dict[str, Any]:
def get_content_type(res: requests.Response) -> Set[str]:
- header = res.headers.get('Content-Type') or ''
- parts = [c_type.strip().lower() for c_type in header.split(';')]
+ header = res.headers.get("Content-Type") or ""
+ parts = [c_type.strip().lower() for c_type in header.split(";")]
return set(filter(None, parts))
def extract_gdrive_contents(link: str) -> Dict[str, Any]:
- file_id = link.split('/')[-2]
- url = f'https://drive.google.com/uc?id={file_id}'
- res = fetch(url, 'head')
+ file_id = link.split("/")[-2]
+ url = f"https://drive.google.com/uc?id={file_id}"
+ res = fetch(url, "head")
if res.status_code == 403:
- logger.error('Could not fetch the file at %s - 403 returned', link)
- return {'error': 'Could not read file from google drive - forbidden'}
+ logger.error("Could not fetch the file at %s - 403 returned", link)
+ return {"error": "Could not read file from google drive - forbidden"}
if res.status_code >= 400:
- logger.error('Could not fetch the file at %s - are you sure that link is correct?', link)
- return {'error': 'Could not read file from google drive'}
+ logger.error("Could not fetch the file at %s - are you sure that link is correct?", link)
+ return {"error": "Could not read file from google drive"}
result: Dict[str, Any] = {
'source_url': link,
@@ -206,16 +206,16 @@ def extract_gdrive_contents(link: str) -> Dict[str, Any]:
content_type = get_content_type(res)
if not content_type:
- result['error'] = 'no content type'
- elif content_type & {'application/octet-stream', 'application/pdf'}:
+ result["error"] = "no content type"
+ elif content_type & {"application/octet-stream", "application/pdf"}:
result.update(fetch_pdf(url))
- elif content_type & {'text/markdown'}:
+ elif content_type & {"text/markdown"}:
result.update(fetch_markdown(file_id))
- elif content_type & {'application/epub+zip', 'application/epub'}:
- result['source_type'] = 'ebook'
- elif content_type & {'text/html'}:
+ elif content_type & {"application/epub+zip", "application/epub"}:
+ result["source_type"] = "ebook"
+ elif content_type & {"text/html"}:
res = fetch(url)
- if 'Google Drive - Virus scan warning' in res.text:
+ if "Google Drive - Virus scan warning" in res.text:
soup = BeautifulSoup(res.content, "html.parser")
form_tag = soup.select_one('form')
@@ -229,30 +229,35 @@ def extract_gdrive_contents(link: str) -> Dict[str, Any]:
res = fetch(form_action_url)
content_type = get_content_type(res)
- if content_type & {'text/xml'}:
+ if content_type & {"text/xml"}:
result.update(parse_grobid(res.content))
- elif content_type & {'text/html'}:
+ elif content_type & {"text/html"}:
soup = BeautifulSoup(res.content, "html.parser")
- result.update({
- 'text': MarkdownConverter().convert_soup(soup.select_one('body')).strip(),
- 'source_type': 'html',
- })
+ result.update(
+ {
+ "text": MarkdownConverter().convert_soup(soup.select_one("body")).strip(),
+ "source_type": "html",
+ }
+ )
else:
- result['error'] = f'unknown content type: {content_type}'
+ result["error"] = f"unknown content type: {content_type}"
else:
- result['error'] = f'unknown content type: {content_type}'
+ result["error"] = f"unknown content type: {content_type}"
return result
def google_doc(url: str) -> Dict[str, Any]:
"""Fetch the contents of the given gdoc url as markdown."""
- res = re.search(r'https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/', url)
+ res = re.search(r"https://docs.google.com/document/(?:u/)?(?:0/)?d/(.*?)/", url)
if not res:
return {}
doc_id = res.group(1)
- body = fetch_element(f'https://docs.google.com/document/d/{doc_id}/export?format=html', 'body')
+ body = fetch_element(f"https://docs.google.com/document/d/{doc_id}/export?format=html", "body")
if body:
- return {'text': MarkdownConverter().convert_soup(body).strip(), 'source_url': url}
+ return {
+ "text": MarkdownConverter().convert_soup(body).strip(),
+ "source_url": url,
+ }
return {}
diff --git a/align_data/sources/articles/html.py b/align_data/sources/articles/html.py
index 9365bde5..9c1c15a1 100644
--- a/align_data/sources/articles/html.py
+++ b/align_data/sources/articles/html.py
@@ -75,9 +75,9 @@ def getter(url: str) -> Dict[str, Any]:
for e in elem.select(sel):
e.extract()
return {
- 'text': MarkdownConverter().convert_soup(elem).strip(),
- 'source_url': url,
- 'source_type': 'html',
+ "text": MarkdownConverter().convert_soup(elem).strip(),
+ "source_url": url,
+ "source_type": "html",
}
return getter
diff --git a/align_data/sources/articles/indices.py b/align_data/sources/articles/indices.py
index a673257d..b9cc1621 100644
--- a/align_data/sources/articles/indices.py
+++ b/align_data/sources/articles/indices.py
@@ -85,16 +85,6 @@ def format_anthropic(post):
}
-def format_transformer_circuits(item):
- if not item.get("href").startswith("http"):
- url = f'https://transformer-circuits.pub/{item.get("href")}'
- return {
- "title": get_text(item, "h3"),
- "url": url,
- "source_url": url,
- }
-
-
def format_safe_ai(item):
return {
"title": get_text(item, "h4"),
@@ -209,12 +199,6 @@ def fetch_all():
"a",
format_anthropic,
),
- indice_fetcher(
- "https://transformer-circuits.pub/",
- "div.toc",
- "a",
- format_transformer_circuits,
- ),
indice_fetcher(
"https://www.safe.ai/research",
"#guiding-principles",
@@ -264,48 +248,58 @@ def fetch_all():
class IndicesDataset(AlignmentDataset):
-
- done_key = 'url'
+ done_key = "url"
@property
def items_list(self):
return fetch_all().values()
def get_item_key(self, item):
- return item.get('url')
+ return item.get("url")
@staticmethod
def extract_authors(item):
- if authors := (item.get('authors') or '').strip():
- return [author.strip() for author in authors.split(',') if author.strip()]
+ if authors := (item.get("authors") or "").strip():
+ return [author.strip() for author in authors.split(",") if author.strip()]
return []
def process_entry(self, item):
contents = {}
- if url := item.get('source_url') or item.get('url'):
- contents= item_metadata(url)
-
- if not contents.get('text'):
- logger.error('Could not get text for %s (%s) - %s - skipping for now', item.get('title'), url, contents.get('error'))
+ url = item.get("source_url") or item.get("url")
+ if url:
+ contents = item_metadata(url)
+
+ if not contents.get("text"):
+ logger.error(
+ "Could not get text for %s (%s) - %s - skipping for now",
+ item.get("title"),
+ url,
+ contents.get("error"),
+ )
return None
# If the article is not an arxiv paper, just mark it as ignored - if in the future editors
# decide it's worth adding, it can be fetched then
- if parse_domain(contents.get('source_url') or '') != 'arxiv.org':
- return self.make_data_entry({
- 'source': self.name,
- 'url': self.get_item_key(item),
- 'title': item.get('title'),
- 'date_published': self._get_published_date(item.get('date_published')),
- 'authors': self.extract_authors(item),
- 'status': 'Ignored',
- 'comments': 'Added from indices',
- })
-
- return self.make_data_entry({
- 'source': self.name,
- 'url': self.get_item_key(item),
- 'title': item.get('title'),
- 'date_published': self._get_published_date(item.get('date_published')),
- 'authors': self.extract_authors(item),
- }, **contents)
+ if parse_domain(url or "") != "arxiv.org":
+ return self.make_data_entry(
+ {
+ "source": self.name,
+ "url": self.get_item_key(item),
+ "title": item.get("title"),
+ "date_published": self._get_published_date(item.get("date_published")),
+ "authors": self.extract_authors(item),
+ "status": "Ignored",
+ "comments": "Added from indices",
+ }
+ )
+
+ return self.make_data_entry(
+ {
+ "source": "arxiv",
+ "url": contents.get("url") or self.get_item_key(item),
+ "title": item.get("title"),
+ "date_published": self._get_published_date(item.get("date_published")),
+ "authors": self.extract_authors(item),
+ },
+ **contents,
+ )
diff --git a/align_data/sources/articles/parsers.py b/align_data/sources/articles/parsers.py
index c5fe325a..5dc1da3d 100644
--- a/align_data/sources/articles/parsers.py
+++ b/align_data/sources/articles/parsers.py
@@ -93,7 +93,7 @@ def error(error_msg: str):
def func(url: str) -> Dict[str, Any]:
if error_msg:
logger.error(error_msg)
- return {'error': error_msg, 'source_url': url}
+ return {"error": error_msg, "source_url": url}
return func
@@ -160,17 +160,17 @@ def getter(url: str) -> Dict[str, Any]:
"mediangroup.org": element_extractor("div.entry-content"),
"www.alexirpan.com": element_extractor("article"),
"www.incompleteideas.net": element_extractor("body"),
- "ai-alignment.com": MediumParser(name='html', url='ai-alignment.com'),
+ "ai-alignment.com": MediumParser(name="html", url="ai-alignment.com"),
"aisrp.org": element_extractor("article"),
"bounded-regret.ghost.io": element_extractor("div.post-content"),
"carnegieendowment.org": element_extractor(
"div.article-body", remove=[".no-print", ".related-pubs"]
),
- "casparoesterheld.com": element_extractor(
- ".entry-content", remove=["div.sharedaddy"]
- ),
+ "casparoesterheld.com": element_extractor(".entry-content", remove=["div.sharedaddy"]),
"cullenokeefe.com": element_extractor("div.sqs-block-content"),
- "deepmindsafetyresearch.medium.com": MediumParser(name='html', url='deepmindsafetyresearch.medium.com'),
+ "deepmindsafetyresearch.medium.com": MediumParser(
+ name="html", url="deepmindsafetyresearch.medium.com"
+ ),
"docs.google.com": google_doc,
"docs.microsoft.com": element_extractor("div.content"),
"digichina.stanford.edu": element_extractor("div.h_editor-content"),
@@ -185,7 +185,7 @@ def getter(url: str) -> Dict[str, Any]:
"link.springer.com": element_extractor("article.c-article-body"),
"longtermrisk.org": element_extractor("div.entry-content"),
"lukemuehlhauser.com": element_extractor("div.entry-content"),
- "medium.com": MediumParser(name='html', url='medium.com'),
+ "medium.com": MediumParser(name="html", url="medium.com"),
"openai.com": element_extractor("#content"),
"ought.org": element_extractor("div.BlogPostBodyContainer"),
"sideways-view.com": element_extractor("article", remove=["header"]),
@@ -200,10 +200,8 @@ def getter(url: str) -> Dict[str, Any]:
),
"theconversation.com": element_extractor("div.content-body"),
"thegradient.pub": element_extractor("div.c-content"),
- "towardsdatascience.com": MediumParser(name='html', url='towardsdatascience.com'),
- "unstableontology.com": element_extractor(
- ".entry-content", remove=["div.sharedaddy"]
- ),
+ "towardsdatascience.com": MediumParser(name="html", url="towardsdatascience.com"),
+ "unstableontology.com": element_extractor(".entry-content", remove=["div.sharedaddy"]),
"waitbutwhy.com": element_extractor("article", remove=[".entry-header"]),
"weightagnostic.github.io": element_extractor(
"dt-article", remove=["#authors_section", "dt-byline"]
@@ -211,9 +209,7 @@ def getter(url: str) -> Dict[str, Any]:
"cnas.org": element_extractor("#mainbar-toc"),
"econlib.org": element_extractor("div.post-content"),
"humanityplus.org": element_extractor("div.content"),
- "gleech.org": element_extractor(
- "article.post-content", remove=["center", "div.accordion"]
- ),
+ "gleech.org": element_extractor("article.post-content", remove=["center", "div.accordion"]),
"ibm.com": element_extractor("div:has(> p)"), # IBM's HTML is really ugly...
"microsoft.com": element_extractor("div.content-container"),
"mdpi.com": element_extractor(
@@ -289,9 +285,7 @@ def getter(url: str) -> Dict[str, Any]:
"jstor.org": doi_getter,
"ri.cmu.edu": get_pdf_from_page("a.pub-link"),
"risksciences.ucla.edu": get_pdf_from_page('a:-soup-contains("Download")'),
- "ssrn.com": get_pdf_from_page(
- '.abstract-buttons a.button-link:-soup-contains("Download")'
- ),
+ "ssrn.com": get_pdf_from_page('.abstract-buttons a.button-link:-soup-contains("Download")'),
"yjolt.org": get_pdf_from_page("span.file a"),
}
@@ -320,7 +314,7 @@ def item_metadata(url: str) -> Dict[str, Any]:
# there is a link to a pdf on it
if parser := HTML_PARSERS.get(domain):
res = parser(url)
- if res and 'error' not in res:
+ if res and "error" not in res:
# Proper contents were found on the page, so use them
return res
@@ -338,7 +332,9 @@ def item_metadata(url: str) -> Dict[str, Any]:
return {"error": f"No domain handler defined for {domain}"}
return {"error": "could not parse url"}
elif content_type & {"application/octet-stream", "application/pdf"}:
- # this looks like it could be a pdf - try to download it as one
+ if domain == "arxiv.org":
+ return fetch_arxiv(url)
+ # just download it as a pdf
return fetch_pdf(url)
elif content_type & {"application/epub+zip", "application/epub"}:
# it looks like an ebook. Assume it's fine.
diff --git a/align_data/sources/articles/updater.py b/align_data/sources/articles/updater.py
index 93971a86..ca2669e5 100644
--- a/align_data/sources/articles/updater.py
+++ b/align_data/sources/articles/updater.py
@@ -37,8 +37,9 @@ def maybe(item, key):
def items_list(self) -> List[Item]:
df = pd.read_csv(self.csv_path, delimiter=self.delimiter)
self.csv_items = [
- item for item in df.itertuples()
- if self.maybe(item, 'id') or self.maybe(item, 'hash_id')
+ item
+ for item in df.itertuples()
+ if self.maybe(item, "id") or self.maybe(item, "hash_id")
]
by_id = {id: item for item in self.csv_items if (id := self.maybe(item, 'id'))}
by_hash_id = {hash_id: item for item in self.csv_items if (hash_id := self.maybe(item, 'hash_id'))}
@@ -53,39 +54,39 @@ def items_list(self) -> List[Item]:
@property
def _query_items(self) -> Select[Tuple[Article]]:
- ids = [i.id for i in self.csv_items if self.maybe(i, 'id')]
- hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, 'hash_id')]
+ ids = [i.id for i in self.csv_items if self.maybe(i, "id")]
+ hash_ids = [i.hash_id for i in self.csv_items if self.maybe(i, "hash_id")]
return select(Article).where(or_(Article.id.in_(hash_ids), Article._id.in_(ids)))
def update_text(self, updates: NamedTuple, article: Article):
# If the url is the same as it was before, and there isn't a source url provided, assume that the
# previous text is still valid
- if article.url == self.maybe(updates, 'url') and not self.maybe(updates, 'source_url'):
+ if article.url == self.maybe(updates, "url") and not self.maybe(updates, "source_url"):
return
# If no url found, then don't bother fetching the text - assume it was successfully fetched previously
- url = self.maybe(updates, 'source_url') or self.maybe(updates, 'url')
+ url = self.maybe(updates, "source_url") or self.maybe(updates, "url")
if not url:
return
if article.url != url:
- article.add_meta('source_url', url)
+ article.add_meta("source_url", url)
metadata = item_metadata(url)
# Only change the text if it could be fetched - better to have outdated values than none
- if metadata.get('text'):
- article.text = metadata['text']
- article.status = metadata.get('error')
+ if metadata.get("text"):
+ article.text = metadata["text"]
+ article.status = metadata.get("error")
def process_entry(self, item: Item) -> Article:
updates, article = item
- for key in ['url', 'title', 'source', 'authors', 'comment', 'confidence']:
+ for key in ["url", "title", "source", "authors", "comment", "confidence"]:
value = self.maybe(updates, key)
if value and getattr(article, key, None) != value:
setattr(article, key, value)
- if date := getattr(updates, 'date_published', None):
+ if date := getattr(updates, "date_published", None):
article.date_published = self._get_published_date(date)
self.update_text(updates, article)
diff --git a/align_data/sources/arxiv_papers.py b/align_data/sources/arxiv_papers.py
index 89739e11..9067152a 100644
--- a/align_data/sources/arxiv_papers.py
+++ b/align_data/sources/arxiv_papers.py
@@ -6,6 +6,7 @@
from align_data.sources.articles.pdf import fetch_pdf, parse_vanity
from align_data.sources.articles.html import fetch_element
+from align_data.sources.utils import merge_dicts
logger = logging.getLogger(__name__)
@@ -22,7 +23,7 @@ def get_arxiv_metadata(paper_id: str) -> arxiv.Result | None:
return None
-def get_id(url: str) -> Optional[str]:
+def get_id(url: str) -> str | None:
if res := re.search(r"https?://arxiv.org/(?:abs|pdf)/(.*?)(?:v\d+)?(?:/|\.pdf)?$", url):
return res.group(1)
return None
@@ -30,25 +31,30 @@ def get_id(url: str) -> Optional[str]:
def canonical_url(url: str) -> str:
if paper_id := get_id(url):
- return f'https://arxiv.org/abs/{paper_id}'
+ return f"https://arxiv.org/abs/{paper_id}"
return url
def get_contents(paper_id: str) -> Dict[str, Any]:
arxiv_vanity = parse_vanity(f"https://www.arxiv-vanity.com/papers/{paper_id}")
- if 'error' not in arxiv_vanity:
+ if "error" not in arxiv_vanity:
return arxiv_vanity
ar5iv = parse_vanity(f"https://ar5iv.org/abs/{paper_id}")
- if 'error' not in ar5iv:
+ if "error" not in ar5iv:
return ar5iv
return fetch_pdf(f"https://arxiv.org/pdf/{paper_id}.pdf")
-def get_version(id: str) -> Optional[str]:
- if res := re.search(r'.*v(\d+)$', id):
+def get_version(id: str) -> str | None:
+ if res := re.search(r".*v(\d+)$", id):
return res.group(1)
+
+
+def is_withdrawn(url: str):
+ if elem := fetch_element(canonical_url(url), ".extra-services .full-text ul"):
+ return elem.text.strip().lower() == "withdrawn"
return None
@@ -62,7 +68,7 @@ def add_metadata(data: Dict[str, Any], paper_id: str) -> Dict[str, Any]:
metadata = get_arxiv_metadata(paper_id)
if not metadata:
return {}
- return dict({
+ return merge_dicts({
"authors": metadata.authors,
"title": metadata.title,
"date_published": metadata.published,
@@ -74,28 +80,28 @@ def add_metadata(data: Dict[str, Any], paper_id: str) -> Dict[str, Any]:
"primary_category": metadata.primary_category,
"categories": metadata.categories,
"version": get_version(metadata.get_short_id()),
- }, **data)
+ }, data)
def fetch_arxiv(url: str) -> Dict[str, Any]:
paper_id = get_id(url)
if not paper_id:
- return {'error': 'Could not extract arxiv id'}
-
- if is_withdrawn(url): paper = {'status': 'Withdrawn'}
- else: paper = get_contents(paper_id)
+ return {"error": "Could not extract arxiv id"}
- data = add_metadata({
- "url": canonical_url(url),
- "source_type": paper.get('data_source'),
- }, paper_id)
-
- authors = data.get('authors') or paper.get("authors")
- if not authors:
- data['authors'] = []
- elif not isinstance(authors, list):
- data['authors'] = [str(authors).strip()]
+ if is_withdrawn(url):
+ paper = {"status": "Withdrawn"}
else:
- data['authors'] = [str(author).strip() for author in authors]
-
- return dict(data, **paper)
+ paper = get_contents(paper_id)
+
+ data = add_metadata(
+ {
+ "url": canonical_url(url),
+ "source_type": paper.get("data_source"),
+ },
+ paper_id,
+ )
+ authors = data.get("authors") or paper.get("authors") or []
+ data["authors"] = [str(a).strip() for a in authors]
+ data["source"] = "arxiv"
+
+ return merge_dicts(data, paper)
diff --git a/align_data/sources/blogs/__init__.py b/align_data/sources/blogs/__init__.py
index 923c5257..64f27310 100644
--- a/align_data/sources/blogs/__init__.py
+++ b/align_data/sources/blogs/__init__.py
@@ -1,18 +1,22 @@
from align_data.sources.blogs.wp_blog import WordpressBlog
from align_data.sources.blogs.gwern_blog import GwernBlog
from align_data.sources.blogs.blogs import (
+ AXRPDataset,
ColdTakes,
GenerativeInk,
CaradoMoe,
EleutherAI,
OpenAIResearch,
DeepMindTechnicalBlog,
+ TransformerCircuits,
)
from align_data.sources.blogs.substack_blog import SubstackBlog
from align_data.sources.articles.parsers import MediumParser
+from align_data.common.alignment_dataset import MultiDataset
-BLOG_REGISTRY = [
+BLOG_DATASETS = [
+ AXRPDataset(name='axrp', url='https://axrp.net', authors=['AXRP']),
WordpressBlog(name="aiimpacts", url="https://aiimpacts.org"),
WordpressBlog(name="aisafety.camp", url="https://aisafety.camp"),
WordpressBlog(name="miri", url="https://intelligence.org"),
@@ -24,9 +28,7 @@
url="https://deepmindsafetyresearch.medium.com/",
authors=["DeepMind Safety Research"],
),
- GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- ),
+ GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]),
ColdTakes(
name="cold_takes",
url="https://www.cold-takes.com/",
@@ -56,4 +58,10 @@
name="deepmind_technical_blog",
url="https://www.deepmind.com/blog-categories/technical-blogs",
),
+ TransformerCircuits(name='transformer-circuits', url='https://transformer-circuits.pub/'),
+]
+
+
+BLOG_REGISTRY = [
+ MultiDataset(name='blogs', datasets=BLOG_DATASETS),
]
diff --git a/align_data/sources/blogs/blogs.py b/align_data/sources/blogs/blogs.py
index 849a7889..08414199 100644
--- a/align_data/sources/blogs/blogs.py
+++ b/align_data/sources/blogs/blogs.py
@@ -1,4 +1,5 @@
import logging
+from urllib.parse import urljoin
import requests
from bs4 import BeautifulSoup
@@ -51,12 +52,7 @@ def _get_published_date(self, contents):
return ""
def extract_authors(self, article):
- return (
- article.select_one("header .post-meta")
- .text.split("·")[1]
- .strip()
- .split(", ")
- )
+ return article.select_one("header .post-meta").text.split("·")[1].strip().split(", ")
class OpenAIResearch(HTMLDataset):
@@ -116,9 +112,7 @@ def items_list(self):
page += 1
# update the tqdm progress bar
- pbar.set_postfix_str(
- f"page {page}", refresh=True
- ) # Set postfix to "page X"
+ pbar.set_postfix_str(f"page {page}", refresh=True) # Set postfix to "page X"
pbar.update() # Here we increment the progress bar by 1
logger.info("Got %s pages", len(articles))
@@ -135,3 +129,64 @@ def extract_authors(self, article):
):
return [author.strip() for author in div.text.split(",")]
return []
+
+
+class TransformerCircuits(HTMLDataset):
+
+ item_selector = "div.toc a"
+ text_selector = 'h3'
+
+ def get_item_key(self, item):
+ article_url = item.get("href").split("?")[0]
+ return urljoin(self.url, article_url)
+
+ @property
+ def items_list(self):
+ return [i for i in super().items_list if self.get_item_key(i).startswith(self.url)]
+
+ def _metadata(self, contents, selector):
+ if meta := contents.select_one('div.d-byline'):
+ return meta.select(selector)
+
+ def _get_title(self, contents):
+ title = contents.find("title")
+ return title and title.text.strip()
+
+ def _get_published_date(self, contents):
+ if date := self._metadata(contents, 'div.published div'):
+ return super()._get_published_date(date[0].text)
+
+ def _get_text(self, contents):
+ article = contents.find("d-article") or contents.find("dt-article")
+ return self._extract_markdown(article)
+
+ def extract_authors(self, contents):
+ if authors := self._metadata(contents, 'span.author'):
+ for a in authors:
+ for sup in a.select('sup'):
+ sup.extract()
+ return [a.text.strip().strip(',*') for a in authors]
+ return []
+
+
+class AXRPDataset(RSSDataset):
+
+ @property
+ def feed_url(self):
+ return f"{self.url}/feed.xml"
+
+ def _extract_item_url(self, item) -> str | None:
+ if path := item.get('link'):
+ return self.url + path
+ return None
+
+ def extract_authors(self, item):
+ if "authors" in item:
+ authors = [name for a in item["authors"] if (name := (a.get("name") or '').strip())]
+ if authors:
+ return authors
+
+ bits = item.get('title', '').split(' with ')
+ if len(bits) > 1 and bits[-1].strip():
+ return self.authors + [bits[-1].strip()]
+ return self.authors
diff --git a/align_data/sources/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py
index bdb45292..b0c9e9f1 100644
--- a/align_data/sources/blogs/wp_blog.py
+++ b/align_data/sources/blogs/wp_blog.py
@@ -41,9 +41,7 @@ def items_list(self):
self.items[item["link"]] = item
# update the tqdm progress bar
- pbar.set_postfix_str(
- f"page {page_number}", refresh=True
- ) # Set postfix to "page X"
+ pbar.set_postfix_str(f"page {page_number}", refresh=True) # Set postfix to "page X"
pbar.update() # Here we increment the progress bar by 1
logger.info(f"Got {len(self.items)} pages")
diff --git a/align_data/sources/distill/distill.py b/align_data/sources/distill/distill.py
index 2b1690b6..8709b154 100644
--- a/align_data/sources/distill/distill.py
+++ b/align_data/sources/distill/distill.py
@@ -1,15 +1,14 @@
-from dataclasses import dataclass
-from markdownify import markdownify
from align_data.common.html_dataset import RSSDataset
-@dataclass
class Distill(RSSDataset):
source_type = "html"
done_key = "url"
def extract_authors(self, item):
- return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or ["Distill"]
+ return [a.text for a in item["soup"].select(".authors-affiliations p.author a")] or [
+ "Distill"
+ ]
def _get_text(self, item):
article = item["soup"].find("d-article") or item["soup"].find("dt-article")
diff --git a/align_data/sources/ebooks/__init__.py b/align_data/sources/ebooks/__init__.py
index 0055f5e0..7fdcd729 100644
--- a/align_data/sources/ebooks/__init__.py
+++ b/align_data/sources/ebooks/__init__.py
@@ -1,7 +1,5 @@
from .agentmodels import AgentModels
EBOOK_REGISTRY = [
- AgentModels(
- name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git"
- ),
+ AgentModels(name="agentmodels", repo="https://github.com/agentmodels/agentmodels.org.git"),
]
diff --git a/align_data/sources/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py
index 748ceebf..3756524a 100644
--- a/align_data/sources/ebooks/agentmodels.py
+++ b/align_data/sources/ebooks/agentmodels.py
@@ -29,9 +29,7 @@ def setup(self):
self.files_path = self.base_dir / "chapters"
def _get_published_date(self, filename):
- last_commit = next(
- self.repository.iter_commits(paths=f"chapters/{filename.name}")
- )
+ last_commit = next(self.repository.iter_commits(paths=f"chapters/{filename.name}"))
return last_commit.committed_datetime.astimezone(timezone.utc)
def process_entry(self, filename):
diff --git a/align_data/sources/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py
index bd906348..12dc9b80 100644
--- a/align_data/sources/greaterwrong/greaterwrong.py
+++ b/align_data/sources/greaterwrong/greaterwrong.py
@@ -76,10 +76,7 @@ def setup(self):
self.ai_tags = get_allowed_tags(self.base_url, self.name)
def tags_ok(self, post):
- return (
- not self.ai_tags
- or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags
- )
+ return not self.ai_tags or {t["name"] for t in post["tags"] if t.get("name")} & self.ai_tags
def get_item_key(self, item):
return item["pageUrl"]
@@ -176,7 +173,7 @@ def process_entry(self, item):
authors = item["coauthors"]
if item["user"]:
authors = [item["user"]] + authors
- authors = [a["displayName"] for a in authors] or ['anonymous']
+ authors = [a["displayName"] for a in authors] or ["anonymous"]
return self.make_data_entry(
{
"title": item["title"],
diff --git a/align_data/sources/stampy/stampy.py b/align_data/sources/stampy/stampy.py
index c6f784c5..7aad532c 100644
--- a/align_data/sources/stampy/stampy.py
+++ b/align_data/sources/stampy/stampy.py
@@ -43,13 +43,9 @@ def _get_published_date(self, entry):
def process_entry(self, entry):
def clean_text(text):
text = html.unescape(text)
- return re.sub(
- r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text
- )
+ return re.sub(r"\(/\?state=(\w+)\)", r"(http://aisafety.info?state=\1)", text)
- question = clean_text(
- entry["Question"]
- ) # raise an error if the entry has no question
+ question = clean_text(entry["Question"]) # raise an error if the entry has no question
answer = clean_text(entry["Rich Text"])
url = "https://aisafety.info?state=" + entry["UI ID"]
diff --git a/align_data/sources/utils.py b/align_data/sources/utils.py
new file mode 100644
index 00000000..3ed9abfe
--- /dev/null
+++ b/align_data/sources/utils.py
@@ -0,0 +1,5 @@
+def merge_dicts(*dicts):
+ final = {}
+ for d in dicts:
+ final = dict(final, **{k: v for k, v in d.items() if v is not None})
+ return final
diff --git a/align_data/sources/youtube/__init__.py b/align_data/sources/youtube/__init__.py
index 06c8defe..ca0d9b33 100644
--- a/align_data/sources/youtube/__init__.py
+++ b/align_data/sources/youtube/__init__.py
@@ -1,9 +1,10 @@
+from align_data.common.alignment_dataset import MultiDataset
from align_data.sources.youtube.youtube import (
YouTubeChannelDataset,
YouTubePlaylistDataset,
)
-YOUTUBE_REGISTRY = [
+YOUTUBE_DATASETS = [
YouTubeChannelDataset(
name="rob_miles_ai_safety",
channel_id="UCLB7AzTwc6VFZrBsO2ucBMg",
@@ -40,3 +41,8 @@
],
),
]
+
+
+YOUTUBE_REGISTRY = [
+ MultiDataset(name='youtube', datasets=YOUTUBE_DATASETS),
+]
diff --git a/align_data/sources/youtube/youtube.py b/align_data/sources/youtube/youtube.py
index 759c9840..9bdad819 100644
--- a/align_data/sources/youtube/youtube.py
+++ b/align_data/sources/youtube/youtube.py
@@ -1,4 +1,3 @@
-import collections
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Iterable
diff --git a/main.py b/main.py
index 3c934374..2edeec9d 100644
--- a/main.py
+++ b/main.py
@@ -7,8 +7,13 @@
from align_data import ALL_DATASETS, get_dataset
from align_data.analysis.count_tokens import count_token
-from align_data.sources.articles.articles import update_new_items, check_new_articles, update_articles
-from align_data.pinecone.update_pinecone import PineconeUpdater
+from align_data.sources.articles.articles import (
+ update_new_items,
+ check_new_articles,
+ update_articles,
+)
+from align_data.embeddings.pinecone.update_pinecone import PineconeUpdater
+from align_data.embeddings.finetuning.training import finetune_embeddings
from align_data.settings import (
METADATA_OUTPUT_SPREADSHEET,
METADATA_SOURCE_SHEET,
@@ -75,12 +80,10 @@ def count_tokens(self, merged_dataset_path: str) -> None:
This function counts the number of tokens, words, and characters in the dataset
:return: None
"""
- assert os.path.exists(
- merged_dataset_path
- ), "The path to the merged dataset does not exist"
+ assert os.path.exists(merged_dataset_path), "The path to the merged dataset does not exist"
count_token(merged_dataset_path)
- def update(self, csv_path, delimiter=','):
+ def update(self, csv_path, delimiter=","):
"""Update all articles in the provided csv files, overwriting the provided values and fetching new text if a different url provided.
:param str csv_path: The path to the csv file to be processed
@@ -114,7 +117,7 @@ def fetch_new_articles(
"""
return check_new_articles(source_spreadsheet, source_sheet)
- def pinecone_update(self, *names) -> None:
+ def pinecone_update(self, *names, force_update=False) -> None:
"""
This function updates the Pinecone vector DB.
@@ -124,14 +127,20 @@ def pinecone_update(self, *names) -> None:
names = ALL_DATASETS
missing = {name for name in names if name not in ALL_DATASETS}
assert not missing, f"{missing} are not valid dataset names"
- PineconeUpdater().update(names)
+ PineconeUpdater().update(names, force_update)
- def pinecone_update_all(self, *skip) -> None:
+ def pinecone_update_all(self, *skip, force_update=False) -> None:
"""
This function updates the Pinecone vector DB.
"""
names = [name for name in ALL_DATASETS if name not in skip]
- PineconeUpdater().update(names)
+ PineconeUpdater().update(names, force_update)
+
+ def train_finetuning_layer(self) -> None:
+ """
+ This function trains a finetuning layer on top of the OpenAI embeddings.
+ """
+ finetune_embeddings()
if __name__ == "__main__":
diff --git a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
index 7a8485fe..e5b9a303 100644
--- a/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
+++ b/migrations/versions/59ac3cb671e3_added_pinecone_update_required_to_.py
@@ -18,9 +18,7 @@
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column(
- "articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False)
- )
+ op.add_column("articles", sa.Column("pinecone_update_required", sa.Boolean(), nullable=False))
# ### end Alembic commands ###
diff --git a/migrations/versions/f5a2bcfa6b2c_add_status_column.py b/migrations/versions/f5a2bcfa6b2c_add_status_column.py
index 76c89ee0..d93a8c86 100644
--- a/migrations/versions/f5a2bcfa6b2c_add_status_column.py
+++ b/migrations/versions/f5a2bcfa6b2c_add_status_column.py
@@ -10,17 +10,17 @@
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = 'f5a2bcfa6b2c'
-down_revision = '59ac3cb671e3'
+revision = "f5a2bcfa6b2c"
+down_revision = "59ac3cb671e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
- op.add_column('articles', sa.Column('status', sa.String(length=256), nullable=True))
- op.add_column('articles', sa.Column('comments', mysql.LONGTEXT(), nullable=True))
+ op.add_column("articles", sa.Column("status", sa.String(length=256), nullable=True))
+ op.add_column("articles", sa.Column("comments", mysql.LONGTEXT(), nullable=True))
def downgrade() -> None:
- op.drop_column('articles', 'comments')
- op.drop_column('articles', 'status')
+ op.drop_column("articles", "comments")
+ op.drop_column("articles", "status")
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..aa4949aa
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,2 @@
+[tool.black]
+line-length = 100
diff --git a/requirements.txt b/requirements.txt
index b7985f3f..008d629d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -28,6 +28,7 @@ google-auth-httplib2
google-api-python-client
gspread
youtube-transcript-api
+airtable
alembic
mysqlclient
@@ -36,3 +37,5 @@ openai
langchain
nltk
pinecone-client
+
+torch
diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py
index 29b38c46..f7246525 100644
--- a/tests/align_data/articles/test_datasets.py
+++ b/tests/align_data/articles/test_datasets.py
@@ -48,7 +48,7 @@ def mock_arxiv():
journal_ref="sdf",
primary_category="cat",
)
- metadata.get_short_id.return_value = '2001.11038'
+ metadata.get_short_id.return_value = "2001.11038"
arxiv = Mock()
arxiv.Search.return_value.results.return_value = iter([metadata])
@@ -124,30 +124,24 @@ def test_pdf_articles_process_item(articles):
"text": "pdf contents [bla](asd.com)",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
def test_html_articles_get_text():
def parser(url):
assert url == "http://example.org/bla.bla"
- return {'text': "html contents"}
+ return {"text": "html contents"}
- with patch(
- "align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser}
- ):
+ with patch("align_data.sources.articles.datasets.HTML_PARSERS", {"example.org": parser}):
assert (
- HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla"))
- == "html contents"
+ HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) == "html contents"
)
def test_html_articles_get_text_no_parser():
with patch("align_data.sources.articles.datasets.HTML_PARSERS", {}):
- assert (
- HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla"))
- is None
- )
+ assert HTMLArticles._get_text(Mock(source_url="http://example.org/bla.bla")) is None
def test_html_articles_process_entry(articles):
@@ -156,7 +150,9 @@ def test_html_articles_process_entry(articles):
item = list(dataset.items_list)[0]
parsers = {
- "example.com": lambda _: {'text': ' html contents with proper elements ble ble '}
+ "example.com": lambda _: {
+ "text": ' html contents with proper elements ble ble '
+ }
}
with patch("align_data.sources.articles.datasets.HTML_PARSERS", parsers):
assert dataset.process_entry(item).to_dict() == {
@@ -170,7 +166,7 @@ def test_html_articles_process_entry(articles):
"text": "html contents with [proper elements](bla.com) ble ble",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
@@ -201,9 +197,7 @@ def test_ebook_articles_process_entry(articles):
contents = ' html contents with proper elements ble ble '
with patch("align_data.sources.articles.datasets.download"):
- with patch(
- "align_data.sources.articles.datasets.convert_file", return_value=contents
- ):
+ with patch("align_data.sources.articles.datasets.convert_file", return_value=contents):
assert dataset.process_entry(item).to_dict() == {
"authors": ["John Snow", "mr Blobby"],
"date_published": "2023-01-01T12:32:11Z",
@@ -215,7 +209,7 @@ def test_ebook_articles_process_entry(articles):
"text": "html contents with [proper elements](bla.com) ble ble",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
@@ -248,7 +242,7 @@ def test_xml_articles_process_entry(articles):
"text": "bla bla",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
@@ -281,19 +275,15 @@ def test_markdown_articles_process_entry(articles):
"text": "bla bla",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
def test_doc_articles_get_text():
dataset = DocArticles(name="bla", spreadsheet_id="123", sheet_id="456")
with patch("align_data.sources.articles.datasets.fetch_file"):
- with patch(
- "align_data.sources.articles.datasets.convert_file", return_value="bla bla"
- ):
- assert (
- dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla"
- )
+ with patch("align_data.sources.articles.datasets.convert_file", return_value="bla bla"):
+ assert dataset._get_text(Mock(source_url="bla.com/bla/123/bla")) == "bla bla"
def test_doc_articles_process_entry(articles):
@@ -302,9 +292,7 @@ def test_doc_articles_process_entry(articles):
item = list(dataset.items_list)[0]
with patch("align_data.sources.articles.datasets.fetch_file"):
- with patch(
- "align_data.sources.articles.datasets.convert_file", return_value="bla bla"
- ):
+ with patch("align_data.sources.articles.datasets.convert_file", return_value="bla bla"):
assert dataset.process_entry(item).to_dict() == {
"authors": ["John Snow", "mr Blobby"],
"date_published": "2023-01-01T12:32:11Z",
@@ -316,11 +304,11 @@ def test_doc_articles_process_entry(articles):
"text": "bla bla",
"title": "article no 0",
"url": "http://example.com/item/0",
- 'source_url': 'http://example.com/source_url/0',
+ "source_url": "http://example.com/source_url/0",
}
-@patch('requests.get', return_value=Mock(content=''))
+@patch("requests.get", return_value=Mock(content=""))
def test_arxiv_process_entry(_, mock_arxiv):
dataset = ArxivPapers(name="asd", spreadsheet_id="ad", sheet_id="da")
item = Mock(
@@ -335,9 +323,7 @@ def test_arxiv_process_entry(_, mock_arxiv):
"authors": ["mr blobby"],
"source_type": "html",
}
- with patch(
- "align_data.sources.arxiv_papers.parse_vanity", return_value=contents
- ):
+ with patch("align_data.sources.arxiv_papers.parse_vanity", return_value=contents):
assert dataset.process_entry(item).to_dict() == {
"comment": "no comment",
"authors": ["mr blobby"],
@@ -377,12 +363,9 @@ def test_arxiv_process_entry_retracted(mock_arxiv):
"""
- with patch('requests.get', return_value=Mock(content=response)):
+ with patch("requests.get", return_value=Mock(content=response)):
article = dataset.process_entry(item)
- print(article.to_dict())
- print(article.status)
- print(article.__dir__())
- assert article.status == 'Withdrawn'
+ assert article.status == "Withdrawn"
assert article.to_dict() == {
"comment": "no comment",
"authors": [],
@@ -410,7 +393,7 @@ def test_special_docs_process_entry():
authors="mr. blobby",
date_published="2023-10-02T01:23:45",
source_type=None,
- source_url="https://ble.ble.com"
+ source_url="https://ble.ble.com",
)
contents = {
"text": "this is the text",
@@ -421,20 +404,20 @@ def test_special_docs_process_entry():
with patch("align_data.sources.articles.datasets.item_metadata", return_value=contents):
assert dataset.process_entry(item).to_dict() == {
- 'authors': ['mr. blobby'],
- 'date_published': '2023-10-02T01:23:45Z',
- 'id': None,
- 'source': 'html',
- 'source_url': "https://ble.ble.com",
- 'source_type': 'html',
- 'summaries': [],
- 'text': 'this is the text',
- 'title': 'this is the title',
- 'url': 'https://bla.bla.bla',
+ "authors": ["mr. blobby"],
+ "date_published": "2023-10-02T01:23:45Z",
+ "id": None,
+ "source": "html",
+ "source_url": "https://ble.ble.com",
+ "source_type": "html",
+ "summaries": [],
+ "text": "this is the text",
+ "title": "this is the title",
+ "url": "https://bla.bla.bla",
}
-@patch('requests.get', return_value=Mock(content=''))
+@patch("requests.get", return_value=Mock(content=""))
def test_special_docs_process_entry_arxiv(_, mock_arxiv):
dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da")
item = Mock(
@@ -450,9 +433,7 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv):
"source_type": "pdf",
}
- with patch(
- "align_data.sources.arxiv_papers.parse_vanity", return_value=contents
- ):
+ with patch("align_data.sources.arxiv_papers.parse_vanity", return_value=contents):
assert dataset.process_entry(item).to_dict() == {
"comment": "no comment",
"authors": ["mr blobby"],
@@ -472,16 +453,22 @@ def test_special_docs_process_entry_arxiv(_, mock_arxiv):
}
-@pytest.mark.parametrize('url, expected', (
- ("http://bla.bla", "http://bla.bla"),
- ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"),
- ("https://arxiv.org/abs/math/2001.11038", "https://arxiv.org/abs/math/2001.11038"),
-))
+@pytest.mark.parametrize(
+ "url, expected",
+ (
+ ("http://bla.bla", "http://bla.bla"),
+ ("http://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/abs/2001.11038", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/abs/2001.11038/", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/pdf/2001.11038", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/pdf/2001.11038.pdf", "https://arxiv.org/abs/2001.11038"),
+ ("https://arxiv.org/pdf/2001.11038v3.pdf", "https://arxiv.org/abs/2001.11038"),
+ (
+ "https://arxiv.org/abs/math/2001.11038",
+ "https://arxiv.org/abs/math/2001.11038",
+ ),
+ ),
+)
def test_special_docs_not_processed_true(url, expected):
dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da")
dataset._outputted_items = [url, expected]
@@ -489,13 +476,15 @@ def test_special_docs_not_processed_true(url, expected):
assert not dataset.not_processed(Mock(url=None, source_url=url))
-@pytest.mark.parametrize('url', (
- "http://bla.bla"
- "http://arxiv.org/abs/2001.11038",
- "https://arxiv.org/abs/2001.11038",
- "https://arxiv.org/abs/2001.11038/",
- "https://arxiv.org/pdf/2001.11038",
-))
+@pytest.mark.parametrize(
+ "url",
+ (
+ "http://bla.bla" "http://arxiv.org/abs/2001.11038",
+ "https://arxiv.org/abs/2001.11038",
+ "https://arxiv.org/abs/2001.11038/",
+ "https://arxiv.org/pdf/2001.11038",
+ ),
+)
def test_special_docs_not_processed_false(url):
dataset = SpecialDocs(name="asd", spreadsheet_id="ad", sheet_id="da")
dataset._outputted_items = []
diff --git a/tests/align_data/articles/test_google_cloud.py b/tests/align_data/articles/test_google_cloud.py
index 39cacce3..bc814fe1 100644
--- a/tests/align_data/articles/test_google_cloud.py
+++ b/tests/align_data/articles/test_google_cloud.py
@@ -1,7 +1,12 @@
from unittest.mock import Mock, patch
import pytest
-from align_data.sources.articles.google_cloud import extract_gdrive_contents, get_content_type, google_doc, parse_grobid
+from align_data.sources.articles.google_cloud import (
+ extract_gdrive_contents,
+ get_content_type,
+ google_doc,
+ parse_grobid,
+)
SAMPLE_XML = """
@@ -45,8 +50,12 @@
def test_google_doc():
def fetcher(url, *args, **kwargs):
- assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html'
- return Mock(content="""
+ assert (
+ url
+ == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html"
+ )
+ return Mock(
+ content="""
@@ -54,35 +63,45 @@ def fetcher(url, *args, **kwargs):
- """)
+ """
+ )
- with patch('requests.get', fetcher):
- url = 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit'
+ with patch("requests.get", fetcher):
+ url = "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit"
assert google_doc(url) == {
- 'text': 'ble ble [a link](bla.com)',
- 'source_url': url
+ "text": "ble ble [a link](bla.com)",
+ "source_url": url,
}
def test_google_doc_no_body():
def fetcher(url, *args, **kwargs):
- assert url == 'https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html'
+ assert (
+ url
+ == "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/export?format=html"
+ )
return Mock(content=" ")
- with patch('requests.get', fetcher):
- assert google_doc('https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit') == {}
+ with patch("requests.get", fetcher):
+ assert (
+ google_doc(
+ "https://docs.google.com/document/d/1fenKXrbvGeZ83hxYf_6mghsZMChxWXjGsZSqY3LZzms/edit"
+ )
+ == {}
+ )
def test_google_doc_bad_url():
- assert google_doc('https://docs.google.com/bla/bla') == {}
+ assert google_doc("https://docs.google.com/bla/bla") == {}
+
def test_parse_grobid():
assert parse_grobid(SAMPLE_XML) == {
- 'abstract': 'this is the abstract',
- 'authors': ['Cullen Oâ\x80\x99Keefe'],
- 'text': 'This is the contents',
- 'title': 'The title!!',
- 'source_type': 'xml',
+ "abstract": "this is the abstract",
+ "authors": ["Cullen Oâ\x80\x99Keefe"],
+ "text": "This is the contents",
+ "title": "The title!!",
+ "source_type": "xml",
}
@@ -104,74 +123,94 @@ def test_parse_grobid_no_body():
"""
- assert parse_grobid(xml) == {'error': 'No contents in XML file', 'source_type': 'xml'}
-
+ assert parse_grobid(xml) == {
+ "error": "No contents in XML file",
+ "source_type": "xml",
+ }
-@pytest.mark.parametrize('header, expected', (
- (None, set()),
- ('', set()),
- ('text/html', {'text/html'}),
- ('text/html; bla=asdas; fewwe=fe', {'text/html', 'bla=asdas', 'fewwe=fe'}),
-))
+@pytest.mark.parametrize(
+ "header, expected",
+ (
+ (None, set()),
+ ("", set()),
+ ("text/html", {"text/html"}),
+ ("text/html; bla=asdas; fewwe=fe", {"text/html", "bla=asdas", "fewwe=fe"}),
+ ),
+)
def test_get_content_type(header, expected):
- assert get_content_type(Mock(headers={'Content-Type': header})) == expected
-
-
-@pytest.mark.parametrize('headers', (
- {},
- {'Content-Type': None},
- {'Content-Type': ''},
- {'Content-Type': ' '},
- {'Content-Type': ' ; ;; '},
-))
+ assert get_content_type(Mock(headers={"Content-Type": header})) == expected
+
+
+@pytest.mark.parametrize(
+ "headers",
+ (
+ {},
+ {"Content-Type": None},
+ {"Content-Type": ""},
+ {"Content-Type": " "},
+ {"Content-Type": " ; ;; "},
+ ),
+)
def test_extract_gdrive_contents_no_contents(headers):
- url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing'
- with patch('requests.head', return_value=Mock(headers=headers, status_code=200)):
+ url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing"
+ with patch("requests.head", return_value=Mock(headers=headers, status_code=200)):
assert extract_gdrive_contents(url) == {
- 'downloaded_from': 'google drive',
- 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing',
- 'error': 'no content type'
+ "downloaded_from": "google drive",
+ "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing",
+ "error": "no content type",
}
-@pytest.mark.parametrize('header', (
- 'application/octet-stream',
- 'application/pdf',
- 'application/pdf; filename=bla.pdf'
-))
+@pytest.mark.parametrize(
+ "header",
+ (
+ "application/octet-stream",
+ "application/pdf",
+ "application/pdf; filename=bla.pdf",
+ ),
+)
def test_extract_gdrive_contents_pdf(header):
- res = Mock(headers={'Content-Type': header}, status_code=200)
- url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing'
- with patch('requests.head', return_value=res):
- with patch('align_data.sources.articles.google_cloud.fetch_pdf', return_value={'text': 'bla'}):
+ res = Mock(headers={"Content-Type": header}, status_code=200)
+ url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing"
+ with patch("requests.head", return_value=res):
+ with patch(
+ "align_data.sources.articles.google_cloud.fetch_pdf",
+ return_value={"text": "bla"},
+ ):
assert extract_gdrive_contents(url) == {
- 'downloaded_from': 'google drive',
- 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing',
- 'text': 'bla',
+ "downloaded_from": "google drive",
+ "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing",
+ "text": "bla",
}
-@pytest.mark.parametrize('header', (
- 'application/epub',
- 'application/epub+zip',
- 'application/epub; filename=bla.epub',
-))
+@pytest.mark.parametrize(
+ "header",
+ (
+ "application/epub",
+ "application/epub+zip",
+ "application/epub; filename=bla.epub",
+ ),
+)
def test_extract_gdrive_contents_ebook(header):
- res = Mock(headers={'Content-Type': header}, status_code=200)
- url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing'
- with patch('requests.head', return_value=res):
+ res = Mock(headers={"Content-Type": header}, status_code=200)
+ url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing"
+ with patch("requests.head", return_value=res):
assert extract_gdrive_contents(url) == {
- 'downloaded_from': 'google drive',
- 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing',
- 'source_type': 'ebook',
+ "downloaded_from": "google drive",
+ "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing",
+ "source_type": "ebook",
}
def test_extract_gdrive_contents_html():
- res = Mock(headers={'Content-Type': 'text/html'}, status_code=200)
- url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing'
- with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)):
+ res = Mock(headers={"Content-Type": "text/html"}, status_code=200)
+ url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing"
+ with patch(
+ "requests.head",
+ return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200),
+ ):
html = """
@@ -179,45 +218,48 @@ def test_extract_gdrive_contents_html():
"""
res = Mock(
- headers={'Content-Type': 'text/html'},
+ headers={"Content-Type": "text/html"},
status_code=200,
content=html,
text=html,
)
- with patch('requests.get', return_value=res):
+ with patch("requests.get", return_value=res):
assert extract_gdrive_contents(url) == {
- 'downloaded_from': 'google drive',
- 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing',
- 'text': 'bla bla',
- 'source_type': 'html',
+ "downloaded_from": "google drive",
+ "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing",
+ "text": "bla bla",
+ "source_type": "html",
}
def test_extract_gdrive_contents_xml():
- res = Mock(headers={'Content-Type': 'text/html'}, status_code=200)
- url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing'
- with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)):
+ res = Mock(headers={"Content-Type": "text/html"}, status_code=200)
+ url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing"
+ with patch(
+ "requests.head",
+ return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200),
+ ):
res = Mock(
- headers={'Content-Type': 'text/xml'},
+ headers={"Content-Type": "text/xml"},
status_code=200,
content=SAMPLE_XML,
text=SAMPLE_XML,
)
- with patch('requests.get', return_value=res):
+ with patch("requests.get", return_value=res):
assert extract_gdrive_contents(url) == {
- 'abstract': 'this is the abstract',
- 'authors': ['Cullen Oâ\x80\x99Keefe'],
- 'downloaded_from': 'google drive',
- 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing',
- 'text': 'This is the contents',
- 'title': 'The title!!',
- 'source_type': 'xml',
+ "abstract": "this is the abstract",
+ "authors": ["Cullen Oâ\x80\x99Keefe"],
+ "downloaded_from": "google drive",
+ "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing",
+ "text": "This is the contents",
+ "title": "The title!!",
+ "source_type": "xml",
}
def test_extract_gdrive_contents_xml_with_confirm():
- res = Mock(headers={'Content-Type': 'text/html'}, status_code=200)
- url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing'
+ res = Mock(headers={"Content-Type": "text/html"}, status_code=200)
+ url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing"
def fetcher(link, *args, **kwargs):
# The first request should get the google drive warning page
@@ -228,27 +270,35 @@ def fetcher(link, *args, **kwargs):
+
+ This is also the title
+
+
+
+
+
+
+
Published
+
March 16, 2023
+
+
+
+
+
+ This is where the text goes. With a link to test
+
+
"""
- return Mock(headers={'Content-Type': 'text/html'}, status_code=200, text=html, content=html)
+ return Mock(
+ headers={"Content-Type": "text/html"},
+ status_code=200,
+ text=html,
+ content=html,
+ )
# The second one returns the actual contents
- return Mock(headers={'Content-Type': 'text/xml'}, status_code=200, content=SAMPLE_XML)
+ return Mock(headers={"Content-Type": "text/xml"}, status_code=200, content=SAMPLE_XML)
- with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)):
- with patch('requests.get', fetcher):
+ with patch(
+ "requests.head",
+ return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200),
+ ):
+ with patch("requests.get", fetcher):
assert extract_gdrive_contents(url) == {
- 'abstract': 'this is the abstract',
- 'authors': ['Cullen Oâ\x80\x99Keefe'],
- 'downloaded_from': 'google drive',
- 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing',
- 'text': 'This is the contents',
- 'title': 'The title!!',
- 'source_type': 'xml',
+ "abstract": "this is the abstract",
+ "authors": ["Cullen Oâ\x80\x99Keefe"],
+ "downloaded_from": "google drive",
+ "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing",
+ "text": "This is the contents",
+ "title": "The title!!",
+ "source_type": "xml",
}
def test_extract_gdrive_contents_warning_with_unknown():
- res = Mock(headers={'Content-Type': 'text/html'}, status_code=200)
- url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing'
+ res = Mock(headers={"Content-Type": "text/html"}, status_code=200)
+ url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing"
def fetcher(link, *args, **kwargs):
# The first request should get the google drive warning page
@@ -259,26 +309,34 @@ def fetcher(link, *args, **kwargs):
"""
- return Mock(headers={'Content-Type': 'text/html'}, status_code=200, text=html, content=html)
+ return Mock(
+ headers={"Content-Type": "text/html"},
+ status_code=200,
+ text=html,
+ content=html,
+ )
# The second one returns the actual contents, with an unhandled content type
- return Mock(headers={'Content-Type': 'text/bla bla'}, status_code=200)
+ return Mock(headers={"Content-Type": "text/bla bla"}, status_code=200)
- with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'}, status_code=200)):
- with patch('requests.get', fetcher):
+ with patch(
+ "requests.head",
+ return_value=Mock(headers={"Content-Type": "text/html"}, status_code=200),
+ ):
+ with patch("requests.get", fetcher):
assert extract_gdrive_contents(url) == {
- 'downloaded_from': 'google drive',
- 'error': "unknown content type: {'text/bla bla'}",
- 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing',
+ "downloaded_from": "google drive",
+ "error": "unknown content type: {'text/bla bla'}",
+ "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing",
}
def test_extract_gdrive_contents_unknown_content_type():
- res = Mock(headers={'Content-Type': 'bla bla'}, status_code=200)
- url = 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing'
- with patch('requests.head', return_value=res):
+ res = Mock(headers={"Content-Type": "bla bla"}, status_code=200)
+ url = "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing"
+ with patch("requests.head", return_value=res):
assert extract_gdrive_contents(url) == {
- 'downloaded_from': 'google drive',
- 'source_url': 'https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing',
- 'error': "unknown content type: {'bla bla'}",
+ "downloaded_from": "google drive",
+ "source_url": "https://drive.google.com/file/d/1OrKZlksba2a8gKa5bAQfP2qF717O_57I/view?usp=sharing",
+ "error": "unknown content type: {'bla bla'}",
}
diff --git a/tests/align_data/articles/test_parsers.py b/tests/align_data/articles/test_parsers.py
index 8bac313f..5a174e3f 100644
--- a/tests/align_data/articles/test_parsers.py
+++ b/tests/align_data/articles/test_parsers.py
@@ -42,6 +42,7 @@
"""
+
def test_medium_blog():
html = """
@@ -60,14 +61,14 @@ def test_medium_blog():
"""
with patch("requests.get", return_value=Mock(content=html)):
- assert MediumParser('html', 'ble')("bla.com") == {
- 'authors': [],
- 'date_published': parse('Oct 7, 2023').replace(tzinfo=pytz.UTC),
- 'source': 'html',
- 'source_type': 'blog',
- 'text': 'bla bla bla [a link](http://ble.com) bla bla',
- 'title': 'This is the title',
- 'url': 'bla.com',
+ assert MediumParser("html", "ble")("bla.com") == {
+ "authors": [],
+ "date_published": parse("Oct 7, 2023").replace(tzinfo=pytz.UTC),
+ "source": "html",
+ "source_type": "blog",
+ "text": "bla bla bla [a link](http://ble.com) bla bla",
+ "title": "This is the title",
+ "url": "bla.com",
}
@@ -83,14 +84,14 @@ def test_medium_blog_no_title():
"""
with patch("requests.get", return_value=Mock(content=html)):
- assert MediumParser(name='html', url='')("bla.com") == {
- 'authors': [],
- 'date_published': None,
- 'source': 'html',
- 'source_type': 'blog',
- 'text': 'bla bla bla [a link](http://ble.com) bla bla',
- 'title': None,
- 'url': 'bla.com',
+ assert MediumParser(name="html", url="")("bla.com") == {
+ "authors": [],
+ "date_published": None,
+ "source": "html",
+ "source_type": "blog",
+ "text": "bla bla bla [a link](http://ble.com) bla bla",
+ "title": None,
+ "url": "bla.com",
}
@@ -105,13 +106,13 @@ def test_medium_blog_no_contents():
"""
- with patch('requests.get', return_value=Mock(content=html)):
- assert MediumParser(name='html', url='')('bla.com') == {
- 'authors': [],
- 'date_published': None,
- 'source': 'html',
- 'source_type': 'blog',
- 'text': None,
- 'title': None,
- 'url': 'bla.com',
+ with patch("requests.get", return_value=Mock(content=html)):
+ assert MediumParser(name="html", url="")("bla.com") == {
+ "authors": [],
+ "date_published": None,
+ "source": "html",
+ "source_type": "blog",
+ "text": None,
+ "title": None,
+ "url": "bla.com",
}
diff --git a/tests/align_data/articles/test_updater.py b/tests/align_data/articles/test_updater.py
index 7d11fbb7..f9e2aea2 100644
--- a/tests/align_data/articles/test_updater.py
+++ b/tests/align_data/articles/test_updater.py
@@ -9,39 +9,43 @@
SAMPLE_UPDATES = [
{},
- {'title': 'no id - should be ignored'},
-
- {'id': '122', 'hash_id': 'deadbeef000'},
+ {"title": "no id - should be ignored"},
+ {"id": "122", "hash_id": "deadbeef000"},
+ {
+ "id": "123",
+ "hash_id": "deadbeef001",
+ "title": "bla bla",
+ "url": "http://bla.com",
+ "source_url": "http://bla.bla.com",
+ "authors": "mr. blobby, johnny",
+ },
+ {
+ "id": "124",
+ "title": "no hash id",
+ "url": "http://bla.com",
+ "source_url": "http://bla.bla.com",
+ "authors": "mr. blobby",
+ },
{
- 'id': '123', 'hash_id': 'deadbeef001',
- 'title': 'bla bla',
- 'url': 'http://bla.com',
- 'source_url': 'http://bla.bla.com',
- 'authors': 'mr. blobby, johnny',
- }, {
- 'id': '124',
- 'title': 'no hash id',
- 'url': 'http://bla.com',
- 'source_url': 'http://bla.bla.com',
- 'authors': 'mr. blobby',
- }, {
- 'hash_id': 'deadbeef002',
- 'title': 'no id',
- 'url': 'http://bla.com',
- 'source_url': 'http://bla.bla.com',
- 'authors': 'mr. blobby',
- }, {
- 'id': '125',
- 'title': 'no hash id, url or title',
- 'authors': 'mr. blobby',
- }
+ "hash_id": "deadbeef002",
+ "title": "no id",
+ "url": "http://bla.com",
+ "source_url": "http://bla.bla.com",
+ "authors": "mr. blobby",
+ },
+ {
+ "id": "125",
+ "title": "no hash id, url or title",
+ "authors": "mr. blobby",
+ },
]
+
@pytest.fixture
def csv_file(tmp_path):
- filename = tmp_path / 'data.csv'
- with open(filename, 'w', newline='') as csvfile:
- fieldnames = ['id', 'hash_id', 'title', 'url', 'source_url', 'authors']
+ filename = tmp_path / "data.csv"
+ with open(filename, "w", newline="") as csvfile:
+ fieldnames = ["id", "hash_id", "title", "url", "source_url", "authors"]
writer = DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
@@ -51,152 +55,195 @@ def csv_file(tmp_path):
def test_items_list(csv_file):
- dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',')
+ dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",")
def mock_entries():
return [
Mock(
- _id=dataset.maybe(v, 'id'),
- id=dataset.maybe(v, 'hash_id'),
- title=dataset.maybe(v, 'title'),
- url=dataset.maybe(v, 'url'),
- authors=dataset.maybe(v, 'authors')
+ _id=dataset.maybe(v, "id"),
+ id=dataset.maybe(v, "hash_id"),
+ title=dataset.maybe(v, "title"),
+ url=dataset.maybe(v, "url"),
+ authors=dataset.maybe(v, "authors"),
)
for v in dataset.csv_items
]
- with patch.object(dataset, 'read_entries', mock_entries):
+ with patch.object(dataset, "read_entries", mock_entries):
items = dataset.items_list
- assert len(items) == 5, "items_list should only contain items with valid ids - something is wrong"
+ assert (
+ len(items) == 5
+ ), "items_list should only contain items with valid ids - something is wrong"
for item in items:
- assert dataset.maybe(item.updates, 'id') == item.article._id
- assert dataset.maybe(item.updates, 'hash_id') == item.article.id
- assert dataset.maybe(item.updates, 'title') == item.article.title
- assert dataset.maybe(item.updates, 'url') == item.article.url
- assert dataset.maybe(item.updates, 'authors') == item.article.authors
-
-
-@pytest.mark.parametrize('updates', (
- Mock(url='http://some.other.url'),
- Mock(source_url='http://some.other.url'),
- Mock(url='http://some.other.url', source_url='http://another.url'),
-))
+ assert dataset.maybe(item.updates, "id") == item.article._id
+ assert dataset.maybe(item.updates, "hash_id") == item.article.id
+ assert dataset.maybe(item.updates, "title") == item.article.title
+ assert dataset.maybe(item.updates, "url") == item.article.url
+ assert dataset.maybe(item.updates, "authors") == item.article.authors
+
+
+@pytest.mark.parametrize(
+ "updates",
+ (
+ Mock(url="http://some.other.url"),
+ Mock(source_url="http://some.other.url"),
+ Mock(url="http://some.other.url", source_url="http://another.url"),
+ ),
+)
def test_update_text(csv_file, updates):
- dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',')
+ dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",")
- article = Mock(text='this should be changed', status='as should this', url='http:/bla.bla.com')
+ article = Mock(text="this should be changed", status="as should this", url="http:/bla.bla.com")
- with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}):
+ with patch(
+ "align_data.sources.articles.updater.item_metadata",
+ return_value={"text": "bla bla bla"},
+ ):
dataset.update_text(updates, article)
- assert article.text == 'bla bla bla'
+ assert article.text == "bla bla bla"
assert article.status == None
-@pytest.mark.parametrize('updates', (
- Mock(url='http://some.other.url'),
- Mock(source_url='http://some.other.url'),
- Mock(url='http://some.other.url', source_url='http://another.url'),
-))
+@pytest.mark.parametrize(
+ "updates",
+ (
+ Mock(url="http://some.other.url"),
+ Mock(source_url="http://some.other.url"),
+ Mock(url="http://some.other.url", source_url="http://another.url"),
+ ),
+)
def test_update_text_error(csv_file, updates):
- dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',')
+ dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",")
- article = Mock(text='this should not be changed', status='but this should be', url='http:/bla.bla.com')
+ article = Mock(
+ text="this should not be changed",
+ status="but this should be",
+ url="http:/bla.bla.com",
+ )
- with patch('align_data.sources.articles.updater.item_metadata', return_value={'error': 'oh noes!'}):
+ with patch(
+ "align_data.sources.articles.updater.item_metadata",
+ return_value={"error": "oh noes!"},
+ ):
dataset.update_text(updates, article)
- assert article.text == 'this should not be changed'
- assert article.status == 'oh noes!'
-
-
-@pytest.mark.parametrize('updates', (
- Mock(url='http://bla.bla.com', source_url=None, comment='Same url as article, no source_url'),
- Mock(url='http://bla.bla.com', source_url='', comment='Same url as article, empty source_url'),
- Mock(url=None, source_url=None, comment='no urls provided'),
-))
+ assert article.text == "this should not be changed"
+ assert article.status == "oh noes!"
+
+
+@pytest.mark.parametrize(
+ "updates",
+ (
+ Mock(
+ url="http://bla.bla.com",
+ source_url=None,
+ comment="Same url as article, no source_url",
+ ),
+ Mock(
+ url="http://bla.bla.com",
+ source_url="",
+ comment="Same url as article, empty source_url",
+ ),
+ Mock(url=None, source_url=None, comment="no urls provided"),
+ ),
+)
def test_update_text_no_update(csv_file, updates):
- dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',')
+ dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",")
- article = Mock(text='this should not be changed', status='as should not this', url='http://bla.bla.com')
+ article = Mock(
+ text="this should not be changed",
+ status="as should not this",
+ url="http://bla.bla.com",
+ )
- with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}):
+ with patch(
+ "align_data.sources.articles.updater.item_metadata",
+ return_value={"text": "bla bla bla"},
+ ):
dataset.update_text(updates, article)
- assert article.text == 'this should not be changed'
- assert article.status == 'as should not this'
+ assert article.text == "this should not be changed"
+ assert article.status == "as should not this"
def test_process_entry(csv_file):
- dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',')
+ dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",")
article = Article(
- _id=123, id='deadbeef0123',
- title='this should be changed',
- url='this should be changed',
- text='this should be changed',
- authors='this should be changed',
- date_published='this should be changed',
+ _id=123,
+ id="deadbeef0123",
+ title="this should be changed",
+ url="this should be changed",
+ text="this should be changed",
+ authors="this should be changed",
+ date_published="this should be changed",
)
updates = Mock(
- id='123',
- hash_id='deadbeef001',
- title='bla bla',
- url='http://bla.com',
- source_url='http://bla.bla.com',
- source='tests',
- authors='mr. blobby, johnny',
- date_published='2000-12-23T10:32:43Z',
+ id="123",
+ hash_id="deadbeef001",
+ title="bla bla",
+ url="http://bla.com",
+ source_url="http://bla.bla.com",
+ source="tests",
+ authors="mr. blobby, johnny",
+ date_published="2000-12-23T10:32:43Z",
)
- with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}):
+ with patch(
+ "align_data.sources.articles.updater.item_metadata",
+ return_value={"text": "bla bla bla"},
+ ):
assert dataset.process_entry(Item(updates, article)).to_dict() == {
- 'authors': ['mr. blobby', 'johnny'],
- 'date_published': '2000-12-23T10:32:43Z',
- 'id': 'd8d8cad8d28739a0862654a0e6e8ce6e',
- 'source': 'tests',
- 'source_type': None,
- 'summaries': [],
- 'text': 'bla bla bla',
- 'title': 'bla bla',
- 'url': 'http://bla.com',
- 'source_url': 'http://bla.bla.com',
+ "authors": ["mr. blobby", "johnny"],
+ "date_published": "2000-12-23T10:32:43Z",
+ "id": "d8d8cad8d28739a0862654a0e6e8ce6e",
+ "source": "tests",
+ "source_type": None,
+ "summaries": [],
+ "text": "bla bla bla",
+ "title": "bla bla",
+ "url": "http://bla.com",
+ "source_url": "http://bla.bla.com",
}
def test_process_entry_empty(csv_file):
- dataset = ReplacerDataset(name='bla', csv_path=csv_file, delimiter=',')
+ dataset = ReplacerDataset(name="bla", csv_path=csv_file, delimiter=",")
article = Article(
- _id=123, id='deadbeef0123',
- title='this should not be changed',
- url='this should not be changed',
- source='this should not be changed',
- authors='this should not be changed',
-
- text='this should be changed',
- date_published='this should be changed',
+ _id=123,
+ id="deadbeef0123",
+ title="this should not be changed",
+ url="this should not be changed",
+ source="this should not be changed",
+ authors="this should not be changed",
+ text="this should be changed",
+ date_published="this should be changed",
)
updates = Mock(
- id='123',
- hash_id='deadbeef001',
+ id="123",
+ hash_id="deadbeef001",
title=None,
- url='',
- source_url='http://bla.bla.com',
- source=' ',
- authors=' \n \n \t \t ',
- date_published='2000-12-23T10:32:43Z',
+ url="",
+ source_url="http://bla.bla.com",
+ source=" ",
+ authors=" \n \n \t \t ",
+ date_published="2000-12-23T10:32:43Z",
)
- with patch('align_data.sources.articles.updater.item_metadata', return_value={'text': 'bla bla bla'}):
+ with patch(
+ "align_data.sources.articles.updater.item_metadata",
+ return_value={"text": "bla bla bla"},
+ ):
assert dataset.process_entry(Item(updates, article)).to_dict() == {
- 'authors': ['this should not be changed'],
- 'date_published': '2000-12-23T10:32:43Z',
- 'id': '606e9224254f508d297bcb17bcc6d104',
- 'source': 'this should not be changed',
- 'source_type': None,
- 'summaries': [],
- 'text': 'bla bla bla',
- 'title': 'this should not be changed',
- 'url': 'this should not be changed',
- 'source_url': 'http://bla.bla.com',
+ "authors": ["this should not be changed"],
+ "date_published": "2000-12-23T10:32:43Z",
+ "id": "606e9224254f508d297bcb17bcc6d104",
+ "source": "this should not be changed",
+ "source_type": None,
+ "summaries": [],
+ "text": "bla bla bla",
+ "title": "this should not be changed",
+ "url": "this should not be changed",
+ "source_url": "http://bla.bla.com",
}
diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py
index e45c62eb..d18aaf78 100644
--- a/tests/align_data/common/test_alignment_dataset.py
+++ b/tests/align_data/common/test_alignment_dataset.py
@@ -75,41 +75,68 @@ def test_data_entry_id_from_urls_and_title():
)
-@pytest.mark.parametrize('item, error', (
- (
- {"key1": 12, "key2": 312, "title": "wikipedia goes to war on porcupines", "url": "asd"},
- 'missing fields: date_published, source, text'
- ),
- (
- {"key1": 12, "key2": 312, "url": "www.wikipedia.org", "text": "asdasd", "title": "asdasd"},
- 'missing fields: date_published, source'
- ),
- (
- {
- "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla",
- "source": "dwe", "date_published": "dwe"
- },
- 'missing fields: text'
- ),
- (
- {
- "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla",
- "text": "asdasd", "date_published": "dwe"
- },
- 'missing fields: source'
- ),
+@pytest.mark.parametrize(
+ "item, error",
(
- {
- "key1": 12, "key2": 312, "url": "www.wikipedia.org", "title": "bla", "text": "asdasd", "source": "dwe"
- },
- 'missing fields: date_published'
+ (
+ {
+ "key1": 12,
+ "key2": 312,
+ "title": "wikipedia goes to war on porcupines",
+ "url": "asd",
+ },
+ "missing fields: date_published, source, text",
+ ),
+ (
+ {
+ "key1": 12,
+ "key2": 312,
+ "url": "www.wikipedia.org",
+ "text": "asdasd",
+ "title": "asdasd",
+ },
+ "missing fields: date_published, source",
+ ),
+ (
+ {
+ "key1": 12,
+ "key2": 312,
+ "url": "www.wikipedia.org",
+ "title": "bla",
+ "source": "dwe",
+ "date_published": "dwe",
+ },
+ "missing fields: text",
+ ),
+ (
+ {
+ "key1": 12,
+ "key2": 312,
+ "url": "www.wikipedia.org",
+ "title": "bla",
+ "text": "asdasd",
+ "date_published": "dwe",
+ },
+ "missing fields: source",
+ ),
+ (
+ {
+ "key1": 12,
+ "key2": 312,
+ "url": "www.wikipedia.org",
+ "title": "bla",
+ "text": "asdasd",
+ "source": "dwe",
+ },
+ "missing fields: date_published",
+ ),
),
-))
+)
def test_data_entry_missing(item, error):
dataset = AlignmentDataset(name="blaa")
entry = dataset.make_data_entry(item)
Article.before_write(None, None, entry)
- assert entry.status == 'Missing fields'
+ assert entry.status == "Missing fields"
assert entry.comments == error
@@ -136,7 +163,7 @@ def test_data_entry_verify_id_fails():
"id": "f2b4e02fc1dd8ae43845e4f930f2d84f",
}
)
- expected = 'Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97'
+ expected = "Entry id f2b4e02fc1dd8ae43845e4f930f2d84f does not match id from id_fields: 770fe57c8c2130eda08dc392b8696f97"
with pytest.raises(AssertionError, match=expected):
entry.verify_id()
@@ -172,7 +199,9 @@ def test_data_entry_verify_fields_fails(data, error):
def test_data_entry_id_fields():
dataset = AlignmentDataset(name="blaa")
- entry = dataset.make_data_entry({"url": "https://www.google.ca/once_upon_a_time", 'title': 'bla'})
+ entry = dataset.make_data_entry(
+ {"url": "https://www.google.ca/once_upon_a_time", "title": "bla"}
+ )
Article.before_write(None, None, entry)
assert entry.id
@@ -246,16 +275,11 @@ def test_unprocessed_items_some_done(numbers_dataset):
def test_fetch_entries(numbers_dataset):
- assert [i.meta["value"] for i in numbers_dataset.fetch_entries()] == [
- i**2 for i in range(10)
- ]
+ assert [i.meta["value"] for i in numbers_dataset.fetch_entries()] == [i**2 for i in range(10)]
def test_format_datatime(dataset):
- assert (
- dataset._format_datetime(datetime(2022, 1, 1, 12, 23, 43))
- == "2022-01-01T12:23:43Z"
- )
+ assert dataset._format_datetime(datetime(2022, 1, 1, 12, 23, 43)) == "2022-01-01T12:23:43Z"
def test_format_datatime_ignore_timezone(dataset):
diff --git a/tests/align_data/common/test_html_dataset.py b/tests/align_data/common/test_html_dataset.py
index 25e84b8b..3efbddb0 100644
--- a/tests/align_data/common/test_html_dataset.py
+++ b/tests/align_data/common/test_html_dataset.py
@@ -91,16 +91,12 @@ def test_html_dataset_items_list(html_dataset):
def test_html_datasetfetch_contents(html_dataset):
with patch("requests.get", return_value=Mock(content=SAMPLE_HTML)):
- assert html_dataset.fetch_contents("url") == BeautifulSoup(
- SAMPLE_HTML, "html.parser"
- )
+ assert html_dataset.fetch_contents("url") == BeautifulSoup(SAMPLE_HTML, "html.parser")
def test_html_dataset_get_text(html_dataset):
soup = BeautifulSoup(f"{SAMPLE_CONTENTS}", "html.parser")
- assert (
- html_dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla"
- )
+ assert html_dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla"
def test_html_dataset_find_date(html_dataset):
@@ -125,10 +121,7 @@ def test_html_dataset_find_date(html_dataset):
),
)
def test_html_dataset_extract_metadata(html_dataset, text):
- assert (
- html_dataset._extract_markdown(text)
- == "bla bla bla [a link](http://ble.com) bla bla"
- )
+ assert html_dataset._extract_markdown(text) == "bla bla bla [a link](http://ble.com) bla bla"
def test_html_dataset_process_entry(html_dataset):
@@ -176,9 +169,7 @@ def test_html_dataset_process_entry_no_text(html_dataset):
),
)
def test_rss_dataset_extract_authors(item, authors):
- dataset = RSSDataset(
- name="bla", url="http://example.org", authors=["default author"]
- )
+ dataset = RSSDataset(name="bla", url="http://example.org", authors=["default author"])
assert dataset.extract_authors(item) == authors
@@ -202,9 +193,7 @@ def test_rss_dataset_get_title():
),
)
def test_rss_dataset_get_published_date(item, date):
- dataset = RSSDataset(
- name="bla", url="http://example.org", authors=["default author"]
- )
+ dataset = RSSDataset(name="bla", url="http://example.org", authors=["default author"])
assert dataset._get_published_date(item) == date
@@ -263,6 +252,4 @@ def test_rss_dataset_items_list():
}
with patch("feedparser.parse", return_value=contents):
- assert dataset.items_list == [
- f"http://example.org/article-{i}" for i in range(5)
- ]
+ assert dataset.items_list == [f"http://example.org/article-{i}" for i in range(5)]
diff --git a/tests/align_data/embeddings/test_embedding_utils.py b/tests/align_data/embeddings/test_embedding_utils.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/align_data/embeddings/test_pinecone_db_handler.py b/tests/align_data/embeddings/test_pinecone_db_handler.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/align_data/test_alignment_newsletter.py b/tests/align_data/sources/test_alignment_newsletter.py
similarity index 98%
rename from tests/align_data/test_alignment_newsletter.py
rename to tests/align_data/sources/test_alignment_newsletter.py
index ffd5e31b..da8fa043 100644
--- a/tests/align_data/test_alignment_newsletter.py
+++ b/tests/align_data/sources/test_alignment_newsletter.py
@@ -43,9 +43,7 @@ def test_process_entry_no_summary(dataset):
def test_format_datatime(dataset):
- assert dataset._get_published_date(2022) == datetime(
- 2022, 1, 1, tzinfo=timezone.utc
- )
+ assert dataset._get_published_date(2022) == datetime(2022, 1, 1, tzinfo=timezone.utc)
def test_process_entry(dataset):
diff --git a/tests/align_data/test_arbital.py b/tests/align_data/sources/test_arbital.py
similarity index 96%
rename from tests/align_data/test_arbital.py
rename to tests/align_data/sources/test_arbital.py
index 304c5398..af65ed05 100644
--- a/tests/align_data/test_arbital.py
+++ b/tests/align_data/sources/test_arbital.py
@@ -127,9 +127,7 @@ def post(url, *args, **kwargs):
page = json.loads(kwargs.get("data", "{}")).get("pageAlias")
if "json/explore" in url:
- response.json.return_value = {
- "pages": {f"{page}-{i}": i for i in range(10)}
- }
+ response.json.return_value = {"pages": {f"{page}-{i}": i for i in range(10)}}
elif "json/primaryPage" in url:
response.json.return_value = {
"pages": {
@@ -201,9 +199,7 @@ def test_extract_authors_ignore_missing(dataset):
page = {"changeLogs": [{"userId": author} for author in authors]}
with patch.object(dataset, "get_title", lambda author: author):
- assert sorted(dataset.extract_authors(page)) == sorted(
- ["John Snow", "mr. blobby"]
- )
+ assert sorted(dataset.extract_authors(page)) == sorted(["John Snow", "mr. blobby"])
@pytest.mark.parametrize(
diff --git a/tests/align_data/test_arxiv.py b/tests/align_data/sources/test_arxiv.py
similarity index 100%
rename from tests/align_data/test_arxiv.py
rename to tests/align_data/sources/test_arxiv.py
diff --git a/tests/align_data/test_blogs.py b/tests/align_data/sources/test_blogs.py
similarity index 74%
rename from tests/align_data/test_blogs.py
rename to tests/align_data/sources/test_blogs.py
index 9e648f14..1e5a01c7 100644
--- a/tests/align_data/test_blogs.py
+++ b/tests/align_data/sources/test_blogs.py
@@ -5,6 +5,7 @@
from dateutil.parser import parse
from align_data.sources.blogs import (
+ AXRPDataset,
CaradoMoe,
ColdTakes,
GenerativeInk,
@@ -14,6 +15,7 @@
WordpressBlog,
OpenAIResearch,
DeepMindTechnicalBlog,
+ TransformerCircuits,
)
from align_data.sources.blogs.blogs import EleutherAI
@@ -161,10 +163,7 @@ def test_caradomoe_text():
"""
soup = BeautifulSoup(contents, "html.parser")
- assert (
- dataset._get_text({"soup": soup})
- == "bla bla bla [a link](http://ble.com) bla bla"
- )
+ assert dataset._get_text({"soup": soup}) == "bla bla bla [a link](http://ble.com) bla bla"
def test_caradomoe_process_entry():
@@ -229,9 +228,7 @@ def test_caradomoe_process_entry():
def test_gwern_get_text():
- dataset = GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- )
+ dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"])
soup = BeautifulSoup(GWERN_CONTENTS, "html.parser")
assert dataset._get_text(soup) == "bla bla bla [a link](http://ble.com) bla bla"
@@ -251,17 +248,13 @@ def test_gwern_get_text():
),
)
def test_gwern_get_published_date(metadata, date):
- dataset = GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- )
+ dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"])
assert dataset._get_published_date(metadata) == date
def test_gwern_get_article():
- dataset = GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- )
+ dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"])
with patch("requests.get", return_value="article contents"):
assert dataset._get_article("http://bla.com") == "article contents"
@@ -302,13 +295,9 @@ def test_gwern_process_markdown():
...
{SAMPLE_HTML}
"""
- dataset = GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- )
+ dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"])
- assert dataset._process_markdown(
- "http://article.url", Mock(text=text)
- ).to_dict() == {
+ assert dataset._process_markdown("http://article.url", Mock(text=text)).to_dict() == {
"authors": ["Gwern Branwen"],
"date_published": "2020-05-28T00:00:00Z",
"id": None,
@@ -329,13 +318,9 @@ def test_gwern_process_entry_markdown():
...
{SAMPLE_HTML}
"""
- dataset = GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- )
+ dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"])
- with patch(
- "requests.get", return_value=Mock(text=text, status_code=200, headers={})
- ):
+ with patch("requests.get", return_value=Mock(text=text, status_code=200, headers={})):
assert dataset.process_entry("http://article.url").to_dict() == {
"authors": ["Gwern Branwen"],
"date_published": "2020-05-28T00:00:00Z",
@@ -350,9 +335,7 @@ def test_gwern_process_entry_markdown():
def test_gwern_process_entry_html():
- dataset = GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- )
+ dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"])
with patch(
"requests.get",
@@ -376,9 +359,7 @@ def test_gwern_process_entry_html():
def test_gwern_process_entry_erro():
- dataset = GwernBlog(
- name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"]
- )
+ dataset = GwernBlog(name="gwern_blog", url="https://www.gwern.net/", authors=["Gwern Branwen"])
with patch("requests.get", return_value=Mock(status_code=404)):
assert dataset.process_entry("http://article.url") is None
@@ -490,9 +471,7 @@ def test_substack_blog_process_entry():
"title": "Eliezer S. Yudkowsky",
"link": "https://www.yudkowsky.net",
},
- "headers": {
- "link": '; rel="https://api.w.org/"'
- },
+ "headers": {"link": '; rel="https://api.w.org/"'},
}
@@ -508,9 +487,7 @@ def test_wordpress_blog_setup():
@patch("feedparser.parse", return_value=WORDPRESS_FEED)
def test_wordpress_blog_items_list(feedparser_parse):
blog = WordpressBlog(name="blog", url="https://www.bla.yudkowsky.net")
- assert blog.items_list == [
- "https://www.yudkowsky.net/other/fiction/prospiracy-theory"
- ]
+ assert blog.items_list == ["https://www.yudkowsky.net/other/fiction/prospiracy-theory"]
def test_wordpress_blog_get_item_key():
@@ -527,9 +504,7 @@ def test_wordpress_blog_get_published_date():
name="blog",
url="https://www.bla.yudkowsky.net",
)
- date_published = blog._get_published_date(
- {"published": "Mon, 26 Jun 2023 13:40:01 +0000"}
- )
+ date_published = blog._get_published_date({"published": "Mon, 26 Jun 2023 13:40:01 +0000"})
assert date_published == parse("2023-06-26T13:40:01Z")
@@ -540,9 +515,7 @@ def test_wordpress_blog_process_entry(feedparser_parse):
url="https://www.bla.yudkowsky.net",
)
blog.items = {i["link"]: i for i in WORDPRESS_FEED["entries"]}
- entry = blog.process_entry(
- "https://www.yudkowsky.net/other/fiction/prospiracy-theory"
- )
+ entry = blog.process_entry("https://www.yudkowsky.net/other/fiction/prospiracy-theory")
assert entry.to_dict() == {
"authors": ["Eliezer S. Yudkowsky"],
"date_published": "2020-09-04T04:11:23Z",
@@ -642,10 +615,8 @@ def test_openai_research_get_text():
dataset = OpenAIResearch(name="openai", url="bla.bla")
soup = BeautifulSoup(OPENAI_HTML, "html.parser")
- parsers = {"arxiv.org": lambda _: {'text': 'bla bla bla'}}
- with patch(
- "requests.head", return_value=Mock(headers={"Content-Type": "text/html"})
- ):
+ parsers = {"arxiv.org": lambda _: {"text": "bla bla bla"}}
+ with patch("requests.head", return_value=Mock(headers={"Content-Type": "text/html"})):
with patch("align_data.sources.articles.parsers.PDF_PARSERS", parsers):
assert dataset._get_text(soup) == "bla bla bla"
@@ -696,10 +667,8 @@ def test_openai_research_process_entry():
dataset = OpenAIResearch(name="openai", url="bla.bla")
soup = BeautifulSoup(OPENAI_HTML, "html.parser")
- parsers = {"arxiv.org": lambda _: {'text': 'bla bla bla'}}
- with patch(
- "requests.head", return_value=Mock(headers={"Content-Type": "text/html"})
- ):
+ parsers = {"arxiv.org": lambda _: {"text": "bla bla bla"}}
+ with patch("requests.head", return_value=Mock(headers={"Content-Type": "text/html"})):
with patch("requests.get", return_value=Mock(content=OPENAI_HTML)):
with patch("align_data.sources.articles.parsers.PDF_PARSERS", parsers):
assert dataset.process_entry(soup).to_dict() == {
@@ -781,3 +750,154 @@ def test_deepmind_technical_proces_entry():
"title": "title!",
"url": "http://bla.bl",
}
+
+
+TRANSFORMER_CIRCUITS_HTML = """
+
+
This is the title
+
+
+