Skip to content

Commit

Permalink
Merge pull request #18 from aozalevsky/document-analyzer-script
Browse files Browse the repository at this point in the history
Document analyzer script
  • Loading branch information
AntounMichael authored Nov 12, 2023
2 parents f20c6cd + 02fc7b4 commit 4ca6d96
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 104 deletions.
76 changes: 38 additions & 38 deletions VectorDatabase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,39 @@
import psycopg2
from database_entities import Fragment, Publication


# Class to represent a publication with attributes id, title, pmc, pubmed, and doi
class Publication:
id = ""
title = ""
pmc = ""
pubmed = ""
doi = ""

def __init__(self, id, title, pmc, pubmed, doi):
self.id = id # (DOI) Unique identifier for the publication
self.title = title # Title of the publication
self.pmc = pmc # PubMed Central (PMC) Link
self.pubmed = pubmed # PubMed Link
self.doi = doi # Digital Object Identifier (DOI) Link for the publication


# Class to represent a fragment of a publication with attributes id, header, content, and vector
class Fragment:
# Class variables to store default values for attributes
id = ""
header = ""
content = ""
vector = ""

def __init__(self, id, header, content, vector):
# Constructor to initialize the attributes of the Fragment object

# Set the attributes of the object with the values provided during instantiation
self.id = id # (DOI) Unique identifier for the fragment
self.header = header # Header or title of the fragment
self.content = content # Content or text of the fragment
self.vector = vector # Vector representation of the fragment


# Lantern class that exposes functionality of database to application
class Lantern:
Expand Down Expand Up @@ -249,7 +283,7 @@ def getUnreadPublications(self, delete_unread_entries=True):

if delete_unread_entries:
cursor.execute('DELETE FROM unread;')

conn.commit()
cursor.close()

Expand Down Expand Up @@ -292,49 +326,15 @@ def publicationExists(self, id):
- [(text, embedding)] content of a publication's embeddings
Notes:
"""

def get_embeddings_for_pub(self, id):
texts = []
embeddings = []
if not self.publicationExists(id):
return
return
fragments = self.getAllFragmentsOfPublication(id)
for fragment in fragments:
texts.append(fragment.content)
embeddings.append(fragment.vector)
text_embeddings = list(zip(texts, embeddings))
return text_embeddings

# Class to represent a publication with attributes id, title, pmc, pubmed, and doi
class Publication:

id = ""
title = ""
pmc = ""
pubmed = ""
doi = ""

def __init__(self, id, title, pmc, pubmed, doi):
self.id = id # (DOI) Unique identifier for the publication
self.title = title # Title of the publication
self.pmc = pmc # PubMed Central (PMC) Link
self.pubmed = pubmed # PubMed Link
self.doi = doi # Digital Object Identifier (DOI) Link for the publication

# Class to represent a fragment of a publication with attributes id, header, content, and vector
class Fragment:


# Class variables to store default values for attributes
id = ""
header = ""
content = ""
vector = ""

def __init__(self, id, header, content, vector):
# Constructor to initialize the attributes of the Fragment object

# Set the attributes of the object with the values provided during instantiation
self.id = id # (DOI) Unique identifier for the fragment
self.header = header # Header or title of the fragment
self.content = content # Content or text of the fragment
self.vector = vector # Vector representation of the fragment
2 changes: 1 addition & 1 deletion config.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"Emails": [],
"emails": ["[email protected]", "[email protected]"],
"DEBUG": false
}
138 changes: 87 additions & 51 deletions document_analysis.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,57 @@
import json
import re

from VectorDatabase import Lantern, Publication, Fragment
from VectorDatabase import Lantern, Publication
from google_sheets import SheetsApiClient
from prompts import get_qbi_hackathon_prompt, METHODS_KEYWORDS

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain import PromptTemplate
from datetime import date
from langchain.vectorstores import FAISS


class DocumentAnalyzer:
"""Takes in a list of publications to analyze, then prompts the chatbot, processes the response, aggregates the results,
and reports the results to the spreadsheet
"""

"""Takes in a list of publications to analyze, then prompts the chatbot, processes the response,
aggregates the results, and reports the results to the spreadsheet
"""

CONFIG_PATH = "./config.json"

def __init__(self):
# self.lantern = Lantern()
self.lantern = Lantern()
self.sheets = SheetsApiClient()

self.llm = LlmHandler()

self.email_addresses, self.notification_via_email = self.parse_config()

@staticmethod
def parse_config():
try:
with open(DocumentAnalyzer.CONFIG_PATH, 'r') as config_file:
config_data = json.load(config_file)

# Extracting fields from the config_data
my_list = config_data.get('emails', []) # Default to an empty list if 'my_list' is not present
my_bool = config_data.get('DEBUG', False) # Default to False if 'my_bool' is not present

return my_list, my_bool

except FileNotFoundError:
print(f"Config file '{DocumentAnalyzer.CONFIG_PATH}' not found. Using defaults (no email addresses)")
return [], False
except json.JSONDecodeError as e:
print(f"Error decoding JSON in '{DocumentAnalyzer.CONFIG_PATH}': {e}")
return None, None

def analyze_all_unread(self):
"""pulls all new files from Lantern database, evaluates them, and publishes results to google sheets
"""
publications = self.lantern.getUnreadPublications()
self.process_publications(publications)

def process_publications(self, publications: [Publication]):
"""takes a list of publications, applies retrievalQA and processes responses
Expand All @@ -26,60 +60,62 @@ def process_publications(self, publications: [Publication]):
Args:
publications ([]): list of publications
"""
query = [f"You are reading a materials and methods section of a scientific paper. Here is the list of structural biology methods {methods_string}.\n\n Did the authors use any methods from the list? \n\n Answer with Yes or No followed by the names of the methods."]

rows = []
hits = 0
for pub in publications:
text_embeddings = self.lantern.get_embeddings_for_pub(pub.id)
classification, response = 0, ''
if self.paper_about_cryoem(text_embeddings):
if self.paper_about_cryoem(text_embeddings):
classification, response = self.analyze_publication(text_embeddings)
hits += classification
else:
#print('paper not about cryo-em')
# print('paper not about cryo-em')
pass
# add date if it's added
rows.append([pub.doi, pub.title, "", str(date.today()), "", int(classification), response, ""])

self.update_spreadsheet(rows, hits)
def update_spreadsheet(rows: [], hits: int, notify=True):

def update_spreadsheet(self, rows: [], hits: int):
"""pushes a list of rows to the spreadsheet and notifies via email
Args:
rows ([]): rows of data to be uploaded to sheet
hits (int): number of positive classifications in the rows
notify (bool): notify via email if True
"""

if hits > len(rows):
raise ValueError(f"Number of hits ({hits}) is greater than the number of entries ({len(rows)}), sus")

#print(rows)

self.sheets.append_rows(rows)

if self.notification_via_email:
self.email(hits, len(rows))

def email(self, hits: int, total: int):
msg = f"""
This batch of paper analysis has concluded.
{len(rows)} papers were analyzed in total over the date range 11/2 - 11/3
{hits} {"were" if (hits != 1) else "was"} classified as having multi-method structural data"""
This batch of paper analysis has concluded.
{total} papers were analyzed in total over the date range 11/2 - 11/3
{hits} {"were" if (hits != 1) else "was"} classified as having multi-method structural data"""

