Skip to content

Commit

Permalink
allow any number of email addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
AntounMichael committed Nov 12, 2023
1 parent 3bae482 commit 832ff11
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 59 deletions.
94 changes: 49 additions & 45 deletions document_analysis.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
import re

from VectorDatabase import Lantern, Publication, Fragment
from VectorDatabase import Lantern, Publication
from google_sheets import SheetsApiClient
from prompts import get_qbi_hackathon_prompt
from prompts import get_qbi_hackathon_prompt, METHODS_KEYWORDS

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain import PromptTemplate
from datetime import date
from langchain.vectorstores import FAISS


class DocumentAnalyzer:
"""Takes in a list of publications to analyze, then prompts the chatbot, processes the response, aggregates the results,
and reports the results to the spreadsheet
"""
"""Takes in a list of publications to analyze, then prompts the chatbot, processes the response,
aggregates the results, and reports the results to the spreadsheet
"""

def __init__(self):
self.lantern = Lantern()
self.sheets = SheetsApiClient()
self.llm = LlmHandler()


self.email_addresses = []
self.notification_via_email = True

def analyze_all_unread(self):
"""pulls all new files from Lantern database, evaluates them, and publishes results to google sheets
"""
publications = lantern.getUnreadPublications()
publications = self.lantern.getUnreadPublications()
self.process_publications(publications)

def process_publications(self, publications: [Publication]):
Expand All @@ -34,46 +39,45 @@ def process_publications(self, publications: [Publication]):
Args:
publications ([]): list of publications
"""
query = [f"You are reading a materials and methods section of a scientific paper. Here is the list of structural biology methods {methods_string}.\n\n Did the authors use any methods from the list? \n\n Answer with Yes or No followed by the names of the methods."]

rows = []
hits = 0
for pub in publications:
text_embeddings = self.lantern.get_embeddings_for_pub(pub.id)
classification, response = 0, ''
if self.paper_about_cryoem(text_embeddings):
if self.paper_about_cryoem(text_embeddings):
classification, response = self.analyze_publication(text_embeddings)
hits += classification
else:
#print('paper not about cryo-em')
# print('paper not about cryo-em')
pass
# add date if it's added
rows.append([pub.doi, pub.title, "", str(date.today()), "", int(classification), response, ""])

self.update_spreadsheet(rows, hits)
def update_spreadsheet(rows: [], hits: int, notify=True):

def update_spreadsheet(self, rows: [], hits: int):
"""pushes a list of rows to the spreadsheet and notifies via email
Args:
rows ([]): rows of data to be uploaded to sheet
hits (int): number of positive classifications in the rows
notify (bool): notify via email if True
"""

if hits > len(rows):
raise ValueError(f"Number of hits ({hits}) is greater than the number of entries ({len(rows)}), sus")

#print(rows)

self.sheets.append_rows(rows)

if self.notification_via_email:
self.email(hits, len(rows))

def email(self, hits: int, total: int):
msg = f"""
This batch of paper analysis has concluded.
{len(rows)} papers were analyzed in total over the date range 11/2 - 11/3
{hits} {"were" if (hits != 1) else "was"} classified as having multi-method structural data"""
This batch of paper analysis has concluded.
{total} papers were analyzed in total over the date range 11/2 - 11/3
{hits} {"were" if (hits != 1) else "was"} classified as having multi-method structural data"""

if notify:
sheets.notify_arthur(message=msg)

self.sheets.email(msg, self.email_addresses)

def analyze_publication(self, text_embeddings: []):
"""poses a question about the document, processes the result and returns it
Expand All @@ -88,9 +92,9 @@ def analyze_publication(self, text_embeddings: []):
"""
# NOTE: These very likely need to change
open_ai_emb = OpenAIEmbeddings()
query = get_qbi_hackathon_prompt()
faissIndex = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb)
response = self.llm.evaluate_queries(faissIndex, query)[0]
query = get_qbi_hackathon_prompt(METHODS_KEYWORDS)
faiss_index = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb)
response = self.llm.evaluate_queries(faiss_index, query)[0]
return self.classify_response(response), response

