Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
henri123lemoine committed Aug 21, 2023
1 parent 950b69f commit 15e4bec
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 24 deletions.
40 changes: 23 additions & 17 deletions align_data/pinecone/pinecone_db_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def upsert_entry(self, entry: Dict, upsert_size=100):
"title": entry["title"],
"authors": entry["authors"],
"url": entry["url"],
# Note: there is a slight inconsistency between the name of the field in the DB and the name of the field in the Pinecone index.
"date": entry["date_published"],
"date_published": entry["date_published"],
"text": text_chunk,
}
for text_chunk in entry["text_chunks"]
Expand All @@ -83,11 +82,13 @@ def query(self, query: Union[str, List[float]], top_k=10, include_values=False,

@dataclass
class Block:
id: str
source: str
title: str
author: str
date: str
url: str
authors: str
text: str
url: str
date_published: str

if isinstance(query, str):
query = list(get_embeddings(query)[0])
Expand All @@ -100,21 +101,23 @@ class Block:
namespace=PINECONE_NAMESPACE,
**kwargs,
)
# print(query_response)

blocks = []
for match in query_response['matches']:

date = match['metadata']['date']
date_published = match['metadata']['date']

if type(date) == datetime.date: date = date.strftime("%Y-%m-%d") # iso8601
if type(date_published) == datetime.date:
date_published = date_published.strftime("%Y-%m-%d") # iso8601

blocks.append(Block(
id = match['id'],
source = match['metadata']['source'],
title = match['metadata']['title'],
author = match['metadata']['authors'],
date = date,
authors = match['metadata']['authors'],
text = strip_block(match['metadata']['text']),
url = match['metadata']['url'],
text = strip_block(match['metadata']['text'])
date_published = date_published,
))

return blocks
Expand Down Expand Up @@ -167,10 +170,13 @@ def get_embeddings_by_ids(self, ids: List[str]) -> List[Tuple[str, Union[List[fl
# we add the title and authors inside the contents of the block, so that
# searches for the title or author will be more likely to pull it up. This
# strips it back out.
import re
# import re
# def strip_block(text: str) -> str:
# r = re.match(r"^\"(.*)\"\s*-\s*Title:.*$", text, re.DOTALL)
# if not r:
# print("Warning: couldn't strip block")
# print(text)
# return r.group(1) if r else text

def strip_block(text: str) -> str:
r = re.match(r"^\"(.*)\"\s*-\s*Title:.*$", text, re.DOTALL)
if not r:
print("Warning: couldn't strip block")
print(text)
return r.group(1) if r else text
return "\n".join(text.split("\n")[1:])
9 changes: 5 additions & 4 deletions align_data/pinecone/update_pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,14 @@ def _make_pinecone_update(self, article: Article):


def get_text_chunks(article: Article, text_splitter: ParagraphSentenceUnitTextSplitter) -> List[str]:
title = article.title.replace("\n", " ")

authors_lst = [author.strip() for author in article.authors.split(",")]
authors = get_authors_str(authors_lst)

signature = f"Title: {article.title}, Author(s): {authors}"
signature = f"### {title}, by {authors}"
text_chunks = text_splitter.split_text(article.text)
return [f"- {signature}\n\n{text_chunk}" for text_chunk in text_chunks]

return [f"{signature}\n\"{text_chunk}\"" for text_chunk in text_chunks]

def get_authors_str(authors_lst: List[str]) -> str:
if authors_lst == []:
Expand All @@ -127,4 +128,4 @@ def get_authors_str(authors_lst: List[str]) -> str:
else:
authors_lst = authors_lst[:3]
authors_str = f"{', '.join(authors_lst[:-1])} and {authors_lst[-1]}"
return authors_str
return authors_str.replace("\n", " ")
6 changes: 3 additions & 3 deletions align_data/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002"
OPENAI_EMBEDDINGS_DIMS = 1536
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500
openai.api_key = os.environ["OPENAI_API_KEY"]
openai.organization = os.environ["OPENAI_ORGANIZATION"]
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
openai.organization = os.environ.get("OPENAI_ORGANIZATION", None)

SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768
Expand All @@ -62,7 +62,7 @@
else SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS
)
PINECONE_METRIC = "dotproduct"
PINECONE_METADATA_KEYS = ["entry_id", "source", "title", "author", "text", "url", "date"]
PINECONE_METADATA_KEYS = ["entry_id", "source", "title", "authors", "text", "url", "date_published"]
PINECONE_NAMESPACE = os.environ.get("PINECONE_NAMESPACE", "normal") # If the finetuned layer is used, this should be "finetuned"

### FINE-TUNING ###
Expand Down

0 comments on commit 15e4bec

Please sign in to comment.