Skip to content

Commit

Permalink
Merge pull request #150 from StampyAI/pinecone-fix-vector-search
Browse files Browse the repository at this point in the history
Pinecone fix vector search
  • Loading branch information
henri123lemoine authored Aug 23, 2023
2 parents 76e154c + 176b052 commit c9ceb24
Show file tree
Hide file tree
Showing 62 changed files with 1,754 additions and 1,141 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ OPENAI_API_KEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
PINECONE_INDEX_NAME="stampy-chat-ard"
PINECONE_API_KEY="xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
PINECONE_ENVIRONMENT="xx-xxxxx-gcp"
YOUTUBE_API_KEY=""
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
language_version: python3.11
4 changes: 1 addition & 3 deletions align_data/analysis/analyse_jsonl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def process_jsonl_files(data_dir):

for id, duplicates in seen_urls.items():
if len(duplicates) > 1:
list_of_duplicates = "\n".join(
get_data_dict_str(duplicate) for duplicate in duplicates
)
list_of_duplicates = "\n".join(get_data_dict_str(duplicate) for duplicate in duplicates)
print(
f"{len(duplicates)} duplicate ids found. \nId: {id}\n{list_of_duplicates}\n\n\n\n"
)
Expand Down
23 changes: 10 additions & 13 deletions align_data/common/alignment_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,9 @@ def _load_outputted_items(self) -> Set[str]:
# This doesn't filter by self.name. The good thing about that is that it should handle a lot more
# duplicates. The bad thing is that this could potentially return a massive amount of data if there
# are lots of items.
return set(
session.scalars(select(getattr(Article, self.done_key))).all()
)
return set(session.scalars(select(getattr(Article, self.done_key))).all())
# TODO: Properly handle this - it should create a proper SQL JSON select
return {
item.get(self.done_key)
for item in session.scalars(select(Article.meta)).all()
}
return {item.get(self.done_key) for item in session.scalars(select(Article.meta)).all()}

def not_processed(self, item):
# NOTE: `self._outputted_items` reads in all items. Which could potentially be a lot. If this starts to
Expand Down Expand Up @@ -214,7 +209,7 @@ def fetch_entries(self):
if self.COOLDOWN:
time.sleep(self.COOLDOWN)

def process_entry(self, entry) -> Optional[Article]:
def process_entry(self, entry) -> Article | None:
"""Process a single entry."""
raise NotImplementedError

Expand All @@ -223,7 +218,7 @@ def _format_datetime(date) -> str:
return date.strftime("%Y-%m-%dT%H:%M:%SZ")

@staticmethod
def _get_published_date(date) -> Optional[datetime]:
def _get_published_date(date) -> datetime | None:
try:
# Totally ignore any timezone info, forcing everything to UTC
return parse(str(date)).replace(tzinfo=pytz.UTC)
Expand All @@ -239,7 +234,11 @@ def unprocessed_items(self, items=None) -> Iterable:

urls = map(self.get_item_key, items)
with make_session() as session:
articles = session.query(Article).options(joinedload(Article.summaries)).filter(Article.url.in_(urls))
articles = (
session.query(Article)
.options(joinedload(Article.summaries))
.filter(Article.url.in_(urls))
)
self.articles = {a.url: a for a in articles if a.url}

return items
Expand All @@ -249,9 +248,7 @@ def _load_outputted_items(self) -> Set[str]:
with make_session() as session:
return set(
session.scalars(
select(Article.url)
.join(Article.summaries)
.filter(Summary.source == self.name)
select(Article.url).join(Article.summaries).filter(Summary.source == self.name)
)
)

Expand Down
2 changes: 1 addition & 1 deletion align_data/common/html_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_contents(self, article_url):
def process_entry(self, article):
article_url = self.get_item_key(article)
contents = self.get_contents(article_url)
if not contents.get('text'):
if not contents.get("text"):
return None