@staticmethod
Expand All @@ -103,14 +107,16 @@ def classify_response(response: str):
Returns:
bool: True if answer to question is "yes"
"""
if result == None:
if response is None:
return False
# this was used to filter out cases where ChatGPT said "Yes, Cryo-EM was used..." which is wrong because we asked it about
# this was used to filter out cases where ChatGPT said "Yes, Cryo-EM was used...",
# which is wrong because we asked it about
# inclusion of non-cryo-em stuff
#if "cryo" in response.lower():
#
# if "cryo" in response.lower():
# return (False, None)
return response.lower().startswith('yes')

@staticmethod
def paper_about_cryoem(text_embeddings: []):
"""checks if the string "cryoem" or "cryo-em" is present in the text
Expand All @@ -121,26 +127,25 @@ def paper_about_cryoem(text_embeddings: []):
Returns:
bool: True if the text mentions cryo-em
"""
return any(re.search("cryo-?em", text, re.IGNORECASE) for text, _ in embeddings)
return any(re.search("cryo-?em", text, re.IGNORECASE) for text, _ in text_embeddings)


class LlmHandler:
"""pulled this straight from the hackathon code, should work though
"""Handles creation of langchain and evaluation of queries
"""

def __init__(self):
self.llm=ChatOpenAI(
temperature=0, model_name="gpt-4", max_tokens=300, request_timeout = 30, max_retries=3
)


self.llm = ChatOpenAI(
temperature=0, model_name="gpt-4", max_tokens=300, request_timeout=30, max_retries=3
)

def evaluate_queries(self, embedding, queries):
chatbot = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k":3})
llm=self.llm,
chain_type="stuff",
retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k": 3})
)

template = """ {query}? """
responses = []
for q in queries:
Expand All @@ -155,11 +160,10 @@ def evaluate_queries(self, embedding, queries):
return responses




def main():
document_analyzer = DocumentAnalyzer()
document_analyzer.analyze_all_unread() #analyzes all new files in lantern db
document_analyzer.analyze_all_unread() # analyzes all new files in lantern db


if __name__ == '__main__':
main()
main()
29 changes: 16 additions & 13 deletions google_sheets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import gspread
import typing


class SheetsApiClient:
"""interface for all functionality with google sheets
Expand All @@ -20,16 +20,17 @@ class SheetsApiClient:
]

def __init__(self):
self.connect()
self.client = self.connect()
self.spreadsheet = self.client.open(type(self).SPREADSHEET_NAME)
self.worksheet = self.spreadsheet.get_worksheet(0)

def connect(self):
@staticmethod
def connect():
"""connects to Google Sheets API service using private key file
"""
try:
secret_file = os.path.join(os.getcwd(), "google_sheets_credentials.json")
self.client = gspread.service_account(secret_file)
return gspread.service_account(secret_file)
except OSError as e:
print(e)

Expand All @@ -48,18 +49,20 @@ def append_rows(self, rows: [[str]]):
self._check_row(row)
self.worksheet.append_rows(rows)

def notify_arthur(self, message: str):
def email(self, message: str, email_addresses: [str]):
"""Shares the spreadsheet with arthur, along with the message in an email
Args:
message (str):
message (str): message to be sent
email_addresses ([str]): recipients of notification
"""
self.spreadsheet.share(
"[email protected]",
perm_type="user",
role="writer",
notify=True,
email_message=message,
)
for email_address in email_addresses:
self.spreadsheet.share(
email_address,
perm_type="user",
role="writer",
notify=True,
email_message=message,
)

@staticmethod
def _check_row(row: []):
Expand Down
2 changes: 1 addition & 1 deletion hackathon_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def main():
{hits} {"were" if ((hits>0) or (hits == 0)) else was} classified as having multi-method structural data
"""
print(msg)
gs.notify_arthur(message=msg)
gs.email(message=msg)


main()
Expand Down

0 comments on commit 832ff11

Please sign in to comment.