diff --git a/biochatter/api_agent/blast.py b/biochatter/api_agent/blast.py index 1eb592b7..cead731b 100644 --- a/biochatter/api_agent/blast.py +++ b/biochatter/api_agent/blast.py @@ -1,7 +1,14 @@ +"""Module for handling BLAST API interactions. + +Provides functionality for building queries, fetching results, and interpreting +BLAST (Basic Local Alignment Search Tool) sequence alignment data. +""" + import re import time import uuid from collections.abc import Callable +from typing import TYPE_CHECKING from urllib.parse import urlencode import requests @@ -9,14 +16,15 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field -from pydantic import BaseModel, Field -from biochatter.llm_connect import Conversation +if TYPE_CHECKING: + from biochatter.llm_connect import Conversation from .abc import BaseFetcher, BaseInterpreter, BaseQueryBuilder BLAST_QUERY_PROMPT = """ -You are a world class algorithm for creating queries in structured formats. Your task is to use NCBI Web APIs to answer genomic questions. +You are a world class algorithm for creating queries in structured formats. Your task is to use NCBI Web APIs to answer +genomic questions. For questions about DNA sequences (other than genome alignments) you can use BLAST by: "[https://blast.ncbi.nlm.nih.gov/blast/Blast.cgi?CMD={Put|Get}&PROGRAM=blastn&MEGABLAST=on&DATABASE=nt&FORMAT_TYPE={XML|Text}&QUERY={sequence}&HITLIST_SIZE={max_hit_size}]". BLAST maps a specific DNA {sequence} to DNA sequences among different specices. @@ -29,23 +37,28 @@ For questions about protein sequences you can also use BLAST, but this time by: [https://blast.ncbi.nlm.nih.gov/Blast.cgi?CMD=Put&PROGRAM=blastp&DATABASE=nr&FORMAT_TYPE=XML&QUERY=sequence&HITLIST_SIZE=max_hit_size] Example Question 2: What do you find about protein sequence: MEEPQSDPSV Use BLASTp for such a question -> https://blast.ncbi.nlm.nih.gov/Blast.cgi?CMD=Put&PROGRAM=blastp&DATABASE=nr&FORMAT_TYPE=XML&QUERY=MEEPQSDPSV&HITLIST_SIZE=10" -""" +""" # noqa: E501 BLAST_SUMMARY_PROMPT = """ - You have to answer this question in a clear and concise manner: {question} Be factual!\n\ - If you are asked what organism a specific sequence belongs to, check the 'Hit_def' fields. If you find a synthetic construct or predicted entry, move to the next one and look for an organism name.\n\ - Try to use the hits with the best identity score to answer the question. If it is not possible, move to the next one.\n\ - Be clear, and if organism names are present in ANY of the results, please include them in the answer. Do not make up information and mention how relevant the found information is based on the identity scores.\n\ - Use the same reasoning for any potential BLAST results. If you find information that is manually curated, please use it and state it. You may also state other results, but always include the context.\n\ - Based on the information given here:\n\ - {context} - """ +You have to answer this question in a clear and concise manner: {question} Be factual!\n\ +If you are asked what organism a specific sequence belongs to, check the 'Hit_def' fields. If you find a synthetic +construct or predicted entry, move to the next one and look for an organism name.\n\ +Try to use the hits with the best identity score to answer the question. If it is not possible, move to the next one.\n\ +Be clear, and if organism names are present in ANY of the results, please include them in the answer. Do not make up +information and mention how relevant the found information is based on the identity scores.\n\ +Use the same reasoning for any potential BLAST results. If you find information that is manually curated, please use it +and state it. You may also state other results, but always include the context.\n\ +Based on the information given here:\n\ +{context} +""" class BlastQueryParameters(BaseModel): - """BlastQuery is a Pydantic model for the parameters of a BLAST query request, - used for configuring and sending a request to the NCBI BLAST query API. The - fields are dynamically configured by the LLM based on the user's question. + """Pydantic model for the parameters of a BLAST query request. + + The class is used for configuring and sending a request to the NCBI BLAST + query API. The fields are dynamically configured by the LLM based on the + user's question. """ @@ -59,15 +72,27 @@ class BlastQueryParameters(BaseModel): ) program: str | None = Field( default="blastn", - description="BLAST program to use, e.g., 'blastn' for nucleotide-nucleotide BLAST, 'blastp' for protein-protein BLAST.", + description=( + "BLAST program to use, e.g., 'blastn' for nucleotide-nucleotide BLAST, " + "'blastp' for protein-protein BLAST." + ), ) database: str | None = Field( default="nt", - description="Database to search, e.g., 'nt' for nucleotide database, 'nr' for non redundant protein database, pdb the Protein Data Bank database, which is used specifically for protein structures, 'refseq_rna' and 'refseq_genomic': specialized databases for RNA sequences and genomic sequences", + description=( + "Database to search, e.g., 'nt' for nucleotide database, 'nr' for " + "non redundant protein database, 'pdb' the Protein Data Bank " + "database, which is used specifically for protein structures, " + "'refseq_rna' and 'refseq_genomic': specialized databases for " + "RNA sequences and genomic sequences" + ), ) query: str | None = Field( None, - description="Nucleotide or protein sequence for the BLAST or blat query, make sure to always keep the entire sequence given.", + description=( + "Nucleotide or protein sequence for the BLAST or blat query, " + "make sure to always keep the entire sequence given." + ), ) format_type: str | None = Field( default="Text", @@ -111,7 +136,9 @@ def create_runnable( query_parameters: "BlastQueryParameters", conversation: "Conversation", ) -> Callable: - """Creates a runnable object for executing queries using the LangChain + """Create a runnable object for executing queries. + + Creates a runnable using the LangChain `create_structured_output_runnable` method. Args: @@ -137,7 +164,9 @@ def parameterise_query( question: str, conversation: "Conversation", ) -> BlastQueryParameters: - """Generates a BlastQuery object based on the given question, prompt, and + """Generate a BlastQuery object. + + Generates the object based on the given question, prompt, and BioChatter conversation. Uses a Pydantic model to define the API fields. Creates a runnable that can be invoked on LLMs that are qualified to parameterise functions. @@ -166,15 +195,17 @@ def parameterise_query( class BlastFetcher(BaseFetcher): - """A class for retrieving API results from BLAST given a parameterised - BlastQuery. + """A class for retrieving API results from BLAST. + + Retrieves results from BLAST given a parameterised BlastQuery. TODO add a limit of characters to be returned from the response.text? """ def _submit_query(self, request_data: BlastQueryParameters) -> str: - """Function to POST the BLAST query and retrieve RID. - It submits the structured BlastQuery obj and return the RID. + """POST the BLAST query and retrieve the RID. + + The method submits the structured BlastQuery object and returns the RID. Args: ---- @@ -205,29 +236,30 @@ def _submit_query(self, request_data: BlastQueryParameters) -> str: # Print the full URL request_data.full_url = full_url print("Full URL built by retriever:\n", request_data.full_url) - response = requests.post(request_data.url, data=data) + response = requests.post(request_data.url, data=data, timeout=10) response.raise_for_status() # Extract RID from response print(response) match = re.search(r"RID = (\w+)", response.text) if match: return match.group(1) - else: - raise ValueError("RID not found in BLAST submission response.") + + msg = "RID not found in BLAST submission response." + raise ValueError(msg) def _fetch_results( self, rid: str, question_uuid: str, retries: int = 10000, - ): - """SECOND function to be called for a BLAST query. - Will look for the RID to fetch the data + ) -> str: + """Fetch BLAST query data given RID. + + The second function to be called for a BLAST query. """ ### ### TO DO: Implement logging for all BLAST queries ### - # log_question_uuid_json(request_data.question_uuid,question, file_name, log_file_path,request_data.full_url) base_url = "https://blast.ncbi.nlm.nih.gov/Blast.cgi" check_status_params = { "CMD": "Get", @@ -242,7 +274,7 @@ def _fetch_results( # Check the status of the BLAST job for attempt in range(retries): - status_response = requests.get(base_url, params=check_status_params) + status_response = requests.get(base_url, params=check_status_params, timeout=10) status_response.raise_for_status() status_text = status_response.text print("evaluating status") @@ -250,33 +282,35 @@ def _fetch_results( print(f"{question_uuid} results not ready, waiting...") time.sleep(15) elif "Status=FAILED" in status_text: - raise RuntimeError("BLAST query FAILED.") + msg = "BLAST query FAILED." + raise RuntimeError(msg) elif "Status=UNKNOWN" in status_text: - raise RuntimeError("BLAST query expired or does not exist.") + msg = "BLAST query expired or does not exist." + raise RuntimeError(msg) elif "Status=READY" in status_text: if "ThereAreHits=yes" in status_text: print(f"{question_uuid} results are ready, retrieving.") results_response = requests.get( base_url, params=get_results_params, + timeout=10, ) results_response.raise_for_status() - # Save the results to a file return results_response.text - else: - return "No hits found" - if attempt == retries - 1: - raise TimeoutError( - "Maximum attempts reached. Results may not be ready.", - ) + return "No hits found" + if attempt == retries - 1: + msg = "Maximum attempts reached. Results may not be ready." + raise TimeoutError(msg) + return None def fetch_results( self, query_model: BlastQueryParameters, retries: int = 20, ) -> str: - """Submit request and fetch results from BLAST API. Wraps individual - submission and retrieval of results. + """Submit request and fetch results from BLAST API. + + Wraps individual submission and retrieval of results. Args: ---- @@ -298,13 +332,15 @@ def fetch_results( class BlastInterpreter(BaseInterpreter): + """A class for interpreting BLAST results.""" + def summarise_results( self, question: str, conversation_factory: Callable, response_text: str, ) -> str: - """Function to extract the answer from the BLAST results. + """Extract the answer from the BLAST results. Args: ---- @@ -333,5 +369,4 @@ def summarise_results( output_parser = StrOutputParser() conversation = conversation_factory() chain = prompt | conversation.chat | output_parser - answer = chain.invoke({"input": {summary_prompt}}) - return answer + return chain.invoke({"input": {summary_prompt}})