return self.make_data_entry(contents)
Expand Down
50 changes: 25 additions & 25 deletions align_data/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.hybrid import hybrid_property
from align_data.settings import PINECONE_METADATA_KEYS


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,33 +70,26 @@ class Article(Base):
def __repr__(self) -> str:
return f"Article(id={self.id!r}, title={self.title!r}, url={self.url!r}, source={self.source!r}, authors={self.authors!r}, date_published={self.date_published!r})"

def is_metadata_keys_equal(self, other):
if not isinstance(other, Article):
raise TypeError(
f"Expected an instance of Article, got {type(other).__name__}"
)
return not any(
getattr(self, key, None)
!= getattr(other, key, None) # entry_id is implicitly ignored
for key in PINECONE_METADATA_KEYS
)

def generate_id_string(self) -> bytes:
return "".join(str(getattr(self, field)) for field in self.__id_fields).encode(
"utf-8"
)
return "".join(str(getattr(self, field)) for field in self.__id_fields).encode("utf-8")

@property
def __id_fields(self):
if self.source == 'aisafety.info':
return ['url']
if self.source in ['importai', 'ml_safety_newsletter', 'alignment_newsletter']:
return ['url', 'title', 'source']
if self.source == "aisafety.info":
return ["url"]
if self.source in ["importai", "ml_safety_newsletter", "alignment_newsletter"]:
return ["url", "title", "source"]
return ["url", "title"]

@property
def missing_fields(self):
fields = set(self.__id_fields) | {'text', 'title', 'url', 'source', 'date_published'}
fields = set(self.__id_fields) | {
"text",
"title",
"url",
"source",
"date_published",
}
return sorted([field for field in fields if not getattr(self, field, None)])

def verify_id(self):
Expand Down Expand Up @@ -133,13 +125,21 @@ def add_meta(self, key, val):
self.meta = {}
self.meta[key] = val

def append_comment(self, comment: str):
"""Appends a comment to the article.comments field. You must run session.commit() to save the comment to the database."""
if self.comments is None:
self.comments = ""
self.comments = f"{self.comments}\n\n{comment}".strip()

@hybrid_property
def is_valid(self):
return (
self.text and self.text.strip() and
self.url and self.title and
self.authors is not None and
self.status == OK_STATUS
self.text
and self.text.strip()
and self.url
and self.title
and self.authors is not None
and self.status == OK_STATUS
)

@is_valid.expression
Expand All @@ -157,7 +157,7 @@ def before_write(cls, _mapper, _connection, target):
target.verify_id_fields()

if not target.status and target.missing_fields:
target.status = 'Missing fields'
target.status = "Missing fields"
target.comments = f'missing fields: {", ".join(target.missing_fields)}'

if target.id:
Expand Down
26 changes: 20 additions & 6 deletions align_data/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,38 @@

logger = logging.getLogger(__name__)

# We create a single engine for the entire application
engine = create_engine(DB_CONNECTION_URI, echo=False)


@contextmanager
def make_session(auto_commit=False):
engine = create_engine(DB_CONNECTION_URI, echo=False)
with Session(engine).no_autoflush as session:
with Session(engine, autoflush=False) as session:
yield session
if auto_commit:
session.commit()


def stream_pinecone_updates(session, custom_sources: List[str]):
def stream_pinecone_updates(
session: Session, custom_sources: List[str], force_update: bool = False
):
"""Yield Pinecone entries that require an update."""
yield from (
session
.query(Article)
.filter(Article.pinecone_update_required.is_(True))
session.query(Article)
.filter(or_(Article.pinecone_update_required.is_(True), force_update))
.filter(Article.is_valid)
.filter(Article.source.in_(custom_sources))
.filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
.yield_per(1000)
)


def get_all_valid_article_ids(session: Session) -> List[str]:
"""Return all valid article IDs."""
query_result = (
session.query(Article.id)
.filter(Article.is_valid)
.filter(or_(Article.confidence == None, Article.confidence > MIN_CONFIDENCE))
.all()
)
return [item[0] for item in query_result]
Loading

0 comments on commit c9ceb24

Please sign in to comment.