if notify:
sheets.notify_arthur(message=msg)

self.sheets.email(msg, self.email_addresses)

def analyze_publication(self, publication: Publication):
"""leaving this blank for now because i think the way these are stored is changing
def analyze_publication(self, text_embeddings: []):
"""poses a question about the document, processes the result and returns it
NOTE: for now, only uses the hackathon question, might add more later
Args:
publication (Publication): publication to be analyzed
text_embeddings ([]): list of (embedding, text) pairs from document to be analyzed
Returns:
bool: classification of response to query as positive (True) or negative (False)
str: response from chatGPT
"""
#faissIndex = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb)
#result = llm.evaluate_queries(faissIndex, query)
response = None
# NOTE: These very likely need to change
open_ai_emb = OpenAIEmbeddings()
query = get_qbi_hackathon_prompt(METHODS_KEYWORDS)
faiss_index = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb)
response = self.llm.evaluate_queries(faiss_index, query)[0]
return self.classify_response(response), response

@staticmethod
Expand All @@ -92,14 +128,16 @@ def classify_response(response: str):
Returns:
bool: True if answer to question is "yes"
"""
if result == None:
if response is None:
return False
# this was used to filter out cases where ChatGPT said "Yes, Cryo-EM was used..." which is wrong because we asked it about
# this was used to filter out cases where ChatGPT said "Yes, Cryo-EM was used...",
# which is wrong because we asked it about
# inclusion of non-cryo-em stuff
#if "cryo" in response.lower():
#
# if "cryo" in response.lower():
# return (False, None)
return response.lower().startswith('yes')

@staticmethod
def paper_about_cryoem(text_embeddings: []):
"""checks if the string "cryoem" or "cryo-em" is present in the text
Expand All @@ -110,45 +148,43 @@ def paper_about_cryoem(text_embeddings: []):
Returns:
bool: True if the text mentions cryo-em
"""
return any(re.search("cryo-?em", text, re.IGNORECASE) for text, _ in embeddings)
return any(re.search("cryo-?em", text, re.IGNORECASE) for text, _ in text_embeddings)


class LlmHandler:
"""pulled this straight from the hackathon code, should work though
"""Handles creation of langchain and evaluation of queries
"""

def __init__(self):
self.llm=ChatOpenAI(
temperature=0, model_name="gpt-4", max_tokens=300, request_timeout = 30, max_retries=3
)


self.llm = ChatOpenAI(
temperature=0, model_name="gpt-4", max_tokens=300, request_timeout=30, max_retries=3
)

def evaluate_queries(self, embedding, queries):
chatbot = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k":3})
llm=self.llm,
chain_type="stuff",
retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k": 3})
)

template = """ {query}? """
response = []
responses = []
for q in queries:
prompt = PromptTemplate(
input_variables=["query"],
template=template,
)

response.append(chatbot.run(
responses.append(chatbot.run(
prompt.format(query=q)
))
return response


return responses


def main():
x = DocumentAnalyzer()
l = LlmHandler()
document_analyzer = DocumentAnalyzer()
#document_analyzer.analyze_all_unread() # analyzes all new files in lantern db


if __name__ == '__main__':
main()
main()
Loading

0 comments on commit 4ca6d96

Please sign in to comment.