diff --git a/VectorDatabase.py b/VectorDatabase.py index f75485b..f924c0d 100644 --- a/VectorDatabase.py +++ b/VectorDatabase.py @@ -1,5 +1,39 @@ import psycopg2 -from database_entities import Fragment, Publication + + +# Class to represent a publication with attributes id, title, pmc, pubmed, and doi +class Publication: + id = "" + title = "" + pmc = "" + pubmed = "" + doi = "" + + def __init__(self, id, title, pmc, pubmed, doi): + self.id = id # (DOI) Unique identifier for the publication + self.title = title # Title of the publication + self.pmc = pmc # PubMed Central (PMC) Link + self.pubmed = pubmed # PubMed Link + self.doi = doi # Digital Object Identifier (DOI) Link for the publication + + +# Class to represent a fragment of a publication with attributes id, header, content, and vector +class Fragment: + # Class variables to store default values for attributes + id = "" + header = "" + content = "" + vector = "" + + def __init__(self, id, header, content, vector): + # Constructor to initialize the attributes of the Fragment object + + # Set the attributes of the object with the values provided during instantiation + self.id = id # (DOI) Unique identifier for the fragment + self.header = header # Header or title of the fragment + self.content = content # Content or text of the fragment + self.vector = vector # Vector representation of the fragment + # Lantern class that exposes functionality of database to application class Lantern: @@ -249,7 +283,7 @@ def getUnreadPublications(self, delete_unread_entries=True): if delete_unread_entries: cursor.execute('DELETE FROM unread;') - + conn.commit() cursor.close() @@ -292,49 +326,15 @@ def publicationExists(self, id): - [(text, embedding)] content of a publication's embeddings Notes: """ + def get_embeddings_for_pub(self, id): texts = [] embeddings = [] if not self.publicationExists(id): - return + return fragments = self.getAllFragmentsOfPublication(id) for fragment in fragments: texts.append(fragment.content) embeddings.append(fragment.vector) text_embeddings = list(zip(texts, embeddings)) return text_embeddings - -# Class to represent a publication with attributes id, title, pmc, pubmed, and doi -class Publication: - - id = "" - title = "" - pmc = "" - pubmed = "" - doi = "" - - def __init__(self, id, title, pmc, pubmed, doi): - self.id = id # (DOI) Unique identifier for the publication - self.title = title # Title of the publication - self.pmc = pmc # PubMed Central (PMC) Link - self.pubmed = pubmed # PubMed Link - self.doi = doi # Digital Object Identifier (DOI) Link for the publication - -# Class to represent a fragment of a publication with attributes id, header, content, and vector -class Fragment: - - - # Class variables to store default values for attributes - id = "" - header = "" - content = "" - vector = "" - - def __init__(self, id, header, content, vector): - # Constructor to initialize the attributes of the Fragment object - - # Set the attributes of the object with the values provided during instantiation - self.id = id # (DOI) Unique identifier for the fragment - self.header = header # Header or title of the fragment - self.content = content # Content or text of the fragment - self.vector = vector # Vector representation of the fragment diff --git a/config.json b/config.json index 0a315c3..f40fb96 100644 --- a/config.json +++ b/config.json @@ -1,4 +1,4 @@ { - "Emails": [], + "emails": ["aozalevsky@gmail.com", "steveurkel99@gmail.com"], "DEBUG": false } \ No newline at end of file diff --git a/document_analysis.py b/document_analysis.py index 573c149..0326e24 100644 --- a/document_analysis.py +++ b/document_analysis.py @@ -1,23 +1,57 @@ +import json +import re -from VectorDatabase import Lantern, Publication, Fragment +from VectorDatabase import Lantern, Publication from google_sheets import SheetsApiClient +from prompts import get_qbi_hackathon_prompt, METHODS_KEYWORDS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chat_models import ChatOpenAI from langchain.chains import RetrievalQA +from langchain.embeddings.openai import OpenAIEmbeddings from langchain import PromptTemplate from datetime import date +from langchain.vectorstores import FAISS class DocumentAnalyzer: - """Takes in a list of publications to analyze, then prompts the chatbot, processes the response, aggregates the results, - and reports the results to the spreadsheet - """ - + """Takes in a list of publications to analyze, then prompts the chatbot, processes the response, + aggregates the results, and reports the results to the spreadsheet + """ + + CONFIG_PATH = "./config.json" + def __init__(self): - # self.lantern = Lantern() + self.lantern = Lantern() self.sheets = SheetsApiClient() - + self.llm = LlmHandler() + + self.email_addresses, self.notification_via_email = self.parse_config() + + @staticmethod + def parse_config(): + try: + with open(DocumentAnalyzer.CONFIG_PATH, 'r') as config_file: + config_data = json.load(config_file) + + # Extracting fields from the config_data + my_list = config_data.get('emails', []) # Default to an empty list if 'my_list' is not present + my_bool = config_data.get('DEBUG', False) # Default to False if 'my_bool' is not present + + return my_list, my_bool + + except FileNotFoundError: + print(f"Config file '{DocumentAnalyzer.CONFIG_PATH}' not found. Using defaults (no email addresses)") + return [], False + except json.JSONDecodeError as e: + print(f"Error decoding JSON in '{DocumentAnalyzer.CONFIG_PATH}': {e}") + return None, None + + def analyze_all_unread(self): + """pulls all new files from Lantern database, evaluates them, and publishes results to google sheets + """ + publications = self.lantern.getUnreadPublications() + self.process_publications(publications) def process_publications(self, publications: [Publication]): """takes a list of publications, applies retrievalQA and processes responses @@ -26,60 +60,62 @@ def process_publications(self, publications: [Publication]): Args: publications ([]): list of publications """ - query = [f"You are reading a materials and methods section of a scientific paper. Here is the list of structural biology methods {methods_string}.\n\n Did the authors use any methods from the list? \n\n Answer with Yes or No followed by the names of the methods."] rows = [] hits = 0 for pub in publications: text_embeddings = self.lantern.get_embeddings_for_pub(pub.id) classification, response = 0, '' - if self.paper_about_cryoem(text_embeddings): + if self.paper_about_cryoem(text_embeddings): classification, response = self.analyze_publication(text_embeddings) hits += classification else: - #print('paper not about cryo-em') + # print('paper not about cryo-em') pass # add date if it's added rows.append([pub.doi, pub.title, "", str(date.today()), "", int(classification), response, ""]) self.update_spreadsheet(rows, hits) - - def update_spreadsheet(rows: [], hits: int, notify=True): + + def update_spreadsheet(self, rows: [], hits: int): """pushes a list of rows to the spreadsheet and notifies via email Args: rows ([]): rows of data to be uploaded to sheet hits (int): number of positive classifications in the rows - notify (bool): notify via email if True """ - if hits > len(rows): raise ValueError(f"Number of hits ({hits}) is greater than the number of entries ({len(rows)}), sus") - - #print(rows) + self.sheets.append_rows(rows) + + if self.notification_via_email: + self.email(hits, len(rows)) + + def email(self, hits: int, total: int): msg = f""" - This batch of paper analysis has concluded. - {len(rows)} papers were analyzed in total over the date range 11/2 - 11/3 - {hits} {"were" if (hits != 1) else "was"} classified as having multi-method structural data""" +This batch of paper analysis has concluded. +{total} papers were analyzed in total over the date range 11/2 - 11/3 +{hits} {"were" if (hits != 1) else "was"} classified as having multi-method structural data""" - if notify: - sheets.notify_arthur(message=msg) - + self.sheets.email(msg, self.email_addresses) - def analyze_publication(self, publication: Publication): - """leaving this blank for now because i think the way these are stored is changing + def analyze_publication(self, text_embeddings: []): + """poses a question about the document, processes the result and returns it + NOTE: for now, only uses the hackathon question, might add more later Args: - publication (Publication): publication to be analyzed + text_embeddings ([]): list of (embedding, text) pairs from document to be analyzed Returns: bool: classification of response to query as positive (True) or negative (False) str: response from chatGPT """ - #faissIndex = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb) - #result = llm.evaluate_queries(faissIndex, query) - response = None + # NOTE: These very likely need to change + open_ai_emb = OpenAIEmbeddings() + query = get_qbi_hackathon_prompt(METHODS_KEYWORDS) + faiss_index = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb) + response = self.llm.evaluate_queries(faiss_index, query)[0] return self.classify_response(response), response @staticmethod @@ -92,14 +128,16 @@ def classify_response(response: str): Returns: bool: True if answer to question is "yes" """ - if result == None: + if response is None: return False - # this was used to filter out cases where ChatGPT said "Yes, Cryo-EM was used..." which is wrong because we asked it about + # this was used to filter out cases where ChatGPT said "Yes, Cryo-EM was used...", + # which is wrong because we asked it about # inclusion of non-cryo-em stuff - #if "cryo" in response.lower(): + # + # if "cryo" in response.lower(): # return (False, None) return response.lower().startswith('yes') - + @staticmethod def paper_about_cryoem(text_embeddings: []): """checks if the string "cryoem" or "cryo-em" is present in the text @@ -110,45 +148,43 @@ def paper_about_cryoem(text_embeddings: []): Returns: bool: True if the text mentions cryo-em """ - return any(re.search("cryo-?em", text, re.IGNORECASE) for text, _ in embeddings) + return any(re.search("cryo-?em", text, re.IGNORECASE) for text, _ in text_embeddings) class LlmHandler: - """pulled this straight from the hackathon code, should work though + """Handles creation of langchain and evaluation of queries """ def __init__(self): - self.llm=ChatOpenAI( - temperature=0, model_name="gpt-4", max_tokens=300, request_timeout = 30, max_retries=3 - ) - - + self.llm = ChatOpenAI( + temperature=0, model_name="gpt-4", max_tokens=300, request_timeout=30, max_retries=3 + ) + def evaluate_queries(self, embedding, queries): chatbot = RetrievalQA.from_chain_type( - llm=self.llm, - chain_type="stuff", - retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k":3}) + llm=self.llm, + chain_type="stuff", + retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k": 3}) ) - + template = """ {query}? """ - response = [] + responses = [] for q in queries: prompt = PromptTemplate( input_variables=["query"], template=template, ) - response.append(chatbot.run( + responses.append(chatbot.run( prompt.format(query=q) )) - return response - - + return responses def main(): - x = DocumentAnalyzer() - l = LlmHandler() + document_analyzer = DocumentAnalyzer() + #document_analyzer.analyze_all_unread() # analyzes all new files in lantern db + if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/google_sheets.py b/google_sheets.py index 123e117..0d0e38d 100644 --- a/google_sheets.py +++ b/google_sheets.py @@ -1,6 +1,6 @@ import os import gspread -import typing + class SheetsApiClient: """interface for all functionality with google sheets @@ -20,16 +20,17 @@ class SheetsApiClient: ] def __init__(self): - self.connect() + self.client = self.connect() self.spreadsheet = self.client.open(type(self).SPREADSHEET_NAME) self.worksheet = self.spreadsheet.get_worksheet(0) - def connect(self): + @staticmethod + def connect(): """connects to Google Sheets API service using private key file """ try: secret_file = os.path.join(os.getcwd(), "google_sheets_credentials.json") - self.client = gspread.service_account(secret_file) + return gspread.service_account(secret_file) except OSError as e: print(e) @@ -43,23 +44,26 @@ def append_row(self, row: [str]): def append_rows(self, rows: [[str]]): """ Adds a list of rows to the spreadsheet, each row must follow SCHEMA: + WARNING: Assumes that the [rows] list will never exceed the maximum throughput of one api call """ for row in rows: self._check_row(row) self.worksheet.append_rows(rows) - def notify_arthur(self, message: str): + def email(self, message: str, email_addresses: [str]): """Shares the spreadsheet with arthur, along with the message in an email Args: - message (str): + message (str): message to be sent + email_addresses ([str]): recipients of notification """ - self.spreadsheet.share( - "aozalevsky@gmail.com", - perm_type="user", - role="writer", - notify=True, - email_message=message, - ) + for email_address in email_addresses: + self.spreadsheet.share( + email_address, + perm_type="user", + role="reader", + notify=True, + email_message=message, + ) @staticmethod def _check_row(row: []): diff --git a/hackathon_runner.py b/hackathon_runner.py index 761c6e7..27b8319 100644 --- a/hackathon_runner.py +++ b/hackathon_runner.py @@ -260,7 +260,7 @@ def main(): {hits} {"were" if ((hits>0) or (hits == 0)) else was} classified as having multi-method structural data """ print(msg) - gs.notify_arthur(message=msg) + gs.email(message=msg) main()