From aafefbcafc9f1e113b7f0b4d3aebfbf48ea3c87f Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Fri, 14 Jul 2023 19:33:52 +0200 Subject: [PATCH 1/7] basic mysql support --- README.md | 9 +- .../alignment_newsletter.py | 5 +- align_data/arbital/arbital.py | 7 +- align_data/articles/datasets.py | 7 +- align_data/arxiv_papers/arxiv_papers.py | 12 +- .../audio_transcripts/audio_transcripts.py | 9 +- align_data/blogs/blogs.py | 2 +- align_data/blogs/gwern_blog.py | 6 +- align_data/blogs/medium_blog.py | 2 - align_data/common/alignment_dataset.py | 167 +++++--------- align_data/common/html_dataset.py | 7 +- align_data/db/models.py | 113 ++++++++++ align_data/db/session.py | 13 ++ align_data/distill/distill.py | 2 +- align_data/ebooks/agentmodels.py | 7 +- align_data/ebooks/gdrive_ebooks.py | 5 +- align_data/ebooks/mdebooks.py | 11 +- align_data/gdocs/gdocs.py | 7 +- align_data/reports/reports.py | 7 +- align_data/settings.py | 8 + align_data/stampy/stampy.py | 1 - local_db.sh | 26 +++ main.py | 31 +-- migrations/README | 1 + migrations/alembic.ini | 110 +++++++++ migrations/env.py | 82 +++++++ migrations/script.py.mako | 24 ++ .../8c11b666e86f_initial_structure.py | 65 ++++++ requirements.txt | 2 + tests/align_data/articles/test_datasets.py | 27 +-- .../common/test_alignment_dataset.py | 209 ++++++------------ tests/align_data/common/test_html_dataset.py | 17 +- tests/align_data/test_alignment_newsletter.py | 7 +- tests/align_data/test_arbital.py | 18 +- tests/align_data/test_blogs.py | 74 +++---- tests/align_data/test_distill.py | 6 +- tests/align_data/test_greater_wrong.py | 8 +- tests/align_data/test_stampy.py | 10 +- tests/conftest.py | 10 + 39 files changed, 706 insertions(+), 428 deletions(-) create mode 100644 align_data/db/models.py create mode 100644 align_data/db/session.py create mode 100755 local_db.sh create mode 100644 migrations/README create mode 100644 migrations/alembic.ini create mode 100644 migrations/env.py create mode 100644 migrations/script.py.mako create mode 100644 migrations/versions/8c11b666e86f_initial_structure.py diff --git a/README.md b/README.md index ff128c0b..3d3974aa 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ The values of the keys are still being cleaned up for consistency. Additional ke ## Development Environment -To set up the development environment, run the following steps: +To set up the development environment, run the following steps. You'll have to also set up [mysqlclient](https://pypi.org/project/mysqlclient/): ```bash git clone https://github.com/StampyAI/alignment-research-dataset @@ -60,6 +60,13 @@ cd alignment-research-dataset pip install -r requirements.txt ``` +### Database + +You'll also have to set up a MySQL database. To do so with Docker, you can run `./local_db.sh` which should spin up a container +with the database initialised. + +### CLI options + The available CLI options are list, fetch, fetch-all, and count-tokens. To get a list of all available datasets: diff --git a/align_data/alignment_newsletter/alignment_newsletter.py b/align_data/alignment_newsletter/alignment_newsletter.py index c4d80624..749c09a4 100644 --- a/align_data/alignment_newsletter/alignment_newsletter.py +++ b/align_data/alignment_newsletter/alignment_newsletter.py @@ -28,9 +28,8 @@ def get_item_key(self, row): @staticmethod def _get_published_date(year): if not year or pd.isna(year): - return '' - dt = datetime(int(year), 1, 1, tzinfo=timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + return None + return datetime(int(year), 1, 1, tzinfo=timezone.utc) @property def items_list(self): diff --git a/align_data/arbital/arbital.py b/align_data/arbital/arbital.py index 627e524a..4f147e12 100644 --- a/align_data/arbital/arbital.py +++ b/align_data/arbital/arbital.py @@ -132,7 +132,7 @@ def process_entry(self, alias): 'authors': self.extract_authors(page), 'alias': alias, 'tags': list(filter(None, map(self.get_title, page['tagIds']))), - 'summary': [summary] if summary else [], + 'summary': summary, }) except Exception as e: logger.error(f"Error getting page {alias}: {e}") @@ -149,9 +149,8 @@ def get_arbital_page_aliases(self, subspace): def _get_published_date(page): date_published = page.get('editCreatedAt') or page.get('pageCreatedAt') if date_published: - dt = parse(date_published).astimezone(timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") - return '' + return parse(date_published).astimezone(timezone.utc) + return None def get_page(self, alias): headers = self.headers.copy() diff --git a/align_data/articles/datasets.py b/align_data/articles/datasets.py index ed521cf4..aba2a8e8 100644 --- a/align_data/articles/datasets.py +++ b/align_data/articles/datasets.py @@ -34,9 +34,6 @@ def items_list(self): def get_item_key(self, item): return getattr(item, self.done_key) - def _get_published_date(self, item): - return self._format_datetime(parse(item.date_published)) - @staticmethod def _get_text(item): raise NotImplemented @@ -58,9 +55,9 @@ def process_entry(self, item): 'source': self.name, 'source_type': item.source_type, 'source_filetype': self.source_filetype, - 'date_published': self._get_published_date(item), + 'date_published': self._get_published_date(item.date_published), 'authors': self.extract_authors(item), - 'summary': [] if pd.isna(item.summary) else [item.summary], + 'summary': None if pd.isna(item.summary) else item.summary, }) diff --git a/align_data/arxiv_papers/arxiv_papers.py b/align_data/arxiv_papers/arxiv_papers.py index 85e82f88..cfc2ba3b 100644 --- a/align_data/arxiv_papers/arxiv_papers.py +++ b/align_data/arxiv_papers/arxiv_papers.py @@ -59,7 +59,7 @@ def process_entry(self, ids) -> None: "converted_with": "markdownify", "title": paper.title, "authors": [str(x) for x in paper.authors], - "date_published": paper.published.strftime("%Y-%m-%dT%H:%M:%SZ"), + "date_published": paper.published, "data_last_modified": str(paper.updated), "abstract": paper.summary.replace("\n", " "), "author_comment": paper.comment, @@ -81,9 +81,9 @@ def _is_bad_soup(self, soup, parser='vanity') -> bool: return vanity_wrapper and "don’t have to squint at a PDF" not in vanity_wrapper if parser == 'ar5iv': ar5iv_error = soup.find("span", class_="ltx_ERROR") - if ar5iv_error is None: + if ar5iv_error is None: return False - else: + else: ar5iv_error = ar5iv_error.text if "document may be truncated or damaged" in ar5iv_error: return True @@ -95,8 +95,8 @@ def _is_dud(self, markdown) -> bool: Check if markdown is a dud """ return ( - "Paper Not Renderable" in markdown or - "This document may be truncated" in markdown or + "Paper Not Renderable" in markdown or + "This document may be truncated" in markdown or "don’t have to squint at a PDF" not in markdown ) @@ -152,7 +152,7 @@ def _remove_bib_from_article_soup(self, article_soup) -> str: if bib: bib.decompose() return article_soup - + def _strip_markdown(self, s_markdown): return s_markdown.split("\nReferences\n")[0].replace("\n\n", "\n") diff --git a/align_data/audio_transcripts/audio_transcripts.py b/align_data/audio_transcripts/audio_transcripts.py index a70e6e29..cf5cff77 100644 --- a/align_data/audio_transcripts/audio_transcripts.py +++ b/align_data/audio_transcripts/audio_transcripts.py @@ -14,6 +14,7 @@ class AudioTranscripts(GdocDataset): def setup(self): super().setup() + # FIXME: This isn't working - missing files, perhaps? self.files_path = self.raw_data_path / 'transcripts' if not self.files_path.exists(): self.files_path.mkdir(parents=True, exist_ok=True) @@ -50,16 +51,14 @@ def extract_authors(text): if res := re.search('^(.*?):', firstline): return [res.group(1)] return [] - + @staticmethod def _get_published_date(filename): date_str = re.search(r"\d{4}\d{2}\d{2}", str(filename)) if not date_str: - return '' + return None date_str = date_str.group(0) - dt = datetime.strptime(date_str, "%Y%m%d").astimezone(timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") - + return datetime.strptime(date_str, "%Y%m%d").astimezone(timezone.utc) def process_entry(self, filename): logger.info(f"Processing {filename.name}") diff --git a/align_data/blogs/blogs.py b/align_data/blogs/blogs.py index 9c58f363..9c6b962b 100644 --- a/align_data/blogs/blogs.py +++ b/align_data/blogs/blogs.py @@ -28,7 +28,7 @@ def _get_published_date(self, contents): elem for info in contents.select('div.post-info') for elem in info.children ] - return super()._get_published_date(self._find_date(possible_date_elements)) + return self._find_date(possible_date_elements) class CaradoMoe(RSSDataset): diff --git a/align_data/blogs/gwern_blog.py b/align_data/blogs/gwern_blog.py index da425ba4..325bd7d3 100644 --- a/align_data/blogs/gwern_blog.py +++ b/align_data/blogs/gwern_blog.py @@ -1,8 +1,6 @@ import requests import logging from dataclasses import dataclass -from datetime import datetime, timezone -from dateutil.parser import parse from align_data.common.html_dataset import HTMLDataset @@ -86,9 +84,7 @@ def _get_published_date(self, contents): contents.select_one('.page-date-range .page-modified') or contents.select_one('.page-date-range .page-created') ).text.strip() - if date_published: - return self._format_datetime(parse(date_published)) - return '' + return super()._get_published_date(date_published) def _get_text(self, contents): return self._extract_markdown(contents.select_one('div#markdownBody')) diff --git a/align_data/blogs/medium_blog.py b/align_data/blogs/medium_blog.py index 7e25a293..9d57e2ab 100644 --- a/align_data/blogs/medium_blog.py +++ b/align_data/blogs/medium_blog.py @@ -26,8 +26,6 @@ class MediumBlog(HTMLDataset): but various fixes were added to handle a wider range of Medium blogs. """ - url: str - done_key = "url" source_type = "medium_blog" ignored_selectors = ['div:first-child span'] diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 894093db..2f8d236f 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -1,19 +1,20 @@ -import hashlib import logging import time import zipfile -from collections import UserDict -from contextlib import contextmanager from dataclasses import dataclass, field, KW_ONLY -from functools import partial from pathlib import Path -from typing import Optional, List +from typing import List +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError import gdown import jsonlines import pytz from dateutil.parser import parse, ParserError from tqdm import tqdm +from align_data.db.models import Article, Author +from align_data.db.session import make_session + INIT_DICT = { "source": None, @@ -22,7 +23,6 @@ "date_published": None, "title": None, "url": None, - "summary": lambda: [], "authors": lambda: [], } @@ -71,7 +71,7 @@ class AlignmentDataset: """A list of fields to use as the id of the entry. If not set, will use ['url', 'title']""" def __str__(self) -> str: - return f"{self.name} dataset will be written to {self.jsonl_path}" + return self.name def __post_init__(self, data_path=Path(__file__).parent / '../../data/'): self.data_path = data_path @@ -79,72 +79,48 @@ def __post_init__(self, data_path=Path(__file__).parent / '../../data/'): # set the default place to look for data self.files_path = self.raw_data_path / self.name + # TODO: get rid of self.jsonl_path + self.jsonl_path = self.data_path / f"{self.name}.jsonl" - # and the default place to write data - self._set_output_paths(self.data_path) - - def _set_output_paths(self, out_path): - self.jsonl_path = Path(out_path) / f"{self.name}.jsonl" - self.txt_path = Path(out_path) / f"{self.name}.txt" - - def write_entry(self, entry, jsonl_writer, text_writer): - jsonl_writer.write(entry.to_dict()) - - # Save the entry in plain text, mainly for debugging - text = entry["text"].lstrip().replace('\n', '\n ') - text_writer.write(f'[ENTRY {self._entry_idx}]\n {text}\n\n') - - self._entry_idx += 1 - self._outputted_items.add(entry[self.done_key]) - def make_data_entry(self, data, **kwargs): - return DataEntry(dict(data, **kwargs), id_fields=self.id_fields) - - @contextmanager - def writer(self, out_path=None, overwrite=False): - """Returns a function that can be used to write entries to the output file. + data = dict(data, **kwargs) + # TODO: Don't keep adding the same authors - come up with some way to reuse them + # TODO: Prettify this + data['authors'] = [Author(name=name) for name in data.get('authors', [])] + if summary := ('summary' in data and data.pop('summary')): + data['summaries'] = [summary] + return Article( + id_fields=self.id_fields, + meta={k: v for k, v in data.items() if k not in INIT_DICT}, + **{k: v for k, v in data.items() if k in INIT_DICT}, + ) - The resulting function expects to only get a single `DataEntry`, which will then - be written as a json object. - """ - if overwrite: - write_mode = 'w' - self._entry_idx = 0 - else: - write_mode = 'a' + def to_jsonl(self, out_path=None, filename=None): + if not out_path: + out_path=Path(__file__).parent / '../../data/' - if out_path: - self._set_output_paths(out_path) + if not filename: + filename = f"{self.name}.jsonl" - with jsonlines.open(self.jsonl_path, mode=write_mode) as jsonl_writer: - with open(self.txt_path, mode=write_mode, errors="backslashreplace") as text_writer: - yield partial(self.write_entry, jsonl_writer=jsonl_writer, text_writer=text_writer) + with jsonlines.open(Path(out_path) / filename, 'w') as jsonl_writer: + for article in self.read_entries(): + jsonl_writer.write(article.to_dict()) def read_entries(self): """Iterate through all the saved entries.""" - if not self.jsonl_path.exists(): - return [] - - with jsonlines.open(self.jsonl_path) as f: - for line in f: - yield line - - def merge_summaries(self, summaries): - if not self.summary_key or not self.jsonl_path.exists(): - return - - updated = 0 - tmp_file = self.jsonl_path.parent / f'{self.jsonl_path.name}-tmp' - with jsonlines.open(tmp_file, 'w') as writer: - for line in self.read_entries(): - url = line.get('url') - summary = summaries.get(url, {}) - line[self.summary_key] += list(summary.values()) - updated += bool(summary) - writer.write(line) - - logger.info('Updated %s summaries for %s', updated, self.name) - tmp_file.rename(self.jsonl_path) + with make_session() as session: + for item in session.scalars(select(Article).where(Article.source==self.name)): + yield item + + def add_entries(self, entries): + with make_session() as session: + for entry in entries: + session.add(entry) + try: + session.commit() + except IntegrityError: + logger.error(f'found duplicate of {entry}') + session.rollback() def setup(self): # make sure the path to the raw data exists @@ -166,12 +142,11 @@ def get_item_key(self, item): def _load_outputted_items(self): """Load the output file (if it exists) in order to know which items have already been output.""" - if not self.jsonl_path.exists(): - logger.info(f"No previous data found at {self.jsonl_path}") - return set() - - with jsonlines.open(self.jsonl_path, mode='r') as reader: - return {entry.get(self.done_key) for entry in reader} + with make_session() as session: + if hasattr(Article, self.done_key): + return set(session.scalars(select(getattr(Article, self.done_key)).where(Article.source==self.name)).all()) + # TODO: Properly handle this - it should create a proper SQL JSON select + return {getattr(item, self.done_key) for item in session.scalars(select(Article.meta).where(Article.source==self.name)).all()} def unprocessed_items(self, items=None): """Return a list of all items to be processed. @@ -199,7 +174,6 @@ def fetch_entries(self): if not entry: continue - entry.add_id() yield entry if self.COOLDOWN: @@ -211,16 +185,15 @@ def process_entry(self, entry): @staticmethod def _format_datetime(date): - # Totally ignore any timezone info, forcing everything to UTC - dt = date.replace(tzinfo=pytz.UTC) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + return date.strftime("%Y-%m-%dT%H:%M:%SZ") def _get_published_date(self, date): try: - return self._format_datetime(parse(str(date))) + # Totally ignore any timezone info, forcing everything to UTC + return parse(str(date)).replace(tzinfo=pytz.UTC) except ParserError: pass - return '' + return None @dataclass @@ -264,43 +237,3 @@ def folder_from_gdrive(self, url=None, output=None): output=str(output or self.files_path), quiet=False ) - - -class DataEntry(UserDict): - def __init__(self, *args, id_fields, **kwargs): - super().__init__(*args, **kwargs) - for k, default in INIT_DICT.items(): - if k not in self: - self[k] = default and default() - # Store id_fields in a way that does not interfere with UserDict's functionality - assert isinstance(id_fields, list), "id_fields must be a list" - assert id_fields, "id_fields must not be empty" - assert all(isinstance(field, str) for field in id_fields), "id_fields must be a list of strings" - self.__id_fields = id_fields - - def generate_id_string(self): - return ''.join(str(self[field]) for field in self.__id_fields).encode("utf-8") - - def verify_fields(self): - missing = [field for field in self.__id_fields if not self.get(field)] - assert not missing, f'Entry is missing the following fields: {missing}' - - def add_id(self): - self.verify_fields() - - id_string = self.generate_id_string() - self["id"] = hashlib.md5(id_string).hexdigest() - - def _verify_id(self): - assert self["id"] is not None, "Entry is missing id" - self.verify_fields() - - id_string = self.generate_id_string() - id_from_fields = hashlib.md5(id_string).hexdigest() - assert self["id"] == id_from_fields, f"Entry id {self['id']} does not match id from id_fields, {id_from_fields}" - - def to_dict(self): - for k, _ in INIT_DICT.items(): - assert self[k] is not None, f"Entry is missing key {k}" - self._verify_id() - return dict(self.data) diff --git a/align_data/common/html_dataset.py b/align_data/common/html_dataset.py index 83864828..e374dfc9 100644 --- a/align_data/common/html_dataset.py +++ b/align_data/common/html_dataset.py @@ -1,3 +1,4 @@ +import pytz import regex as re import logging from datetime import datetime @@ -93,7 +94,7 @@ def _get_text(self, contents): def _find_date(self, items): for i in items: if re.match('\w+ \d{1,2}, \d{4}', i.text): - return self._format_datetime(datetime.strptime(i.text, '%b %d, %Y')) + return datetime.strptime(i.text, '%b %d, %Y').replace(tzinfo=pytz.UTC) def _extract_markdown(self, element): return element and markdownify(str(element)).strip() @@ -121,9 +122,7 @@ def _get_title(item): def _get_published_date(self, item): date_published = item.get('published') or item.get('pubDate') - if date_published: - return self._format_datetime(parse(date_published)) - return '' + return super()._get_published_date(date_published) def _get_text(self, item): text = item.get('content') and item['content'][0].get('value') diff --git a/align_data/db/models.py b/align_data/db/models.py new file mode 100644 index 00000000..e8c4fcf9 --- /dev/null +++ b/align_data/db/models.py @@ -0,0 +1,113 @@ +import pytz +import hashlib +from datetime import datetime +from typing import List, Optional +from sqlalchemy import JSON, DateTime, ForeignKey, Table, String, Column, Integer, func, Text, event +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy +from sqlalchemy.dialects.mysql import LONGTEXT + + +class Base(DeclarativeBase): + pass + + +author_article = Table( + 'author_article', + Base.metadata, + Column('article_id', Integer, ForeignKey('articles.id'), primary_key=True), + Column('author_id', Integer, ForeignKey('authors.id'), primary_key=True), +) + + +class Author(Base): + + __tablename__ = "authors" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(256), nullable=False) + articles: Mapped[List["Article"]] = relationship(secondary=author_article, back_populates="authors") + + +class Summary(Base): + + __tablename__ = "summaries" + + id: Mapped[int] = mapped_column(primary_key=True) + text: Mapped[str] = mapped_column(Text, nullable=False) + source: Mapped[Optional[str]] = mapped_column(String(256)) + article_id: Mapped[str] = mapped_column(ForeignKey("articles.id")) + + article: Mapped["Article"] = relationship(back_populates="summaries") + + +class Article(Base): + __tablename__ = "articles" + + _id: Mapped[int] = mapped_column('id', primary_key=True) + id: Mapped[str] = mapped_column('hash_id', String(32), unique=True, nullable=False) + title: Mapped[Optional[str]] = mapped_column(String(1028)) + url: Mapped[Optional[str]] = mapped_column(String(1028)) + source: Mapped[Optional[str]] = mapped_column(String(128)) + source_type: Mapped[Optional[str]] = mapped_column(String(128)) + text: Mapped[Optional[str]] = mapped_column(LONGTEXT) + date_published: Mapped[Optional[datetime]] + meta: Mapped[Optional[JSON]] = mapped_column(JSON, name='metadata', default='{}') + date_created: Mapped[datetime] = mapped_column(DateTime, default=func.now()) + date_updated: Mapped[Optional[datetime]] = mapped_column(DateTime, onupdate=func.current_timestamp()) + + authors: Mapped[List['Author']] = relationship(secondary=author_article, back_populates="articles") + summaries: Mapped[List["Summary"]] = relationship(back_populates="article", cascade="all, delete-orphan") + + __id_fields = ['title', 'url'] + + def __init__(self, *args, id_fields, **kwargs): + self.__id_fields = id_fields + super().__init__(*args, **kwargs) + + def __repr__(self) -> str: + return f"User(id={self.id!r}, name={self.title!r}, fullname={self.url!r})" + + def generate_id_string(self): + return ''.join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8") + + def verify_fields(self): + missing = [field for field in self.__id_fields if not getattr(self, field)] + assert not missing, f'Entry is missing the following fields: {missing}' + + def verify_id(self): + assert self.id is not None, "Entry is missing id" + + id_string = self.generate_id_string() + id_from_fields = hashlib.md5(id_string).hexdigest() + assert self.id == id_from_fields, f"Entry id {self.id} does not match id from id_fields, {id_from_fields}" + + @classmethod + def before_write(cls, mapper, connection, target): + target.verify_fields() + + if target.id: + target.verify_id() + else: + id_string = target.generate_id_string() + target.id = hashlib.md5(id_string).hexdigest() + + def to_dict(self): + if date := self.date_published: + date = date.replace(tzinfo=pytz.UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + return { + 'id': self.id, + 'title': self.title, + 'url': self.url, + 'source': self.source, + 'source_type': self.source_type, + 'text': self.text, + 'date_published': date, + 'authors': [a.name for a in self.authors], + 'summaries': [s.text for s in self.summaries], + **self.meta, + } + + +event.listen(Article, 'before_insert', Article.before_write) +event.listen(Article, 'before_update', Article.before_write) diff --git a/align_data/db/session.py b/align_data/db/session.py new file mode 100644 index 00000000..16ff48e4 --- /dev/null +++ b/align_data/db/session.py @@ -0,0 +1,13 @@ +from contextlib import contextmanager +from sqlalchemy import create_engine +from sqlalchemy.orm import Session +from align_data.settings import DB_CONNECTION_URI + + +@contextmanager +def make_session(auto_commit=False): + engine = create_engine(DB_CONNECTION_URI, echo=False) + with Session(engine) as session: + yield session + if auto_commit: + session.commit() diff --git a/align_data/distill/distill.py b/align_data/distill/distill.py index 65ac7662..4a9ea388 100644 --- a/align_data/distill/distill.py +++ b/align_data/distill/distill.py @@ -24,7 +24,7 @@ def _extra_values(self, item): return { 'doi': doi_elem and doi_elem.text, - 'summary': [item['summary']], + 'summary': item['summary'], 'journal_ref': 'distill-pub', 'bibliography': [ {'title': el.find('span').text, 'link': el.find('a').get('href')} diff --git a/align_data/ebooks/agentmodels.py b/align_data/ebooks/agentmodels.py index 8b412402..dd04a30a 100644 --- a/align_data/ebooks/agentmodels.py +++ b/align_data/ebooks/agentmodels.py @@ -26,8 +26,7 @@ def setup(self): def _get_published_date(self, filename): last_commit = next(self.repository.iter_commits(paths=f'chapters/{filename.name}')) - dt = last_commit.committed_datetime.astimezone(timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + return last_commit.committed_datetime.astimezone(timezone.utc) def process_entry(self, filename): return self.make_data_entry({ @@ -36,6 +35,6 @@ def process_entry(self, filename): 'authors': ['Owain Evans', 'Andreas Stuhlmüller', 'John Salvatier', 'Daniel Filan'], 'date_published': self._get_published_date(filename), 'title': 'Modeling Agents with Probabilistic Programs', - 'url': f'https://agentmodels.org/chapters/{filename.stem}.html', + 'url': f'https://agentmodels.org/chapters/{filename.stem}.html', 'text': filename.read_text(encoding='utf-8'), - }) \ No newline at end of file + }) diff --git a/align_data/ebooks/gdrive_ebooks.py b/align_data/ebooks/gdrive_ebooks.py index bfde9e61..b01bc89f 100644 --- a/align_data/ebooks/gdrive_ebooks.py +++ b/align_data/ebooks/gdrive_ebooks.py @@ -65,6 +65,5 @@ def process_entry(self, epub_file): def _get_published_date(metadata): date_published = metadata["publication_date"] if date_published: - dt = parse(date_published).astimezone(timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") - return '' \ No newline at end of file + return parse(date_published).astimezone(timezone.utc) + return None diff --git a/align_data/ebooks/mdebooks.py b/align_data/ebooks/mdebooks.py index 87dc7075..c7b4b7aa 100644 --- a/align_data/ebooks/mdebooks.py +++ b/align_data/ebooks/mdebooks.py @@ -34,14 +34,13 @@ def process_entry(self, filename): "url": "", "filename": filename.name, }) - + @staticmethod def _get_published_date(filename): date_str = re.search(r"\d{4}-\d{2}-\d{2}", filename.name) - + if not date_str: - return '' - + return None + date_str = date_str.group(0) - dt = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + return datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=timezone.utc) diff --git a/align_data/gdocs/gdocs.py b/align_data/gdocs/gdocs.py index 79eb843a..5c41888e 100644 --- a/align_data/gdocs/gdocs.py +++ b/align_data/gdocs/gdocs.py @@ -48,15 +48,14 @@ def process_entry(self, docx_filename): "url": "", "docx_name": docx_filename.name, }) - + @staticmethod def _get_published_date(metadata): date_published = metadata.created or metadata.modified if date_published: assert isinstance(date_published, datetime), f"Expected datetime, got {type(date_published)}" - dt = date_published.replace(tzinfo=timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") - return '' + return date_published.replace(tzinfo=timezone.utc) + return None def _get_metadata(self , docx_filename): diff --git a/align_data/reports/reports.py b/align_data/reports/reports.py index ac9e84a0..e79a8909 100644 --- a/align_data/reports/reports.py +++ b/align_data/reports/reports.py @@ -25,14 +25,13 @@ def setup(self): @property def zip_file(self): return self.raw_data_path / "report_teis.zip" - + @staticmethod def _get_published_data(doc_dict): date_str = doc_dict["header"].get('date') if date_str: - dt = parse(date_str).astimezone(timezone.utc) - return dt.strftime("%Y-%m-%dT%H:%M:%SZ") - return '' + return parse(date_str).astimezone(timezone.utc) + return None def process_entry(self, filename): logger.info(f"Processing {filename.name}") diff --git a/align_data/settings.py b/align_data/settings.py index c67e66c5..49be5460 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -12,3 +12,11 @@ METADATA_SOURCE_SPREADSHEET = os.environ.get('METADATA_SOURCE_SPREADSHEET', '1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI') METADATA_SOURCE_SHEET = os.environ.get('METADATA_SOURCE_SHEET', 'special_docs.csv') METADATA_OUTPUT_SPREADSHEET = os.environ.get('METADATA_OUTPUT_SPREADSHEET', '1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4') + + +user = os.environ.get('ARD_DB_USER', 'user') +password = os.environ.get('ARD_DB_PASSWORD', 'we all live in a yellow submarine') +host = os.environ.get('ARD_DB_HOST', '127.0.0.1') +port = os.environ.get('ARD_DB_PORT', '3306') +db_name = os.environ.get('ARD_DB_NAME', 'alignment_research_dataset') +DB_CONNECTION_URI = f'mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}' diff --git a/align_data/stampy/stampy.py b/align_data/stampy/stampy.py index 9a101b52..025e4658 100644 --- a/align_data/stampy/stampy.py +++ b/align_data/stampy/stampy.py @@ -37,7 +37,6 @@ def items_list(self): def get_item_key(self, entry): return html.unescape(entry['Question']) - def _get_published_date(self, entry): date_published = entry['Doc Last Edited'] return super()._get_published_date(date_published) diff --git a/local_db.sh b/local_db.sh new file mode 100755 index 00000000..1ad506d4 --- /dev/null +++ b/local_db.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +ROOT_PASSWORD=my-secret-pw + +docker start alignment-research-dataset +if [ $? -ne 0 ]; then + echo 'No docker container found - creating a new one' + docker run --name alignment-research-dataset -p 3306:3306 -e MYSQL_ROOT_PASSWORD=$ROOT_PASSWORD -d mysql:latest +fi + +echo "Waiting till mysql is available..." +while ! mysql -h 127.0.0.1 --user root --password=$ROOT_PASSWORD -e "SELECT 1" ; do + sleep 1 +done + +echo "Setting up database..." +mysql -h 127.0.0.1 -u root -p$ROOT_PASSWORD << EOF +CREATE DATABASE IF NOT EXISTS alignment_research_dataset; +CREATE USER IF NOT EXISTS user IDENTIFIED BY 'we all live in a yellow submarine'; +GRANT ALL PRIVILEGES ON alignment_research_dataset.* TO user; +EOF + +echo "Running migrations" + +alembic --config migrations/alembic.ini upgrade head + +echo "The database is set up. Connect to it via 'mysql -h 127.0.0.1 -u user \"--password=we all live in a yellow submarine\" alignment_research_dataset'" diff --git a/main.py b/main.py index f2b8dc21..4e3f35ac 100644 --- a/main.py +++ b/main.py @@ -64,14 +64,15 @@ def fetch(self, *names, rebuild=False, fetch_prev=False) -> None: dataset = get_dataset(name) if fetch_prev: + # TODO: what should this do? Download and load the data? download_from_hf(dataset) elif rebuild: + # TODO: Get this to work properly dataset.jsonl_path.unlink(missing_ok=True) - with dataset.writer(self.out_path) as writer: - for entry in dataset.fetch_entries(): - writer(entry) - + dataset.add_entries(dataset.fetch_entries()) + # TODO: Get rid of jsonl stuff here + dataset.to_jsonl() print(dataset.jsonl_path) def fetch_all(self, *skip, rebuild=False, fetch_prev=False) -> str: @@ -92,28 +93,6 @@ def fetch_all(self, *skip, rebuild=False, fetch_prev=False) -> str: return self.merge_summaries(*names) - def merge_summaries(self, *names): - """Update all source materials with summaries if they have any. - - Some datasets are actual alignment content, e.g. arXiv articles, while other datasets are mainly - summaries of other articles, e.g. the alignment newsletter. This command merges the two, adding all - summaries to all found entries. In theory it's possible for a single article to have multiple different - summaries, therefore the summaries are added as a dict of : - """ - summaries = defaultdict(lambda: dict()) - for dataset in DATASET_REGISTRY: - if dataset.source_key and dataset.summary_key: - add_summaries(summaries, dataset) - - if names: - datasets = [get_dataset(name) for name in names] - else: - datasets = DATASET_REGISTRY - - for dataset in datasets: - if not dataset.source_key and dataset.summary_key: - dataset.merge_summaries(summaries) - def count_tokens(self, merged_dataset_path: str) -> None: """ This function counts the number of tokens, words, and characters in the dataset diff --git a/migrations/README b/migrations/README new file mode 100644 index 00000000..98e4f9c4 --- /dev/null +++ b/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/migrations/alembic.ini b/migrations/alembic.ini new file mode 100644 index 00000000..dbb0b3f6 --- /dev/null +++ b/migrations/alembic.ini @@ -0,0 +1,110 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/migrations/env.py b/migrations/env.py new file mode 100644 index 00000000..838bfb97 --- /dev/null +++ b/migrations/env.py @@ -0,0 +1,82 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +from align_data.settings import DB_CONNECTION_URI +config.set_main_option('sqlalchemy.url', DB_CONNECTION_URI) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +from align_data.db.models import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako new file mode 100644 index 00000000..55df2863 --- /dev/null +++ b/migrations/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/migrations/versions/8c11b666e86f_initial_structure.py b/migrations/versions/8c11b666e86f_initial_structure.py new file mode 100644 index 00000000..113371db --- /dev/null +++ b/migrations/versions/8c11b666e86f_initial_structure.py @@ -0,0 +1,65 @@ +"""initial structure + +Revision ID: 8c11b666e86f +Revises: +Create Date: 2023-07-14 15:48:49.149905 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = '8c11b666e86f' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + 'articles', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('hash_id', sa.String(length=32), nullable=False), + sa.Column('title', sa.String(length=1028), nullable=True), + sa.Column('url', sa.String(length=1028), nullable=True), + sa.Column('source', sa.String(length=128), nullable=True), + sa.Column('source_type', sa.String(length=128), nullable=True), + sa.Column('text', mysql.LONGTEXT(), nullable=True), + sa.Column('date_published', sa.DateTime(), nullable=True), + sa.Column('metadata', sa.JSON(), nullable=True), + sa.Column('date_created', sa.DateTime(), nullable=False), + sa.Column('date_updated', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('hash_id') + ) + op.create_table( + 'authors', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=256), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table( + 'author_article', + sa.Column('article_id', sa.Integer(), nullable=False), + sa.Column('author_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['article_id'], ['articles.id'], ), + sa.ForeignKeyConstraint(['author_id'], ['authors.id'], ), + sa.PrimaryKeyConstraint('article_id', 'author_id') + ) + op.create_table( + 'summaries', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('text', sa.Text(), nullable=False), + sa.Column('source', sa.String(length=256), nullable=True), + sa.Column('article_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['article_id'], ['articles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + + +def downgrade() -> None: + op.drop_table('summaries') + op.drop_table('author_article') + op.drop_table('authors') + op.drop_table('articles') diff --git a/requirements.txt b/requirements.txt index 02549664..1320ea7b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,5 @@ google-auth-oauthlib google-auth-httplib2 google-api-python-client gspread +alembic +mysqlclient diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index f49859fa..d85c3d13 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -1,10 +1,8 @@ -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch import pandas as pd import pytest -from bs4 import BeautifulSoup - -from align_data.articles.datasets import SpreadsheetDataset, PDFArticles, HTMLArticles, EbookArticles, XMLArticles +from align_data.articles.datasets import EbookArticles, HTMLArticles, PDFArticles, SpreadsheetDataset, XMLArticles @pytest.fixture @@ -40,11 +38,6 @@ def test_spreadsheet_dataset_get_item_key(): assert dataset.get_item_key(Mock(bla='ble', title='the key')) == 'the key' -def test_spreadsheet_dataset_get_published_date(): - dataset = SpreadsheetDataset(name='bla', spreadsheet_id='123', sheet_id='456') - assert dataset._get_published_date(Mock(date_published='2023/01/03 12:23:34')) == '2023-01-03T12:23:34Z' - - @pytest.mark.parametrize('authors, expected', ( ('', []), (' \n \n \t', []), @@ -80,14 +73,14 @@ def test_pdf_articles_process_item(articles): with patch('align_data.articles.datasets.download'): with patch('align_data.articles.datasets.read_pdf', return_value='pdf contents bla'): - assert dataset.process_entry(item) == { + assert dataset.process_entry(item).to_dict() == { 'authors': ['John Snow', 'mr Blobby'], 'date_published': '2023-01-01T12:32:11Z', 'id': None, 'source': 'bla', 'source_filetype': 'pdf', 'source_type': 'something', - 'summary': ['the summary of article 0'], + 'summaries': ['the summary of article 0'], 'text': 'pdf contents [bla](asd.com)', 'title': 'article no 0', 'url': 'http://example.com/item/0', @@ -115,14 +108,14 @@ def test_html_articles_process_entry(articles): parsers = {'example.com': lambda _: ' html contents with proper elements ble ble '} with patch('align_data.articles.datasets.HTML_PARSERS', parsers): - assert dataset.process_entry(item) == { + assert dataset.process_entry(item).to_dict() == { 'authors': ['John Snow', 'mr Blobby'], 'date_published': '2023-01-01T12:32:11Z', 'id': None, 'source': 'bla', 'source_filetype': 'html', 'source_type': 'something', - 'summary': ['the summary of article 0'], + 'summaries': ['the summary of article 0'], 'text': 'html contents with [proper elements](bla.com) ble ble', 'title': 'article no 0', 'url': 'http://example.com/item/0', @@ -157,14 +150,14 @@ def test_ebook_articles_process_entry(articles): contents = ' html contents with proper elements ble ble ' with patch('align_data.articles.datasets.download'): with patch('pypandoc.convert_file', return_value=contents): - assert dataset.process_entry(item) == { + assert dataset.process_entry(item).to_dict() == { 'authors': ['John Snow', 'mr Blobby'], 'date_published': '2023-01-01T12:32:11Z', 'id': None, 'source': 'bla', 'source_filetype': 'epub', 'source_type': 'something', - 'summary': ['the summary of article 0'], + 'summaries': ['the summary of article 0'], 'text': 'html contents with [proper elements](bla.com) ble ble', 'title': 'article no 0', 'url': 'http://example.com/item/0', @@ -183,14 +176,14 @@ def test_xml_articles_process_entry(articles): item = list(dataset.items_list)[0] with patch('align_data.articles.datasets.extract_gdrive_contents', return_value={'text': 'bla bla'}): - assert dataset.process_entry(item) == { + assert dataset.process_entry(item).to_dict() == { 'authors': ['John Snow', 'mr Blobby'], 'date_published': '2023-01-01T12:32:11Z', 'id': None, 'source': 'bla', 'source_filetype': 'xml', 'source_type': 'something', - 'summary': ['the summary of article 0'], + 'summaries': ['the summary of article 0'], 'text': 'bla bla', 'title': 'article no 0', 'url': 'http://example.com/item/0', diff --git a/tests/align_data/common/test_alignment_dataset.py b/tests/align_data/common/test_alignment_dataset.py index fce4eef3..d0d8097a 100644 --- a/tests/align_data/common/test_alignment_dataset.py +++ b/tests/align_data/common/test_alignment_dataset.py @@ -1,12 +1,12 @@ -import json -import re import pytest +from align_data.db.models import Article import jsonlines +from unittest.mock import patch from datetime import datetime from dataclasses import dataclass from pathlib import Path from typing import List -from align_data.common.alignment_dataset import AlignmentDataset, GdocDataset +from align_data.common.alignment_dataset import AlignmentDataset @pytest.fixture @@ -22,7 +22,7 @@ def data_entries(): }) for i in range(5) ] for entry in entries: - entry.add_id() + Article.before_write(None, None, entry) return entries @@ -37,29 +37,30 @@ def test_data_entry_default_fields(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({}) - assert entry == { + assert entry.to_dict() == { 'date_published': None, 'source': None, + 'source_type': None, 'title': None, 'url': None, 'id': None, 'text': None, - 'summary': [], + 'summaries': [], 'authors': [], - } + } def test_data_entry_id_from_urls_and_title(): data = {'key1': 12, 'key2': 312, 'url': 'www.arbital.org', 'title': 'once upon a time'} dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry(data) - entry.add_id() - print(entry) - assert entry == dict({ + Article.before_write(None, None, entry) + assert entry.to_dict() == dict({ 'date_published': None, 'id': '770fe57c8c2130eda08dc392b8696f97', 'source': None, + 'source_type': None, 'text': None, - 'summary': [], + 'summaries': [], 'authors': [], }, **data ) @@ -69,105 +70,126 @@ def test_data_entry_no_url_and_title(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'key1': 12, 'key2': 312}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url', 'title'\\]"): - entry.add_id() + Article.before_write(None, None, entry) + def test_data_entry_no_url(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'title': 'wikipedia goes to war on porcupines'}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url'\\]"): - entry.add_id() + Article.before_write(None, None, entry) + def test_data_entry_none_url(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': None}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url', 'title'\\]"): - entry.add_id() + Article.before_write(None, None, entry) + def test_data_entry_none_title(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': 'www.wikipedia.org', 'title': None}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['title'\\]"): - entry.add_id() - + Article.before_write(None, None, entry) + + def test_data_entry_empty_url_and_title(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': '', 'title': ''}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url', 'title'\\]"): - entry.add_id() + Article.before_write(None, None, entry) + def test_data_entry_empty_url_only(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': '', 'title': 'once upon a time'}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url'\\]"): - entry.add_id() + Article.before_write(None, None, entry) + def test_data_entry_empty_title_only(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'key1': 12, 'key2': 312, 'url': 'www.wikipedia.org', 'title':''}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['title'\\]"): - entry.add_id() + Article.before_write(None, None, entry) + def test_data_entry_verify_id_passes(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'source': 'arbital', 'text': 'once upon a time', 'url': 'www.arbital.org', 'title': 'once upon a time', 'id': '770fe57c8c2130eda08dc392b8696f97'}) - entry._verify_id() + entry.verify_id() + def test_data_entry_verify_id_fails(): dataset = AlignmentDataset(name='blaa') entry = dataset.make_data_entry({'url': 'www.arbital.org', 'title': 'once upon a time', 'id': 'f2b4e02fc1dd8ae43845e4f930f2d84f'}) with pytest.raises(AssertionError, match='Entry id does not match id_fields'): - entry._verify_id() + entry.verify_id() + def test_data_entry_id_fields_url_no_url(): dataset = AlignmentDataset(name='blaa', id_fields=['url']) entry = dataset.make_data_entry({'source': 'arbital', 'text': 'once upon a time'}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url'\\]"): - entry.add_id() + Article.before_write(None, None, entry) + def test_data_entry_id_fields_url_empty_url(): dataset = AlignmentDataset(name='blaa', id_fields=['url']) entry = dataset.make_data_entry({'url': ''}) with pytest.raises(AssertionError, match="Entry is missing the following fields: \\['url'\\]"): - entry.add_id() - + Article.before_write(None, None, entry) + + def test_data_entry_id_fields_url(): dataset = AlignmentDataset(name='blaa', id_fields=['url']) entry = dataset.make_data_entry({'url': 'https://www.google.ca/once_upon_a_time'}) - entry.add_id() + + Article.before_write(None, None, entry) + assert entry.id + def test_data_entry_id_fields_url_verify_id_passes(): dataset = AlignmentDataset(name='blaa', id_fields=['url']) entry = dataset.make_data_entry({'url': 'arbitalonce upon a time', 'id':'809d336a0b9b38c4f585e862317e667d'}) - entry._verify_id() + entry.verify_id() + def test_data_entry_different_id_from_different_url(): dataset = AlignmentDataset(name='blaa', id_fields=['url']) entry1 = dataset.make_data_entry({'url': ' https://aisafety.info?state=6478'}) - entry1.add_id() entry2 = dataset.make_data_entry({'source': 'arbital', 'text': 'once upon a time', 'url': ' https://aisafety.info?state=6479'}) - entry2.add_id() - assert entry1['id'] != entry2['id'] + assert entry1.generate_id_string() != entry2.generate_id_string() @pytest.mark.parametrize('data, error', ( ({'text': 'bla bla bla'}, "Entry is missing id"), ({'text': 'bla bla bla', 'id': None}, "Entry is missing id"), + ({'id': '123', 'url':'www.google.com/winter_wonderland','title': 'winter wonderland'}, 'Entry id 123 does not match id from id_fields, [0-9a-fA-F]{32}'), + ({'id': '457c21e0ecabebcb85c12022d481d9f4', 'url':'www.google.com', 'title': 'winter wonderland'}, 'Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}'), + ({'id': '457c21e0ecabebcb85c12022d481d9f4', 'url':'www.google.com', 'title': 'Once upon a time'}, 'Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}'), +)) +def test_data_entry_verify_id_fails(data, error): + dataset = AlignmentDataset(name='blaa', id_fields=['url', 'title']) + entry = dataset.make_data_entry(data) + with pytest.raises(AssertionError, match=error): + entry.verify_id() + + +@pytest.mark.parametrize('data, error', ( ({'id': '123'}, "Entry is missing the following fields: \\['url', 'title'\\]"), ({'id': '123', 'url': None}, "Entry is missing the following fields: \\['url', 'title'\\]"), ({'id': '123', 'url': 'www.google.com/'}, "Entry is missing the following fields: \\['title'\\]"), ({'id': '123', 'url': 'google', 'text': None}, "Entry is missing the following fields: \\['title'\\]"), ({'id': '123', 'url': '', 'title': ''}, "Entry is missing the following fields: \\['url', 'title'\\]"), - - ({'id': '123', 'url':'www.google.com/winter_wonderland','title': 'winter wonderland'}, 'Entry id 123 does not match id from id_fields, [0-9a-fA-F]{32}'), - ({'id': '457c21e0ecabebcb85c12022d481d9f4', 'url':'www.google.com', 'title': 'winter wonderland'}, 'Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}'), - ({'id': '457c21e0ecabebcb85c12022d481d9f4', 'url':'www.google.com', 'title': 'Once upon a time'}, 'Entry id [0-9a-fA-F]{32} does not match id from id_fields, [0-9a-fA-F]{32}'), )) -def test_data_entry_verify_id_fails(data, error): +def test_data_entry_verify_fields_fails(data, error): dataset = AlignmentDataset(name='blaa', id_fields=['url', 'title']) entry = dataset.make_data_entry(data) with pytest.raises(AssertionError, match=error): - entry._verify_id() + entry.verify_fields() def test_alignment_dataset_default_values(dataset, tmp_path): @@ -180,7 +202,6 @@ def test_alignment_dataset_default_values(dataset, tmp_path): # Make sure the output files are correct assert dataset.jsonl_path.resolve() == (tmp_path / f'{dataset.name}.jsonl').resolve() - assert dataset.txt_path.resolve() == (tmp_path / f'{dataset.name}.txt').resolve() def test_alignment_dataset_file_list(dataset, tmp_path): @@ -197,95 +218,6 @@ def test_alignment_dataset_file_list(dataset, tmp_path): assert files == list(Path(tmp_path).glob('*bla')) -def check_written_files(output_path, name, entries): - with jsonlines.open(Path(output_path) / f'{name}.jsonl', mode='r') as reader: - assert list(reader) == entries, f'Not all entries were output to the {name}.jsonl file' - - with open(Path(output_path) / f'{name}.txt') as f: - assert len(f.readlines()) == len(entries) * 3, f'Not all entries were output to the {name}.txt file' - - return True - - -def test_alignment_dataset_writer_default_paths(dataset, tmp_path, data_entries): - with dataset.writer() as writer: - for entry in data_entries: - writer(entry) - - assert check_written_files(tmp_path, dataset.name, data_entries) - - -def test_alignment_dataset_writer_provided_paths(dataset, tmp_path, data_entries): - with dataset.writer(out_path=tmp_path) as writer: - for entry in data_entries: - writer(entry) - - assert check_written_files(tmp_path, dataset.name, data_entries) - - -def test_alignment_dataset_writer_append(dataset, tmp_path, data_entries): - with dataset.writer() as writer: - for entry in data_entries: - writer(entry) - - with dataset.writer(overwrite=False) as writer: - for entry in data_entries: - writer(entry) - - assert check_written_files(tmp_path, dataset.name, data_entries * 2) - - -def test_alignment_dataset_writer_overwrite(dataset, tmp_path, data_entries): - with dataset.writer() as writer: - for entry in data_entries: - writer(entry) - - with dataset.writer(overwrite=True) as writer: - for entry in data_entries: - writer(entry) - - assert check_written_files(tmp_path, dataset.name, data_entries) - - -def test_read_entries(dataset, tmp_path, data_entries): - with dataset.writer() as writer: - for entry in data_entries: - writer(entry) - - assert list(dataset.read_entries()) == data_entries - - -def test_merge_summaries_no_key(dataset): - dataset.summary_key = None - - assert dataset.merge_summaries({}) is None - - -def test_merge_summaries_no_file(dataset): - assert dataset.merge_summaries({}) is None - - -def test_merge_summaries(dataset, data_entries): - dataset.summary_key = 'summary' - with dataset.writer() as writer: - for entry in data_entries: - writer(entry) - - dataset.merge_summaries({ - 'http://bla.bla.bla?page=1': { - 'source1': 'This should be the first summary', - 'source2': 'This should be the second one' - }, - 'http://bla.bla.bla?page=3': { - 'source': 'This should be the only one' - }, - }) - - data_entries[1]['summary'] = ['This should be the first summary', 'This should be the second one'] - data_entries[3]['summary'] = ['This should be the only one'] - assert data_entries == list(dataset.read_entries()) - - @pytest.fixture def numbers_dataset(tmp_path): """Make a dataset that raises its items to the power of 2.""" @@ -310,6 +242,7 @@ def process_entry(self, item): 'url': f'http://bla.bla.bla?page={item}', 'number': item, 'value': item ** 2, + 'authors': [], }) dataset = NumbersDataset(name='numbers', nums=list(range(10))) @@ -319,33 +252,27 @@ def process_entry(self, item): def test_unprocessed_items_fresh(numbers_dataset): """Getting the unprocessed items from a dataset that hasn't written anything should get all items.""" - assert list(numbers_dataset.unprocessed_items()) == list(range(10)) + seen = set() + with patch.object(numbers_dataset, '_load_outputted_items', return_value=seen): + assert list(numbers_dataset.unprocessed_items()) == list(range(10)) def test_unprocessed_items_all_done(numbers_dataset): """Getting the unprocessed items from a dataset that has already processed everything should return an empty list.""" - with numbers_dataset.writer() as writer: - for i in range(10): - entry = numbers_dataset.process_entry(i) - entry.add_id() - writer(entry) - - assert list(numbers_dataset.unprocessed_items()) == [] + seen = set(range(0, 10)) + with patch.object(numbers_dataset, '_load_outputted_items', return_value=seen): + assert list(numbers_dataset.unprocessed_items()) == [] def test_unprocessed_items_some_done(numbers_dataset): """Getting the uprocessed items from a dataset that has partially completed should return the items that haven't been processed.""" - with numbers_dataset.writer() as writer: - for i in range(0, 10, 2): - entry = numbers_dataset.process_entry(i) - entry.add_id() - writer(entry) - - assert list(numbers_dataset.unprocessed_items()) == list(range(1, 10, 2)) + seen = set(range(0, 10, 2)) + with patch.object(numbers_dataset, '_load_outputted_items', return_value=seen): + assert list(numbers_dataset.unprocessed_items()) == list(range(1, 10, 2)) def test_fetch_entries(numbers_dataset): - assert [i['value'] for i in numbers_dataset.fetch_entries()] == [i**2 for i in range(10)] + assert [i.meta['value'] for i in numbers_dataset.fetch_entries()] == [i**2 for i in range(10)] def test_format_datatime(dataset): diff --git a/tests/align_data/common/test_html_dataset.py b/tests/align_data/common/test_html_dataset.py index e50f7fca..f6812716 100644 --- a/tests/align_data/common/test_html_dataset.py +++ b/tests/align_data/common/test_html_dataset.py @@ -2,6 +2,7 @@ import pytest from bs4 import BeautifulSoup +from dateutil.parser import parse from align_data.common.html_dataset import HTMLDataset, RSSDataset @@ -104,7 +105,7 @@ def test_html_dataset_find_date(html_dataset): """ soup = BeautifulSoup(text, "html.parser") - assert html_dataset._find_date(soup.select('span')) == '2023-10-07T00:00:00Z' + assert html_dataset._find_date(soup.select('span')) == parse('2023-10-07T00:00:00Z') @pytest.mark.parametrize('text', ( @@ -125,13 +126,13 @@ def test_html_dataset_process_entry(html_dataset): article = BeautifulSoup(item, "html.parser") with patch('requests.get', return_value=Mock(content=SAMPLE_HTML)): - assert html_dataset.process_entry(article) == { + assert html_dataset.process_entry(article).to_dict() == { 'authors': ['John Smith', 'Your momma'], - 'date_published': '', + 'date_published': None, 'id': None, 'source': 'bla', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla [a link](http://ble.com) bla bla', 'title': 'This is the title', 'url': 'http://example.com/path/to/article', @@ -165,14 +166,14 @@ def test_rss_dataset_get_title(): @pytest.mark.parametrize('item, date', ( - ({'published': '2012/01/02 12:32'}, '2012-01-02T12:32:00Z'), - ({'pubDate': '2012/01/02 12:32'}, '2012-01-02T12:32:00Z'), + ({'published': '2012/01/02 12:32'}, parse('2012-01-02T12:32:00Z')), + ({'pubDate': '2012/01/02 12:32'}, parse('2012-01-02T12:32:00Z')), ({ 'pubDate': '2032/01/02 12:32', 'published': '2012/01/02 12:32', - }, '2012-01-02T12:32:00Z'), + }, parse('2012-01-02T12:32:00Z')), - ({'bla': 'bla'}, ''), + ({'bla': 'bla'}, None), )) def test_rss_dataset_get_published_date(item, date): dataset = RSSDataset(name='bla', url='http://example.org', authors=['default author']) diff --git a/tests/align_data/test_alignment_newsletter.py b/tests/align_data/test_alignment_newsletter.py index 0ef66eff..a0f039d7 100644 --- a/tests/align_data/test_alignment_newsletter.py +++ b/tests/align_data/test_alignment_newsletter.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone import pytest import pandas as pd @@ -32,14 +33,14 @@ def test_process_entry_no_summary(dataset): def test_format_datatime(dataset): - assert dataset._get_published_date(2022) == '2022-01-01T00:00:00Z' + assert dataset._get_published_date(2022) == datetime(2022, 1, 1, tzinfo=timezone.utc) def test_process_entry(dataset): # Do a basic sanity test of the output. If this starts failing and is too much # of a bother to keep up to date, then it can be deleted items = list(dataset.items_list) - assert dataset.process_entry(items[0]) == { + assert dataset.process_entry(items[0]).to_dict() == { 'authors': ['Andrew Ilyas*', 'Shibani Santurkar*', 'Dimitris Tsipras*', @@ -67,7 +68,7 @@ def test_process_entry(dataset): 'source': 'text', 'source_type': 'google-sheets', 'summarizer': 'Rohin', - 'summary': [], + 'summaries': [], 'text': ( '_Distill published a discussion of this paper. This highlights ' 'section will cover the full discussion; all of these summaries and ' diff --git a/tests/align_data/test_arbital.py b/tests/align_data/test_arbital.py index f45ec167..f2f4cd98 100644 --- a/tests/align_data/test_arbital.py +++ b/tests/align_data/test_arbital.py @@ -1,8 +1,10 @@ import json from unittest.mock import Mock, patch + import pytest +from dateutil.parser import parse -from align_data.arbital.arbital import parse_arbital_link, flatten, markdownify_text, extract_text, Arbital +from align_data.arbital.arbital import Arbital, extract_text, flatten, parse_arbital_link @pytest.mark.parametrize('contents, expected', ( @@ -159,15 +161,15 @@ def test_extract_authors_ignore_missing(dataset): @pytest.mark.parametrize('page, expected', ( - ({'editCreatedAt': '2021-02-01T01:23:45Z'}, '2021-02-01T01:23:45Z'), - ({'pageCreatedAt': '2021-02-01T01:23:45Z'}, '2021-02-01T01:23:45Z'), + ({'editCreatedAt': '2021-02-01T01:23:45Z'}, parse('2021-02-01T01:23:45Z')), + ({'pageCreatedAt': '2021-02-01T01:23:45Z'}, parse('2021-02-01T01:23:45Z')), ({ 'editCreatedAt': '2021-02-01T01:23:45Z', 'pageCreatedAt': '2024-02-01T01:23:45Z', - }, '2021-02-01T01:23:45Z'), + }, parse('2021-02-01T01:23:45Z')), - ({}, ''), - ({'bla': 'asdasd'}, ''), + ({}, None), + ({'bla': 'asdasd'}, None), )) def test_get_published_date(dataset, page, expected): assert dataset._get_published_date(page) == expected @@ -182,14 +184,14 @@ def test_process_entry(dataset): 'tagIds': [], } with patch.object(dataset, 'get_page', return_value=page): - assert dataset.process_entry('bla') == { + assert dataset.process_entry('bla').to_dict() == { 'alias': 'bla', 'authors': [], 'date_published': '2001-02-03T12:34:45Z', 'id': None, 'source': 'arbital', 'source_type': 'text', - 'summary': [], + 'summaries': [], 'tags': [], 'text': 'bla bla bla', 'title': 'test article', diff --git a/tests/align_data/test_blogs.py b/tests/align_data/test_blogs.py index 1cd42f5e..62de36ae 100644 --- a/tests/align_data/test_blogs.py +++ b/tests/align_data/test_blogs.py @@ -2,7 +2,7 @@ import pytest from bs4 import BeautifulSoup -from requests import request +from dateutil.parser import parse from align_data.blogs import ( CaradoMoe, ColdTakes, GenerativeInk, GwernBlog, MediumBlog, SubstackBlog, WordpressBlog, @@ -32,7 +32,7 @@ def test_cold_takes_published_date(): """ soup = BeautifulSoup(contents, "html.parser") - assert dataset._get_published_date(soup) == '2001-02-03T00:00:00Z' + assert dataset._get_published_date(soup) == parse('2001-02-03T00:00:00Z') def test_cold_takes_process_entry(): @@ -71,13 +71,13 @@ def test_cold_takes_process_entry(): """ with patch('requests.get', return_value=Mock(content=article)): - assert dataset.process_entry(BeautifulSoup(item, "html.parser")) == { + assert dataset.process_entry(BeautifulSoup(item, "html.parser")).to_dict() == { 'authors': ['Holden Karnofsky'], 'date_published': '2023-02-28T00:00:00Z', 'id': None, 'source': 'cold_takes', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla', 'title': 'What does Bing Chat tell us about AI risk?', 'url': 'https://www.cold-takes.com/how-governments-can-help-with-the-most-important-century/', @@ -107,7 +107,7 @@ def test_generative_ink_published_date(): ) soup = BeautifulSoup(GENERITIVE_INK_HTML, "html.parser") - assert dataset._get_published_date(soup) == '2023-02-09T00:00:00Z' + assert dataset._get_published_date(soup) == parse('2023-02-09T00:00:00Z') def test_generative_ink_process_entry(): @@ -126,13 +126,13 @@ def test_generative_ink_process_entry(): """ with patch('requests.get', return_value=Mock(content=GENERITIVE_INK_HTML)): - assert dataset.process_entry(BeautifulSoup(item, "html.parser")) == { + assert dataset.process_entry(BeautifulSoup(item, "html.parser")).to_dict() == { 'authors': ['janus'], 'date_published': '2023-02-09T00:00:00Z', 'id': None, 'source': 'generative.ink', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla', 'title': 'Anomalous tokens reveal the original identities of Instruct models', 'url': 'https://generative.ink/posts/simulators/', @@ -174,13 +174,13 @@ def test_caradomoe_process_entry(): """ with patch('requests.get', return_value=Mock(content=contents)): - assert dataset.process_entry(item['link']) == { + assert dataset.process_entry(item['link']).to_dict() == { 'authors': ['Tamsin Leake'], 'date_published': '2023-06-10T07:00:00Z', 'id': None, 'source': 'carado.moe', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla [a link](http://ble.com) bla bla', 'title': 'the title', 'url': 'http://example.com/bla' @@ -222,12 +222,12 @@ def test_gwern_get_text(): @pytest.mark.parametrize('metadata, date', ( - ({'modified': '2022-01-02'}, '2022-01-02T00:00:00Z'), - ({'created': '2022-01-02'}, '2022-01-02T00:00:00Z'), - ({'created': '2000-01-01', 'modified': '2022-01-02'}, '2022-01-02T00:00:00Z'), + ({'modified': '2022-01-02'}, parse('2022-01-02T00:00:00Z')), + ({'created': '2022-01-02'}, parse('2022-01-02T00:00:00Z')), + ({'created': '2000-01-01', 'modified': '2022-01-02'}, parse('2022-01-02T00:00:00Z')), - ({}, ''), - ({'bla': 'asda'}, '') + ({}, None), + ({'bla': 'asda'}, None) )) def test_gwern_get_published_date(metadata, date): dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) @@ -279,13 +279,13 @@ def test_gwern_process_markdown(): """ dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) - assert dataset._process_markdown('http://article.url', Mock(text=text)) == { + assert dataset._process_markdown('http://article.url', Mock(text=text)).to_dict() == { 'authors': ['Gwern Branwen'], 'date_published': '2020-05-28T00:00:00Z', 'id': None, 'source': 'gwern_blog', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla [a link](http://ble.com) bla bla', 'title': '"The Scaling Hypothesis"', 'url': 'http://article.url', @@ -303,13 +303,13 @@ def test_gwern_process_entry_markdown(): dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) with patch('requests.get', return_value=Mock(text=text, status_code=200, headers={})): - assert dataset.process_entry('http://article.url') == { + assert dataset.process_entry('http://article.url').to_dict() == { 'authors': ['Gwern Branwen'], 'date_published': '2020-05-28T00:00:00Z', 'id': None, 'source': 'gwern_blog', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla [a link](http://ble.com) bla bla', 'title': '"The Scaling Hypothesis"', 'url': 'http://article.url', @@ -320,13 +320,13 @@ def test_gwern_process_entry_html(): dataset = GwernBlog(name="gwern_blog", url='https://www.gwern.net/', authors=["Gwern Branwen"]) with patch('requests.get', return_value=Mock(content=GWERN_CONTENTS, status_code=200, headers={'Content-Type': 'text/html'})): - assert dataset.process_entry('http://article.url') == { + assert dataset.process_entry('http://article.url').to_dict() == { 'authors': ['Gwern Branwen'], 'date_published': '2023-01-01T00:00:00Z', 'id': None, 'source': 'gwern_blog', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla [a link](http://ble.com) bla bla', 'title': 'The title of the article', 'url': 'http://article.url', @@ -358,7 +358,7 @@ def test_medium_get_published_date(): dataset = MediumBlog(name="deepmind_blog", url="https://bla.medium.com/", authors=["mr Blobby"]) soup = BeautifulSoup(MEDIUM_HTML, "html.parser") - assert dataset._get_published_date(soup) == '2023-10-07T00:00:00Z' + assert dataset._get_published_date(soup) == parse('2023-10-07T00:00:00Z') def test_medium_get_text(): @@ -379,13 +379,13 @@ def test_medium_process_entry(): """ with patch('requests.get', return_value=Mock(content=MEDIUM_HTML)): - assert dataset.process_entry(BeautifulSoup(item, "html.parser")) == { + assert dataset.process_entry(BeautifulSoup(item, "html.parser")).to_dict() == { 'authors': ['mr Blobby'], 'date_published': '2023-10-07T00:00:00Z', 'id': None, 'source': 'deepmind_blog', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla [a link](http://ble.com) bla bla', 'title': 'This is the title', 'url': 'https://bla.medium.com/discovering-when-an-agent-is-present-in-a-system-41154de11e7b', @@ -410,13 +410,13 @@ def test_substack_blog_process_entry(): with patch('feedparser.parse', return_value=contents): dataset.items_list - assert dataset.process_entry('http://example.org/bla') == { + assert dataset.process_entry('http://example.org/bla').to_dict() == { 'authors': ['mr Blobby'], 'date_published': '2023-06-26T13:40:01Z', 'id': None, 'source': 'blog', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla [a link](http://ble.com) bla bla', 'title': 'the article title', 'url': 'http://example.org/bla', @@ -474,7 +474,7 @@ def test_wordpress_blog_get_published_date(): url="https://www.bla.yudkowsky.net", ) date_published = blog._get_published_date({'published': "Mon, 26 Jun 2023 13:40:01 +0000"}) - assert date_published == '2023-06-26T13:40:01Z' + assert date_published == parse('2023-06-26T13:40:01Z') @patch('feedparser.parse', return_value=WORDPRESS_FEED) @@ -485,13 +485,13 @@ def test_wordpress_blog_process_entry(feedparser_parse): ) blog.items = {i['link']: i for i in WORDPRESS_FEED['entries']} entry = blog.process_entry('https://www.yudkowsky.net/other/fiction/prospiracy-theory') - assert entry == { + assert entry.to_dict() == { 'authors': ['Eliezer S. Yudkowsky'], 'date_published': '2020-09-04T04:11:23Z', 'id': None, 'source': 'blog_name', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla [a link](http://ble.com) bla bla', 'title': 'Prospiracy Theory', 'url': 'https://www.yudkowsky.net/other/fiction/prospiracy-theory', @@ -516,7 +516,7 @@ def test_eleutherai_get_published_date(): dataset = EleutherAI(name='eleuther', url='http://bla.bla') soup = BeautifulSoup(ELEUTHER_HTML, "html.parser") - assert dataset._get_published_date(soup) == "2023-07-08T00:00:00Z" + assert dataset._get_published_date(soup) == parse("2023-07-08T00:00:00Z") def test_eleutherai_extract_authors(): @@ -531,13 +531,13 @@ def test_eleutherai_process_entry(): article = BeautifulSoup('', "html.parser") with patch('requests.get', return_value=Mock(content=ELEUTHER_HTML)): - assert dataset.process_entry(article) == { + assert dataset.process_entry(article).to_dict() == { 'authors': ['Curtis Huebner', 'Robert Klassert', 'Stepan Shabalin', 'Edwin Fennell', 'Delta Hessler'], 'date_published': '2023-07-08T00:00:00Z', 'id': None, 'source': 'eleuther', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla', 'title': 'Minetester: A fully open RL environment built on Minetest', 'url': 'http://bla.bla/bla.bla', @@ -562,7 +562,7 @@ def test_openai_research_get_published_date(): dataset = OpenAIResearch(name='openai', url='bla.bla') soup = BeautifulSoup(OPENAI_HTML, "html.parser") - assert dataset._get_published_date(soup) == '2023-07-06T00:00:00Z' + assert dataset._get_published_date(soup) == parse('2023-07-06T00:00:00Z') def test_openai_research_get_text(): @@ -620,13 +620,13 @@ def test_openai_research_process_entry(): with patch('requests.head', return_value=Mock(headers={'Content-Type': 'text/html'})): with patch('requests.get', return_value=Mock(content=OPENAI_HTML)): with patch('align_data.articles.pdf.fetch_pdf', return_value={'text': 'bla bla bla'}): - assert dataset.process_entry(soup) == { + assert dataset.process_entry(soup).to_dict() == { 'authors': ['Mr. Blobby', 'John Snow'], 'date_published': '2023-07-06T00:00:00Z', 'id': None, 'source': 'openai', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla', 'title': None, 'url': 'https://arxiv.org', @@ -673,7 +673,7 @@ def getter(url, *args, **params): def test_deepmind_technical_get_published_date(): dataset = DeepMindTechnicalBlog(name='bla', url='http://bla.com') soup = BeautifulSoup(DEEPMIND_HTML, "html.parser") - assert dataset._get_published_date(soup) == '2023-07-11T00:00:00Z' + assert dataset._get_published_date(soup) == parse('2023-07-11T00:00:00Z') def test_deepmind_technical_extract_authors(): @@ -686,13 +686,13 @@ def test_deepmind_technical_proces_entry(): dataset = DeepMindTechnicalBlog(name='bla', url='http://bla.com') soup = BeautifulSoup('
', "html.parser") with patch('requests.get', return_value=Mock(content=DEEPMIND_HTML)): - assert dataset.process_entry(soup) == { + assert dataset.process_entry(soup).to_dict() == { 'authors': ['Mr. Blobby', 'John Snow'], 'date_published': '2023-07-11T00:00:00Z', 'id': None, 'source': 'bla', 'source_type': 'blog', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla', 'title': 'title!', 'url': 'http://bla.bl', diff --git a/tests/align_data/test_distill.py b/tests/align_data/test_distill.py index 112aa9a7..e0c8943f 100644 --- a/tests/align_data/test_distill.py +++ b/tests/align_data/test_distill.py @@ -74,7 +74,7 @@ def test_extra_values(): ], 'doi': '10.23915/distill.00032', 'journal_ref': 'distill-pub', - 'summary': ['A wild summary has appeared!'], + 'summary': 'A wild summary has appeared!', } @@ -128,7 +128,7 @@ def test_process_entry(): dataset.items_list with patch('requests.get', return_value=Mock(content=contents)): - assert dataset.process_entry('http://example.org/bla') == { + assert dataset.process_entry('http://example.org/bla').to_dict() == { 'authors': ['Ameya Daigavane', 'Balaraman Ravindran', 'Gaurav Aggarwal'], 'bibliography': [ { @@ -145,7 +145,7 @@ def test_process_entry(): 'journal_ref': 'distill-pub', 'source': 'distill', 'source_type': 'blog', - 'summary': ['A wild summary has appeared!'], + 'summaries': ['A wild summary has appeared!'], 'text': 'bla bla [a link](bla.com) ble', 'title': 'the article title', 'url': 'http://example.org/bla', diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/test_greater_wrong.py index e288d034..c5cca825 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/test_greater_wrong.py @@ -72,11 +72,11 @@ def test_greaterwrong_get_item_key(dataset): def test_greaterwrong_get_published_date(dataset): - assert dataset._get_published_date({'postedAt': '2021/02/01'}) == '2021-02-01T00:00:00Z' + assert dataset._get_published_date({'postedAt': '2021/02/01'}) == parse('2021-02-01T00:00:00Z') def test_greaterwrong_get_published_date_missing(dataset): - assert dataset._get_published_date({}) == '' + assert dataset._get_published_date({}) == None def test_items_list_no_previous(dataset): @@ -159,7 +159,7 @@ def test_process_entry(dataset): 'wordCount': 123, 'commentCount': 423, } - assert dataset.process_entry(entry) == { + assert dataset.process_entry(entry).to_dict() == { 'authors': ['Me', 'John Snow', 'Mr Blobby'], 'comment_count': 423, 'date_published': '2012-02-01T12:23:34Z', @@ -168,7 +168,7 @@ def test_process_entry(dataset): 'modified_at': '2001-02-10', 'source': 'bla', 'source_type': 'GreaterWrong', - 'summary': [], + 'summaries': [], 'tags': ['tag1', 'tag2'], 'text': 'bla bla [a link](bla.com)', 'title': 'The title', diff --git a/tests/align_data/test_stampy.py b/tests/align_data/test_stampy.py index b0a2cbc2..1381941a 100644 --- a/tests/align_data/test_stampy.py +++ b/tests/align_data/test_stampy.py @@ -1,5 +1,5 @@ from unittest.mock import patch - +from dateutil.parser import parse from align_data.stampy import Stampy @@ -19,12 +19,12 @@ def test_get_item_key(): def test_get_published_date(): dataset = Stampy(name='bla') - assert dataset._get_published_date({'Doc Last Edited': '2012/01/03 12:23:32'}) == '2012-01-03T12:23:32Z' + assert dataset._get_published_date({'Doc Last Edited': '2012/01/03 12:23:32'}) == parse('2012-01-03T12:23:32Z') def test_get_published_date_missing(): dataset = Stampy(name='bla') - assert dataset._get_published_date({'Doc Last Edited': ''}) == '' + assert dataset._get_published_date({'Doc Last Edited': ''}) == None def test_process_entry(): @@ -35,13 +35,13 @@ def test_process_entry(): 'UI ID': '1234', 'Doc Last Edited': '2012-02-03', } - assert dataset.process_entry(entry) == { + assert dataset.process_entry(entry).to_dict() == { 'authors': ['Stampy aisafety.info'], 'date_published': '2012-02-03T00:00:00Z', 'id': None, 'source': 'bla', 'source_type': 'markdown', - 'summary': [], + 'summaries': [], 'text': 'bla bla bla', 'title': 'Why\nnot just?', 'url': 'https://aisafety.info?state=1234', diff --git a/tests/conftest.py b/tests/conftest.py index e69de29b..5373b4ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +from unittest.mock import patch, Mock +import pytest +from align_data.common.alignment_dataset import make_session + + +@pytest.fixture(autouse=True, scope='session') +def mock_db(): + # This just mocks out all db calls, nothing more + with patch('align_data.common.alignment_dataset.make_session'): + yield From 1ac1b47e2c8aabd0ddf26534a3dd15b28a711e7a Mon Sep 17 00:00:00 2001 From: Daniel O'Connell Date: Tue, 18 Jul 2023 16:46:05 +0200 Subject: [PATCH 2/7] authors as string --- align_data/common/alignment_dataset.py | 35 +++++++++++++------ align_data/db/models.py | 24 ++----------- align_data/greaterwrong/greaterwrong.py | 24 +++++++------ ...e.py => 983b5bdef5f6_initial_structure.py} | 23 +++--------- tests/align_data/test_greater_wrong.py | 12 +++---- 5 files changed, 50 insertions(+), 68 deletions(-) rename migrations/versions/{8c11b666e86f_initial_structure.py => 983b5bdef5f6_initial_structure.py} (68%) diff --git a/align_data/common/alignment_dataset.py b/align_data/common/alignment_dataset.py index 2f8d236f..8086a66a 100644 --- a/align_data/common/alignment_dataset.py +++ b/align_data/common/alignment_dataset.py @@ -2,6 +2,7 @@ import time import zipfile from dataclasses import dataclass, field, KW_ONLY +from itertools import islice from pathlib import Path from typing import List from sqlalchemy import select @@ -12,7 +13,7 @@ import pytz from dateutil.parser import parse, ParserError from tqdm import tqdm -from align_data.db.models import Article, Author +from align_data.db.models import Article from align_data.db.session import make_session @@ -86,7 +87,7 @@ def make_data_entry(self, data, **kwargs): data = dict(data, **kwargs) # TODO: Don't keep adding the same authors - come up with some way to reuse them # TODO: Prettify this - data['authors'] = [Author(name=name) for name in data.get('authors', [])] + data['authors'] = ','.join(data.get('authors', [])) if summary := ('summary' in data and data.pop('summary')): data['summaries'] = [summary] return Article( @@ -106,21 +107,33 @@ def to_jsonl(self, out_path=None, filename=None): for article in self.read_entries(): jsonl_writer.write(article.to_dict()) - def read_entries(self): + def read_entries(self, sort_by=None): """Iterate through all the saved entries.""" with make_session() as session: - for item in session.scalars(select(Article).where(Article.source==self.name)): + query = select(Article).where(Article.source==self.name) + if sort_by is not None: + query = query.order_by(sort_by) + for item in session.scalars(query): yield item def add_entries(self, entries): + def commit(): + try: + session.commit() + return True + except IntegrityError: + session.rollback() + with make_session() as session: - for entry in entries: - session.add(entry) - try: - session.commit() - except IntegrityError: - logger.error(f'found duplicate of {entry}') - session.rollback() + while batch := tuple(islice(entries, 20)): + session.add_all(entries) + # there might be duplicates in the batch, so if they cause + # an exception, try to commit them one by one + if not commit(): + for entry in batch: + session.add(entry) + if not commit(): + logger.error(f'found duplicate of {entry}') def setup(self): # make sure the path to the raw data exists diff --git a/align_data/db/models.py b/align_data/db/models.py index e8c4fcf9..f4c68712 100644 --- a/align_data/db/models.py +++ b/align_data/db/models.py @@ -2,9 +2,8 @@ import hashlib from datetime import datetime from typing import List, Optional -from sqlalchemy import JSON, DateTime, ForeignKey, Table, String, Column, Integer, func, Text, event +from sqlalchemy import JSON, DateTime, ForeignKey, String, func, Text, event from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy from sqlalchemy.dialects.mysql import LONGTEXT @@ -12,23 +11,6 @@ class Base(DeclarativeBase): pass -author_article = Table( - 'author_article', - Base.metadata, - Column('article_id', Integer, ForeignKey('articles.id'), primary_key=True), - Column('author_id', Integer, ForeignKey('authors.id'), primary_key=True), -) - - -class Author(Base): - - __tablename__ = "authors" - - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(String(256), nullable=False) - articles: Mapped[List["Article"]] = relationship(secondary=author_article, back_populates="authors") - - class Summary(Base): __tablename__ = "summaries" @@ -50,13 +32,13 @@ class Article(Base): url: Mapped[Optional[str]] = mapped_column(String(1028)) source: Mapped[Optional[str]] = mapped_column(String(128)) source_type: Mapped[Optional[str]] = mapped_column(String(128)) + authors: Mapped[str] = mapped_column(String(1024)) text: Mapped[Optional[str]] = mapped_column(LONGTEXT) date_published: Mapped[Optional[datetime]] meta: Mapped[Optional[JSON]] = mapped_column(JSON, name='metadata', default='{}') date_created: Mapped[datetime] = mapped_column(DateTime, default=func.now()) date_updated: Mapped[Optional[datetime]] = mapped_column(DateTime, onupdate=func.current_timestamp()) - authors: Mapped[List['Author']] = relationship(secondary=author_article, back_populates="articles") summaries: Mapped[List["Summary"]] = relationship(back_populates="article", cascade="all, delete-orphan") __id_fields = ['title', 'url'] @@ -103,7 +85,7 @@ def to_dict(self): 'source_type': self.source_type, 'text': self.text, 'date_published': date, - 'authors': [a.name for a in self.authors], + 'authors': [i.strip() for i in self.authors.split(',')] if self.authors.strip() else [], 'summaries': [s.text for s in self.summaries], **self.meta, } diff --git a/align_data/greaterwrong/greaterwrong.py b/align_data/greaterwrong/greaterwrong.py index 84974068..850b12b3 100644 --- a/align_data/greaterwrong/greaterwrong.py +++ b/align_data/greaterwrong/greaterwrong.py @@ -1,17 +1,15 @@ -from datetime import datetime, timezone -from dateutil.parser import parse +from datetime import datetime import logging import time from dataclasses import dataclass -from pathlib import Path import requests import jsonlines from bs4 import BeautifulSoup -from tqdm import tqdm from markdownify import markdownify from align_data.common.alignment_dataset import AlignmentDataset +from align_data.db.models import Article logger = logging.getLogger(__name__) @@ -139,14 +137,18 @@ def fetch_posts(self, query: str): return res.json()['data']['posts'] @property - def items_list(self): - next_date = datetime(self.start_year, 1, 1).isoformat() + 'Z' - if self.jsonl_path.exists() and self.jsonl_path.lstat().st_size: - with jsonlines.open(self.jsonl_path) as f: - for item in f: - if item['date_published'] > next_date: - next_date = item['date_published'] + def last_date_published(self): + try: + prev_item = next(self.read_entries(sort_by=Article.date_published.desc())) + if prev_item and prev_item.date_published: + return prev_item.date_published.isoformat() + 'Z' + except StopIteration: + pass + return datetime(self.start_year, 1, 1).isoformat() + 'Z' + @property + def items_list(self): + next_date = self.last_date_published logger.info('Starting from %s', next_date) while next_date: posts = self.fetch_posts(self.make_query(next_date)) diff --git a/migrations/versions/8c11b666e86f_initial_structure.py b/migrations/versions/983b5bdef5f6_initial_structure.py similarity index 68% rename from migrations/versions/8c11b666e86f_initial_structure.py rename to migrations/versions/983b5bdef5f6_initial_structure.py index 113371db..ff1ef321 100644 --- a/migrations/versions/8c11b666e86f_initial_structure.py +++ b/migrations/versions/983b5bdef5f6_initial_structure.py @@ -1,8 +1,8 @@ """initial structure -Revision ID: 8c11b666e86f +Revision ID: 983b5bdef5f6 Revises: -Create Date: 2023-07-14 15:48:49.149905 +Create Date: 2023-07-18 15:54:58.299651 """ from alembic import op @@ -10,7 +10,7 @@ from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = '8c11b666e86f' +revision = '983b5bdef5f6' down_revision = None branch_labels = None depends_on = None @@ -25,6 +25,7 @@ def upgrade() -> None: sa.Column('url', sa.String(length=1028), nullable=True), sa.Column('source', sa.String(length=128), nullable=True), sa.Column('source_type', sa.String(length=128), nullable=True), + sa.Column('authors', sa.String(length=1024), nullable=False), sa.Column('text', mysql.LONGTEXT(), nullable=True), sa.Column('date_published', sa.DateTime(), nullable=True), sa.Column('metadata', sa.JSON(), nullable=True), @@ -33,20 +34,6 @@ def upgrade() -> None: sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('hash_id') ) - op.create_table( - 'authors', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=256), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_table( - 'author_article', - sa.Column('article_id', sa.Integer(), nullable=False), - sa.Column('author_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['article_id'], ['articles.id'], ), - sa.ForeignKeyConstraint(['author_id'], ['authors.id'], ), - sa.PrimaryKeyConstraint('article_id', 'author_id') - ) op.create_table( 'summaries', sa.Column('id', sa.Integer(), nullable=False), @@ -60,6 +47,4 @@ def upgrade() -> None: def downgrade() -> None: op.drop_table('summaries') - op.drop_table('author_article') - op.drop_table('authors') op.drop_table('articles') diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/test_greater_wrong.py index c5cca825..72b32b83 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/test_greater_wrong.py @@ -112,8 +112,6 @@ def fetcher(next_date): def test_items_list_with_previous_items(dataset): dataset.ai_tags = {'tag1', 'tag2'} - with open(dataset.jsonl_path, 'w') as f: - f.write('{"date_published": "2014-12-12T01:23:45Z"}\n') def make_item(date): return { @@ -135,12 +133,14 @@ def fetcher(next_date): ] return {'results': results} + mock_items = (i for i in [Mock(date_published=datetime.fromisoformat('2014-12-12T01:23:45'))]) with patch.object(dataset, 'fetch_posts', fetcher): with patch.object(dataset, 'make_query', lambda next_date: next_date): - # All items that are older than the newest item in the jsonl file are ignored - assert list(dataset.items_list) == [ - make_item(datetime(2014, 12, 12, 1, 23, 45).replace(tzinfo=pytz.UTC) + timedelta(days=i*30)) - for i in range(1, 4) + with patch.object(dataset, 'read_entries', return_value=mock_items): + # All items that are older than the newest item in the jsonl file are ignored + assert list(dataset.items_list) == [ + make_item(datetime(2014, 12, 12, 1, 23, 45).replace(tzinfo=pytz.UTC) + timedelta(days=i*30)) + for i in range(1, 4) ] From 5ccf59687e25328ad7028c49729474f32e552274 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Mon, 24 Jul 2023 19:36:41 -0400 Subject: [PATCH 3/7] Restructure align_data directory --- align_data/__init__.py | 24 +-- align_data/pinecone/__init__.py | 0 align_data/pinecone/pinecone_db_handler.py | 106 ++++++++++ align_data/pinecone/sql_db_handler.py | 130 ++++++++++++ align_data/pinecone/text_splitter.py | 102 ++++++++++ align_data/pinecone/update_pinecone.py | 190 ++++++++++++++++++ align_data/settings.py | 32 ++- .../alignment_newsletter/__init__.py | 0 .../alignment_newsletter.py | 0 align_data/{ => sources}/arbital/__init__.py | 0 align_data/{ => sources}/arbital/arbital.py | 0 align_data/{ => sources}/articles/__init__.py | 2 +- align_data/{ => sources}/articles/articles.py | 6 +- align_data/{ => sources}/articles/datasets.py | 4 +- .../{ => sources}/articles/google_cloud.py | 0 align_data/{ => sources}/articles/html.py | 0 align_data/{ => sources}/articles/indices.py | 2 +- align_data/{ => sources}/articles/parsers.py | 4 +- align_data/{ => sources}/articles/pdf.py | 2 +- .../{ => sources}/arxiv_papers/__init__.py | 0 .../arxiv_papers/arxiv_papers.py | 0 .../audio_transcripts/__init__.py | 0 .../audio_transcripts/audio_transcripts.py | 0 align_data/{ => sources}/blogs/__init__.py | 10 +- align_data/{ => sources}/blogs/blogs.py | 2 +- align_data/{ => sources}/blogs/gwern_blog.py | 0 align_data/{ => sources}/blogs/medium_blog.py | 0 .../{ => sources}/blogs/substack_blog.py | 0 align_data/{ => sources}/blogs/wp_blog.py | 0 align_data/{ => sources}/distill/__init__.py | 0 align_data/{ => sources}/distill/distill.py | 0 align_data/{ => sources}/ebooks/__init__.py | 0 .../{ => sources}/ebooks/agentmodels.py | 0 .../{ => sources}/ebooks/gdrive_ebooks.py | 0 align_data/{ => sources}/ebooks/mdebooks.py | 0 align_data/{ => sources}/gdocs/__init__.py | 0 align_data/{ => sources}/gdocs/gdocs.py | 0 .../{ => sources}/greaterwrong/__init__.py | 0 .../greaterwrong/greaterwrong.py | 0 align_data/{ => sources}/reports/__init__.py | 0 align_data/{ => sources}/reports/reports.py | 0 align_data/{ => sources}/stampy/__init__.py | 0 align_data/{ => sources}/stampy/stampy.py | 0 main.py | 9 +- 44 files changed, 594 insertions(+), 31 deletions(-) create mode 100644 align_data/pinecone/__init__.py create mode 100644 align_data/pinecone/pinecone_db_handler.py create mode 100644 align_data/pinecone/sql_db_handler.py create mode 100644 align_data/pinecone/text_splitter.py create mode 100644 align_data/pinecone/update_pinecone.py rename align_data/{ => sources}/alignment_newsletter/__init__.py (100%) rename align_data/{ => sources}/alignment_newsletter/alignment_newsletter.py (100%) rename align_data/{ => sources}/arbital/__init__.py (100%) rename align_data/{ => sources}/arbital/arbital.py (100%) rename align_data/{ => sources}/articles/__init__.py (85%) rename align_data/{ => sources}/articles/articles.py (94%) rename align_data/{ => sources}/articles/datasets.py (95%) rename align_data/{ => sources}/articles/google_cloud.py (100%) rename align_data/{ => sources}/articles/html.py (100%) rename align_data/{ => sources}/articles/indices.py (98%) rename align_data/{ => sources}/articles/parsers.py (98%) rename align_data/{ => sources}/articles/pdf.py (98%) rename align_data/{ => sources}/arxiv_papers/__init__.py (100%) rename align_data/{ => sources}/arxiv_papers/arxiv_papers.py (100%) rename align_data/{ => sources}/audio_transcripts/__init__.py (100%) rename align_data/{ => sources}/audio_transcripts/audio_transcripts.py (100%) rename align_data/{ => sources}/blogs/__init__.py (86%) rename align_data/{ => sources}/blogs/blogs.py (98%) rename align_data/{ => sources}/blogs/gwern_blog.py (100%) rename align_data/{ => sources}/blogs/medium_blog.py (100%) rename align_data/{ => sources}/blogs/substack_blog.py (100%) rename align_data/{ => sources}/blogs/wp_blog.py (100%) rename align_data/{ => sources}/distill/__init__.py (100%) rename align_data/{ => sources}/distill/distill.py (100%) rename align_data/{ => sources}/ebooks/__init__.py (100%) rename align_data/{ => sources}/ebooks/agentmodels.py (100%) rename align_data/{ => sources}/ebooks/gdrive_ebooks.py (100%) rename align_data/{ => sources}/ebooks/mdebooks.py (100%) rename align_data/{ => sources}/gdocs/__init__.py (100%) rename align_data/{ => sources}/gdocs/gdocs.py (100%) rename align_data/{ => sources}/greaterwrong/__init__.py (100%) rename align_data/{ => sources}/greaterwrong/greaterwrong.py (100%) rename align_data/{ => sources}/reports/__init__.py (100%) rename align_data/{ => sources}/reports/reports.py (100%) rename align_data/{ => sources}/stampy/__init__.py (100%) rename align_data/{ => sources}/stampy/stampy.py (100%) diff --git a/align_data/__init__.py b/align_data/__init__.py index 7c2b55e2..4028ab5b 100644 --- a/align_data/__init__.py +++ b/align_data/__init__.py @@ -1,15 +1,15 @@ -import align_data.arbital as arbital -import align_data.articles as articles -import align_data.blogs as blogs -import align_data.ebooks as ebooks -import align_data.arxiv_papers as arxiv_papers -import align_data.reports as reports -import align_data.greaterwrong as greaterwrong -import align_data.stampy as stampy -import align_data.audio_transcripts as audio_transcripts -import align_data.alignment_newsletter as alignment_newsletter -import align_data.distill as distill -import align_data.gdocs as gdocs +import align_data.sources.arbital as arbital +import align_data.sources.articles as articles +import align_data.sources.blogs as blogs +import align_data.sources.ebooks as ebooks +import align_data.sources.arxiv_papers as arxiv_papers +import align_data.sources.reports as reports +import align_data.sources.greaterwrong as greaterwrong +import align_data.sources.stampy as stampy +import align_data.sources.audio_transcripts as audio_transcripts +import align_data.sources.alignment_newsletter as alignment_newsletter +import align_data.sources.distill as distill +import align_data.sources.gdocs as gdocs DATASET_REGISTRY = ( arbital.ARBITAL_REGISTRY diff --git a/align_data/pinecone/__init__.py b/align_data/pinecone/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py new file mode 100644 index 00000000..4ee2ce7e --- /dev/null +++ b/align_data/pinecone/pinecone_db_handler.py @@ -0,0 +1,106 @@ +# dataset/pinecone_db_handler.py + +import pinecone + +from align_data.settings import PINECONE_INDEX_NAME, PINECONE_VALUES_DIMS, PINECONE_METRIC, PINECONE_METADATA_ENTRIES, PINECONE_API_KEY, PINECONE_ENVIRONMENT + +import logging +logger = logging.getLogger(__name__) + + +class PineconeDB: + def __init__( + self, + index_name: str = PINECONE_INDEX_NAME, + values_dims: int = PINECONE_VALUES_DIMS, + metric: str = PINECONE_METRIC, + metadata_entries: list = PINECONE_METADATA_ENTRIES, + create_index: bool = False, + log_index_stats: bool = True, + ): + self.index_name = index_name + self.values_dims = values_dims + self.metric = metric + self.metadata_entries = metadata_entries + + pinecone.init( + api_key = PINECONE_API_KEY, + environment = PINECONE_ENVIRONMENT, + ) + + if create_index: + self.create_index() + + self.index = pinecone.Index(index_name=self.index_name) + + if log_index_stats: + index_stats_response = self.index.describe_index_stats() + logger.info(f"{self.index_name}:\n{index_stats_response}") + + def upsert_entry(self, entry, chunks, embeddings, upsert_size=100): + self.index.upsert( + vectors=list( + zip( + [f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))], + embeddings.tolist(), + [ + { + 'entry_id': entry['id'], + 'source': entry['source'], + 'title': entry['title'], + 'authors': entry['authors'], + 'text': chunk, + } for chunk in chunks + ] + ) + ), + batch_size=upsert_size + ) + + def upsert_entries(self, entries_batch, chunks_batch, chunks_ids_batch, embeddings, upsert_size=100): + self.index.upsert( + vectors=list( + zip( + chunks_ids_batch, + embeddings.tolist(), + [ + { + 'entry_id': entry['id'], + 'source': entry['source'], + 'title': entry['title'], + 'authors': entry['authors'], + 'text': chunk, + } + for entry in entries_batch + for chunk in chunks_batch + ] + ) + ), + batch_size=upsert_size + ) + + def delete_entry(self, id): + self.index.delete( + filter={"entry_id": {"$eq": id}} + ) + + def delete_entries(self, ids): + self.index.delete( + filter={"entry_id": {"$in": ids}} + ) + + def create_index(self, replace_current_index: bool = True): + if replace_current_index: + self.delete_index() + + pinecone.create_index( + name=self.index_name, + dimension=self.values_dims, + metric=self.metric, + metadata_config = {"indexed": self.metadata_entries}, + ) + + def delete_index(self): + if self.index_name in pinecone.list_indexes(): + logger.info(f"Deleting index '{self.index_name}'.") + pinecone.delete_index(self.index_name) \ No newline at end of file diff --git a/align_data/pinecone/sql_db_handler.py b/align_data/pinecone/sql_db_handler.py new file mode 100644 index 00000000..6877900a --- /dev/null +++ b/align_data/pinecone/sql_db_handler.py @@ -0,0 +1,130 @@ +# dataset/sql_db_handler.py + +from typing import List, Dict, Union +import numpy as np +import sqlite3 + +from align_data.settings import SQL_DB_PATH + +import logging +logger = logging.getLogger(__name__) + + +class SQLDB: + def __init__(self, db_name: str = SQL_DB_PATH): + self.db_name = db_name + + self.create_tables() + + def create_tables(self, reset: bool = False): + with sqlite3.connect(self.db_name) as conn: + cursor = conn.cursor() + try: + if reset: + # Drop the tables if reset is True + cursor.execute("DROP TABLE IF EXISTS entry_database") + cursor.execute("DROP TABLE IF EXISTS chunk_database") + + # Create entry table + query = """ + CREATE TABLE IF NOT EXISTS entry_database ( + id TEXT PRIMARY KEY, + source TEXT, + title TEXT, + text TEXT, + url TEXT, + date_published TEXT, + authors TEXT + ) + """ + cursor.execute(query) + + # Create chunk table + query = """ + CREATE TABLE IF NOT EXISTS chunk_database ( + id TEXT PRIMARY KEY, + text TEXT, + embedding BLOB, + entry_id TEXT, + FOREIGN KEY (entry_id) REFERENCES entry_database(id) + ) + """ + cursor.execute(query) + + except sqlite3.Error as e: + logger.error(f"The error '{e}' occurred.") + + def upsert_entry(self, entry: Dict[str, Union[str, list]]) -> bool: + with sqlite3.connect(self.db_name) as conn: + cursor = conn.cursor() + try: + # Fetch existing data + cursor.execute("SELECT * FROM entry_database WHERE id=?", (entry['id'],)) + existing_entry = cursor.fetchone() + + new_entry = ( + entry['id'], + entry['source'], + entry['title'], + entry['text'], + entry['url'], + entry['date_published'], + ', '.join(entry['authors']) + ) + + if existing_entry != new_entry: + query = """ + INSERT OR REPLACE INTO entry_database + (id, source, title, text, url, date_published, authors) + VALUES (?, ?, ?, ?, ?, ?, ?) + """ + cursor.execute(query, new_entry) + return True + else: + return False + + except sqlite3.Error as e: + logger.error(f"The error '{e}' occurred.") + return False + + finally: + conn.commit() + + def upsert_chunks(self, chunks_ids_batch: List[str], chunks_batch: List[str], embeddings_batch: List[np.ndarray]) -> bool: + with sqlite3.connect(self.db_name) as conn: + cursor = conn.cursor() + try: + for chunk_id, chunk, embedding in zip(chunks_ids_batch, chunks_batch, embeddings_batch): + cursor.execute(""" + INSERT OR REPLACE INTO chunk_database + (id, text, embedding) + VALUES (?, ?, ?) + """, (chunk_id, chunk, embedding.tobytes())) + except sqlite3.Error as e: + logger.error(f"The error '{e}' occurred.") + finally: + conn.commit() + + + def stream_chunks(self): + with sqlite3.connect(self.db_name) as conn: + cursor = conn.cursor() + + # Join entry_database and chunk_database tables and order by source + cursor.execute(""" + SELECT c.id, c.text, c.embedding, e.source + FROM chunk_database c + JOIN entry_database e ON c.entry_id = e.id + ORDER BY e.source + """) + + for row in cursor: + # Convert bytes back to numpy array + embedding = np.frombuffer(row[2], dtype=np.float64) if row[2] else None + + yield { + 'id': row[0], + 'text': row[1], + 'embedding': embedding, + 'source': row[3], + } \ No newline at end of file diff --git a/align_data/pinecone/text_splitter.py b/align_data/pinecone/text_splitter.py new file mode 100644 index 00000000..8e5dc0b8 --- /dev/null +++ b/align_data/pinecone/text_splitter.py @@ -0,0 +1,102 @@ +# dataset/text_splitter.py + +from typing import List, Callable, Any +from langchain.text_splitter import TextSplitter +from nltk.tokenize import sent_tokenize + + +class ParagraphSentenceUnitTextSplitter(TextSplitter): + """A custom TextSplitter that breaks text by paragraphs, sentences, and then units (chars/words/tokens/etc). + + @param min_chunk_size: The minimum number of units in a chunk. + @param max_chunk_size: The maximum number of units in a chunk. + @param length_function: A function that returns the length of a string in units. + @param truncate_function: A function that truncates a string to a given unit length. + """ + + DEFAULT_MIN_CHUNK_SIZE = 900 + DEFAULT_MAX_CHUNK_SIZE = 1100 + DEFAULT_TRUNCATE_FUNCTION = lambda string, length, from_end=False: string[-length:] if from_end else string[:length] + + def __init__( + self, + min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, + max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, + truncate_function: Callable[[str, int], str] = DEFAULT_TRUNCATE_FUNCTION, + **kwargs: Any + ): + super().__init__(**kwargs) + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + + self._truncate_function = truncate_function + + def split_text(self, text: str) -> List[str]: + blocks = [] + current_block = "" + + paragraphs = text.split("\n\n") + for paragraph in paragraphs: + current_block += "\n\n" + paragraph + block_length = self._length_function(current_block) + + if block_length > self.max_chunk_size: # current block is too large, truncate it + current_block = self._handle_large_paragraph(current_block, blocks, paragraph) + elif block_length >= self.min_chunk_size: + blocks.append(current_block) + current_block = "" + else: # current block is too small, continue appending to it + continue + + blocks = self._handle_remaining_text(current_block, blocks) + + return [block.strip() for block in blocks] + + def _handle_large_paragraph(self, current_block, blocks, paragraph): + # Undo adding the whole paragraph + current_block = current_block[:-(len(paragraph)+2)] # +2 accounts for "\n\n" + + sentences = sent_tokenize(paragraph) + for sentence in sentences: + current_block += f" {sentence}" + + block_length = self._length_function(current_block) + if block_length < self.min_chunk_size: + continue + elif block_length <= self.max_chunk_size: + blocks.append(current_block) + current_block = "" + else: + current_block = self._truncate_large_block(current_block, blocks, sentence) + + return current_block + + def _truncate_large_block(self, current_block, blocks, sentence): + while self._length_function(current_block) > self.max_chunk_size: + # Truncate current_block to max size, set remaining sentence as next sentence + truncated_block = self._truncate_function(current_block, self.max_chunk_size) + blocks.append(truncated_block) + + remaining_sentence = current_block[len(truncated_block):].lstrip() + current_block = sentence = remaining_sentence + + return current_block + + def _handle_remaining_text(self, current_block, blocks): + if blocks == []: # no blocks were added + return [current_block] + elif current_block: # any leftover text + len_current_block = self._length_function(current_block) + if len_current_block < self.min_chunk_size: + # it needs to take the last min_chunk_size-len_current_block units from the previous block + previous_block = blocks[-1] + required_units = self.min_chunk_size - len_current_block # calculate the required units + + part_prev_block = self._truncate_function(previous_block, required_units, from_end=True) # get the required units from the previous block + last_block = part_prev_block + current_block + + blocks.append(last_block) + else: + blocks.append(current_block) + + return blocks \ No newline at end of file diff --git a/align_data/pinecone/update_pinecone.py b/align_data/pinecone/update_pinecone.py new file mode 100644 index 00000000..e3a27540 --- /dev/null +++ b/align_data/pinecone/update_pinecone.py @@ -0,0 +1,190 @@ +import os +from typing import Dict, List, Union +import numpy as np +import openai + +from align_data.pinecone.text_splitter import ParagraphSentenceUnitTextSplitter +from align_data.pinecone.pinecone_db_handler import PineconeDB + +from align_data.settings import USE_OPENAI_EMBEDDINGS, OPENAI_EMBEDDINGS_MODEL, \ + OPENAI_EMBEDDINGS_DIMS, OPENAI_EMBEDDINGS_RATE_LIMIT, \ + SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS, \ + CHUNK_SIZE, MAX_NUM_AUTHORS_IN_SIGNATURE, EMBEDDING_LENGTH_BIAS + +import logging +logger = logging.getLogger(__name__) + + +class ARDUpdater: + def __init__( + self, + min_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MIN_CHUNK_SIZE, + max_chunk_size: int = ParagraphSentenceUnitTextSplitter.DEFAULT_MAX_CHUNK_SIZE, + ): + self.text_splitter = ParagraphSentenceUnitTextSplitter( + min_chunk_size=min_chunk_size, + max_chunk_size=max_chunk_size, + ) + + self.pinecone_db = PineconeDB() + + if USE_OPENAI_EMBEDDINGS: + import openai + openai.api_key = os.environ['OPENAI_API_KEY'] + else: + import torch + from langchain.embeddings import HuggingFaceEmbeddings + + self.hf_embeddings = HuggingFaceEmbeddings( + model_name=SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL, + model_kwargs={'device': "cuda" if torch.cuda.is_available() else "cpu"}, + encode_kwargs={'show_progress_bar': False} + ) + + def update(self, custom_sources: List[str] = ['all']): + """ + Update the given sources. If no sources are provided, updates all sources. + + :param custom_sources: List of sources to update. + """ + + for source in custom_sources: + self.update_source(source) + + def update_source(self, source: str): + """ + Updates the entries from the given source. + + :param source: The name of the source to update. + """ + + logger.info(f"Updating {source} entries...") + + # TODO: loop through mysql entries and update the pinecone db + + logger.info(f"Successfully updated {source} entries.") + + def batchify(self, iterable): + """ + Divides the iterable into batches of size ~CHUNK_SIZE. + + :param iterable: The iterable to divide into batches. + :returns: A generator that yields batches from the iterable. + """ + + entries_batch = [] + chunks_batch = [] + chunks_ids_batch = [] + sources_batch = [] + + for entry in iterable: + chunks, chunks_ids = self.create_chunk_ids_and_authors(entry) + + entries_batch.append(entry) + chunks_batch.extend(chunks) + chunks_ids_batch.extend(chunks_ids) + sources_batch.extend([entry['source']] * len(chunks)) + + # If this batch is large enough, yield it and start a new one. + if len(chunks_batch) >= CHUNK_SIZE: + yield self._create_batch(entries_batch, chunks_batch, chunks_ids_batch, sources_batch) + + entries_batch = [] + chunks_batch = [] + chunks_ids_batch = [] + sources_batch = [] + + # Yield any remaining items. + if entries_batch: + yield self._create_batch(entries_batch, chunks_batch, chunks_ids_batch, sources_batch) + + def create_chunk_ids_and_authors(self, entry): + signature = f"Title: {entry['title']}, Author(s): {self.get_authors_str(entry['authors'])}" + chunks = self.text_splitter.split_text(entry['text']) + chunks = [f"- {signature}\n\n{chunk}" for chunk in chunks] + chunks_ids = [f"{entry['id']}_{str(i).zfill(6)}" for i in range(len(chunks))] + return chunks, chunks_ids + + def _create_batch(self, entries_batch, chunks_batch, chunks_ids_batch, sources_batch): + return {'entries_batch': entries_batch, 'chunks_batch': chunks_batch, 'chunks_ids_batch': chunks_ids_batch, 'sources_batch': sources_batch} + + def is_sql_entry_upserted(self, entry): + """Upserts an entry to the SQL database and returns the success status""" + return self.sql_db.upsert_entry(entry) + + def extract_embeddings(self, chunks_batch, sources_batch): + if USE_OPENAI_EMBEDDINGS: + return self.get_openai_embeddings(chunks_batch, sources_batch) + else: + return np.array(self.hf_embeddings.embed_documents(chunks_batch, sources_batch)) + + def reset_dbs(self): + self.sql_db.create_tables(True) + self.pinecone_db.create_index(True) + + @staticmethod + def preprocess_and_validate(entry): + """Preprocesses and validates the entry data""" + try: + ARDUpdater.validate_entry(entry) + + return { + 'id': entry['id'], + 'source': entry['source'], + 'title': entry['title'], + 'text': entry['text'], + 'url': entry['url'], + 'date_published': entry['date_published'], + 'authors': entry['authors'] + } + except ValueError as e: + logger.error(f"Entry validation failed: {str(e)}", exc_info=True) + return None + + @staticmethod + def validate_entry(entry: Dict[str, Union[str, list]], char_len_lower_limit: int = 0): + metadata_types = { + 'id': str, + 'source': str, + 'title': str, + 'text': str, + 'url': str, + 'date_published': str, + 'authors': list + } + + for metadata_type, metadata_type_type in metadata_types.items(): + if not isinstance(entry.get(metadata_type), metadata_type_type): + raise ValueError(f"Entry metadata '{metadata_type}' is not of type '{metadata_type_type}' or is missing.") + + if len(entry['text']) < char_len_lower_limit: + raise ValueError(f"Entry text is too short (< {char_len_lower_limit} characters).") + + @staticmethod + def is_valid_entry(entry): + """Checks if the entry is valid""" + return entry is not None + + @staticmethod + def get_openai_embeddings(chunks, sources=''): + embeddings = np.zeros((len(chunks), OPENAI_EMBEDDINGS_DIMS)) + + openai_output = openai.Embedding.create( + model=OPENAI_EMBEDDINGS_MODEL, + input=chunks + )['data'] + + for i, (embedding, source) in enumerate(zip(openai_output, sources)): + bias = EMBEDDING_LENGTH_BIAS.get(source, 1.0) + embeddings[i] = bias * np.array(embedding['embedding']) + + return embeddings + + @staticmethod + def get_authors_str(authors_lst: List[str]) -> str: + if authors_lst == []: return 'n/a' + if len(authors_lst) == 1: return authors_lst[0] + else: + authors_lst = authors_lst[:MAX_NUM_AUTHORS_IN_SIGNATURE] + authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}" + return authors_str \ No newline at end of file diff --git a/align_data/settings.py b/align_data/settings.py index 49be5460..425810f8 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -2,21 +2,49 @@ from dotenv import load_dotenv load_dotenv() - +### CODA ### CODA_TOKEN = os.environ.get("CODA_TOKEN") CODA_DOC_ID = os.environ.get("CODA_DOC_ID", "fau7sl2hmG") ON_SITE_TABLE = os.environ.get('CODA_ON_SITE_TABLE', 'table-aOTSHIz_mN') +### GOOGLE DRIVE ### PDFS_FOLDER_ID = os.environ.get('PDF_FOLDER_ID', '1etWiXPRl0QqdgYzivVIj6wCv9xj5VYoN') +### GOOGLE SHEETS ### METADATA_SOURCE_SPREADSHEET = os.environ.get('METADATA_SOURCE_SPREADSHEET', '1pgG3HzercOhf4gniaqp3tBc3uvZnHpPhXErwHcthmbI') METADATA_SOURCE_SHEET = os.environ.get('METADATA_SOURCE_SHEET', 'special_docs.csv') METADATA_OUTPUT_SPREADSHEET = os.environ.get('METADATA_OUTPUT_SPREADSHEET', '1l3azVJVukGAvZPgg0GyeqiaQe8bEMZvycBJaA8cRXf4') - +### MYSQL ### user = os.environ.get('ARD_DB_USER', 'user') password = os.environ.get('ARD_DB_PASSWORD', 'we all live in a yellow submarine') host = os.environ.get('ARD_DB_HOST', '127.0.0.1') port = os.environ.get('ARD_DB_PORT', '3306') db_name = os.environ.get('ARD_DB_NAME', 'alignment_research_dataset') DB_CONNECTION_URI = f'mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}' +DB_CONNECTION_URI = f'mysql+mysqldb://user:we all live in a yellow submarine@127.0.0.1:3306/alignment_research_dataset' + +### EMBEDDINGS ### +USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used. +EMBEDDING_LENGTH_BIAS = { + "aisafety.info": 1.05, # In search, favor AISafety.info entries. +} + +OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002" +OPENAI_EMBEDDINGS_DIMS = 1536 +OPENAI_EMBEDDINGS_RATE_LIMIT = 3500 + +SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1" +SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768 + +### PINECONE ### +PINECONE_INDEX_NAME = "stampy-chat-ard" +PINECONE_VALUES_DIMS = OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS +PINECONE_METRIC = "dotproduct" +PINECONE_METADATA_ENTRIES = ["entry_id", "source", "title", "authors", "text"] +PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) +PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) + +### MISCELLANEOUS ### +CHUNK_SIZE = 1750 +MAX_NUM_AUTHORS_IN_SIGNATURE = 3 \ No newline at end of file diff --git a/align_data/alignment_newsletter/__init__.py b/align_data/sources/alignment_newsletter/__init__.py similarity index 100% rename from align_data/alignment_newsletter/__init__.py rename to align_data/sources/alignment_newsletter/__init__.py diff --git a/align_data/alignment_newsletter/alignment_newsletter.py b/align_data/sources/alignment_newsletter/alignment_newsletter.py similarity index 100% rename from align_data/alignment_newsletter/alignment_newsletter.py rename to align_data/sources/alignment_newsletter/alignment_newsletter.py diff --git a/align_data/arbital/__init__.py b/align_data/sources/arbital/__init__.py similarity index 100% rename from align_data/arbital/__init__.py rename to align_data/sources/arbital/__init__.py diff --git a/align_data/arbital/arbital.py b/align_data/sources/arbital/arbital.py similarity index 100% rename from align_data/arbital/arbital.py rename to align_data/sources/arbital/arbital.py diff --git a/align_data/articles/__init__.py b/align_data/sources/articles/__init__.py similarity index 85% rename from align_data/articles/__init__.py rename to align_data/sources/articles/__init__.py index cd93345b..04664fd6 100644 --- a/align_data/articles/__init__.py +++ b/align_data/sources/articles/__init__.py @@ -1,4 +1,4 @@ -from align_data.articles.datasets import PDFArticles, HTMLArticles, EbookArticles, XMLArticles +from align_data.sources.articles.datasets import PDFArticles, HTMLArticles, EbookArticles, XMLArticles ARTICLES_REGISTRY = [ PDFArticles( diff --git a/align_data/articles/articles.py b/align_data/sources/articles/articles.py similarity index 94% rename from align_data/articles/articles.py rename to align_data/sources/articles/articles.py index 9b670e74..f32c3e64 100644 --- a/align_data/articles/articles.py +++ b/align_data/sources/articles/articles.py @@ -3,9 +3,9 @@ from tqdm import tqdm -from align_data.articles.google_cloud import iterate_rows, get_spreadsheet, get_sheet, upload_file, OK, with_retry -from align_data.articles.parsers import item_metadata, fetch -from align_data.articles.indices import fetch_all +from align_data.sources.articles.google_cloud import iterate_rows, get_spreadsheet, get_sheet, upload_file, OK, with_retry +from align_data.sources.articles.parsers import item_metadata, fetch +from align_data.sources.articles.indices import fetch_all from align_data.settings import PDFS_FOLDER_ID diff --git a/align_data/articles/datasets.py b/align_data/sources/articles/datasets.py similarity index 95% rename from align_data/articles/datasets.py rename to align_data/sources/articles/datasets.py index aba2a8e8..fff7f051 100644 --- a/align_data/articles/datasets.py +++ b/align_data/sources/articles/datasets.py @@ -10,8 +10,8 @@ from gdown.download import download from markdownify import markdownify -from align_data.articles.pdf import fetch_pdf, read_pdf, fetch -from align_data.articles.parsers import HTML_PARSERS, extract_gdrive_contents +from align_data.sources.articles.pdf import fetch_pdf, read_pdf, fetch +from align_data.sources.articles.parsers import HTML_PARSERS, extract_gdrive_contents from align_data.common.alignment_dataset import AlignmentDataset logger = logging.getLogger(__name__) diff --git a/align_data/articles/google_cloud.py b/align_data/sources/articles/google_cloud.py similarity index 100% rename from align_data/articles/google_cloud.py rename to align_data/sources/articles/google_cloud.py diff --git a/align_data/articles/html.py b/align_data/sources/articles/html.py similarity index 100% rename from align_data/articles/html.py rename to align_data/sources/articles/html.py diff --git a/align_data/articles/indices.py b/align_data/sources/articles/indices.py similarity index 98% rename from align_data/articles/indices.py rename to align_data/sources/articles/indices.py index 5da56e31..6eb5761c 100644 --- a/align_data/articles/indices.py +++ b/align_data/sources/articles/indices.py @@ -1,7 +1,7 @@ import re from collections import defaultdict -from align_data.articles.html import fetch, fetch_element +from align_data.sources.articles.html import fetch, fetch_element from align_data.common.alignment_dataset import AlignmentDataset from dateutil.parser import ParserError, parse from markdownify import MarkdownConverter diff --git a/align_data/articles/parsers.py b/align_data/sources/articles/parsers.py similarity index 98% rename from align_data/articles/parsers.py rename to align_data/sources/articles/parsers.py index 4b8faaa7..62426335 100644 --- a/align_data/articles/parsers.py +++ b/align_data/sources/articles/parsers.py @@ -4,8 +4,8 @@ import grobid_tei_xml import regex as re -from align_data.articles.html import element_extractor, fetch, fetch_element -from align_data.articles.pdf import doi_getter, fetch_pdf, get_pdf_from_page, get_arxiv_pdf +from align_data.sources.articles.html import element_extractor, fetch, fetch_element +from align_data.sources.articles.pdf import doi_getter, fetch_pdf, get_pdf_from_page, get_arxiv_pdf from markdownify import MarkdownConverter from bs4 import BeautifulSoup from markdownify import MarkdownConverter diff --git a/align_data/articles/pdf.py b/align_data/sources/articles/pdf.py similarity index 98% rename from align_data/articles/pdf.py rename to align_data/sources/articles/pdf.py index 7be755f1..ae4492f6 100644 --- a/align_data/articles/pdf.py +++ b/align_data/sources/articles/pdf.py @@ -9,7 +9,7 @@ from PyPDF2 import PdfReader from PyPDF2.errors import PdfReadError -from align_data.articles.html import fetch, fetch_element +from align_data.sources.articles.html import fetch, fetch_element logger = logging.getLogger(__name__) diff --git a/align_data/arxiv_papers/__init__.py b/align_data/sources/arxiv_papers/__init__.py similarity index 100% rename from align_data/arxiv_papers/__init__.py rename to align_data/sources/arxiv_papers/__init__.py diff --git a/align_data/arxiv_papers/arxiv_papers.py b/align_data/sources/arxiv_papers/arxiv_papers.py similarity index 100% rename from align_data/arxiv_papers/arxiv_papers.py rename to align_data/sources/arxiv_papers/arxiv_papers.py diff --git a/align_data/audio_transcripts/__init__.py b/align_data/sources/audio_transcripts/__init__.py similarity index 100% rename from align_data/audio_transcripts/__init__.py rename to align_data/sources/audio_transcripts/__init__.py diff --git a/align_data/audio_transcripts/audio_transcripts.py b/align_data/sources/audio_transcripts/audio_transcripts.py similarity index 100% rename from align_data/audio_transcripts/audio_transcripts.py rename to align_data/sources/audio_transcripts/audio_transcripts.py diff --git a/align_data/blogs/__init__.py b/align_data/sources/blogs/__init__.py similarity index 86% rename from align_data/blogs/__init__.py rename to align_data/sources/blogs/__init__.py index 3c2c7c0c..8f1d5fc1 100644 --- a/align_data/blogs/__init__.py +++ b/align_data/sources/blogs/__init__.py @@ -1,10 +1,10 @@ -from align_data.blogs.wp_blog import WordpressBlog -from align_data.blogs.medium_blog import MediumBlog -from align_data.blogs.gwern_blog import GwernBlog -from align_data.blogs.blogs import ( +from align_data.sources.blogs.wp_blog import WordpressBlog +from align_data.sources.blogs.medium_blog import MediumBlog +from align_data.sources.blogs.gwern_blog import GwernBlog +from align_data.sources.blogs.blogs import ( ColdTakes, GenerativeInk, CaradoMoe, EleutherAI, OpenAIResearch, DeepMindTechnicalBlog ) -from align_data.blogs.substack_blog import SubstackBlog +from align_data.sources.blogs.substack_blog import SubstackBlog BLOG_REGISTRY = [ diff --git a/align_data/blogs/blogs.py b/align_data/sources/blogs/blogs.py similarity index 98% rename from align_data/blogs/blogs.py rename to align_data/sources/blogs/blogs.py index 9c6b962b..65ad95e6 100644 --- a/align_data/blogs/blogs.py +++ b/align_data/sources/blogs/blogs.py @@ -1,7 +1,7 @@ import logging import requests -from align_data.articles.parsers import item_metadata +from align_data.sources.articles.parsers import item_metadata from align_data.common.html_dataset import HTMLDataset, RSSDataset from bs4 import BeautifulSoup from dateutil.parser import ParserError diff --git a/align_data/blogs/gwern_blog.py b/align_data/sources/blogs/gwern_blog.py similarity index 100% rename from align_data/blogs/gwern_blog.py rename to align_data/sources/blogs/gwern_blog.py diff --git a/align_data/blogs/medium_blog.py b/align_data/sources/blogs/medium_blog.py similarity index 100% rename from align_data/blogs/medium_blog.py rename to align_data/sources/blogs/medium_blog.py diff --git a/align_data/blogs/substack_blog.py b/align_data/sources/blogs/substack_blog.py similarity index 100% rename from align_data/blogs/substack_blog.py rename to align_data/sources/blogs/substack_blog.py diff --git a/align_data/blogs/wp_blog.py b/align_data/sources/blogs/wp_blog.py similarity index 100% rename from align_data/blogs/wp_blog.py rename to align_data/sources/blogs/wp_blog.py diff --git a/align_data/distill/__init__.py b/align_data/sources/distill/__init__.py similarity index 100% rename from align_data/distill/__init__.py rename to align_data/sources/distill/__init__.py diff --git a/align_data/distill/distill.py b/align_data/sources/distill/distill.py similarity index 100% rename from align_data/distill/distill.py rename to align_data/sources/distill/distill.py diff --git a/align_data/ebooks/__init__.py b/align_data/sources/ebooks/__init__.py similarity index 100% rename from align_data/ebooks/__init__.py rename to align_data/sources/ebooks/__init__.py diff --git a/align_data/ebooks/agentmodels.py b/align_data/sources/ebooks/agentmodels.py similarity index 100% rename from align_data/ebooks/agentmodels.py rename to align_data/sources/ebooks/agentmodels.py diff --git a/align_data/ebooks/gdrive_ebooks.py b/align_data/sources/ebooks/gdrive_ebooks.py similarity index 100% rename from align_data/ebooks/gdrive_ebooks.py rename to align_data/sources/ebooks/gdrive_ebooks.py diff --git a/align_data/ebooks/mdebooks.py b/align_data/sources/ebooks/mdebooks.py similarity index 100% rename from align_data/ebooks/mdebooks.py rename to align_data/sources/ebooks/mdebooks.py diff --git a/align_data/gdocs/__init__.py b/align_data/sources/gdocs/__init__.py similarity index 100% rename from align_data/gdocs/__init__.py rename to align_data/sources/gdocs/__init__.py diff --git a/align_data/gdocs/gdocs.py b/align_data/sources/gdocs/gdocs.py similarity index 100% rename from align_data/gdocs/gdocs.py rename to align_data/sources/gdocs/gdocs.py diff --git a/align_data/greaterwrong/__init__.py b/align_data/sources/greaterwrong/__init__.py similarity index 100% rename from align_data/greaterwrong/__init__.py rename to align_data/sources/greaterwrong/__init__.py diff --git a/align_data/greaterwrong/greaterwrong.py b/align_data/sources/greaterwrong/greaterwrong.py similarity index 100% rename from align_data/greaterwrong/greaterwrong.py rename to align_data/sources/greaterwrong/greaterwrong.py diff --git a/align_data/reports/__init__.py b/align_data/sources/reports/__init__.py similarity index 100% rename from align_data/reports/__init__.py rename to align_data/sources/reports/__init__.py diff --git a/align_data/reports/reports.py b/align_data/sources/reports/reports.py similarity index 100% rename from align_data/reports/reports.py rename to align_data/sources/reports/reports.py diff --git a/align_data/stampy/__init__.py b/align_data/sources/stampy/__init__.py similarity index 100% rename from align_data/stampy/__init__.py rename to align_data/sources/stampy/__init__.py diff --git a/align_data/stampy/stampy.py b/align_data/sources/stampy/stampy.py similarity index 100% rename from align_data/stampy/stampy.py rename to align_data/sources/stampy/stampy.py diff --git a/main.py b/main.py index 4e3f35ac..e1bce889 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,8 @@ from align_data import ALL_DATASETS, DATASET_REGISTRY, get_dataset from align_data.analysis.count_tokens import count_token -from align_data.articles.articles import update_new_items, check_new_articles +from align_data.sources.articles.articles import update_new_items, check_new_articles +from align_data.pinecone.update_pinecone import ARDUpdater from align_data.settings import ( METADATA_OUTPUT_SPREADSHEET, METADATA_SOURCE_SHEET, METADATA_SOURCE_SPREADSHEET ) @@ -121,6 +122,12 @@ def fetch_new_articles(self, source_spreadsheet=METADATA_SOURCE_SPREADSHEET, sou """ return check_new_articles(source_spreadsheet, source_sheet) + def update_pinecone(self): + """ + This function updates the Pinecone vector DB. + """ + updater = ARDUpdater() + updater.update() if __name__ == "__main__": fire.Fire(AlignmentDataset) From 45e40a7635062567014ec6ba4f63d57019da126c Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Mon, 24 Jul 2023 19:47:53 -0400 Subject: [PATCH 4/7] Bug fix and sqlite code removal --- align_data/pinecone/sql_db_handler.py | 130 -------------------------- align_data/settings.py | 1 - 2 files changed, 131 deletions(-) delete mode 100644 align_data/pinecone/sql_db_handler.py diff --git a/align_data/pinecone/sql_db_handler.py b/align_data/pinecone/sql_db_handler.py deleted file mode 100644 index 6877900a..00000000 --- a/align_data/pinecone/sql_db_handler.py +++ /dev/null @@ -1,130 +0,0 @@ -# dataset/sql_db_handler.py - -from typing import List, Dict, Union -import numpy as np -import sqlite3 - -from align_data.settings import SQL_DB_PATH - -import logging -logger = logging.getLogger(__name__) - - -class SQLDB: - def __init__(self, db_name: str = SQL_DB_PATH): - self.db_name = db_name - - self.create_tables() - - def create_tables(self, reset: bool = False): - with sqlite3.connect(self.db_name) as conn: - cursor = conn.cursor() - try: - if reset: - # Drop the tables if reset is True - cursor.execute("DROP TABLE IF EXISTS entry_database") - cursor.execute("DROP TABLE IF EXISTS chunk_database") - - # Create entry table - query = """ - CREATE TABLE IF NOT EXISTS entry_database ( - id TEXT PRIMARY KEY, - source TEXT, - title TEXT, - text TEXT, - url TEXT, - date_published TEXT, - authors TEXT - ) - """ - cursor.execute(query) - - # Create chunk table - query = """ - CREATE TABLE IF NOT EXISTS chunk_database ( - id TEXT PRIMARY KEY, - text TEXT, - embedding BLOB, - entry_id TEXT, - FOREIGN KEY (entry_id) REFERENCES entry_database(id) - ) - """ - cursor.execute(query) - - except sqlite3.Error as e: - logger.error(f"The error '{e}' occurred.") - - def upsert_entry(self, entry: Dict[str, Union[str, list]]) -> bool: - with sqlite3.connect(self.db_name) as conn: - cursor = conn.cursor() - try: - # Fetch existing data - cursor.execute("SELECT * FROM entry_database WHERE id=?", (entry['id'],)) - existing_entry = cursor.fetchone() - - new_entry = ( - entry['id'], - entry['source'], - entry['title'], - entry['text'], - entry['url'], - entry['date_published'], - ', '.join(entry['authors']) - ) - - if existing_entry != new_entry: - query = """ - INSERT OR REPLACE INTO entry_database - (id, source, title, text, url, date_published, authors) - VALUES (?, ?, ?, ?, ?, ?, ?) - """ - cursor.execute(query, new_entry) - return True - else: - return False - - except sqlite3.Error as e: - logger.error(f"The error '{e}' occurred.") - return False - - finally: - conn.commit() - - def upsert_chunks(self, chunks_ids_batch: List[str], chunks_batch: List[str], embeddings_batch: List[np.ndarray]) -> bool: - with sqlite3.connect(self.db_name) as conn: - cursor = conn.cursor() - try: - for chunk_id, chunk, embedding in zip(chunks_ids_batch, chunks_batch, embeddings_batch): - cursor.execute(""" - INSERT OR REPLACE INTO chunk_database - (id, text, embedding) - VALUES (?, ?, ?) - """, (chunk_id, chunk, embedding.tobytes())) - except sqlite3.Error as e: - logger.error(f"The error '{e}' occurred.") - finally: - conn.commit() - - - def stream_chunks(self): - with sqlite3.connect(self.db_name) as conn: - cursor = conn.cursor() - - # Join entry_database and chunk_database tables and order by source - cursor.execute(""" - SELECT c.id, c.text, c.embedding, e.source - FROM chunk_database c - JOIN entry_database e ON c.entry_id = e.id - ORDER BY e.source - """) - - for row in cursor: - # Convert bytes back to numpy array - embedding = np.frombuffer(row[2], dtype=np.float64) if row[2] else None - - yield { - 'id': row[0], - 'text': row[1], - 'embedding': embedding, - 'source': row[3], - } \ No newline at end of file diff --git a/align_data/settings.py b/align_data/settings.py index 425810f8..0b08ed3d 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -22,7 +22,6 @@ port = os.environ.get('ARD_DB_PORT', '3306') db_name = os.environ.get('ARD_DB_NAME', 'alignment_research_dataset') DB_CONNECTION_URI = f'mysql+mysqldb://{user}:{password}@{host}:{port}/{db_name}' -DB_CONNECTION_URI = f'mysql+mysqldb://user:we all live in a yellow submarine@127.0.0.1:3306/alignment_research_dataset' ### EMBEDDINGS ### USE_OPENAI_EMBEDDINGS = True # If false, SentenceTransformer embeddings will be used. From b9369a97c710abdceadafbff5cf6d5566760a262 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Mon, 24 Jul 2023 19:59:49 -0400 Subject: [PATCH 5/7] Fixed tests imports --- .dockerignore | 1 + Dockerfile | 15 +++++++++++++++ tests/align_data/articles/test_datasets.py | 2 +- tests/align_data/articles/test_parsers.py | 2 +- tests/align_data/test_alignment_newsletter.py | 2 +- tests/align_data/test_arbital.py | 2 +- tests/align_data/test_blogs.py | 4 ++-- tests/align_data/test_distill.py | 2 +- tests/align_data/test_greater_wrong.py | 2 +- tests/align_data/test_stampy.py | 2 +- 10 files changed, 25 insertions(+), 9 deletions(-) create mode 100644 .dockerignore create mode 100644 Dockerfile diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..81c914eb --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +data/raw/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..ed26ac81 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.11-slim-buster + +WORKDIR /app + +ADD . /app + +RUN apt-get update && apt-get install -y \ + git \ + pkg-config \ + default-libmysqlclient-dev \ + gcc + +RUN pip install --no-cache-dir -r requirements.txt + +CMD [ "python", "main.py" ] \ No newline at end of file diff --git a/tests/align_data/articles/test_datasets.py b/tests/align_data/articles/test_datasets.py index d85c3d13..1c5f7a2d 100644 --- a/tests/align_data/articles/test_datasets.py +++ b/tests/align_data/articles/test_datasets.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from align_data.articles.datasets import EbookArticles, HTMLArticles, PDFArticles, SpreadsheetDataset, XMLArticles +from align_data.sources.articles.datasets import EbookArticles, HTMLArticles, PDFArticles, SpreadsheetDataset, XMLArticles @pytest.fixture diff --git a/tests/align_data/articles/test_parsers.py b/tests/align_data/articles/test_parsers.py index 7e480b22..f062343a 100644 --- a/tests/align_data/articles/test_parsers.py +++ b/tests/align_data/articles/test_parsers.py @@ -4,7 +4,7 @@ import pytest from bs4 import BeautifulSoup -from align_data.articles.parsers import ( +from align_data.sources.articles.parsers import ( google_doc, medium_blog, parse_grobid, get_content_type, extract_gdrive_contents ) diff --git a/tests/align_data/test_alignment_newsletter.py b/tests/align_data/test_alignment_newsletter.py index a0f039d7..0e9db7a4 100644 --- a/tests/align_data/test_alignment_newsletter.py +++ b/tests/align_data/test_alignment_newsletter.py @@ -2,7 +2,7 @@ import pytest import pandas as pd -from align_data.alignment_newsletter import AlignmentNewsletter +from align_data.sources.alignment_newsletter import AlignmentNewsletter @pytest.fixture(scope="module") diff --git a/tests/align_data/test_arbital.py b/tests/align_data/test_arbital.py index f2f4cd98..d0c454dd 100644 --- a/tests/align_data/test_arbital.py +++ b/tests/align_data/test_arbital.py @@ -4,7 +4,7 @@ import pytest from dateutil.parser import parse -from align_data.arbital.arbital import Arbital, extract_text, flatten, parse_arbital_link +from align_data.sources.arbital.arbital import Arbital, extract_text, flatten, parse_arbital_link @pytest.mark.parametrize('contents, expected', ( diff --git a/tests/align_data/test_blogs.py b/tests/align_data/test_blogs.py index 62de36ae..f2dc7d21 100644 --- a/tests/align_data/test_blogs.py +++ b/tests/align_data/test_blogs.py @@ -4,11 +4,11 @@ from bs4 import BeautifulSoup from dateutil.parser import parse -from align_data.blogs import ( +from align_data.sources.blogs import ( CaradoMoe, ColdTakes, GenerativeInk, GwernBlog, MediumBlog, SubstackBlog, WordpressBlog, OpenAIResearch, DeepMindTechnicalBlog ) -from align_data.blogs.blogs import EleutherAI +from align_data.sources.blogs.blogs import EleutherAI SAMPLE_HTML = """ diff --git a/tests/align_data/test_distill.py b/tests/align_data/test_distill.py index e0c8943f..b94b5bda 100644 --- a/tests/align_data/test_distill.py +++ b/tests/align_data/test_distill.py @@ -3,7 +3,7 @@ import pytest from bs4 import BeautifulSoup -from align_data.distill import Distill +from align_data.sources.distill import Distill def test_extract_authors(): diff --git a/tests/align_data/test_greater_wrong.py b/tests/align_data/test_greater_wrong.py index 72b32b83..f7c273c2 100644 --- a/tests/align_data/test_greater_wrong.py +++ b/tests/align_data/test_greater_wrong.py @@ -5,7 +5,7 @@ import pytest -from align_data.greaterwrong.greaterwrong import ( +from align_data.sources.greaterwrong.greaterwrong import ( fetch_LW_tags, fetch_ea_forum_topics, GreaterWrong ) diff --git a/tests/align_data/test_stampy.py b/tests/align_data/test_stampy.py index 1381941a..c3694086 100644 --- a/tests/align_data/test_stampy.py +++ b/tests/align_data/test_stampy.py @@ -1,7 +1,7 @@ from unittest.mock import patch from dateutil.parser import parse -from align_data.stampy import Stampy +from align_data.sources.stampy import Stampy def test_validate_coda_token(): From 594354b05ce71a17cd7592b31f1a15ffe3494512 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Mon, 24 Jul 2023 20:03:14 -0400 Subject: [PATCH 6/7] Removed docker files --- .dockerignore | 1 - Dockerfile | 15 --------------- 2 files changed, 16 deletions(-) delete mode 100644 .dockerignore delete mode 100644 Dockerfile diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 81c914eb..00000000 --- a/.dockerignore +++ /dev/null @@ -1 +0,0 @@ -data/raw/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index ed26ac81..00000000 --- a/Dockerfile +++ /dev/null @@ -1,15 +0,0 @@ -FROM python:3.11-slim-buster - -WORKDIR /app - -ADD . /app - -RUN apt-get update && apt-get install -y \ - git \ - pkg-config \ - default-libmysqlclient-dev \ - gcc - -RUN pip install --no-cache-dir -r requirements.txt - -CMD [ "python", "main.py" ] \ No newline at end of file From b9a735f668a19ea0110e43fa606fed25c2fe00c7 Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Mon, 24 Jul 2023 20:37:11 -0400 Subject: [PATCH 7/7] Add .env.example file and fixed pinecone index def --- .env.example | 10 ++++++++++ align_data/settings.py | 6 +++--- 2 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 .env.example diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..7636d2da --- /dev/null +++ b/.env.example @@ -0,0 +1,10 @@ +CODA_TOKEN="" +ARD_DB_USER="user" +ARD_DB_PASSWORD="we all live in a yellow submarine" +ARD_DB_HOST="127.0.0.1" +ARD_DB_PORT="3306" +ARD_DB_NAME="alignment_research_dataset" +OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" +PINECONE_INDEX_NAME="stampy-chat-ard" +PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" +PINECONE_ENVIRONMENT="xx-xxxxx-gcp" diff --git a/align_data/settings.py b/align_data/settings.py index 0b08ed3d..8e0da373 100644 --- a/align_data/settings.py +++ b/align_data/settings.py @@ -37,12 +37,12 @@ SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768 ### PINECONE ### -PINECONE_INDEX_NAME = "stampy-chat-ard" +PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME", "stampy-chat-ard") +PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) +PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) PINECONE_VALUES_DIMS = OPENAI_EMBEDDINGS_DIMS if USE_OPENAI_EMBEDDINGS else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS PINECONE_METRIC = "dotproduct" PINECONE_METADATA_ENTRIES = ["entry_id", "source", "title", "authors", "text"] -PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) -PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) ### MISCELLANEOUS ### CHUNK_SIZE = 1750