diff --git a/404.html b/404.html index 0fd1f6ac..a55da65d 100644 --- a/404.html +++ b/404.html @@ -12,7 +12,7 @@ - + diff --git a/api-docs/api-calling/index.html b/api-docs/api-calling/index.html index 57aaf533..845aece3 100644 --- a/api-docs/api-calling/index.html +++ b/api-docs/api-calling/index.html @@ -18,7 +18,7 @@ - + @@ -1155,6 +1155,21 @@ + + @@ -1239,6 +1254,21 @@ + +
  • @@ -1632,6 +1662,21 @@ + +
  • @@ -1716,6 +1761,21 @@ + +
  • @@ -2048,7 +2108,10 @@

    Source code in biochatter/api_agent/abc.py -
     83
    +                
     80
    + 81
    + 82
    + 83
      84
      85
      86
    @@ -2068,35 +2131,30 @@ 

    100 101 102 -103 -104 -105 -106 -107

    class BaseFetcher(ABC):
    -    """
    -    Abstract base class for fetchers. A fetcher is responsible for submitting
    -    queries (in systems where submission and fetching are separate) and fetching
    -    and saving results of queries. It has to implement a `fetch_results()`
    -    method, which can wrap a multi-step procedure to submit and retrieve. Should
    -    implement retry method to account for connectivity issues or processing
    -    times.
    -    """
    -
    -    @abstractmethod
    -    def fetch_results(
    -        self,
    -        query_model: BaseModel,
    -        retries: Optional[int] = 3,
    -    ):
    -        """
    -        Fetches results by submitting a query. Can implement a multi-step
    -        procedure if submitting and fetching are distinct processes (e.g., in
    -        the case of long processing times as in the case of BLAST).
    -
    -        Args:
    -            query_model: the Pydantic model describing the parameterised query
    -        """
    -        pass
    +103
    class BaseFetcher(ABC):
    +    """Abstract base class for fetchers. A fetcher is responsible for submitting
    +    queries (in systems where submission and fetching are separate) and fetching
    +    and saving results of queries. It has to implement a `fetch_results()`
    +    method, which can wrap a multi-step procedure to submit and retrieve. Should
    +    implement retry method to account for connectivity issues or processing
    +    times.
    +    """
    +
    +    @abstractmethod
    +    def fetch_results(
    +        self,
    +        query_model: BaseModel,
    +        retries: int | None = 3,
    +    ):
    +        """Fetches results by submitting a query. Can implement a multi-step
    +        procedure if submitting and fetching are distinct processes (e.g., in
    +        the case of long processing times as in the case of BLAST).
    +
    +        Args:
    +        ----
    +            query_model: the Pydantic model describing the parameterised query
    +
    +        """
     
    @@ -2130,41 +2188,17 @@

    Parameters:

    - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - query_model - - BaseModel - -
    -

    the Pydantic model describing the parameterised query

    -
    -
    - required -
    +
    +
    query_model: the Pydantic model describing the parameterised query
    +
    Source code in biochatter/api_agent/abc.py -
     93
    +              
     89
    + 90
    + 91
    + 92
    + 93
      94
      95
      96
    @@ -2174,25 +2208,21 @@ 

    100 101 102 -103 -104 -105 -106 -107

    @abstractmethod
    -def fetch_results(
    -    self,
    -    query_model: BaseModel,
    -    retries: Optional[int] = 3,
    -):
    -    """
    -    Fetches results by submitting a query. Can implement a multi-step
    -    procedure if submitting and fetching are distinct processes (e.g., in
    -    the case of long processing times as in the case of BLAST).
    -
    -    Args:
    -        query_model: the Pydantic model describing the parameterised query
    -    """
    -    pass
    +103
    @abstractmethod
    +def fetch_results(
    +    self,
    +    query_model: BaseModel,
    +    retries: int | None = 3,
    +):
    +    """Fetches results by submitting a query. Can implement a multi-step
    +    procedure if submitting and fetching are distinct processes (e.g., in
    +    the case of long processing times as in the case of BLAST).
    +
    +    Args:
    +    ----
    +        query_model: the Pydantic model describing the parameterised query
    +
    +    """
     
    @@ -2234,7 +2264,11 @@

    Source code in biochatter/api_agent/abc.py -
    110
    +                
    106
    +107
    +108
    +109
    +110
     111
     112
     113
    @@ -2264,43 +2298,41 @@ 

    137 138 139 -140 -141 -142 -143

    class BaseInterpreter(ABC):
    -    """
    -    Abstract base class for result interpreters. The interpreter is aware of the
    -    nature and structure of the results and can extract and summarise
    -    information from them.
    -    """
    -
    -    @abstractmethod
    -    def summarise_results(
    -        self,
    -        question: str,
    -        conversation_factory: Callable,
    -        response_text: str,
    -    ) -> str:
    -        """
    -        Summarises an answer based on the given parameters.
    -
    -        Args:
    -            question (str): The question that was asked.
    +140
    class BaseInterpreter(ABC):
    +    """Abstract base class for result interpreters. The interpreter is aware of the
    +    nature and structure of the results and can extract and summarise
    +    information from them.
    +    """
    +
    +    @abstractmethod
    +    def summarise_results(
    +        self,
    +        question: str,
    +        conversation_factory: Callable,
    +        response_text: str,
    +    ) -> str:
    +        """Summarises an answer based on the given parameters.
    +
    +        Args:
    +        ----
    +            question (str): The question that was asked.
    +
    +            conversation_factory (Callable): A function that creates a
    +                BioChatter conversation.
    +
    +            response_text (str): The response.text returned from the request.
     
    -            conversation_factory (Callable): A function that creates a
    -                BioChatter conversation.
    -
    -            response_text (str): The response.text returned from the request.
    -
    -        Returns:
    -            A summary of the answer.
    -
    -        Todo:
    -            Genericise (remove file path and n_lines parameters, and use a
    -            generic way to get the results). The child classes should manage the
    -            specifics of the results.
    -        """
    -        pass
    +        Returns:
    +        -------
    +            A summary of the answer.
    +
    +        Todo:
    +        ----
    +            Genericise (remove file path and n_lines parameters, and use a
    +            generic way to get the results). The child classes should manage the
    +            specifics of the results.
    +
    +        """
     
    @@ -2332,104 +2364,31 @@

    Summarises an answer based on the given parameters.

    - - -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    The question that was asked.

    -
    -
    - required -
    - conversation_factory - - Callable - -
    -

    A function that creates a -BioChatter conversation.

    -
    -
    - required -
    - response_text - - str - -
    -

    The response.text returned from the request.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    TypeDescription
    - str - -
    -

    A summary of the answer.

    -
    -
    - - -
    - Todo -

    Genericise (remove file path and n_lines parameters, and use a +


    +
    question (str): The question that was asked.
    +
    +conversation_factory (Callable): A function that creates a
    +    BioChatter conversation.
    +
    +response_text (str): The response.text returned from the request.
    +
    +
    +
    A summary of the answer.
    +
    +
    Todo:
    +
    Genericise (remove file path and n_lines parameters, and use a
     generic way to get the results). The child classes should manage the
    -specifics of the results.

    -
    +specifics of the results. +

    +
    Source code in biochatter/api_agent/abc.py -
    117
    +              
    112
    +113
    +114
    +115
    +116
    +117
     118
     119
     120
    @@ -2452,36 +2411,35 @@ 

    137 138 139 -140 -141 -142 -143

    @abstractmethod
    -def summarise_results(
    -    self,
    -    question: str,
    -    conversation_factory: Callable,
    -    response_text: str,
    -) -> str:
    -    """
    -    Summarises an answer based on the given parameters.
    -
    -    Args:
    -        question (str): The question that was asked.
    +140
    @abstractmethod
    +def summarise_results(
    +    self,
    +    question: str,
    +    conversation_factory: Callable,
    +    response_text: str,
    +) -> str:
    +    """Summarises an answer based on the given parameters.
    +
    +    Args:
    +    ----
    +        question (str): The question that was asked.
    +
    +        conversation_factory (Callable): A function that creates a
    +            BioChatter conversation.
    +
    +        response_text (str): The response.text returned from the request.
     
    -        conversation_factory (Callable): A function that creates a
    -            BioChatter conversation.
    -
    -        response_text (str): The response.text returned from the request.
    -
    -    Returns:
    -        A summary of the answer.
    -
    -    Todo:
    -        Genericise (remove file path and n_lines parameters, and use a
    -        generic way to get the results). The child classes should manage the
    -        specifics of the results.
    -    """
    -    pass
    +    Returns:
    +    -------
    +        A summary of the answer.
    +
    +    Todo:
    +    ----
    +        Genericise (remove file path and n_lines parameters, and use a
    +        generic way to get the results). The child classes should manage the
    +        specifics of the results.
    +
    +    """
     
    @@ -2521,7 +2479,8 @@

    Source code in biochatter/api_agent/abc.py -
    11
    +                
    10
    +11
     12
     13
     14
    @@ -2587,79 +2546,74 @@ 

    74 75 76 -77 -78 -79 -80

    class BaseQueryBuilder(ABC):
    -    """
    -    An abstract base class for query builders.
    -    """
    -
    -    @property
    -    def structured_output_prompt(self) -> ChatPromptTemplate:
    -        """
    -        Defines a structured output prompt template. This provides a default
    -        implementation for an API agent that can be overridden by subclasses to
    -        return a ChatPromptTemplate-compatible object.
    -        """
    -        return ChatPromptTemplate.from_messages(
    -            [
    +77
    class BaseQueryBuilder(ABC):
    +    """An abstract base class for query builders."""
    +
    +    @property
    +    def structured_output_prompt(self) -> ChatPromptTemplate:
    +        """Defines a structured output prompt template. This provides a default
    +        implementation for an API agent that can be overridden by subclasses to
    +        return a ChatPromptTemplate-compatible object.
    +        """
    +        return ChatPromptTemplate.from_messages(
    +            [
    +                (
    +                    "system",
    +                    "You are a world class algorithm for extracting information in structured formats.",
    +                ),
                     (
    -                    "system",
    -                    "You are a world class algorithm for extracting information in structured formats.",
    +                    "human",
    +                    "Use the given format to extract information from the following input: {input}",
                     ),
    -                (
    -                    "human",
    -                    "Use the given format to extract information from the following input: {input}",
    -                ),
    -                ("human", "Tip: Make sure to answer in the correct format"),
    -            ]
    -        )
    -
    -    @abstractmethod
    -    def create_runnable(
    -        self,
    -        query_parameters: "BaseModel",
    -        conversation: "Conversation",
    -    ) -> Callable:
    -        """
    -        Creates a runnable object for executing queries. Must be implemented by
    -        subclasses. Should use the LangChain `create_structured_output_runnable`
    -        method to generate the Callable.
    +                ("human", "Tip: Make sure to answer in the correct format"),
    +            ],
    +        )
    +
    +    @abstractmethod
    +    def create_runnable(
    +        self,
    +        query_parameters: "BaseModel",
    +        conversation: "Conversation",
    +    ) -> Callable:
    +        """Creates a runnable object for executing queries. Must be implemented by
    +        subclasses. Should use the LangChain `create_structured_output_runnable`
    +        method to generate the Callable.
    +
    +        Args:
    +        ----
    +            query_parameters: A Pydantic data model that specifies the fields of
    +                the API that should be queried.
     
    -        Args:
    -            query_parameters: A Pydantic data model that specifies the fields of
    -                the API that should be queried.
    -
    -            conversation: A BioChatter conversation object.
    +            conversation: A BioChatter conversation object.
    +
    +        Returns:
    +        -------
    +            A Callable object that can execute the query.
     
    -        Returns:
    -            A Callable object that can execute the query.
    -        """
    -        pass
    -
    -    @abstractmethod
    -    def parameterise_query(
    -        self,
    -        question: str,
    -        conversation: "Conversation",
    -    ) -> BaseModel:
    -        """
    -
    -        Parameterises a query object (a Pydantic model with the fields of the
    -        API) based on the given question using a BioChatter conversation
    -        instance. Must be implemented by subclasses.
    -
    -        Args:
    -            question (str): The question to be answered.
    -
    -            conversation: The BioChatter conversation object containing the LLM
    -                that should parameterise the query.
    +        """
    +
    +    @abstractmethod
    +    def parameterise_query(
    +        self,
    +        question: str,
    +        conversation: "Conversation",
    +    ) -> BaseModel:
    +        """Parameterises a query object (a Pydantic model with the fields of the
    +        API) based on the given question using a BioChatter conversation
    +        instance. Must be implemented by subclasses.
    +
    +        Args:
    +        ----
    +            question (str): The question to be answered.
    +
    +            conversation: The BioChatter conversation object containing the LLM
    +                that should parameterise the query.
    +
    +        Returns:
    +        -------
    +            A parameterised instance of the query object (Pydantic BaseModel)
     
    -        Returns:
    -            A parameterised instance of the query object (Pydantic BaseModel)
    -        """
    -        pass
    +        """
     
    @@ -2716,81 +2670,23 @@

    query_parameters: A Pydantic data model that specifies the fields of
    +    the API that should be queried.
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - query_parameters - - BaseModel - -
    -

    A Pydantic data model that specifies the fields of -the API that should be queried.

    -
    -
    - required -
    - conversation - - Conversation - -
    -

    A BioChatter conversation object.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    TypeDescription
    - Callable - -
    -

    A Callable object that can execute the query.

    -
    -
    +conversation: A BioChatter conversation object. +

    +
    +
    A Callable object that can execute the query.
    +
    Source code in biochatter/api_agent/abc.py -
    37
    +              
    33
    +34
    +35
    +36
    +37
     38
     39
     40
    @@ -2807,30 +2703,28 @@ 

    51 52 53 -54 -55 -56 -57

    @abstractmethod
    -def create_runnable(
    -    self,
    -    query_parameters: "BaseModel",
    -    conversation: "Conversation",
    -) -> Callable:
    -    """
    -    Creates a runnable object for executing queries. Must be implemented by
    -    subclasses. Should use the LangChain `create_structured_output_runnable`
    -    method to generate the Callable.
    +54
    @abstractmethod
    +def create_runnable(
    +    self,
    +    query_parameters: "BaseModel",
    +    conversation: "Conversation",
    +) -> Callable:
    +    """Creates a runnable object for executing queries. Must be implemented by
    +    subclasses. Should use the LangChain `create_structured_output_runnable`
    +    method to generate the Callable.
    +
    +    Args:
    +    ----
    +        query_parameters: A Pydantic data model that specifies the fields of
    +            the API that should be queried.
     
    -    Args:
    -        query_parameters: A Pydantic data model that specifies the fields of
    -            the API that should be queried.
    -
    -        conversation: A BioChatter conversation object.
    +        conversation: A BioChatter conversation object.
    +
    +    Returns:
    +    -------
    +        A Callable object that can execute the query.
     
    -    Returns:
    -        A Callable object that can execute the query.
    -    """
    -    pass
    +    """
     
    @@ -2855,81 +2749,22 @@

    question (str): The question to be answered.
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    The question to be answered.

    -
    -
    - required -
    - conversation - - Conversation - -
    -

    The BioChatter conversation object containing the LLM -that should parameterise the query.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    TypeDescription
    - BaseModel - -
    -

    A parameterised instance of the query object (Pydantic BaseModel)

    -
    -
    +conversation: The BioChatter conversation object containing the LLM + that should parameterise the query. +
    +
    +
    A parameterised instance of the query object (Pydantic BaseModel)
    +
    Source code in biochatter/api_agent/abc.py -
    59
    +              
    56
    +57
    +58
    +59
     60
     61
     62
    @@ -2947,31 +2782,28 @@ 

    74 75 76 -77 -78 -79 -80

    @abstractmethod
    -def parameterise_query(
    -    self,
    -    question: str,
    -    conversation: "Conversation",
    -) -> BaseModel:
    -    """
    -
    -    Parameterises a query object (a Pydantic model with the fields of the
    -    API) based on the given question using a BioChatter conversation
    -    instance. Must be implemented by subclasses.
    -
    -    Args:
    -        question (str): The question to be answered.
    -
    -        conversation: The BioChatter conversation object containing the LLM
    -            that should parameterise the query.
    +77
    @abstractmethod
    +def parameterise_query(
    +    self,
    +    question: str,
    +    conversation: "Conversation",
    +) -> BaseModel:
    +    """Parameterises a query object (a Pydantic model with the fields of the
    +    API) based on the given question using a BioChatter conversation
    +    instance. Must be implemented by subclasses.
    +
    +    Args:
    +    ----
    +        question (str): The question to be answered.
    +
    +        conversation: The BioChatter conversation object containing the LLM
    +            that should parameterise the query.
    +
    +    Returns:
    +    -------
    +        A parameterised instance of the query object (Pydantic BaseModel)
     
    -    Returns:
    -        A parameterised instance of the query object (Pydantic BaseModel)
    -    """
    -    pass
    +    """
     
    @@ -3040,7 +2872,8 @@

    Source code in biochatter/api_agent/api_agent.py -
     24
    +                
     23
    + 24
      25
      26
      27
    @@ -3157,95 +2990,93 @@ 

    138 139 140 -141 -142 -143 -144

    class APIAgent:
    -    def __init__(
    -        self,
    -        conversation_factory: Callable,
    -        query_builder: "BaseQueryBuilder",
    -        fetcher: "BaseFetcher",
    -        interpreter: "BaseInterpreter",
    -    ):
    -        """
    -
    -        API agent class to interact with a tool's API for querying and fetching
    -        results.  The query fields have to be defined in a Pydantic model
    -        (`BaseModel`) and used (i.e., parameterised by the LLM) in the query
    -        builder. Specific API agents are defined in submodules of this directory
    -        (`api_agent`). The agent's logic is implemented in the `execute` method.
    -
    -        Attributes:
    -            conversation_factory (Callable): A function used to create a
    -                BioChatter conversation, providing LLM access.
    -
    -            query_builder (BaseQueryBuilder): An instance of a child of the
    -                BaseQueryBuilder class.
    -
    -            result_fetcher (BaseFetcher): An instance of a child of the
    -                BaseFetcher class.
    -
    -            result_interpreter (BaseInterpreter): An instance of a child of the
    -                BaseInterpreter class.
    -        """
    -        self.conversation_factory = conversation_factory
    -        self.query_builder = query_builder
    -        self.fetcher = fetcher
    -        self.interpreter = interpreter
    -        self.final_answer = None
    -
    -    def parameterise_query(self, question: str) -> Optional[BaseModel]:
    -        """
    -        Use LLM to parameterise a query (a Pydantic model) based on the given
    -        question using a BioChatter conversation instance.
    -        """
    -        try:
    -            conversation = self.conversation_factory()
    -            return self.query_builder.parameterise_query(question, conversation)
    -        except Exception as e:
    -            print(f"Error generating query: {e}")
    -            return None
    -
    -    def fetch_results(self, query_model: str) -> Optional[str]:
    -        """
    -        Fetch the results of the query using the individual API's implementation
    -        (either single-step or submit-retrieve).
    -
    -        Args:
    -            query_model: the parameterised query Pydantic model
    -        """
    -        try:
    -            return self.fetcher.fetch_results(query_model, 100)
    -        except Exception as e:
    -            print(f"Error fetching results: {e}")
    -            return None
    -
    -    def summarise_results(
    -        self, question: str, response_text: str
    -    ) -> Optional[str]:
    -        """
    -        Summarise the retrieved results to extract the answer to the question.
    -        """
    -        try:
    -            return self.interpreter.summarise_results(
    -                question=question,
    -                conversation_factory=self.conversation_factory,
    -                response_text=response_text,
    -            )
    -        except Exception as e:
    -            print(f"Error extracting answer: {e}")
    -            return None
    -
    -    def execute(self, question: str) -> Optional[str]:
    -        """
    -        Wrapper that uses class methods to execute the API agent logic. Consists
    -        of 1) query generation, 2) query submission, 3) results fetching, and
    -        4) answer extraction. The final answer is stored in the final_answer
    -        attribute.
    -
    -        Args:
    -            question (str): The question to be answered.
    +141
    class APIAgent:
    +    def __init__(
    +        self,
    +        conversation_factory: Callable,
    +        query_builder: "BaseQueryBuilder",
    +        fetcher: "BaseFetcher",
    +        interpreter: "BaseInterpreter",
    +    ):
    +        """API agent class to interact with a tool's API for querying and fetching
    +        results.  The query fields have to be defined in a Pydantic model
    +        (`BaseModel`) and used (i.e., parameterised by the LLM) in the query
    +        builder. Specific API agents are defined in submodules of this directory
    +        (`api_agent`). The agent's logic is implemented in the `execute` method.
    +
    +        Attributes
    +        ----------
    +            conversation_factory (Callable): A function used to create a
    +                BioChatter conversation, providing LLM access.
    +
    +            query_builder (BaseQueryBuilder): An instance of a child of the
    +                BaseQueryBuilder class.
    +
    +            result_fetcher (BaseFetcher): An instance of a child of the
    +                BaseFetcher class.
    +
    +            result_interpreter (BaseInterpreter): An instance of a child of the
    +                BaseInterpreter class.
    +
    +        """
    +        self.conversation_factory = conversation_factory
    +        self.query_builder = query_builder
    +        self.fetcher = fetcher
    +        self.interpreter = interpreter
    +        self.final_answer = None
    +
    +    def parameterise_query(self, question: str) -> BaseModel | None:
    +        """Use LLM to parameterise a query (a Pydantic model) based on the given
    +        question using a BioChatter conversation instance.
    +        """
    +        try:
    +            conversation = self.conversation_factory()
    +            return self.query_builder.parameterise_query(question, conversation)
    +        except Exception as e:
    +            print(f"Error generating query: {e}")
    +            return None
    +
    +    def fetch_results(self, query_model: str) -> str | None:
    +        """Fetch the results of the query using the individual API's implementation
    +        (either single-step or submit-retrieve).
    +
    +        Args:
    +        ----
    +            query_model: the parameterised query Pydantic model
    +
    +        """
    +        try:
    +            return self.fetcher.fetch_results(query_model, 100)
    +        except Exception as e:
    +            print(f"Error fetching results: {e}")
    +            return None
    +
    +    def summarise_results(
    +        self,
    +        question: str,
    +        response_text: str,
    +    ) -> str | None:
    +        """Summarise the retrieved results to extract the answer to the question."""
    +        try:
    +            return self.interpreter.summarise_results(
    +                question=question,
    +                conversation_factory=self.conversation_factory,
    +                response_text=response_text,
    +            )
    +        except Exception as e:
    +            print(f"Error extracting answer: {e}")
    +            return None
    +
    +    def execute(self, question: str) -> str | None:
    +        """Wrapper that uses class methods to execute the API agent logic. Consists
    +        of 1) query generation, 2) query submission, 3) results fetching, and
    +        4) answer extraction. The final answer is stored in the final_answer
    +        attribute.
    +
    +        Args:
    +        ----
    +            question (str): The question to be answered.
    +
             """
             # Generate query
             try:
    @@ -3277,10 +3108,7 @@ 

    return final_answer def get_description(self, tool_name: str, tool_desc: str): - return ( - f"This API agent interacts with {tool_name}'s API for querying and " - f"fetching results. {tool_desc}" - ) + return f"This API agent interacts with {tool_name}'s API for querying and fetching results. {tool_desc}"

    @@ -3312,72 +3140,24 @@

    Attributes

    +
    conversation_factory (Callable): A function used to create a
    +    BioChatter conversation, providing LLM access.
     
    +query_builder (BaseQueryBuilder): An instance of a child of the
    +    BaseQueryBuilder class.
     
    -

    Attributes:

    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescription
    conversation_factory - Callable - -
    -

    A function used to create a -BioChatter conversation, providing LLM access.

    -
    -
    query_builder - BaseQueryBuilder - -
    -

    An instance of a child of the -BaseQueryBuilder class.

    -
    -
    result_fetcher - BaseFetcher - -
    -

    An instance of a child of the -BaseFetcher class.

    -
    -
    result_interpreter - BaseInterpreter - -
    -

    An instance of a child of the -BaseInterpreter class.

    -
    -
    +result_fetcher (BaseFetcher): An instance of a child of the + BaseFetcher class. + +result_interpreter (BaseInterpreter): An instance of a child of the + BaseInterpreter class. +
    Source code in biochatter/api_agent/api_agent.py -
    25
    +              
    24
    +25
     26
     27
     28
    @@ -3408,40 +3188,39 @@ 

    53 54 55 -56 -57

    def __init__(
    -    self,
    -    conversation_factory: Callable,
    -    query_builder: "BaseQueryBuilder",
    -    fetcher: "BaseFetcher",
    -    interpreter: "BaseInterpreter",
    -):
    -    """
    -
    -    API agent class to interact with a tool's API for querying and fetching
    -    results.  The query fields have to be defined in a Pydantic model
    -    (`BaseModel`) and used (i.e., parameterised by the LLM) in the query
    -    builder. Specific API agents are defined in submodules of this directory
    -    (`api_agent`). The agent's logic is implemented in the `execute` method.
    -
    -    Attributes:
    -        conversation_factory (Callable): A function used to create a
    -            BioChatter conversation, providing LLM access.
    -
    -        query_builder (BaseQueryBuilder): An instance of a child of the
    -            BaseQueryBuilder class.
    -
    -        result_fetcher (BaseFetcher): An instance of a child of the
    -            BaseFetcher class.
    -
    -        result_interpreter (BaseInterpreter): An instance of a child of the
    -            BaseInterpreter class.
    -    """
    -    self.conversation_factory = conversation_factory
    -    self.query_builder = query_builder
    -    self.fetcher = fetcher
    -    self.interpreter = interpreter
    -    self.final_answer = None
    +56
    def __init__(
    +    self,
    +    conversation_factory: Callable,
    +    query_builder: "BaseQueryBuilder",
    +    fetcher: "BaseFetcher",
    +    interpreter: "BaseInterpreter",
    +):
    +    """API agent class to interact with a tool's API for querying and fetching
    +    results.  The query fields have to be defined in a Pydantic model
    +    (`BaseModel`) and used (i.e., parameterised by the LLM) in the query
    +    builder. Specific API agents are defined in submodules of this directory
    +    (`api_agent`). The agent's logic is implemented in the `execute` method.
    +
    +    Attributes
    +    ----------
    +        conversation_factory (Callable): A function used to create a
    +            BioChatter conversation, providing LLM access.
    +
    +        query_builder (BaseQueryBuilder): An instance of a child of the
    +            BaseQueryBuilder class.
    +
    +        result_fetcher (BaseFetcher): An instance of a child of the
    +            BaseFetcher class.
    +
    +        result_interpreter (BaseInterpreter): An instance of a child of the
    +            BaseInterpreter class.
    +
    +    """
    +    self.conversation_factory = conversation_factory
    +    self.query_builder = query_builder
    +    self.fetcher = fetcher
    +    self.interpreter = interpreter
    +    self.final_answer = None
     
    @@ -3463,41 +3242,14 @@

    - - -

    Parameters:

    - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    The question to be answered.

    -
    -
    - required -
    +
    +
    question (str): The question to be answered.
    +
    Source code in biochatter/api_agent/api_agent.py -
    101
    +              
    100
    +101
     102
     103
     104
    @@ -3534,15 +3286,16 @@ 

    135 136 137 -138

    def execute(self, question: str) -> Optional[str]:
    -    """
    -    Wrapper that uses class methods to execute the API agent logic. Consists
    -    of 1) query generation, 2) query submission, 3) results fetching, and
    -    4) answer extraction. The final answer is stored in the final_answer
    -    attribute.
    -
    -    Args:
    -        question (str): The question to be answered.
    +138
    def execute(self, question: str) -> str | None:
    +    """Wrapper that uses class methods to execute the API agent logic. Consists
    +    of 1) query generation, 2) query submission, 3) results fetching, and
    +    4) answer extraction. The final answer is stored in the final_answer
    +    attribute.
    +
    +    Args:
    +    ----
    +        question (str): The question to be answered.
    +
         """
         # Generate query
         try:
    @@ -3591,41 +3344,15 @@ 

    Parameters:

    - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - query_model - - str - -
    -

    the parameterised query Pydantic model

    -
    -
    - required -
    +
    +
    query_model: the parameterised query Pydantic model
    +
    Source code in biochatter/api_agent/api_agent.py -
    71
    +              
    69
    +70
    +71
     72
     73
     74
    @@ -3636,20 +3363,20 @@ 

    79 80 81 -82 -83

    def fetch_results(self, query_model: str) -> Optional[str]:
    -    """
    -    Fetch the results of the query using the individual API's implementation
    -    (either single-step or submit-retrieve).
    -
    -    Args:
    -        query_model: the parameterised query Pydantic model
    -    """
    -    try:
    -        return self.fetcher.fetch_results(query_model, 100)
    -    except Exception as e:
    -        print(f"Error fetching results: {e}")
    -        return None
    +82
    def fetch_results(self, query_model: str) -> str | None:
    +    """Fetch the results of the query using the individual API's implementation
    +    (either single-step or submit-retrieve).
    +
    +    Args:
    +    ----
    +        query_model: the parameterised query Pydantic model
    +
    +    """
    +    try:
    +        return self.fetcher.fetch_results(query_model, 100)
    +    except Exception as e:
    +        print(f"Error fetching results: {e}")
    +        return None
     
    @@ -3672,7 +3399,8 @@

    Source code in biochatter/api_agent/api_agent.py -
    59
    +              
    58
    +59
     60
     61
     62
    @@ -3680,19 +3408,16 @@ 

    64 65 66 -67 -68 -69

    def parameterise_query(self, question: str) -> Optional[BaseModel]:
    -    """
    -    Use LLM to parameterise a query (a Pydantic model) based on the given
    -    question using a BioChatter conversation instance.
    -    """
    -    try:
    -        conversation = self.conversation_factory()
    -        return self.query_builder.parameterise_query(question, conversation)
    -    except Exception as e:
    -        print(f"Error generating query: {e}")
    -        return None
    +67
    def parameterise_query(self, question: str) -> BaseModel | None:
    +    """Use LLM to parameterise a query (a Pydantic model) based on the given
    +    question using a BioChatter conversation instance.
    +    """
    +    try:
    +        conversation = self.conversation_factory()
    +        return self.query_builder.parameterise_query(question, conversation)
    +    except Exception as e:
    +        print(f"Error generating query: {e}")
    +        return None
     
    @@ -3714,7 +3439,8 @@

    Source code in biochatter/api_agent/api_agent.py -
    85
    +              
    84
    +85
     86
     87
     88
    @@ -3727,22 +3453,21 @@ 

    95 96 97 -98 -99

    def summarise_results(
    -    self, question: str, response_text: str
    -) -> Optional[str]:
    -    """
    -    Summarise the retrieved results to extract the answer to the question.
    -    """
    -    try:
    -        return self.interpreter.summarise_results(
    -            question=question,
    -            conversation_factory=self.conversation_factory,
    -            response_text=response_text,
    -        )
    -    except Exception as e:
    -        print(f"Error extracting answer: {e}")
    -        return None
    +98
    def summarise_results(
    +    self,
    +    question: str,
    +    response_text: str,
    +) -> str | None:
    +    """Summarise the retrieved results to extract the answer to the question."""
    +    try:
    +        return self.interpreter.summarise_results(
    +            question=question,
    +            conversation_factory=self.conversation_factory,
    +            response_text=response_text,
    +        )
    +    except Exception as e:
    +        print(f"Error extracting answer: {e}")
    +        return None
     
    @@ -3817,11 +3542,7 @@

    Source code in biochatter/api_agent/blast.py -
    164
    -165
    -166
    -167
    -168
    +                
    168
     169
     170
     171
    @@ -3938,128 +3659,148 @@ 

    282 283 284 -285

    class BlastFetcher(BaseFetcher):
    -    """
    -    A class for retrieving API results from BLAST given a parameterised
    -    BlastQuery.
    -
    -    TODO add a limit of characters to be returned from the response.text?
    -    """
    +285
    +286
    +287
    +288
    +289
    +290
    +291
    +292
    +293
    +294
    +295
    +296
    +297
    class BlastFetcher(BaseFetcher):
    +    """A class for retrieving API results from BLAST given a parameterised
    +    BlastQuery.
     
    -    def _submit_query(self, request_data: BlastQueryParameters) -> str:
    -        """Function to POST the BLAST query and retrieve RID.
    -        It submits the structured BlastQuery obj and return the RID.
    -
    -        Args:
    -            request_data: BlastQuery object containing the BLAST query
    -                parameters.
    -        Returns:
    -            str: The Request ID (RID) for the submitted BLAST query.
    -        """
    -        data = {
    -            "CMD": request_data.cmd,
    -            "PROGRAM": request_data.program,
    -            "DATABASE": request_data.database,
    -            "QUERY": request_data.query,
    -            "FORMAT_TYPE": request_data.format_type,
    -            "MEGABLAST": request_data.megablast,
    -            "HITLIST_SIZE": request_data.max_hits,
    -        }
    -        # Include any other_params if provided
    -        if request_data.other_params:
    -            data.update(request_data.other_params)
    -        # Make the API call
    -        query_string = urlencode(data)
    -        # Combine base URL with the query string
    -        full_url = f"{request_data.url}?{query_string}"
    -        # Print the full URL
    -        request_data.full_url = full_url
    -        print("Full URL built by retriever:\n", request_data.full_url)
    -        response = requests.post(request_data.url, data=data)
    -        response.raise_for_status()
    -        # Extract RID from response
    -        print(response)
    -        match = re.search(r"RID = (\w+)", response.text)
    -        if match:
    -            return match.group(1)
    -        else:
    -            raise ValueError("RID not found in BLAST submission response.")
    -
    -    def _fetch_results(
    -        self,
    -        rid: str,
    -        question_uuid: str,
    -        retries: int = 10000,
    -    ):
    -        """SECOND function to be called for a BLAST query.
    -        Will look for the RID to fetch the data
    -        """
    -        ###
    -        ###    TO DO: Implement logging for all BLAST queries
    -        ###
    -        # log_question_uuid_json(request_data.question_uuid,question, file_name, log_file_path,request_data.full_url)
    -        base_url = "https://blast.ncbi.nlm.nih.gov/Blast.cgi"
    -        check_status_params = {
    -            "CMD": "Get",
    -            "FORMAT_OBJECT": "SearchInfo",
    -            "RID": rid,
    -        }
    -        get_results_params = {
    -            "CMD": "Get",
    -            "FORMAT_TYPE": "XML",
    -            "RID": rid,
    -        }
    -
    -        # Check the status of the BLAST job
    -        for attempt in range(retries):
    -            status_response = requests.get(base_url, params=check_status_params)
    -            status_response.raise_for_status()
    -            status_text = status_response.text
    -            print("evaluating status")
    -            if "Status=WAITING" in status_text:
    -                print(f"{question_uuid} results not ready, waiting...")
    -                time.sleep(15)
    -            elif "Status=FAILED" in status_text:
    -                raise RuntimeError("BLAST query FAILED.")
    -            elif "Status=UNKNOWN" in status_text:
    -                raise RuntimeError("BLAST query expired or does not exist.")
    -            elif "Status=READY" in status_text:
    -                if "ThereAreHits=yes" in status_text:
    -                    print(f"{question_uuid} results are ready, retrieving.")
    -                    results_response = requests.get(
    -                        base_url, params=get_results_params
    -                    )
    -                    results_response.raise_for_status()
    -                    # Save the results to a file
    -                    return results_response.text
    -                else:
    -                    return "No hits found"
    -        if attempt == retries - 1:
    -            raise TimeoutError(
    -                "Maximum attempts reached. Results may not be ready."
    -            )
    -
    -    def fetch_results(
    -        self, query_model: BlastQueryParameters, retries: int = 20
    -    ) -> str:
    -        """
    -        Submit request and fetch results from BLAST API. Wraps individual
    -        submission and retrieval of results.
    -
    -        Args:
    -            query_model: the Pydantic model of the query
    -
    -            retries: the number of maximum retries
    -
    -        Returns:
    -            str: the result from the BLAST API
    -        """
    -        rid = self._submit_query(request_data=query_model)
    -        return self._fetch_results(
    -            rid=rid,
    -            question_uuid=query_model.question_uuid,
    -            retries=retries,
    -        )
    +    TODO add a limit of characters to be returned from the response.text?
    +    """
    +
    +    def _submit_query(self, request_data: BlastQueryParameters) -> str:
    +        """Function to POST the BLAST query and retrieve RID.
    +        It submits the structured BlastQuery obj and return the RID.
    +
    +        Args:
    +        ----
    +            request_data: BlastQuery object containing the BLAST query
    +                parameters.
    +
    +        Returns:
    +        -------
    +            str: The Request ID (RID) for the submitted BLAST query.
    +
    +        """
    +        data = {
    +            "CMD": request_data.cmd,
    +            "PROGRAM": request_data.program,
    +            "DATABASE": request_data.database,
    +            "QUERY": request_data.query,
    +            "FORMAT_TYPE": request_data.format_type,
    +            "MEGABLAST": request_data.megablast,
    +            "HITLIST_SIZE": request_data.max_hits,
    +        }
    +        # Include any other_params if provided
    +        if request_data.other_params:
    +            data.update(request_data.other_params)
    +        # Make the API call
    +        query_string = urlencode(data)
    +        # Combine base URL with the query string
    +        full_url = f"{request_data.url}?{query_string}"
    +        # Print the full URL
    +        request_data.full_url = full_url
    +        print("Full URL built by retriever:\n", request_data.full_url)
    +        response = requests.post(request_data.url, data=data)
    +        response.raise_for_status()
    +        # Extract RID from response
    +        print(response)
    +        match = re.search(r"RID = (\w+)", response.text)
    +        if match:
    +            return match.group(1)
    +        else:
    +            raise ValueError("RID not found in BLAST submission response.")
    +
    +    def _fetch_results(
    +        self,
    +        rid: str,
    +        question_uuid: str,
    +        retries: int = 10000,
    +    ):
    +        """SECOND function to be called for a BLAST query.
    +        Will look for the RID to fetch the data
    +        """
    +        ###
    +        ###    TO DO: Implement logging for all BLAST queries
    +        ###
    +        # log_question_uuid_json(request_data.question_uuid,question, file_name, log_file_path,request_data.full_url)
    +        base_url = "https://blast.ncbi.nlm.nih.gov/Blast.cgi"
    +        check_status_params = {
    +            "CMD": "Get",
    +            "FORMAT_OBJECT": "SearchInfo",
    +            "RID": rid,
    +        }
    +        get_results_params = {
    +            "CMD": "Get",
    +            "FORMAT_TYPE": "XML",
    +            "RID": rid,
    +        }
    +
    +        # Check the status of the BLAST job
    +        for attempt in range(retries):
    +            status_response = requests.get(base_url, params=check_status_params)
    +            status_response.raise_for_status()
    +            status_text = status_response.text
    +            print("evaluating status")
    +            if "Status=WAITING" in status_text:
    +                print(f"{question_uuid} results not ready, waiting...")
    +                time.sleep(15)
    +            elif "Status=FAILED" in status_text:
    +                raise RuntimeError("BLAST query FAILED.")
    +            elif "Status=UNKNOWN" in status_text:
    +                raise RuntimeError("BLAST query expired or does not exist.")
    +            elif "Status=READY" in status_text:
    +                if "ThereAreHits=yes" in status_text:
    +                    print(f"{question_uuid} results are ready, retrieving.")
    +                    results_response = requests.get(
    +                        base_url,
    +                        params=get_results_params,
    +                    )
    +                    results_response.raise_for_status()
    +                    # Save the results to a file
    +                    return results_response.text
    +                else:
    +                    return "No hits found"
    +        if attempt == retries - 1:
    +            raise TimeoutError(
    +                "Maximum attempts reached. Results may not be ready.",
    +            )
    +
    +    def fetch_results(
    +        self,
    +        query_model: BlastQueryParameters,
    +        retries: int = 20,
    +    ) -> str:
    +        """Submit request and fetch results from BLAST API. Wraps individual
    +        submission and retrieval of results.
    +
    +        Args:
    +        ----
    +            query_model: the Pydantic model of the query
    +
    +            retries: the number of maximum retries
    +
    +        Returns:
    +        -------
    +            str: the result from the BLAST API
    +
    +        """
    +        rid = self._submit_query(request_data=query_model)
    +        return self._fetch_results(
    +            rid=rid,
    +            question_uuid=query_model.question_uuid,
    +            retries=retries,
    +        )
     
    @@ -4088,88 +3829,18 @@

    query_model: the Pydantic model of the query
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - query_model - - BlastQueryParameters - -
    -

    the Pydantic model of the query

    -
    -
    - required -
    - retries - - int - -
    -

    the number of maximum retries

    -
    -
    - 20 -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    Name TypeDescription
    str - str - -
    -

    the result from the BLAST API

    -
    -
    +retries: the number of maximum retries +

    +
    +
    str: the result from the BLAST API
    +
    Source code in biochatter/api_agent/blast.py -
    265
    -266
    -267
    -268
    -269
    -270
    -271
    -272
    -273
    +              
    273
     274
     275
     276
    @@ -4181,27 +3852,43 @@ 

    282 283 284 -285

    def fetch_results(
    -    self, query_model: BlastQueryParameters, retries: int = 20
    -) -> str:
    -    """
    -    Submit request and fetch results from BLAST API. Wraps individual
    -    submission and retrieval of results.
    -
    -    Args:
    -        query_model: the Pydantic model of the query
    -
    -        retries: the number of maximum retries
    -
    -    Returns:
    -        str: the result from the BLAST API
    -    """
    -    rid = self._submit_query(request_data=query_model)
    -    return self._fetch_results(
    -        rid=rid,
    -        question_uuid=query_model.question_uuid,
    -        retries=retries,
    -    )
    +285
    +286
    +287
    +288
    +289
    +290
    +291
    +292
    +293
    +294
    +295
    +296
    +297
    def fetch_results(
    +    self,
    +    query_model: BlastQueryParameters,
    +    retries: int = 20,
    +) -> str:
    +    """Submit request and fetch results from BLAST API. Wraps individual
    +    submission and retrieval of results.
    +
    +    Args:
    +    ----
    +        query_model: the Pydantic model of the query
    +
    +        retries: the number of maximum retries
    +
    +    Returns:
    +    -------
    +        str: the result from the BLAST API
    +
    +    """
    +    rid = self._submit_query(request_data=query_model)
    +    return self._fetch_results(
    +        rid=rid,
    +        question_uuid=query_model.question_uuid,
    +        retries=retries,
    +    )
     
    @@ -4239,19 +3926,7 @@

    Source code in biochatter/api_agent/blast.py -
    288
    -289
    -290
    -291
    -292
    -293
    -294
    -295
    -296
    -297
    -298
    -299
    -300
    +                
    300
     301
     302
     303
    @@ -4274,42 +3949,58 @@ 

    320 321 322 -323

    class BlastInterpreter(BaseInterpreter):
    -    def summarise_results(
    -        self,
    -        question: str,
    -        conversation_factory: Callable,
    -        response_text: str,
    -    ) -> str:
    -        """
    -        Function to extract the answer from the BLAST results.
    -
    -        Args:
    -            question (str): The question to be answered.
    -            conversation_factory: A BioChatter conversation object.
    -            response_text (str): The response.text returned by NCBI.
    -
    -        Returns:
    -            str: The extracted answer from the BLAST results.
    -
    -        """
    -        prompt = ChatPromptTemplate.from_messages(
    -            [
    -                (
    -                    "system",
    -                    "You are a world class molecular biologist who knows everything about NCBI and BLAST results.",
    -                ),
    -                ("user", "{input}"),
    -            ]
    -        )
    -        summary_prompt = BLAST_SUMMARY_PROMPT.format(
    -            question=question, context=response_text
    -        )
    -        output_parser = StrOutputParser()
    -        conversation = conversation_factory()
    -        chain = prompt | conversation.chat | output_parser
    -        answer = chain.invoke({"input": {summary_prompt}})
    -        return answer
    +323
    +324
    +325
    +326
    +327
    +328
    +329
    +330
    +331
    +332
    +333
    +334
    +335
    +336
    +337
    class BlastInterpreter(BaseInterpreter):
    +    def summarise_results(
    +        self,
    +        question: str,
    +        conversation_factory: Callable,
    +        response_text: str,
    +    ) -> str:
    +        """Function to extract the answer from the BLAST results.
    +
    +        Args:
    +        ----
    +            question (str): The question to be answered.
    +            conversation_factory: A BioChatter conversation object.
    +            response_text (str): The response.text returned by NCBI.
    +
    +        Returns:
    +        -------
    +            str: The extracted answer from the BLAST results.
    +
    +        """
    +        prompt = ChatPromptTemplate.from_messages(
    +            [
    +                (
    +                    "system",
    +                    "You are a world class molecular biologist who knows everything about NCBI and BLAST results.",
    +                ),
    +                ("user", "{input}"),
    +            ],
    +        )
    +        summary_prompt = BLAST_SUMMARY_PROMPT.format(
    +            question=question,
    +            context=response_text,
    +        )
    +        output_parser = StrOutputParser()
    +        conversation = conversation_factory()
    +        chain = prompt | conversation.chat | output_parser
    +        answer = chain.invoke({"input": {summary_prompt}})
    +        return answer
     
    @@ -4337,108 +4028,18 @@

    Function to extract the answer from the BLAST results.

    - - -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    The question to be answered.

    -
    -
    - required -
    - conversation_factory - - Callable - -
    -

    A BioChatter conversation object.

    -
    -
    - required -
    - response_text - - str - -
    -

    The response.text returned by NCBI.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    Name TypeDescription
    str - str - -
    -

    The extracted answer from the BLAST results.

    -
    -
    +
    +
    question (str): The question to be answered.
    +conversation_factory: A BioChatter conversation object.
    +response_text (str): The response.text returned by NCBI.
    +
    +
    +
    str: The extracted answer from the BLAST results.
    +
    Source code in biochatter/api_agent/blast.py -
    289
    -290
    -291
    -292
    -293
    -294
    -295
    -296
    -297
    -298
    -299
    -300
    -301
    +              
    301
     302
     303
     304
    @@ -4460,41 +4061,57 @@ 

    320 321 322 -323

    def summarise_results(
    -    self,
    -    question: str,
    -    conversation_factory: Callable,
    -    response_text: str,
    -) -> str:
    -    """
    -    Function to extract the answer from the BLAST results.
    -
    -    Args:
    -        question (str): The question to be answered.
    -        conversation_factory: A BioChatter conversation object.
    -        response_text (str): The response.text returned by NCBI.
    -
    -    Returns:
    -        str: The extracted answer from the BLAST results.
    -
    -    """
    -    prompt = ChatPromptTemplate.from_messages(
    -        [
    -            (
    -                "system",
    -                "You are a world class molecular biologist who knows everything about NCBI and BLAST results.",
    -            ),
    -            ("user", "{input}"),
    -        ]
    -    )
    -    summary_prompt = BLAST_SUMMARY_PROMPT.format(
    -        question=question, context=response_text
    -    )
    -    output_parser = StrOutputParser()
    -    conversation = conversation_factory()
    -    chain = prompt | conversation.chat | output_parser
    -    answer = chain.invoke({"input": {summary_prompt}})
    -    return answer
    +323
    +324
    +325
    +326
    +327
    +328
    +329
    +330
    +331
    +332
    +333
    +334
    +335
    +336
    +337
    def summarise_results(
    +    self,
    +    question: str,
    +    conversation_factory: Callable,
    +    response_text: str,
    +) -> str:
    +    """Function to extract the answer from the BLAST results.
    +
    +    Args:
    +    ----
    +        question (str): The question to be answered.
    +        conversation_factory: A BioChatter conversation object.
    +        response_text (str): The response.text returned by NCBI.
    +
    +    Returns:
    +    -------
    +        str: The extracted answer from the BLAST results.
    +
    +    """
    +    prompt = ChatPromptTemplate.from_messages(
    +        [
    +            (
    +                "system",
    +                "You are a world class molecular biologist who knows everything about NCBI and BLAST results.",
    +            ),
    +            ("user", "{input}"),
    +        ],
    +    )
    +    summary_prompt = BLAST_SUMMARY_PROMPT.format(
    +        question=question,
    +        context=response_text,
    +    )
    +    output_parser = StrOutputParser()
    +    conversation = conversation_factory()
    +    chain = prompt | conversation.chat | output_parser
    +    answer = chain.invoke({"input": {summary_prompt}})
    +    return answer
     
    @@ -4589,7 +4206,11 @@

    158 159 160 -161

    class BlastQueryBuilder(BaseQueryBuilder):
    +161
    +162
    +163
    +164
    +165
    class BlastQueryBuilder(BaseQueryBuilder):
         """A class for building a BlastQuery object."""
     
         def create_runnable(
    @@ -4597,54 +4218,58 @@ 

    query_parameters: "BlastQueryParameters", conversation: "Conversation", ) -> Callable: - """ - Creates a runnable object for executing queries using the LangChain - `create_structured_output_runnable` method. - - Args: + """Creates a runnable object for executing queries using the LangChain + `create_structured_output_runnable` method. + + Args: + ---- query_parameters: A Pydantic data model that specifies the fields of the API that should be queried. conversation: A BioChatter conversation object. Returns: - A Callable object that can execute the query. - """ - return create_structured_output_runnable( - output_schema=query_parameters, - llm=conversation.chat, - prompt=self.structured_output_prompt, - ) - - def parameterise_query( - self, - question: str, - conversation: "Conversation", - ) -> BlastQueryParameters: - """ - Generates a BlastQuery object based on the given question, prompt, and - BioChatter conversation. Uses a Pydantic model to define the API fields. - Creates a runnable that can be invoked on LLMs that are qualified to - parameterise functions. - - Args: - question (str): The question to be answered. - - conversation: The conversation object used for parameterising the - BlastQuery. - - Returns: - BlastQuery: the parameterised query object (Pydantic model) - """ - runnable = self.create_runnable( - query_parameters=BlastQueryParameters, - conversation=conversation, - ) - blast_call_obj = runnable.invoke( - {"input": f"Answer:\n{question} based on:\n {BLAST_QUERY_PROMPT}"} - ) - blast_call_obj.question_uuid = str(uuid.uuid4()) - return blast_call_obj + ------- + A Callable object that can execute the query. + + """ + return create_structured_output_runnable( + output_schema=query_parameters, + llm=conversation.chat, + prompt=self.structured_output_prompt, + ) + + def parameterise_query( + self, + question: str, + conversation: "Conversation", + ) -> BlastQueryParameters: + """Generates a BlastQuery object based on the given question, prompt, and + BioChatter conversation. Uses a Pydantic model to define the API fields. + Creates a runnable that can be invoked on LLMs that are qualified to + parameterise functions. + + Args: + ---- + question (str): The question to be answered. + + conversation: The conversation object used for parameterising the + BlastQuery. + + Returns: + ------- + BlastQuery: the parameterised query object (Pydantic model) + + """ + runnable = self.create_runnable( + query_parameters=BlastQueryParameters, + conversation=conversation, + ) + blast_call_obj = runnable.invoke( + {"input": f"Answer:\n{question} based on:\n {BLAST_QUERY_PROMPT}"}, + ) + blast_call_obj.question_uuid = str(uuid.uuid4()) + return blast_call_obj

    @@ -4673,77 +4298,15 @@

    query_parameters: A Pydantic data model that specifies the fields of
    +    the API that should be queried.
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - query_parameters - - BlastQueryParameters - -
    -

    A Pydantic data model that specifies the fields of -the API that should be queried.

    -
    -
    - required -
    - conversation - - Conversation - -
    -

    A BioChatter conversation object.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    TypeDescription
    - Callable - -
    -

    A Callable object that can execute the query.

    -
    -
    +conversation: A BioChatter conversation object. +

    +
    +
    A Callable object that can execute the query.
    +
    Source code in biochatter/api_agent/blast.py @@ -4769,29 +4332,33 @@

    128 129 130 -131

    def create_runnable(
    +131
    +132
    +133
    def create_runnable(
         self,
         query_parameters: "BlastQueryParameters",
         conversation: "Conversation",
     ) -> Callable:
    -    """
    -    Creates a runnable object for executing queries using the LangChain
    -    `create_structured_output_runnable` method.
    -
    -    Args:
    +    """Creates a runnable object for executing queries using the LangChain
    +    `create_structured_output_runnable` method.
    +
    +    Args:
    +    ----
             query_parameters: A Pydantic data model that specifies the fields of
                 the API that should be queried.
     
             conversation: A BioChatter conversation object.
     
         Returns:
    -        A Callable object that can execute the query.
    -    """
    -    return create_structured_output_runnable(
    -        output_schema=query_parameters,
    -        llm=conversation.chat,
    -        prompt=self.structured_output_prompt,
    -    )
    +    -------
    +        A Callable object that can execute the query.
    +
    +    """
    +    return create_structured_output_runnable(
    +        output_schema=query_parameters,
    +        llm=conversation.chat,
    +        prompt=self.structured_output_prompt,
    +    )
     
    @@ -4813,83 +4380,19 @@

    question (str): The question to be answered.
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    The question to be answered.

    -
    -
    - required -
    - conversation - - Conversation - -
    -

    The conversation object used for parameterising the -BlastQuery.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    Name TypeDescription
    BlastQuery - BlastQueryParameters - -
    -

    the parameterised query object (Pydantic model)

    -
    -
    +conversation: The conversation object used for parameterising the + BlastQuery. +
    +
    +
    BlastQuery: the parameterised query object (Pydantic model)
    +
    Source code in biochatter/api_agent/blast.py -
    133
    -134
    -135
    +              
    135
     136
     137
     138
    @@ -4915,35 +4418,41 @@ 

    158 159 160 -161

    def parameterise_query(
    -    self,
    -    question: str,
    -    conversation: "Conversation",
    -) -> BlastQueryParameters:
    -    """
    -    Generates a BlastQuery object based on the given question, prompt, and
    -    BioChatter conversation. Uses a Pydantic model to define the API fields.
    -    Creates a runnable that can be invoked on LLMs that are qualified to
    -    parameterise functions.
    -
    -    Args:
    -        question (str): The question to be answered.
    -
    -        conversation: The conversation object used for parameterising the
    -            BlastQuery.
    -
    -    Returns:
    -        BlastQuery: the parameterised query object (Pydantic model)
    -    """
    -    runnable = self.create_runnable(
    -        query_parameters=BlastQueryParameters,
    -        conversation=conversation,
    -    )
    -    blast_call_obj = runnable.invoke(
    -        {"input": f"Answer:\n{question} based on:\n {BLAST_QUERY_PROMPT}"}
    -    )
    -    blast_call_obj.question_uuid = str(uuid.uuid4())
    -    return blast_call_obj
    +161
    +162
    +163
    +164
    +165
    def parameterise_query(
    +    self,
    +    question: str,
    +    conversation: "Conversation",
    +) -> BlastQueryParameters:
    +    """Generates a BlastQuery object based on the given question, prompt, and
    +    BioChatter conversation. Uses a Pydantic model to define the API fields.
    +    Creates a runnable that can be invoked on LLMs that are qualified to
    +    parameterise functions.
    +
    +    Args:
    +    ----
    +        question (str): The question to be answered.
    +
    +        conversation: The conversation object used for parameterising the
    +            BlastQuery.
    +
    +    Returns:
    +    -------
    +        BlastQuery: the parameterised query object (Pydantic model)
    +
    +    """
    +    runnable = self.create_runnable(
    +        query_parameters=BlastQueryParameters,
    +        conversation=conversation,
    +    )
    +    blast_call_obj = runnable.invoke(
    +        {"input": f"Answer:\n{question} based on:\n {BLAST_QUERY_PROMPT}"},
    +    )
    +    blast_call_obj.question_uuid = str(uuid.uuid4())
    +    return blast_call_obj
     
    @@ -4971,7 +4480,7 @@

    - Bases: BaseModel

    + Bases: BaseModel

    BlastQuery is a Pydantic model for the parameters of a BLAST query request, @@ -5044,61 +4553,61 @@

    101 102 103

    class BlastQueryParameters(BaseModel):
    -    """
    -
    -    BlastQuery is a Pydantic model for the parameters of a BLAST query request,
    -    used for configuring and sending a request to the NCBI BLAST query API. The
    -    fields are dynamically configured by the LLM based on the user's question.
    +    """BlastQuery is a Pydantic model for the parameters of a BLAST query request,
    +    used for configuring and sending a request to the NCBI BLAST query API. The
    +    fields are dynamically configured by the LLM based on the user's question.
    +
    +    """
     
    -    """
    -
    -    url: Optional[str] = Field(
    -        default="https://blast.ncbi.nlm.nih.gov/Blast.cgi?",
    -        description="ALWAYS USE DEFAULT, DO NOT CHANGE",
    -    )
    -    cmd: Optional[str] = Field(
    -        default="Put",
    -        description="Command to execute, 'Put' for submitting query, 'Get' for retrieving results.",
    -    )
    -    program: Optional[str] = Field(
    -        default="blastn",
    -        description="BLAST program to use, e.g., 'blastn' for nucleotide-nucleotide BLAST, 'blastp' for protein-protein BLAST.",
    -    )
    -    database: Optional[str] = Field(
    -        default="nt",
    -        description="Database to search, e.g., 'nt' for nucleotide database, 'nr' for non redundant protein database, pdb the Protein Data Bank database, which is used specifically for protein structures, 'refseq_rna' and 'refseq_genomic': specialized databases for RNA sequences and genomic sequences",
    -    )
    -    query: Optional[str] = Field(
    -        None,
    -        description="Nucleotide or protein sequence for the BLAST or blat query, make sure to always keep the entire sequence given.",
    -    )
    -    format_type: Optional[str] = Field(
    -        default="Text",
    -        description="Format of the BLAST results, e.g., 'Text', 'XML'.",
    -    )
    -    rid: Optional[str] = Field(
    -        None, description="Request ID for retrieving BLAST results."
    -    )
    -    other_params: Optional[dict] = Field(
    -        default={"email": "user@example.com"},
    -        description="Other optional BLAST parameters, including user email.",
    -    )
    -    max_hits: Optional[int] = Field(
    -        default=15,
    -        description="Maximum number of hits to return in the BLAST results.",
    -    )
    -    sort_by: Optional[str] = Field(
    -        default="score",
    -        description="Criterion to sort BLAST results by, e.g., 'score', 'evalue'.",
    -    )
    -    megablast: Optional[str] = Field(
    -        default="on", description="Set to 'on' for human genome alignemnts"
    +    url: str | None = Field(
    +        default="https://blast.ncbi.nlm.nih.gov/Blast.cgi?",
    +        description="ALWAYS USE DEFAULT, DO NOT CHANGE",
    +    )
    +    cmd: str | None = Field(
    +        default="Put",
    +        description="Command to execute, 'Put' for submitting query, 'Get' for retrieving results.",
    +    )
    +    program: str | None = Field(
    +        default="blastn",
    +        description="BLAST program to use, e.g., 'blastn' for nucleotide-nucleotide BLAST, 'blastp' for protein-protein BLAST.",
    +    )
    +    database: str | None = Field(
    +        default="nt",
    +        description="Database to search, e.g., 'nt' for nucleotide database, 'nr' for non redundant protein database, pdb the Protein Data Bank database, which is used specifically for protein structures, 'refseq_rna' and 'refseq_genomic': specialized databases for RNA sequences and genomic sequences",
    +    )
    +    query: str | None = Field(
    +        None,
    +        description="Nucleotide or protein sequence for the BLAST or blat query, make sure to always keep the entire sequence given.",
    +    )
    +    format_type: str | None = Field(
    +        default="Text",
    +        description="Format of the BLAST results, e.g., 'Text', 'XML'.",
    +    )
    +    rid: str | None = Field(
    +        None,
    +        description="Request ID for retrieving BLAST results.",
    +    )
    +    other_params: dict | None = Field(
    +        default={"email": "user@example.com"},
    +        description="Other optional BLAST parameters, including user email.",
    +    )
    +    max_hits: int | None = Field(
    +        default=15,
    +        description="Maximum number of hits to return in the BLAST results.",
    +    )
    +    sort_by: str | None = Field(
    +        default="score",
    +        description="Criterion to sort BLAST results by, e.g., 'score', 'evalue'.",
    +    )
    +    megablast: str | None = Field(
    +        default="on",
    +        description="Set to 'on' for human genome alignemnts",
         )
    -    question_uuid: Optional[str] = Field(
    +    question_uuid: str | None = Field(
             default_factory=lambda: str(uuid.uuid4()),
             description="Unique identifier for the question.",
         )
    -    full_url: Optional[str] = Field(
    +    full_url: str | None = Field(
             default="TBF",
             description="Full URL to be used to submit the BLAST query",
         )
    @@ -5184,13 +4693,7 @@ 

    Source code in biochatter/api_agent/oncokb.py -
    248
    -249
    -250
    -251
    -252
    -253
    -254
    +                
    254
     255
     256
     257
    @@ -5223,46 +4726,60 @@ 

    284 285 286 -287

    class OncoKBFetcher(BaseFetcher):
    -    """
    -    A class for retrieving API results from OncoKB given a parameterized
    -    OncoKBQuery.
    -    """
    -
    -    def __init__(self, api_token="demo"):
    -        self.headers = {
    -            "Authorization": f"Bearer {api_token}",
    -            "Accept": "application/json",
    -        }
    -        self.base_url = "https://demo.oncokb.org/api/v1"
    -
    -    def fetch_results(
    -        self, request_data: OncoKBQueryParameters, retries: Optional[int] = 3
    -    ) -> str:
    -        """Function to submit the OncoKB query and fetch the results directly.
    -        No multi-step procedure, thus no wrapping of submission and retrieval in
    -        this case.
    -
    -        Args:
    -            request_data: OncoKBQuery object (Pydantic model) containing the
    -                OncoKB query parameters.
    -
    -        Returns:
    -            str: The results of the OncoKB query.
    -        """
    -        # Submit the query and get the URL
    -        params = request_data.dict(exclude_unset=True)
    -        endpoint = params.pop("endpoint")
    -        params.pop("question_uuid")
    -        full_url = f"{self.base_url}/{endpoint}"
    -        response = requests.get(full_url, headers=self.headers, params=params)
    -        response.raise_for_status()
    -
    -        # Fetch the results from the URL
    -        results_response = requests.get(response.url, headers=self.headers)
    -        results_response.raise_for_status()
    -
    -        return results_response.text
    +287
    +288
    +289
    +290
    +291
    +292
    +293
    +294
    +295
    +296
    +297
    class OncoKBFetcher(BaseFetcher):
    +    """A class for retrieving API results from OncoKB given a parameterized
    +    OncoKBQuery.
    +    """
    +
    +    def __init__(self, api_token="demo"):
    +        self.headers = {
    +            "Authorization": f"Bearer {api_token}",
    +            "Accept": "application/json",
    +        }
    +        self.base_url = "https://demo.oncokb.org/api/v1"
    +
    +    def fetch_results(
    +        self,
    +        request_data: OncoKBQueryParameters,
    +        retries: int | None = 3,
    +    ) -> str:
    +        """Function to submit the OncoKB query and fetch the results directly.
    +        No multi-step procedure, thus no wrapping of submission and retrieval in
    +        this case.
    +
    +        Args:
    +        ----
    +            request_data: OncoKBQuery object (Pydantic model) containing the
    +                OncoKB query parameters.
    +
    +        Returns:
    +        -------
    +            str: The results of the OncoKB query.
    +
    +        """
    +        # Submit the query and get the URL
    +        params = request_data.dict(exclude_unset=True)
    +        endpoint = params.pop("endpoint")
    +        params.pop("question_uuid")
    +        full_url = f"{self.base_url}/{endpoint}"
    +        response = requests.get(full_url, headers=self.headers, params=params)
    +        response.raise_for_status()
    +
    +        # Fetch the results from the URL
    +        results_response = requests.get(response.url, headers=self.headers)
    +        results_response.raise_for_status()
    +
    +        return results_response.text
     
    @@ -5292,70 +4809,17 @@

    Parameters:

    - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - request_data - - OncoKBQueryParameters - -
    -

    OncoKBQuery object (Pydantic model) containing the -OncoKB query parameters.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    Name TypeDescription
    str - str - -
    -

    The results of the OncoKB query.

    -
    -
    +
    +
    request_data: OncoKBQuery object (Pydantic model) containing the
    +    OncoKB query parameters.
    +
    +
    +
    str: The results of the OncoKB query.
    +
    Source code in biochatter/api_agent/oncokb.py -
    261
    -262
    -263
    -264
    -265
    -266
    +              
    266
     267
     268
     269
    @@ -5376,33 +4840,48 @@ 

    284 285 286 -287

    def fetch_results(
    -    self, request_data: OncoKBQueryParameters, retries: Optional[int] = 3
    -) -> str:
    -    """Function to submit the OncoKB query and fetch the results directly.
    -    No multi-step procedure, thus no wrapping of submission and retrieval in
    -    this case.
    -
    -    Args:
    -        request_data: OncoKBQuery object (Pydantic model) containing the
    -            OncoKB query parameters.
    -
    -    Returns:
    -        str: The results of the OncoKB query.
    -    """
    -    # Submit the query and get the URL
    -    params = request_data.dict(exclude_unset=True)
    -    endpoint = params.pop("endpoint")
    -    params.pop("question_uuid")
    -    full_url = f"{self.base_url}/{endpoint}"
    -    response = requests.get(full_url, headers=self.headers, params=params)
    -    response.raise_for_status()
    -
    -    # Fetch the results from the URL
    -    results_response = requests.get(response.url, headers=self.headers)
    -    results_response.raise_for_status()
    -
    -    return results_response.text
    +287
    +288
    +289
    +290
    +291
    +292
    +293
    +294
    +295
    +296
    +297
    def fetch_results(
    +    self,
    +    request_data: OncoKBQueryParameters,
    +    retries: int | None = 3,
    +) -> str:
    +    """Function to submit the OncoKB query and fetch the results directly.
    +    No multi-step procedure, thus no wrapping of submission and retrieval in
    +    this case.
    +
    +    Args:
    +    ----
    +        request_data: OncoKBQuery object (Pydantic model) containing the
    +            OncoKB query parameters.
    +
    +    Returns:
    +    -------
    +        str: The results of the OncoKB query.
    +
    +    """
    +    # Submit the query and get the URL
    +    params = request_data.dict(exclude_unset=True)
    +    endpoint = params.pop("endpoint")
    +    params.pop("question_uuid")
    +    full_url = f"{self.base_url}/{endpoint}"
    +    response = requests.get(full_url, headers=self.headers, params=params)
    +    response.raise_for_status()
    +
    +    # Fetch the results from the URL
    +    results_response = requests.get(response.url, headers=self.headers)
    +    results_response.raise_for_status()
    +
    +    return results_response.text
     
    @@ -5440,17 +4919,7 @@

    Source code in biochatter/api_agent/oncokb.py -
    290
    -291
    -292
    -293
    -294
    -295
    -296
    -297
    -298
    -299
    -300
    +                
    300
     301
     302
     303
    @@ -5478,45 +4947,59 @@ 

    325 326 327 -328

    class OncoKBInterpreter(BaseInterpreter):
    -    def summarise_results(
    -        self,
    -        question: str,
    -        conversation_factory: Callable,
    -        response_text: str,
    -    ) -> str:
    -        """
    -        Function to extract the answer from the BLAST results.
    -
    -        Args:
    -            question (str): The question to be answered.
    -            conversation_factory: A BioChatter conversation object.
    -            response_text (str): The response.text returned by OncoKB.
    -
    -        Returns:
    -            str: The extracted answer from the BLAST results.
    -
    -        """
    -        prompt = ChatPromptTemplate.from_messages(
    -            [
    -                (
    -                    "system",
    -                    "You are a world class molecular biologist who knows "
    -                    "everything about OncoKB and cancer genomics. Your task is "
    -                    "to interpret results from OncoKB API calls and summarise "
    -                    "them for the user.",
    -                ),
    -                ("user", "{input}"),
    -            ]
    -        )
    -        summary_prompt = ONCOKB_SUMMARY_PROMPT.format(
    -            question=question, context=response_text
    -        )
    -        output_parser = StrOutputParser()
    -        conversation = conversation_factory()
    -        chain = prompt | conversation.chat | output_parser
    -        answer = chain.invoke({"input": {summary_prompt}})
    -        return answer
    +328
    +329
    +330
    +331
    +332
    +333
    +334
    +335
    +336
    +337
    +338
    +339
    +340
    class OncoKBInterpreter(BaseInterpreter):
    +    def summarise_results(
    +        self,
    +        question: str,
    +        conversation_factory: Callable,
    +        response_text: str,
    +    ) -> str:
    +        """Function to extract the answer from the BLAST results.
    +
    +        Args:
    +        ----
    +            question (str): The question to be answered.
    +            conversation_factory: A BioChatter conversation object.
    +            response_text (str): The response.text returned by OncoKB.
    +
    +        Returns:
    +        -------
    +            str: The extracted answer from the BLAST results.
    +
    +        """
    +        prompt = ChatPromptTemplate.from_messages(
    +            [
    +                (
    +                    "system",
    +                    "You are a world class molecular biologist who knows "
    +                    "everything about OncoKB and cancer genomics. Your task is "
    +                    "to interpret results from OncoKB API calls and summarise "
    +                    "them for the user.",
    +                ),
    +                ("user", "{input}"),
    +            ],
    +        )
    +        summary_prompt = ONCOKB_SUMMARY_PROMPT.format(
    +            question=question,
    +            context=response_text,
    +        )
    +        output_parser = StrOutputParser()
    +        conversation = conversation_factory()
    +        chain = prompt | conversation.chat | output_parser
    +        answer = chain.invoke({"input": {summary_prompt}})
    +        return answer
     
    @@ -5544,106 +5027,18 @@

    Function to extract the answer from the BLAST results.

    - - -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    The question to be answered.

    -
    -
    - required -
    - conversation_factory - - Callable - -
    -

    A BioChatter conversation object.

    -
    -
    - required -
    - response_text - - str - -
    -

    The response.text returned by OncoKB.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    Name TypeDescription
    str - str - -
    -

    The extracted answer from the BLAST results.

    -
    -
    +
    +
    question (str): The question to be answered.
    +conversation_factory: A BioChatter conversation object.
    +response_text (str): The response.text returned by OncoKB.
    +
    +
    +
    str: The extracted answer from the BLAST results.
    +
    Source code in biochatter/api_agent/oncokb.py -
    291
    -292
    -293
    -294
    -295
    -296
    -297
    -298
    -299
    -300
    -301
    +              
    301
     302
     303
     304
    @@ -5670,44 +5065,58 @@ 

    325 326 327 -328

    def summarise_results(
    -    self,
    -    question: str,
    -    conversation_factory: Callable,
    -    response_text: str,
    -) -> str:
    -    """
    -    Function to extract the answer from the BLAST results.
    -
    -    Args:
    -        question (str): The question to be answered.
    -        conversation_factory: A BioChatter conversation object.
    -        response_text (str): The response.text returned by OncoKB.
    -
    -    Returns:
    -        str: The extracted answer from the BLAST results.
    -
    -    """
    -    prompt = ChatPromptTemplate.from_messages(
    -        [
    -            (
    -                "system",
    -                "You are a world class molecular biologist who knows "
    -                "everything about OncoKB and cancer genomics. Your task is "
    -                "to interpret results from OncoKB API calls and summarise "
    -                "them for the user.",
    -            ),
    -            ("user", "{input}"),
    -        ]
    -    )
    -    summary_prompt = ONCOKB_SUMMARY_PROMPT.format(
    -        question=question, context=response_text
    -    )
    -    output_parser = StrOutputParser()
    -    conversation = conversation_factory()
    -    chain = prompt | conversation.chat | output_parser
    -    answer = chain.invoke({"input": {summary_prompt}})
    -    return answer
    +328
    +329
    +330
    +331
    +332
    +333
    +334
    +335
    +336
    +337
    +338
    +339
    +340
    def summarise_results(
    +    self,
    +    question: str,
    +    conversation_factory: Callable,
    +    response_text: str,
    +) -> str:
    +    """Function to extract the answer from the BLAST results.
    +
    +    Args:
    +    ----
    +        question (str): The question to be answered.
    +        conversation_factory: A BioChatter conversation object.
    +        response_text (str): The response.text returned by OncoKB.
    +
    +    Returns:
    +    -------
    +        str: The extracted answer from the BLAST results.
    +
    +    """
    +    prompt = ChatPromptTemplate.from_messages(
    +        [
    +            (
    +                "system",
    +                "You are a world class molecular biologist who knows "
    +                "everything about OncoKB and cancer genomics. Your task is "
    +                "to interpret results from OncoKB API calls and summarise "
    +                "them for the user.",
    +            ),
    +            ("user", "{input}"),
    +        ],
    +    )
    +    summary_prompt = ONCOKB_SUMMARY_PROMPT.format(
    +        question=question,
    +        context=response_text,
    +    )
    +    output_parser = StrOutputParser()
    +    conversation = conversation_factory()
    +    chain = prompt | conversation.chat | output_parser
    +    answer = chain.invoke({"input": {summary_prompt}})
    +    return answer
     
    @@ -5747,9 +5156,7 @@

    Source code in biochatter/api_agent/oncokb.py -
    190
    -191
    -192
    +                
    192
     193
     194
     195
    @@ -5802,62 +5209,72 @@ 

    242 243 244 -245

    class OncoKBQueryBuilder(BaseQueryBuilder):
    -    """A class for building an OncoKBQuery object."""
    -
    -    def create_runnable(
    -        self,
    -        query_parameters: "OncoKBQueryParameters",
    -        conversation: "Conversation",
    -    ) -> Callable:
    -        """
    -        Creates a runnable object for executing queries using the LangChain
    -        `create_structured_output_runnable` method.
    -
    -        Args:
    -            query_parameters: A Pydantic data model that specifies the fields of
    -                the API that should be queried.
    -
    -            conversation: A BioChatter conversation object.
    +245
    +246
    +247
    +248
    +249
    +250
    +251
    class OncoKBQueryBuilder(BaseQueryBuilder):
    +    """A class for building an OncoKBQuery object."""
    +
    +    def create_runnable(
    +        self,
    +        query_parameters: "OncoKBQueryParameters",
    +        conversation: "Conversation",
    +    ) -> Callable:
    +        """Creates a runnable object for executing queries using the LangChain
    +        `create_structured_output_runnable` method.
    +
    +        Args:
    +        ----
    +            query_parameters: A Pydantic data model that specifies the fields of
    +                the API that should be queried.
     
    -        Returns:
    -            A Callable object that can execute the query.
    -        """
    -        return create_structured_output_runnable(
    -            output_schema=query_parameters,
    -            llm=conversation.chat,
    -            prompt=self.structured_output_prompt,
    -        )
    -
    -    def parameterise_query(
    -        self,
    -        question: str,
    -        conversation: "Conversation",
    -    ) -> OncoKBQueryParameters:
    -        """
    -        Generates an OncoKBQuery object based on the given question, prompt, and
    -        BioChatter conversation. Uses a Pydantic model to define the API fields.
    -        Creates a runnable that can be invoked on LLMs that are qualified to
    -        parameterise functions.
    -
    -        Args:
    -            question (str): The question to be answered.
    +            conversation: A BioChatter conversation object.
    +
    +        Returns:
    +        -------
    +            A Callable object that can execute the query.
    +
    +        """
    +        return create_structured_output_runnable(
    +            output_schema=query_parameters,
    +            llm=conversation.chat,
    +            prompt=self.structured_output_prompt,
    +        )
    +
    +    def parameterise_query(
    +        self,
    +        question: str,
    +        conversation: "Conversation",
    +    ) -> OncoKBQueryParameters:
    +        """Generates an OncoKBQuery object based on the given question, prompt, and
    +        BioChatter conversation. Uses a Pydantic model to define the API fields.
    +        Creates a runnable that can be invoked on LLMs that are qualified to
    +        parameterise functions.
     
    -            conversation: The conversation object used for parameterising the
    -                OncoKBQuery.
    -
    -        Returns:
    -            OncoKBQueryParameters: the parameterised query object (Pydantic model)
    -        """
    -        runnable = self.create_runnable(
    -            query_parameters=OncoKBQueryParameters,
    -            conversation=conversation,
    -        )
    -        oncokb_call_obj = runnable.invoke(
    -            {"input": f"Answer:\n{question} based on:\n {ONCOKB_QUERY_PROMPT}"}
    -        )
    -        oncokb_call_obj.question_uuid = str(uuid.uuid4())
    -        return oncokb_call_obj
    +        Args:
    +        ----
    +            question (str): The question to be answered.
    +
    +            conversation: The conversation object used for parameterising the
    +                OncoKBQuery.
    +
    +        Returns:
    +        -------
    +            OncoKBQueryParameters: the parameterised query object (Pydantic model)
    +
    +        """
    +        runnable = self.create_runnable(
    +            query_parameters=OncoKBQueryParameters,
    +            conversation=conversation,
    +        )
    +        oncokb_call_obj = runnable.invoke(
    +            {"input": f"Answer:\n{question} based on:\n {ONCOKB_QUERY_PROMPT}"},
    +        )
    +        oncokb_call_obj.question_uuid = str(uuid.uuid4())
    +        return oncokb_call_obj
     
    @@ -5886,83 +5303,19 @@

    query_parameters: A Pydantic data model that specifies the fields of
    +    the API that should be queried.
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - query_parameters - - OncoKBQueryParameters - -
    -

    A Pydantic data model that specifies the fields of -the API that should be queried.

    -
    -
    - required -
    - conversation - - Conversation - -
    -

    A BioChatter conversation object.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    TypeDescription
    - Callable - -
    -

    A Callable object that can execute the query.

    -
    -
    +conversation: A BioChatter conversation object. +

    +
    +
    A Callable object that can execute the query.
    +
    Source code in biochatter/api_agent/oncokb.py -
    193
    -194
    -195
    +              
    195
     196
     197
     198
    @@ -5982,29 +5335,35 @@ 

    212 213 214 -215

    def create_runnable(
    -    self,
    -    query_parameters: "OncoKBQueryParameters",
    -    conversation: "Conversation",
    -) -> Callable:
    -    """
    -    Creates a runnable object for executing queries using the LangChain
    -    `create_structured_output_runnable` method.
    -
    -    Args:
    -        query_parameters: A Pydantic data model that specifies the fields of
    -            the API that should be queried.
    -
    -        conversation: A BioChatter conversation object.
    +215
    +216
    +217
    +218
    +219
    def create_runnable(
    +    self,
    +    query_parameters: "OncoKBQueryParameters",
    +    conversation: "Conversation",
    +) -> Callable:
    +    """Creates a runnable object for executing queries using the LangChain
    +    `create_structured_output_runnable` method.
    +
    +    Args:
    +    ----
    +        query_parameters: A Pydantic data model that specifies the fields of
    +            the API that should be queried.
     
    -    Returns:
    -        A Callable object that can execute the query.
    -    """
    -    return create_structured_output_runnable(
    -        output_schema=query_parameters,
    -        llm=conversation.chat,
    -        prompt=self.structured_output_prompt,
    -    )
    +        conversation: A BioChatter conversation object.
    +
    +    Returns:
    +    -------
    +        A Callable object that can execute the query.
    +
    +    """
    +    return create_structured_output_runnable(
    +        output_schema=query_parameters,
    +        llm=conversation.chat,
    +        prompt=self.structured_output_prompt,
    +    )
     
    @@ -6026,85 +5385,19 @@

    +
    +
    question (str): The question to be answered.
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    The question to be answered.

    -
    -
    - required -
    - conversation - - Conversation - -
    -

    The conversation object used for parameterising the -OncoKBQuery.

    -
    -
    - required -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    Name TypeDescription
    OncoKBQueryParameters - OncoKBQueryParameters - -
    -

    the parameterised query object (Pydantic model)

    -
    -
    +conversation: The conversation object used for parameterising the + OncoKBQuery. +
    +
    +
    OncoKBQueryParameters: the parameterised query object (Pydantic model)
    +
    Source code in biochatter/api_agent/oncokb.py -
    217
    -218
    -219
    -220
    -221
    +              
    221
     222
     223
     224
    @@ -6128,35 +5421,43 @@ 

    242 243 244 -245

    def parameterise_query(
    -    self,
    -    question: str,
    -    conversation: "Conversation",
    -) -> OncoKBQueryParameters:
    -    """
    -    Generates an OncoKBQuery object based on the given question, prompt, and
    -    BioChatter conversation. Uses a Pydantic model to define the API fields.
    -    Creates a runnable that can be invoked on LLMs that are qualified to
    -    parameterise functions.
    -
    -    Args:
    -        question (str): The question to be answered.
    +245
    +246
    +247
    +248
    +249
    +250
    +251
    def parameterise_query(
    +    self,
    +    question: str,
    +    conversation: "Conversation",
    +) -> OncoKBQueryParameters:
    +    """Generates an OncoKBQuery object based on the given question, prompt, and
    +    BioChatter conversation. Uses a Pydantic model to define the API fields.
    +    Creates a runnable that can be invoked on LLMs that are qualified to
    +    parameterise functions.
     
    -        conversation: The conversation object used for parameterising the
    -            OncoKBQuery.
    -
    -    Returns:
    -        OncoKBQueryParameters: the parameterised query object (Pydantic model)
    -    """
    -    runnable = self.create_runnable(
    -        query_parameters=OncoKBQueryParameters,
    -        conversation=conversation,
    -    )
    -    oncokb_call_obj = runnable.invoke(
    -        {"input": f"Answer:\n{question} based on:\n {ONCOKB_QUERY_PROMPT}"}
    -    )
    -    oncokb_call_obj.question_uuid = str(uuid.uuid4())
    -    return oncokb_call_obj
    +    Args:
    +    ----
    +        question (str): The question to be answered.
    +
    +        conversation: The conversation object used for parameterising the
    +            OncoKBQuery.
    +
    +    Returns:
    +    -------
    +        OncoKBQueryParameters: the parameterised query object (Pydantic model)
    +
    +    """
    +    runnable = self.create_runnable(
    +        query_parameters=OncoKBQueryParameters,
    +        conversation=conversation,
    +    )
    +    oncokb_call_obj = runnable.invoke(
    +        {"input": f"Answer:\n{question} based on:\n {ONCOKB_QUERY_PROMPT}"},
    +    )
    +    oncokb_call_obj.question_uuid = str(uuid.uuid4())
    +    return oncokb_call_obj
     
    diff --git a/api-docs/index.html b/api-docs/index.html index a309fa73..71148ff6 100644 --- a/api-docs/index.html +++ b/api-docs/index.html @@ -18,7 +18,7 @@ - + diff --git a/api-docs/kg/index.html b/api-docs/kg/index.html index 06871bc5..f089303f 100644 --- a/api-docs/kg/index.html +++ b/api-docs/kg/index.html @@ -18,7 +18,7 @@ - + @@ -1493,7 +1493,8 @@

    Source code in biochatter/prompts.py -
     12
    +                
     11
    + 12
      13
      14
      15
    @@ -2122,666 +2123,637 @@ 

    638 639 640 -641 -642 -643 -644 -645 -646 -647 -648 -649 -650 -651 -652 -653 -654 -655 -656

    class BioCypherPromptEngine:
    -    def __init__(
    -        self,
    -        schema_config_or_info_path: Optional[str] = None,
    -        schema_config_or_info_dict: Optional[dict] = None,
    -        model_name: str = "gpt-3.5-turbo",
    -        conversation_factory: Optional[Callable] = None,
    -    ) -> None:
    -        """
    -
    -        Given a biocypher schema configuration, extract the entities and
    -        relationships, and for each extract their mode of representation (node
    -        or edge), properties, and identifier namespace. Using these data, allow
    -        the generation of prompts for a large language model, informing it of
    -        the schema constituents and their properties, to enable the
    -        parameterisation of function calls to a knowledge graph.
    -
    -        Args:
    -            schema_config_or_info_path: Path to a biocypher schema configuration
    -                file or the extended schema information output generated by
    -                BioCypher's `write_schema_info` function (preferred).
    -
    -            schema_config_or_info_dict: A dictionary containing the schema
    -                configuration file or the extended schema information output
    -                generated by BioCypher's `write_schema_info` function
    -                (preferred).
    -
    -            model_name: The name of the model to use for the conversation.
    -                DEPRECATED: This should now be set in the conversation factory.
    -
    -            conversation_factory: A function used to create a conversation for
    -                creating the KG query. If not provided, a default function is
    -                used (creating an OpenAI conversation with the specified model,
    -                see `_get_conversation`).
    -        """
    -
    -        if not schema_config_or_info_path and not schema_config_or_info_dict:
    -            raise ValueError(
    -                "Please provide the schema configuration or schema info as a "
    -                "path to a file or as a dictionary."
    -            )
    -
    -        if schema_config_or_info_path and schema_config_or_info_dict:
    -            raise ValueError(
    -                "Please provide the schema configuration or schema info as a "
    -                "path to a file or as a dictionary, not both."
    -            )
    +641
    class BioCypherPromptEngine:
    +    def __init__(
    +        self,
    +        schema_config_or_info_path: str | None = None,
    +        schema_config_or_info_dict: dict | None = None,
    +        model_name: str = "gpt-3.5-turbo",
    +        conversation_factory: Callable | None = None,
    +    ) -> None:
    +        """Given a biocypher schema configuration, extract the entities and
    +        relationships, and for each extract their mode of representation (node
    +        or edge), properties, and identifier namespace. Using these data, allow
    +        the generation of prompts for a large language model, informing it of
    +        the schema constituents and their properties, to enable the
    +        parameterisation of function calls to a knowledge graph.
    +
    +        Args:
    +        ----
    +            schema_config_or_info_path: Path to a biocypher schema configuration
    +                file or the extended schema information output generated by
    +                BioCypher's `write_schema_info` function (preferred).
    +
    +            schema_config_or_info_dict: A dictionary containing the schema
    +                configuration file or the extended schema information output
    +                generated by BioCypher's `write_schema_info` function
    +                (preferred).
    +
    +            model_name: The name of the model to use for the conversation.
    +                DEPRECATED: This should now be set in the conversation factory.
    +
    +            conversation_factory: A function used to create a conversation for
    +                creating the KG query. If not provided, a default function is
    +                used (creating an OpenAI conversation with the specified model,
    +                see `_get_conversation`).
    +
    +        """
    +        if not schema_config_or_info_path and not schema_config_or_info_dict:
    +            raise ValueError(
    +                "Please provide the schema configuration or schema info as a path to a file or as a dictionary.",
    +            )
    +
    +        if schema_config_or_info_path and schema_config_or_info_dict:
    +            raise ValueError(
    +                "Please provide the schema configuration or schema info as a "
    +                "path to a file or as a dictionary, not both.",
    +            )
    +
    +        # set conversation factory or use default
    +        self.conversation_factory = conversation_factory if conversation_factory is not None else self._get_conversation
     
    -        # set conversation factory or use default
    -        self.conversation_factory = (
    -            conversation_factory
    -            if conversation_factory is not None
    -            else self._get_conversation
    -        )
    +        if schema_config_or_info_path:
    +            # read the schema configuration
    +            with open(schema_config_or_info_path) as f:
    +                schema_config = yaml.safe_load(f)
    +        elif schema_config_or_info_dict:
    +            schema_config = schema_config_or_info_dict
     
    -        if schema_config_or_info_path:
    -            # read the schema configuration
    -            with open(schema_config_or_info_path, "r") as f:
    -                schema_config = yaml.safe_load(f)
    -        elif schema_config_or_info_dict:
    -            schema_config = schema_config_or_info_dict
    -
    -        # check whether it is the original schema config or the output of
    -        # biocypher info
    -        is_schema_info = schema_config.get("is_schema_info", False)
    -
    -        # extract the entities and relationships: each top level key that has
    -        # a 'represented_as' key
    -        self.entities = {}
    -        self.relationships = {}
    -        if not is_schema_info:
    -            for key, value in schema_config.items():
    -                # hacky, better with biocypher output
    -                name_indicates_relationship = (
    -                    "interaction" in key.lower() or "association" in key.lower()
    -                )
    -                if "represented_as" in value:
    -                    if (
    -                        value["represented_as"] == "node"
    -                        and not name_indicates_relationship
    -                    ):
    -                        self.entities[sentencecase_to_pascalcase(key)] = value
    -                    elif (
    -                        value["represented_as"] == "node"
    -                        and name_indicates_relationship
    -                    ):
    -                        self.relationships[sentencecase_to_pascalcase(key)] = (
    -                            value
    -                        )
    -                    elif value["represented_as"] == "edge":
    -                        self.relationships[sentencecase_to_pascalcase(key)] = (
    -                            value
    -                        )
    -        else:
    -            for key, value in schema_config.items():
    -                if not isinstance(value, dict):
    -                    continue
    -                if value.get("present_in_knowledge_graph", None) == False:
    -                    continue
    -                if value.get("is_relationship", None) == False:
    -                    self.entities[sentencecase_to_pascalcase(key)] = value
    -                elif value.get("is_relationship", None) == True:
    -                    value = self._capitalise_source_and_target(value)
    -                    self.relationships[sentencecase_to_pascalcase(key)] = value
    -
    -        self.question = ""
    -        self.selected_entities = []
    -        self.selected_relationships = []  # used in property selection
    -        self.selected_relationship_labels = {}  # copy to deal with labels that
    -        # are not the same as the relationship name, used in query generation
    -        # dictionary to also include source and target types
    -        self.rel_directions = {}
    -        self.model_name = model_name
    -
    -    def _capitalise_source_and_target(self, relationship: dict) -> dict:
    -        """
    -        Make sources and targets PascalCase to match the entities. Sources and
    -        targets can be strings or lists of strings.
    -        """
    -        if "source" in relationship:
    -            if isinstance(relationship["source"], str):
    -                relationship["source"] = sentencecase_to_pascalcase(
    -                    relationship["source"]
    -                )
    -            elif isinstance(relationship["source"], list):
    -                relationship["source"] = [
    -                    sentencecase_to_pascalcase(s)
    -                    for s in relationship["source"]
    -                ]
    -        if "target" in relationship:
    -            if isinstance(relationship["target"], str):
    -                relationship["target"] = sentencecase_to_pascalcase(
    -                    relationship["target"]
    -                )
    -            elif isinstance(relationship["target"], list):
    -                relationship["target"] = [
    -                    sentencecase_to_pascalcase(t)
    -                    for t in relationship["target"]
    -                ]
    -        return relationship
    -
    -    def _select_graph_entities_from_question(
    -        self, question: str, conversation: Conversation
    -    ) -> str:
    -        conversation.reset()
    -        success1 = self._select_entities(
    -            question=question, conversation=conversation
    -        )
    -        if not success1:
    -            raise ValueError(
    -                "Entity selection failed. Please try again with a different "
    -                "question."
    -            )
    -        conversation.reset()
    -        success2 = self._select_relationships(conversation=conversation)
    -        if not success2:
    -            raise ValueError(
    -                "Relationship selection failed. Please try again with a "
    -                "different question."
    -            )
    -        conversation.reset()
    -        success3 = self._select_properties(conversation=conversation)
    -        if not success3:
    -            raise ValueError(
    -                "Property selection failed. Please try again with a different "
    -                "question."
    -            )
    +        # check whether it is the original schema config or the output of
    +        # biocypher info
    +        is_schema_info = schema_config.get("is_schema_info", False)
    +
    +        # extract the entities and relationships: each top level key that has
    +        # a 'represented_as' key
    +        self.entities = {}
    +        self.relationships = {}
    +        if not is_schema_info:
    +            for key, value in schema_config.items():
    +                # hacky, better with biocypher output
    +                name_indicates_relationship = "interaction" in key.lower() or "association" in key.lower()
    +                if "represented_as" in value:
    +                    if value["represented_as"] == "node" and not name_indicates_relationship:
    +                        self.entities[sentencecase_to_pascalcase(key)] = value
    +                    elif (value["represented_as"] == "node" and name_indicates_relationship) or value[
    +                        "represented_as"
    +                    ] == "edge":
    +                        self.relationships[sentencecase_to_pascalcase(key)] = value
    +        else:
    +            for key, value in schema_config.items():
    +                if not isinstance(value, dict):
    +                    continue
    +                if value.get("present_in_knowledge_graph", None) == False:
    +                    continue
    +                if value.get("is_relationship", None) == False:
    +                    self.entities[sentencecase_to_pascalcase(key)] = value
    +                elif value.get("is_relationship", None) == True:
    +                    value = self._capitalise_source_and_target(value)
    +                    self.relationships[sentencecase_to_pascalcase(key)] = value
    +
    +        self.question = ""
    +        self.selected_entities = []
    +        self.selected_relationships = []  # used in property selection
    +        self.selected_relationship_labels = {}  # copy to deal with labels that
    +        # are not the same as the relationship name, used in query generation
    +        # dictionary to also include source and target types
    +        self.rel_directions = {}
    +        self.model_name = model_name
    +
    +    def _capitalise_source_and_target(self, relationship: dict) -> dict:
    +        """Make sources and targets PascalCase to match the entities. Sources and
    +        targets can be strings or lists of strings.
    +        """
    +        if "source" in relationship:
    +            if isinstance(relationship["source"], str):
    +                relationship["source"] = sentencecase_to_pascalcase(
    +                    relationship["source"],
    +                )
    +            elif isinstance(relationship["source"], list):
    +                relationship["source"] = [sentencecase_to_pascalcase(s) for s in relationship["source"]]
    +        if "target" in relationship:
    +            if isinstance(relationship["target"], str):
    +                relationship["target"] = sentencecase_to_pascalcase(
    +                    relationship["target"],
    +                )
    +            elif isinstance(relationship["target"], list):
    +                relationship["target"] = [sentencecase_to_pascalcase(t) for t in relationship["target"]]
    +        return relationship
    +
    +    def _select_graph_entities_from_question(
    +        self,
    +        question: str,
    +        conversation: Conversation,
    +    ) -> str:
    +        conversation.reset()
    +        success1 = self._select_entities(
    +            question=question,
    +            conversation=conversation,
    +        )
    +        if not success1:
    +            raise ValueError(
    +                "Entity selection failed. Please try again with a different question.",
    +            )
    +        conversation.reset()
    +        success2 = self._select_relationships(conversation=conversation)
    +        if not success2:
    +            raise ValueError(
    +                "Relationship selection failed. Please try again with a different question.",
    +            )
    +        conversation.reset()
    +        success3 = self._select_properties(conversation=conversation)
    +        if not success3:
    +            raise ValueError(
    +                "Property selection failed. Please try again with a different question.",
    +            )
    +
    +    def _generate_query_prompt(
    +        self,
    +        entities: list,
    +        relationships: dict,
    +        properties: dict,
    +        query_language: str | None = "Cypher",
    +    ) -> str:
    +        """Generate a prompt for a large language model to generate a database
    +        query based on the selected entities, relationships, and properties.
    +
    +        Args:
    +        ----
    +            entities: A list of entities that are relevant to the question.
    +
    +            relationships: A list of relationships that are relevant to the
    +                question.
    +
    +            properties: A dictionary of properties that are relevant to the
    +                question.
    +
    +            query_language: The language of the query to generate.
    +
    +        Returns:
    +        -------
    +            A prompt for a large language model to generate a database query.
     
    -    def _generate_query_prompt(
    -        self,
    -        entities: list,
    -        relationships: dict,
    -        properties: dict,
    -        query_language: Optional[str] = "Cypher",
    -    ) -> str:
    -        """
    -        Generate a prompt for a large language model to generate a database
    -        query based on the selected entities, relationships, and properties.
    -
    -        Args:
    -            entities: A list of entities that are relevant to the question.
    -
    -            relationships: A list of relationships that are relevant to the
    -                question.
    -
    -            properties: A dictionary of properties that are relevant to the
    -                question.
    -
    -            query_language: The language of the query to generate.
    +        """
    +        msg = (
    +            f"Generate a database query in {query_language} that answers "
    +            f"the user's question. "
    +            f"You can use the following entities: {entities}, "
    +            f"relationships: {list(relationships.keys())}, and "
    +            f"properties: {properties}. "
    +        )
    +
    +        for relationship, values in relationships.items():
    +            self._expand_pairs(relationship, values)
    +
    +        if self.rel_directions:
    +            msg += "Given the following valid combinations of source, relationship, and target: "
    +            for key, value in self.rel_directions.items():
    +                for pair in value:
    +                    msg += f"'(:{pair[0]})-(:{key})->(:{pair[1]})', "
    +            msg += f"generate a {query_language} query using one of these combinations. "
    +
    +        msg += "Only return the query, without any additional text, symbols or characters --- just the query statement."
    +        return msg
     
    -        Returns:
    -            A prompt for a large language model to generate a database query.
    -        """
    -        msg = (
    -            f"Generate a database query in {query_language} that answers "
    -            f"the user's question. "
    -            f"You can use the following entities: {entities}, "
    -            f"relationships: {list(relationships.keys())}, and "
    -            f"properties: {properties}. "
    -        )
    -
    -        for relationship, values in relationships.items():
    -            self._expand_pairs(relationship, values)
    -
    -        if self.rel_directions:
    -            msg += "Given the following valid combinations of source, relationship, and target: "
    -            for key, value in self.rel_directions.items():
    -                for pair in value:
    -                    msg += f"'(:{pair[0]})-(:{key})->(:{pair[1]})', "
    -            msg += f"generate a {query_language} query using one of these combinations. "
    -
    -        msg += "Only return the query, without any additional text, symbols or characters --- just the query statement."
    -        return msg
    -
    -    def generate_query_prompt(
    -        self, question: str, query_language: Optional[str] = "Cypher"
    -    ) -> str:
    -        """
    -        Generate a prompt for a large language model to generate a database
    -        query based on the user's question and class attributes informing about
    -        the schema.
    +    def generate_query_prompt(
    +        self,
    +        question: str,
    +        query_language: str | None = "Cypher",
    +    ) -> str:
    +        """Generate a prompt for a large language model to generate a database
    +        query based on the user's question and class attributes informing about
    +        the schema.
    +
    +        Args:
    +        ----
    +            question: A user's question.
    +
    +            query_language: The language of the query to generate.
    +
    +        Returns:
    +        -------
    +            A prompt for a large language model to generate a database query.
    +
    +        """
    +        self._select_graph_entities_from_question(
    +            question,
    +            self.conversation_factory(),
    +        )
    +        msg = self._generate_query_prompt(
    +            self.selected_entities,
    +            self.selected_relationship_labels,
    +            self.selected_properties,
    +            query_language,
    +        )
    +        return msg
     
    -        Args:
    -            question: A user's question.
    -
    -            query_language: The language of the query to generate.
    -
    -        Returns:
    -            A prompt for a large language model to generate a database query.
    -        """
    -        self._select_graph_entities_from_question(
    -            question, self.conversation_factory()
    -        )
    -        msg = self._generate_query_prompt(
    -            self.selected_entities,
    -            self.selected_relationship_labels,
    -            self.selected_properties,
    -            query_language,
    -        )
    -        return msg
    -
    -    def generate_query(
    -        self, question: str, query_language: Optional[str] = "Cypher"
    -    ) -> str:
    -        """
    -        Wrap entity and property selection and query generation; return the
    -        generated query.
    -
    -        Args:
    -            question: A user's question.
    -
    -            query_language: The language of the query to generate.
    -
    -        Returns:
    -            A database query that could answer the user's question.
    -        """
    -
    -        self._select_graph_entities_from_question(
    -            question, self.conversation_factory()
    -        )
    +    def generate_query(
    +        self,
    +        question: str,
    +        query_language: str | None = "Cypher",
    +    ) -> str:
    +        """Wrap entity and property selection and query generation; return the
    +        generated query.
    +
    +        Args:
    +        ----
    +            question: A user's question.
    +
    +            query_language: The language of the query to generate.
    +
    +        Returns:
    +        -------
    +            A database query that could answer the user's question.
    +
    +        """
    +        self._select_graph_entities_from_question(
    +            question,
    +            self.conversation_factory(),
    +        )
    +
    +        return self._generate_query(
    +            question=question,
    +            entities=self.selected_entities,
    +            relationships=self.selected_relationship_labels,
    +            properties=self.selected_properties,
    +            query_language=query_language,
    +            conversation=self.conversation_factory(),
    +        )
    +
    +    def _get_conversation(
    +        self,
    +        model_name: str | None = None,
    +    ) -> "Conversation":
    +        """Create a conversation object given a model name.
     
    -        return self._generate_query(
    -            question=question,
    -            entities=self.selected_entities,
    -            relationships=self.selected_relationship_labels,
    -            properties=self.selected_properties,
    -            query_language=query_language,
    -            conversation=self.conversation_factory(),
    -        )
    -
    -    def _get_conversation(
    -        self, model_name: Optional[str] = None
    -    ) -> "Conversation":
    -        """
    -        Create a conversation object given a model name.
    -
    -        Args:
    -            model_name: The name of the model to use for the conversation.
    -
    -        Returns:
    -            A BioChatter Conversation object for connecting to the LLM.
    -
    -        Todo:
    -            Genericise to models outside of OpenAI.
    -        """
    -
    -        conversation = GptConversation(
    -            model_name=model_name or self.model_name,
    -            prompts={},
    -            correct=False,
    -        )
    -        conversation.set_api_key(
    -            api_key=os.getenv("OPENAI_API_KEY"), user="test_user"
    -        )
    -        return conversation
    -
    -    def _select_entities(
    -        self, question: str, conversation: "Conversation"
    -    ) -> bool:
    -        """
    +        Args:
    +        ----
    +            model_name: The name of the model to use for the conversation.
    +
    +        Returns:
    +        -------
    +            A BioChatter Conversation object for connecting to the LLM.
    +
    +        Todo:
    +        ----
    +            Genericise to models outside of OpenAI.
    +
    +        """
    +        conversation = GptConversation(
    +            model_name=model_name or self.model_name,
    +            prompts={},
    +            correct=False,
    +        )
    +        conversation.set_api_key(
    +            api_key=os.getenv("OPENAI_API_KEY"),
    +            user="test_user",
    +        )
    +        return conversation
    +
    +    def _select_entities(
    +        self,
    +        question: str,
    +        conversation: "Conversation",
    +    ) -> bool:
    +        """Given a question, select the entities that are relevant to the question
    +        and store them in `selected_entities` and `selected_relationships`. Use
    +        LLM conversation to do this.
    +
    +        Args:
    +        ----
    +            question: A user's question.
    +
    +            conversation: A BioChatter Conversation object for connecting to the
    +                LLM.
     
    -        Given a question, select the entities that are relevant to the question
    -        and store them in `selected_entities` and `selected_relationships`. Use
    -        LLM conversation to do this.
    +        Returns:
    +        -------
    +            True if at least one entity was selected, False otherwise.
     
    -        Args:
    -            question: A user's question.
    +        """
    +        self.question = question
     
    -            conversation: A BioChatter Conversation object for connecting to the
    -                LLM.
    -
    -        Returns:
    -            True if at least one entity was selected, False otherwise.
    -
    -        """
    -
    -        self.question = question
    -
    -        conversation.append_system_message(
    -            (
    -                "You have access to a knowledge graph that contains "
    -                f"these entity types: {', '.join(self.entities)}. Your task is "
    -                "to select the entity types that are relevant to the user's question "
    -                "for subsequent use in a query. Only return the entity types, "
    -                "comma-separated, without any additional text. Do not return "
    -                "entity names, relationships, or properties."
    -            )
    -        )
    -
    -        msg, token_usage, correction = conversation.query(question)
    -
    -        result = msg.split(",") if msg else []
    -        # TODO: do we go back and retry if no entities were selected? or ask for
    -        # a reason? offer visual selection of entities and relationships by the
    -        # user?
    +        conversation.append_system_message(
    +            "You have access to a knowledge graph that contains "
    +            f"these entity types: {', '.join(self.entities)}. Your task is "
    +            "to select the entity types that are relevant to the user's question "
    +            "for subsequent use in a query. Only return the entity types, "
    +            "comma-separated, without any additional text. Do not return "
    +            "entity names, relationships, or properties.",
    +        )
    +
    +        msg, token_usage, correction = conversation.query(question)
    +
    +        result = msg.split(",") if msg else []
    +        # TODO: do we go back and retry if no entities were selected? or ask for
    +        # a reason? offer visual selection of entities and relationships by the
    +        # user?
    +
    +        if result:
    +            for entity in result:
    +                entity = entity.strip()
    +                if entity in self.entities:
    +                    self.selected_entities.append(entity)
    +
    +        return bool(result)
    +
    +    def _select_relationships(self, conversation: "Conversation") -> bool:
    +        """Given a question and the preselected entities, select relationships for
    +        the query.
     
    -        if result:
    -            for entity in result:
    -                entity = entity.strip()
    -                if entity in self.entities:
    -                    self.selected_entities.append(entity)
    -
    -        return bool(result)
    -
    -    def _select_relationships(self, conversation: "Conversation") -> bool:
    -        """
    -        Given a question and the preselected entities, select relationships for
    -        the query.
    -
    -        Args:
    -            conversation: A BioChatter Conversation object for connecting to the
    -                LLM.
    -
    -        Returns:
    -            True if at least one relationship was selected, False otherwise.
    -
    -        Todo:
    -            Now we have the problem that we discard all relationships that do
    -            not have a source and target, if at least one relationship has a
    -            source and target. At least communicate this all-or-nothing
    -            behaviour to the user.
    -        """
    +        Args:
    +        ----
    +            conversation: A BioChatter Conversation object for connecting to the
    +                LLM.
    +
    +        Returns:
    +        -------
    +            True if at least one relationship was selected, False otherwise.
    +
    +        Todo:
    +        ----
    +            Now we have the problem that we discard all relationships that do
    +            not have a source and target, if at least one relationship has a
    +            source and target. At least communicate this all-or-nothing
    +            behaviour to the user.
    +
    +        """
    +        if not self.question:
    +            raise ValueError(
    +                "No question found. Please make sure to run entity selection first.",
    +            )
    +
    +        if not self.selected_entities:
    +            raise ValueError(
    +                "No entities found. Please run the entity selection step first.",
    +            )
     
    -        if not self.question:
    -            raise ValueError(
    -                "No question found. Please make sure to run entity selection "
    -                "first."
    -            )
    -
    -        if not self.selected_entities:
    -            raise ValueError(
    -                "No entities found. Please run the entity selection step first."
    -            )
    -
    -        rels = {}
    -        source_and_target_present = False
    -        for key, value in self.relationships.items():
    -            if "source" in value and "target" in value:
    -                # if source or target is a list, expand to single pairs
    -                source = ensure_iterable(value["source"])
    -                target = ensure_iterable(value["target"])
    -                pairs = []
    -                for s in source:
    -                    for t in target:
    -                        pairs.append(
    -                            (
    -                                sentencecase_to_pascalcase(s),
    -                                sentencecase_to_pascalcase(t),
    -                            )
    -                        )
    -                rels[key] = pairs
    -                source_and_target_present = True
    -            else:
    -                rels[key] = {}
    +        rels = {}
    +        source_and_target_present = False
    +        for key, value in self.relationships.items():
    +            if "source" in value and "target" in value:
    +                # if source or target is a list, expand to single pairs
    +                source = ensure_iterable(value["source"])
    +                target = ensure_iterable(value["target"])
    +                pairs = []
    +                for s in source:
    +                    for t in target:
    +                        pairs.append(
    +                            (
    +                                sentencecase_to_pascalcase(s),
    +                                sentencecase_to_pascalcase(t),
    +                            ),
    +                        )
    +                rels[key] = pairs
    +                source_and_target_present = True
    +            else:
    +                rels[key] = {}
    +
    +        # prioritise relationships that have source and target, and discard
    +        # relationships that do not have both source and target, if at least one
    +        # relationship has both source and target. keep relationships that have
    +        # either source or target, if none of the relationships have both source
    +        # and target.
    +
    +        if source_and_target_present:
    +            # First, separate the relationships into two groups: those with both
    +            # source and target in the selected entities, and those with either
    +            # source or target but not both.
     
    -        # prioritise relationships that have source and target, and discard
    -        # relationships that do not have both source and target, if at least one
    -        # relationship has both source and target. keep relationships that have
    -        # either source or target, if none of the relationships have both source
    -        # and target.
    -
    -        if source_and_target_present:
    -            # First, separate the relationships into two groups: those with both
    -            # source and target in the selected entities, and those with either
    -            # source or target but not both.
    -
    -            rels_with_both = {}
    -            rels_with_either = {}
    -            for key, value in rels.items():
    -                for pair in value:
    -                    if pair[0] in self.selected_entities:
    -                        if pair[1] in self.selected_entities:
    -                            rels_with_both[key] = value
    -                        else:
    -                            rels_with_either[key] = value
    -                    elif pair[1] in self.selected_entities:
    -                        rels_with_either[key] = value
    -
    -            # If there are any relationships with both source and target,
    -            # discard the others.
    -
    -            if rels_with_both:
    -                rels = rels_with_both
    -            else:
    -                rels = rels_with_either
    -
    -            selected_rels = []
    -            for key, value in rels.items():
    -                if not value:
    -                    continue
    -
    -                for pair in value:
    -                    if (
    -                        pair[0] in self.selected_entities
    -                        or pair[1] in self.selected_entities
    -                    ):
    -                        selected_rels.append((key, pair))
    -
    -            rels = json.dumps(selected_rels)
    -        else:
    -            rels = json.dumps(self.relationships)
    -
    -        msg = (
    -            "You have access to a knowledge graph that contains "
    -            f"these entities: {', '.join(self.selected_entities)}. "
    -            "Your task is to select the relationships that are relevant "
    -            "to the user's question for subsequent use in a query. Only "
    -            "return the relationships without their sources or targets, "
    -            "comma-separated, and without any additional text. Here are the "
    -            "possible relationships and their source and target entities: "
    -            f"{rels}."
    -        )
    -
    -        conversation.append_system_message(msg)
    -
    -        res, token_usage, correction = conversation.query(self.question)
    -
    -        result = res.split(",") if msg else []
    -
    -        if result:
    -            for relationship in result:
    -                relationship = relationship.strip()
    -                if relationship in self.relationships:
    -                    self.selected_relationships.append(relationship)
    -                    rel_dict = self.relationships[relationship]
    -                    label = rel_dict.get("label_as_edge", relationship)
    -                    if "source" in rel_dict and "target" in rel_dict:
    -                        self.selected_relationship_labels[label] = {
    -                            "source": rel_dict["source"],
    -                            "target": rel_dict["target"],
    -                        }
    -                    else:
    -                        self.selected_relationship_labels[label] = {
    -                            "source": None,
    -                            "target": None,
    -                        }
    -
    -        # if we selected relationships that have either source or target which
    -        # is not in the selected entities, we add those entities to the selected
    -        # entities.
    -
    -        if self.selected_relationship_labels:
    -            for key, value in self.selected_relationship_labels.items():
    -                sources = ensure_iterable(value["source"])
    -                targets = ensure_iterable(value["target"])
    -                for source in sources:
    -                    if source is None:
    -                        continue
    -                    if source not in self.selected_entities:
    -                        self.selected_entities.append(
    -                            sentencecase_to_pascalcase(source)
    -                        )
    -                for target in targets:
    -                    if target is None:
    -                        continue
    -                    if target not in self.selected_entities:
    -                        self.selected_entities.append(
    -                            sentencecase_to_pascalcase(target)
    -                        )
    -
    -        return bool(result)
    -
    -    @staticmethod
    -    def _validate_json_str(json_str: str):
    -        json_str = json_str.strip()
    -        if json_str.startswith("```json"):
    -            json_str = json_str[7:]
    -        if json_str.endswith("```"):
    -            json_str = json_str[:-3]
    -        return json_str.strip()
    -
    -    def _select_properties(self, conversation: "Conversation") -> bool:
    -        """
    -
    -        Given a question (optionally provided, but in the standard use case
    -        reused from the entity selection step) and the selected entities, select
    -        the properties that are relevant to the question and store them in
    -        the dictionary `selected_properties`.
    -
    -        Returns:
    -            True if at least one property was selected, False otherwise.
    -
    -        """
    -
    -        if not self.question:
    -            raise ValueError(
    -                "No question found. Please make sure to run entity and "
    -                "relationship selection first."
    -            )
    -
    -        if not self.selected_entities and not self.selected_relationships:
    -            raise ValueError(
    -                "No entities or relationships provided, and none available "
    -                "from entity selection step. Please provide "
    -                "entities/relationships or run the entity selection "
    -                "(`select_entities()`) step first."
    -            )
    -
    -        e_props = {}
    -        for entity in self.selected_entities:
    -            if self.entities[entity].get("properties"):
    -                e_props[entity] = list(
    -                    self.entities[entity]["properties"].keys()
    -                )
    -
    -        r_props = {}
    -        for relationship in self.selected_relationships:
    -            if self.relationships[relationship].get("properties"):
    -                r_props[relationship] = list(
    -                    self.relationships[relationship]["properties"].keys()
    -                )
    -
    -        msg = (
    -            "You have access to a knowledge graph that contains entities and "
    -            "relationships. They have the following properties. Entities:"
    -            f"{e_props}, Relationships: {r_props}. "
    -            "Your task is to select the properties that are relevant to the "
    -            "user's question for subsequent use in a query. Only return the "
    -            "entities and relationships with their relevant properties in compact "
    -            "JSON format, without any additional text. Return the "
    -            "entities/relationships as top-level dictionary keys, and their "
    -            "properties as dictionary values. "
    -            "Do not return properties that are not relevant to the question."
    -        )
    -
    -        conversation.append_system_message(msg)
    -
    -        msg, token_usage, correction = conversation.query(self.question)
    -        msg = BioCypherPromptEngine._validate_json_str(msg)
    -
    -        try:
    -            self.selected_properties = json.loads(msg) if msg else {}
    -        except json.decoder.JSONDecodeError:
    -            self.selected_properties = {}
    -
    -        return bool(self.selected_properties)
    -
    -    def _generate_query(
    -        self,
    -        question: str,
    -        entities: list,
    -        relationships: dict,
    -        properties: dict,
    -        query_language: str,
    -        conversation: "Conversation",
    -    ) -> str:
    -        """
    -        Generate a query in the specified query language that answers the user's
    -        question.
    -
    -        Args:
    -            question: A user's question.
    +            rels_with_both = {}
    +            rels_with_either = {}
    +            for key, value in rels.items():
    +                for pair in value:
    +                    if pair[0] in self.selected_entities:
    +                        if pair[1] in self.selected_entities:
    +                            rels_with_both[key] = value
    +                        else:
    +                            rels_with_either[key] = value
    +                    elif pair[1] in self.selected_entities:
    +                        rels_with_either[key] = value
    +
    +            # If there are any relationships with both source and target,
    +            # discard the others.
    +
    +            if rels_with_both:
    +                rels = rels_with_both
    +            else:
    +                rels = rels_with_either
    +
    +            selected_rels = []
    +            for key, value in rels.items():
    +                if not value:
    +                    continue
    +
    +                for pair in value:
    +                    if pair[0] in self.selected_entities or pair[1] in self.selected_entities:
    +                        selected_rels.append((key, pair))
    +
    +            rels = json.dumps(selected_rels)
    +        else:
    +            rels = json.dumps(self.relationships)
    +
    +        msg = (
    +            "You have access to a knowledge graph that contains "
    +            f"these entities: {', '.join(self.selected_entities)}. "
    +            "Your task is to select the relationships that are relevant "
    +            "to the user's question for subsequent use in a query. Only "
    +            "return the relationships without their sources or targets, "
    +            "comma-separated, and without any additional text. Here are the "
    +            "possible relationships and their source and target entities: "
    +            f"{rels}."
    +        )
    +
    +        conversation.append_system_message(msg)
    +
    +        res, token_usage, correction = conversation.query(self.question)
    +
    +        result = res.split(",") if msg else []
    +
    +        if result:
    +            for relationship in result:
    +                relationship = relationship.strip()
    +                if relationship in self.relationships:
    +                    self.selected_relationships.append(relationship)
    +                    rel_dict = self.relationships[relationship]
    +                    label = rel_dict.get("label_as_edge", relationship)
    +                    if "source" in rel_dict and "target" in rel_dict:
    +                        self.selected_relationship_labels[label] = {
    +                            "source": rel_dict["source"],
    +                            "target": rel_dict["target"],
    +                        }
    +                    else:
    +                        self.selected_relationship_labels[label] = {
    +                            "source": None,
    +                            "target": None,
    +                        }
    +
    +        # if we selected relationships that have either source or target which
    +        # is not in the selected entities, we add those entities to the selected
    +        # entities.
    +
    +        if self.selected_relationship_labels:
    +            for key, value in self.selected_relationship_labels.items():
    +                sources = ensure_iterable(value["source"])
    +                targets = ensure_iterable(value["target"])
    +                for source in sources:
    +                    if source is None:
    +                        continue
    +                    if source not in self.selected_entities:
    +                        self.selected_entities.append(
    +                            sentencecase_to_pascalcase(source),
    +                        )
    +                for target in targets:
    +                    if target is None:
    +                        continue
    +                    if target not in self.selected_entities:
    +                        self.selected_entities.append(
    +                            sentencecase_to_pascalcase(target),
    +                        )
    +
    +        return bool(result)
    +
    +    @staticmethod
    +    def _validate_json_str(json_str: str):
    +        json_str = json_str.strip()
    +        if json_str.startswith("```json"):
    +            json_str = json_str[7:]
    +        if json_str.endswith("```"):
    +            json_str = json_str[:-3]
    +        return json_str.strip()
    +
    +    def _select_properties(self, conversation: "Conversation") -> bool:
    +        """Given a question (optionally provided, but in the standard use case
    +        reused from the entity selection step) and the selected entities, select
    +        the properties that are relevant to the question and store them in
    +        the dictionary `selected_properties`.
    +
    +        Returns
    +        -------
    +            True if at least one property was selected, False otherwise.
    +
    +        """
    +        if not self.question:
    +            raise ValueError(
    +                "No question found. Please make sure to run entity and relationship selection first.",
    +            )
    +
    +        if not self.selected_entities and not self.selected_relationships:
    +            raise ValueError(
    +                "No entities or relationships provided, and none available "
    +                "from entity selection step. Please provide "
    +                "entities/relationships or run the entity selection "
    +                "(`select_entities()`) step first.",
    +            )
    +
    +        e_props = {}
    +        for entity in self.selected_entities:
    +            if self.entities[entity].get("properties"):
    +                e_props[entity] = list(
    +                    self.entities[entity]["properties"].keys(),
    +                )
    +
    +        r_props = {}
    +        for relationship in self.selected_relationships:
    +            if self.relationships[relationship].get("properties"):
    +                r_props[relationship] = list(
    +                    self.relationships[relationship]["properties"].keys(),
    +                )
    +
    +        msg = (
    +            "You have access to a knowledge graph that contains entities and "
    +            "relationships. They have the following properties. Entities:"
    +            f"{e_props}, Relationships: {r_props}. "
    +            "Your task is to select the properties that are relevant to the "
    +            "user's question for subsequent use in a query. Only return the "
    +            "entities and relationships with their relevant properties in compact "
    +            "JSON format, without any additional text. Return the "
    +            "entities/relationships as top-level dictionary keys, and their "
    +            "properties as dictionary values. "
    +            "Do not return properties that are not relevant to the question."
    +        )
    +
    +        conversation.append_system_message(msg)
    +
    +        msg, token_usage, correction = conversation.query(self.question)
    +        msg = BioCypherPromptEngine._validate_json_str(msg)
    +
    +        try:
    +            self.selected_properties = json.loads(msg) if msg else {}
    +        except json.decoder.JSONDecodeError:
    +            self.selected_properties = {}
    +
    +        return bool(self.selected_properties)
    +
    +    def _generate_query(
    +        self,
    +        question: str,
    +        entities: list,
    +        relationships: dict,
    +        properties: dict,
    +        query_language: str,
    +        conversation: "Conversation",
    +    ) -> str:
    +        """Generate a query in the specified query language that answers the user's
    +        question.
    +
    +        Args:
    +        ----
    +            question: A user's question.
    +
    +            entities: A list of entities that are relevant to the question.
    +
    +            relationships: A list of relationships that are relevant to the
    +                question.
    +
    +            properties: A dictionary of properties that are relevant to the
    +                question.
    +
    +            query_language: The language of the query to generate.
    +
    +            conversation: A BioChatter Conversation object for connecting to the
    +                LLM.
    +
    +        Returns:
    +        -------
    +            A database query that could answer the user's question.
     
    -            entities: A list of entities that are relevant to the question.
    -
    -            relationships: A list of relationships that are relevant to the
    -                question.
    -
    -            properties: A dictionary of properties that are relevant to the
    -                question.
    +        """
    +        msg = self._generate_query_prompt(
    +            entities,
    +            relationships,
    +            properties,
    +            query_language,
    +        )
     
    -            query_language: The language of the query to generate.
    +        conversation.append_system_message(msg)
     
    -            conversation: A BioChatter Conversation object for connecting to the
    -                LLM.
    -
    -        Returns:
    -            A database query that could answer the user's question.
    -        """
    -        msg = self._generate_query_prompt(
    -            entities,
    -            relationships,
    -            properties,
    -            query_language,
    -        )
    -
    -        conversation.append_system_message(msg)
    -
    -        out_msg, token_usage, correction = conversation.query(question)
    -
    -        return out_msg.strip()
    -
    -    def _expand_pairs(self, relationship, values) -> None:
    -        if not self.rel_directions.get(relationship):
    -            self.rel_directions[relationship] = []
    -        if isinstance(values["source"], list):
    -            for source in values["source"]:
    -                if isinstance(values["target"], list):
    -                    for target in values["target"]:
    -                        self.rel_directions[relationship].append(
    -                            (source, target)
    -                        )
    -                else:
    -                    self.rel_directions[relationship].append(
    -                        (source, values["target"])
    -                    )
    -        elif isinstance(values["target"], list):
    -            for target in values["target"]:
    -                self.rel_directions[relationship].append(
    -                    (values["source"], target)
    -                )
    -        else:
    -            self.rel_directions[relationship].append(
    -                (values["source"], values["target"])
    -            )
    +        out_msg, token_usage, correction = conversation.query(question)
    +
    +        return out_msg.strip()
    +
    +    def _expand_pairs(self, relationship, values) -> None:
    +        if not self.rel_directions.get(relationship):
    +            self.rel_directions[relationship] = []
    +        if isinstance(values["source"], list):
    +            for source in values["source"]:
    +                if isinstance(values["target"], list):
    +                    for target in values["target"]:
    +                        self.rel_directions[relationship].append(
    +                            (source, target),
    +                        )
    +                else:
    +                    self.rel_directions[relationship].append(
    +                        (source, values["target"]),
    +                    )
    +        elif isinstance(values["target"], list):
    +            for target in values["target"]:
    +                self.rel_directions[relationship].append(
    +                    (values["source"], target),
    +                )
    +        else:
    +            self.rel_directions[relationship].append(
    +                (values["source"], values["target"]),
    +            )
     
    @@ -2814,98 +2786,29 @@

    schema_config_or_info_path: Path to a biocypher schema configuration
    +    file or the extended schema information output generated by
    +    BioCypher's `write_schema_info` function (preferred).
     
    +schema_config_or_info_dict: A dictionary containing the schema
    +    configuration file or the extended schema information output
    +    generated by BioCypher's `write_schema_info` function
    +    (preferred).
     
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - schema_config_or_info_path - - Optional[str] - -
    -

    Path to a biocypher schema configuration -file or the extended schema information output generated by -BioCypher's write_schema_info function (preferred).

    -
    -
    - None -
    - schema_config_or_info_dict - - Optional[dict] - -
    -

    A dictionary containing the schema -configuration file or the extended schema information output -generated by BioCypher's write_schema_info function -(preferred).

    -
    -
    - None -
    - model_name - - str - -
    -

    The name of the model to use for the conversation. -DEPRECATED: This should now be set in the conversation factory.

    -
    -
    - 'gpt-3.5-turbo' -
    - conversation_factory - - Optional[Callable] - -
    -

    A function used to create a conversation for -creating the KG query. If not provided, a default function is -used (creating an OpenAI conversation with the specified model, -see _get_conversation).

    -
    -
    - None -
    +model_name: The name of the model to use for the conversation. + DEPRECATED: This should now be set in the conversation factory. + +conversation_factory: A function used to create a conversation for + creating the KG query. If not provided, a default function is + used (creating an OpenAI conversation with the specified model, + see `_get_conversation`). +

    Source code in biochatter/prompts.py -
     13
    +              
     12
    + 13
      14
      15
      16
    @@ -2997,137 +2900,100 @@ 

    102 103 104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124

    def __init__(
    -    self,
    -    schema_config_or_info_path: Optional[str] = None,
    -    schema_config_or_info_dict: Optional[dict] = None,
    -    model_name: str = "gpt-3.5-turbo",
    -    conversation_factory: Optional[Callable] = None,
    -) -> None:
    -    """
    -
    -    Given a biocypher schema configuration, extract the entities and
    -    relationships, and for each extract their mode of representation (node
    -    or edge), properties, and identifier namespace. Using these data, allow
    -    the generation of prompts for a large language model, informing it of
    -    the schema constituents and their properties, to enable the
    -    parameterisation of function calls to a knowledge graph.
    -
    -    Args:
    -        schema_config_or_info_path: Path to a biocypher schema configuration
    -            file or the extended schema information output generated by
    -            BioCypher's `write_schema_info` function (preferred).
    -
    -        schema_config_or_info_dict: A dictionary containing the schema
    -            configuration file or the extended schema information output
    -            generated by BioCypher's `write_schema_info` function
    -            (preferred).
    -
    -        model_name: The name of the model to use for the conversation.
    -            DEPRECATED: This should now be set in the conversation factory.
    -
    -        conversation_factory: A function used to create a conversation for
    -            creating the KG query. If not provided, a default function is
    -            used (creating an OpenAI conversation with the specified model,
    -            see `_get_conversation`).
    -    """
    -
    -    if not schema_config_or_info_path and not schema_config_or_info_dict:
    -        raise ValueError(
    -            "Please provide the schema configuration or schema info as a "
    -            "path to a file or as a dictionary."
    -        )
    -
    -    if schema_config_or_info_path and schema_config_or_info_dict:
    -        raise ValueError(
    -            "Please provide the schema configuration or schema info as a "
    -            "path to a file or as a dictionary, not both."
    -        )
    +105
    def __init__(
    +    self,
    +    schema_config_or_info_path: str | None = None,
    +    schema_config_or_info_dict: dict | None = None,
    +    model_name: str = "gpt-3.5-turbo",
    +    conversation_factory: Callable | None = None,
    +) -> None:
    +    """Given a biocypher schema configuration, extract the entities and
    +    relationships, and for each extract their mode of representation (node
    +    or edge), properties, and identifier namespace. Using these data, allow
    +    the generation of prompts for a large language model, informing it of
    +    the schema constituents and their properties, to enable the
    +    parameterisation of function calls to a knowledge graph.
    +
    +    Args:
    +    ----
    +        schema_config_or_info_path: Path to a biocypher schema configuration
    +            file or the extended schema information output generated by
    +            BioCypher's `write_schema_info` function (preferred).
    +
    +        schema_config_or_info_dict: A dictionary containing the schema
    +            configuration file or the extended schema information output
    +            generated by BioCypher's `write_schema_info` function
    +            (preferred).
    +
    +        model_name: The name of the model to use for the conversation.
    +            DEPRECATED: This should now be set in the conversation factory.
    +
    +        conversation_factory: A function used to create a conversation for
    +            creating the KG query. If not provided, a default function is
    +            used (creating an OpenAI conversation with the specified model,
    +            see `_get_conversation`).
    +
    +    """
    +    if not schema_config_or_info_path and not schema_config_or_info_dict:
    +        raise ValueError(
    +            "Please provide the schema configuration or schema info as a path to a file or as a dictionary.",
    +        )
    +
    +    if schema_config_or_info_path and schema_config_or_info_dict:
    +        raise ValueError(
    +            "Please provide the schema configuration or schema info as a "
    +            "path to a file or as a dictionary, not both.",
    +        )
    +
    +    # set conversation factory or use default
    +    self.conversation_factory = conversation_factory if conversation_factory is not None else self._get_conversation
     
    -    # set conversation factory or use default
    -    self.conversation_factory = (
    -        conversation_factory
    -        if conversation_factory is not None
    -        else self._get_conversation
    -    )
    +    if schema_config_or_info_path:
    +        # read the schema configuration
    +        with open(schema_config_or_info_path) as f:
    +            schema_config = yaml.safe_load(f)
    +    elif schema_config_or_info_dict:
    +        schema_config = schema_config_or_info_dict
     
    -    if schema_config_or_info_path:
    -        # read the schema configuration
    -        with open(schema_config_or_info_path, "r") as f:
    -            schema_config = yaml.safe_load(f)
    -    elif schema_config_or_info_dict:
    -        schema_config = schema_config_or_info_dict
    -
    -    # check whether it is the original schema config or the output of
    -    # biocypher info
    -    is_schema_info = schema_config.get("is_schema_info", False)
    -
    -    # extract the entities and relationships: each top level key that has
    -    # a 'represented_as' key
    -    self.entities = {}
    -    self.relationships = {}
    -    if not is_schema_info:
    -        for key, value in schema_config.items():
    -            # hacky, better with biocypher output
    -            name_indicates_relationship = (
    -                "interaction" in key.lower() or "association" in key.lower()
    -            )
    -            if "represented_as" in value:
    -                if (
    -                    value["represented_as"] == "node"
    -                    and not name_indicates_relationship
    -                ):
    -                    self.entities[sentencecase_to_pascalcase(key)] = value
    -                elif (
    -                    value["represented_as"] == "node"
    -                    and name_indicates_relationship
    -                ):
    -                    self.relationships[sentencecase_to_pascalcase(key)] = (
    -                        value
    -                    )
    -                elif value["represented_as"] == "edge":
    -                    self.relationships[sentencecase_to_pascalcase(key)] = (
    -                        value
    -                    )
    -    else:
    -        for key, value in schema_config.items():
    -            if not isinstance(value, dict):
    -                continue
    -            if value.get("present_in_knowledge_graph", None) == False:
    -                continue
    -            if value.get("is_relationship", None) == False:
    -                self.entities[sentencecase_to_pascalcase(key)] = value
    -            elif value.get("is_relationship", None) == True:
    -                value = self._capitalise_source_and_target(value)
    -                self.relationships[sentencecase_to_pascalcase(key)] = value
    -
    -    self.question = ""
    -    self.selected_entities = []
    -    self.selected_relationships = []  # used in property selection
    -    self.selected_relationship_labels = {}  # copy to deal with labels that
    -    # are not the same as the relationship name, used in query generation
    -    # dictionary to also include source and target types
    -    self.rel_directions = {}
    -    self.model_name = model_name
    +    # check whether it is the original schema config or the output of
    +    # biocypher info
    +    is_schema_info = schema_config.get("is_schema_info", False)
    +
    +    # extract the entities and relationships: each top level key that has
    +    # a 'represented_as' key
    +    self.entities = {}
    +    self.relationships = {}
    +    if not is_schema_info:
    +        for key, value in schema_config.items():
    +            # hacky, better with biocypher output
    +            name_indicates_relationship = "interaction" in key.lower() or "association" in key.lower()
    +            if "represented_as" in value:
    +                if value["represented_as"] == "node" and not name_indicates_relationship:
    +                    self.entities[sentencecase_to_pascalcase(key)] = value
    +                elif (value["represented_as"] == "node" and name_indicates_relationship) or value[
    +                    "represented_as"
    +                ] == "edge":
    +                    self.relationships[sentencecase_to_pascalcase(key)] = value
    +    else:
    +        for key, value in schema_config.items():
    +            if not isinstance(value, dict):
    +                continue
    +            if value.get("present_in_knowledge_graph", None) == False:
    +                continue
    +            if value.get("is_relationship", None) == False:
    +                self.entities[sentencecase_to_pascalcase(key)] = value
    +            elif value.get("is_relationship", None) == True:
    +                value = self._capitalise_source_and_target(value)
    +                self.relationships[sentencecase_to_pascalcase(key)] = value
    +
    +    self.question = ""
    +    self.selected_entities = []
    +    self.selected_relationships = []  # used in property selection
    +    self.selected_relationship_labels = {}  # copy to deal with labels that
    +    # are not the same as the relationship name, used in query generation
    +    # dictionary to also include source and target types
    +    self.rel_directions = {}
    +    self.model_name = model_name
     
    @@ -3147,80 +3013,37 @@

    question: A user's question.
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    A user's question.

    -
    -
    - required -
    - query_language - - Optional[str] - -
    -

    The language of the query to generate.

    -
    -
    - 'Cypher' -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    TypeDescription
    - str - -
    -

    A database query that could answer the user's question.

    -
    -
    +query_language: The language of the query to generate. +
    +
    +
    A database query that could answer the user's question.
    +
    Source code in biochatter/prompts.py -
    253
    +              
    234
    +235
    +236
    +237
    +238
    +239
    +240
    +241
    +242
    +243
    +244
    +245
    +246
    +247
    +248
    +249
    +250
    +251
    +252
    +253
     254
     255
     256
    @@ -3232,49 +3055,38 @@ 

    262 263 264 -265 -266 -267 -268 -269 -270 -271 -272 -273 -274 -275 -276 -277 -278 -279 -280

    def generate_query(
    -    self, question: str, query_language: Optional[str] = "Cypher"
    -) -> str:
    -    """
    -    Wrap entity and property selection and query generation; return the
    -    generated query.
    -
    -    Args:
    -        question: A user's question.
    -
    -        query_language: The language of the query to generate.
    -
    -    Returns:
    -        A database query that could answer the user's question.
    -    """
    -
    -    self._select_graph_entities_from_question(
    -        question, self.conversation_factory()
    -    )
    -
    -    return self._generate_query(
    -        question=question,
    -        entities=self.selected_entities,
    -        relationships=self.selected_relationship_labels,
    -        properties=self.selected_properties,
    -        query_language=query_language,
    -        conversation=self.conversation_factory(),
    -    )
    +265
    def generate_query(
    +    self,
    +    question: str,
    +    query_language: str | None = "Cypher",
    +) -> str:
    +    """Wrap entity and property selection and query generation; return the
    +    generated query.
    +
    +    Args:
    +    ----
    +        question: A user's question.
    +
    +        query_language: The language of the query to generate.
    +
    +    Returns:
    +    -------
    +        A database query that could answer the user's question.
    +
    +    """
    +    self._select_graph_entities_from_question(
    +        question,
    +        self.conversation_factory(),
    +    )
    +
    +    return self._generate_query(
    +        question=question,
    +        entities=self.selected_entities,
    +        relationships=self.selected_relationship_labels,
    +        properties=self.selected_properties,
    +        query_language=query_language,
    +        conversation=self.conversation_factory(),
    +    )
     
    @@ -3295,130 +3107,78 @@

    question: A user's question.
     
    -
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - question - - str - -
    -

    A user's question.

    -
    -
    - required -
    - query_language - - Optional[str] - -
    -

    The language of the query to generate.

    -
    -
    - 'Cypher' -
    - - -

    Returns:

    - - - - - - - - - - - - - -
    TypeDescription
    - str - -
    -

    A prompt for a large language model to generate a database query.

    -
    -
    +query_language: The language of the query to generate. +
    +
    +
    A prompt for a large language model to generate a database query.
    +
    Source code in biochatter/prompts.py -
    226
    +              
    def generate_query_prompt(
    -    self, question: str, query_language: Optional[str] = "Cypher"
    -) -> str:
    -    """
    -    Generate a prompt for a large language model to generate a database
    -    query based on the user's question and class attributes informing about
    -    the schema.
    -
    -    Args:
    -        question: A user's question.
    -
    -        query_language: The language of the query to generate.
    -
    -    Returns:
    -        A prompt for a large language model to generate a database query.
    -    """
    -    self._select_graph_entities_from_question(
    -        question, self.conversation_factory()
    -    )
    -    msg = self._generate_query_prompt(
    -        self.selected_entities,
    -        self.selected_relationship_labels,
    -        self.selected_properties,
    -        query_language,
    -    )
    -    return msg
    +232
    def generate_query_prompt(
    +    self,
    +    question: str,
    +    query_language: str | None = "Cypher",
    +) -> str:
    +    """Generate a prompt for a large language model to generate a database
    +    query based on the user's question and class attributes informing about
    +    the schema.
    +
    +    Args:
    +    ----
    +        question: A user's question.
    +
    +        query_language: The language of the query to generate.
    +
    +    Returns:
    +    -------
    +        A prompt for a large language model to generate a database query.
    +
    +    """
    +    self._select_graph_entities_from_question(
    +        question,
    +        self.conversation_factory(),
    +    )
    +    msg = self._generate_query_prompt(
    +        self.selected_entities,
    +        self.selected_relationship_labels,
    +        self.selected_properties,
    +        query_language,
    +    )
    +    return msg
     
    @@ -3487,7 +3247,8 @@

    Source code in biochatter/database_agent.py -
     13
    +                
     12
    + 13
      14
      15
      16
    @@ -3650,36 +3411,30 @@ 

    173 174 175 -176 -177 -178 -179 -180 -181 -182 -183

    class DatabaseAgent:
    -    def __init__(
    -        self,
    -        model_name: str,
    -        connection_args: dict,
    -        schema_config_or_info_dict: dict,
    -        conversation_factory: Callable,
    -        use_reflexion: bool,
    -    ) -> None:
    -        """
    -        Create a DatabaseAgent analogous to the VectorDatabaseAgentMilvus class,
    -        which can return results from a database using a query engine. Currently
    -        limited to Neo4j for development.
    -
    -        Args:
    -            connection_args (dict): A dictionary of arguments to connect to the
    -                database. Contains database name, URI, user, and password.
    -
    -            conversation_factory (Callable): A function to create a conversation
    -                for creating the KG query.
    -
    -            use_reflexion (bool): Whether to use the ReflexionAgent to generate
    -                the query.
    +176
    class DatabaseAgent:
    +    def __init__(
    +        self,
    +        model_name: str,
    +        connection_args: dict,
    +        schema_config_or_info_dict: dict,
    +        conversation_factory: Callable,
    +        use_reflexion: bool,
    +    ) -> None:
    +        """Create a DatabaseAgent analogous to the VectorDatabaseAgentMilvus class,
    +        which can return results from a database using a query engine. Currently
    +        limited to Neo4j for development.
    +
    +        Args:
    +        ----
    +            connection_args (dict): A dictionary of arguments to connect to the
    +                database. Contains database name, URI, user, and password.
    +
    +            conversation_factory (Callable): A function to create a conversation
    +                for creating the KG query.
    +
    +            use_reflexion (bool): Whether to use the ReflexionAgent to generate
    +                the query.
    +
             """
             self.conversation_factory = conversation_factory
             self.prompt_engine = BioCypherPromptEngine(
    @@ -3692,142 +3447,135 @@ 

    self.use_reflexion = use_reflexion def connect(self) -> None: - """ - Connect to the database and authenticate. - """ - db_name = self.connection_args.get("db_name") - uri = f"{self.connection_args.get('host')}:{self.connection_args.get('port')}" - uri = uri if uri.startswith("bolt://") else "bolt://" + uri - user = self.connection_args.get("user") - password = self.connection_args.get("password") - self.driver = nu.Driver( - db_name=db_name or "neo4j", - db_uri=uri, - user=user, - password=password, - ) - - def is_connected(self) -> bool: - return not self.driver is None - - def _generate_query(self, query: str): - if self.use_reflexion: - agent = KGQueryReflexionAgent( - self.conversation_factory, - self.connection_args, - ) - query_prompt = self.prompt_engine.generate_query_prompt(query) - agent_result = agent.execute(query, query_prompt) - tool_result = ( - [agent_result.tool_result] - if agent_result.tool_result is not None - else None - ) - return agent_result.answer, tool_result - else: - query = self.prompt_engine.generate_query(query) - results = self.driver.query(query=query) - return query, results - - def _build_response( - self, - results: List[Dict], - cypher_query: str, - results_num: Optional[int] = 3, - ) -> List[Document]: - if len(results) == 0: - return [ - Document( - page_content=( - "I didn't find any result in knowledge graph, " - f"but here is the query I used: {cypher_query}. " - "You can ask user to refine the question. " - "Note: please ensure to include the query in a code " - "block in your response so that the user can refine " - "their question effectively." - ), - metadata={"cypher_query": cypher_query}, - ) - ] - - clipped_results = results[:results_num] if results_num > 0 else results - results_dump = json.dumps(clipped_results) - - return [ - Document( - page_content=( - "The results retrieved from knowledge graph are: " - f"{results_dump}. " - f"The query used is: {cypher_query}. " - "Note: please ensure to include the query in a code block " - "in your response so that the user can refine " - "their question effectively." - ), - metadata={"cypher_query": cypher_query}, - ) - ] - - def get_query_results(self, query: str, k: int = 3) -> list[Document]: - """ - Generate a query using the prompt engine and return the results. - Replicates vector database similarity search API. Results are returned - as a list of Document objects to align with the vector database agent. - - Args: - query (str): A query string. - - k (int): The number of results to return. - - Returns: - List[Document]: A list of Document objects. The page content values - are the literal dictionaries returned by the query, the metadata - values are the cypher query used to generate the results, for - now. - """ - (cypher_query, tool_result) = self._generate_query( - query - ) # self.prompt_engine.generate_query(query) - # TODO some logic if it fails? - if tool_result is not None: - # If _generate_query() already returned tool_result, we won't connect - # to graph database to query result any more - results = tool_result - else: - results = self.driver.query(query=cypher_query) - - # return first k results - # returned nodes can have any formatting, and can also be empty or fewer - # than k - if results is None or len(results) == 0 or results[0] is None: - return [] - return self._build_response( - results=results[0], cypher_query=cypher_query, results_num=k - ) - - def get_description(self): - result = self.driver.query("MATCH (n:Schema_info) RETURN n LIMIT 1") - - if result[0]: - schema_info_node = result[0][0]["n"] - schema_dict_content = schema_info_node["schema_info"][ - :MAX_AGENT_DESC_LENGTH - ] # limit to 1000 characters - return ( - f"the graph database contains the following nodes and edges: \n\n" - f"{schema_dict_content}" - ) - - # schema_info is not found in database - nodes_query = "MATCH (n) RETURN DISTINCT labels(n) LIMIT 300" - node_results = self.driver.query(query=nodes_query) - edges_query = "MATCH (n) RETURN DISTINCT type(n) LIMIT 300" - edge_results = self.driver.query(query=edges_query) - desc = ( - f"The graph database contains the following nodes and edges: \n" - f"nodes: \n{node_results}" - f"edges: \n{edge_results}" - ) - return desc[:MAX_AGENT_DESC_LENGTH] + """Connect to the database and authenticate.""" + db_name = self.connection_args.get("db_name") + uri = f"{self.connection_args.get('host')}:{self.connection_args.get('port')}" + uri = uri if uri.startswith("bolt://") else "bolt://" + uri + user = self.connection_args.get("user") + password = self.connection_args.get("password") + self.driver = nu.Driver( + db_name=db_name or "neo4j", + db_uri=uri, + user=user, + password=password, + ) + + def is_connected(self) -> bool: + return self.driver is not None + + def _generate_query(self, query: str): + if self.use_reflexion: + agent = KGQueryReflexionAgent( + self.conversation_factory, + self.connection_args, + ) + query_prompt = self.prompt_engine.generate_query_prompt(query) + agent_result = agent.execute(query, query_prompt) + tool_result = [agent_result.tool_result] if agent_result.tool_result is not None else None + return agent_result.answer, tool_result + else: + query = self.prompt_engine.generate_query(query) + results = self.driver.query(query=query) + return query, results + + def _build_response( + self, + results: list[dict], + cypher_query: str, + results_num: int | None = 3, + ) -> list[Document]: + if len(results) == 0: + return [ + Document( + page_content=( + "I didn't find any result in knowledge graph, " + f"but here is the query I used: {cypher_query}. " + "You can ask user to refine the question. " + "Note: please ensure to include the query in a code " + "block in your response so that the user can refine " + "their question effectively." + ), + metadata={"cypher_query": cypher_query}, + ), + ] + + clipped_results = results[:results_num] if results_num > 0 else results + results_dump = json.dumps(clipped_results) + + return [ + Document( + page_content=( + "The results retrieved from knowledge graph are: " + f"{results_dump}. " + f"The query used is: {cypher_query}. " + "Note: please ensure to include the query in a code block " + "in your response so that the user can refine " + "their question effectively." + ), + metadata={"cypher_query": cypher_query}, + ), + ] + + def get_query_results(self, query: str, k: int = 3) -> list[Document]: + """Generate a query using the prompt engine and return the results. + Replicates vector database similarity search API. Results are returned + as a list of Document objects to align with the vector database agent. + + Args: + ---- + query (str): A query string. + + k (int): The number of results to return. + + Returns: + ------- + List[Document]: A list of Document objects. The page content values + are the literal dictionaries returned by the query, the metadata + values are the cypher query used to generate the results, for + now. + + """ + (cypher_query, tool_result) = self._generate_query( + query, + ) # self.prompt_engine.generate_query(query) + # TODO some logic if it fails? + if tool_result is not None: + # If _generate_query() already returned tool_result, we won't connect + # to graph database to query result any more + results = tool_result + else: + results = self.driver.query(query=cypher_query) + + # return first k results + # returned nodes can have any formatting, and can also be empty or fewer + # than k + if results is None or len(results) == 0 or results[0] is None: + return [] + return self._build_response( + results=results[0], + cypher_query=cypher_query, + results_num=k, + ) + + def get_description(self): + result = self.driver.query("MATCH (n:Schema_info) RETURN n LIMIT 1") + + if result[0]: + schema_info_node = result[0][0]["n"] + schema_dict_content = schema_info_node["schema_info"][:MAX_AGENT_DESC_LENGTH] # limit to 1000 characters + return f"the graph database contains the following nodes and edges: \n\n{schema_dict_content}" + + # schema_info is not found in database + nodes_query = "MATCH (n) RETURN DISTINCT labels(n) LIMIT 300" + node_results = self.driver.query(query=nodes_query) + edges_query = "MATCH (n) RETURN DISTINCT type(n) LIMIT 300" + edge_results = self.driver.query(query=edges_query) + desc = ( + f"The graph database contains the following nodes and edges: \n" + f"nodes: \n{node_results}" + f"edges: \n{edge_results}" + ) + return desc[:MAX_AGENT_DESC_LENGTH]

    @@ -3857,76 +3605,21 @@

    connection_args (dict): A dictionary of arguments to connect to the
    +    database. Contains database name, URI, user, and password.
     
    +conversation_factory (Callable): A function to create a conversation
    +    for creating the KG query.
     
    -

    Parameters:

    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    NameTypeDescriptionDefault
    - connection_args - - dict - -
    -

    A dictionary of arguments to connect to the -database. Contains database name, URI, user, and password.

    -
    -
    - required -
    - conversation_factory - - Callable - -
    -

    A function to create a conversation -for creating the KG query.

    -
    -
    - required -
    - use_reflexion - - bool - -
    -

    Whether to use the ReflexionAgent to generate -the query.

    -
    -
    - required -
    +use_reflexion (bool): Whether to use the ReflexionAgent to generate + the query. +

    Source code in biochatter/database_agent.py -
    14
    +              
    13
    +14
     15
     16
     17
    @@ -3957,28 +3650,29 @@ 

    42 43 44 -45

    def __init__(
    -    self,
    -    model_name: str,
    -    connection_args: dict,
    -    schema_config_or_info_dict: dict,
    -    conversation_factory: Callable,
    -    use_reflexion: bool,
    -) -> None:
    -    """
    -    Create a DatabaseAgent analogous to the VectorDatabaseAgentMilvus class,
    -    which can return results from a database using a query engine. Currently
    -    limited to Neo4j for development.
    -
    -    Args:
    -        connection_args (dict): A dictionary of arguments to connect to the
    -            database. Contains database name, URI, user, and password.
    -
    -        conversation_factory (Callable): A function to create a conversation
    -            for creating the KG query.
    -
    -        use_reflexion (bool): Whether to use the ReflexionAgent to generate
    -            the query.
    +45
    def __init__(
    +    self,
    +    model_name: str,
    +    connection_args: dict,
    +    schema_config_or_info_dict: dict,
    +    conversation_factory: Callable,
    +    use_reflexion: bool,
    +) -> None:
    +    """Create a DatabaseAgent analogous to the VectorDatabaseAgentMilvus class,
    +    which can return results from a database using a query engine. Currently
    +    limited to Neo4j for development.
    +
    +    Args:
    +    ----
    +        connection_args (dict): A dictionary of arguments to connect to the
    +            database. Contains database name, URI, user, and password.
    +
    +        conversation_factory (Callable): A function to create a conversation
    +            for creating the KG query.
    +
    +        use_reflexion (bool): Whether to use the ReflexionAgent to generate
    +            the query.
    +
         """
         self.conversation_factory = conversation_factory
         self.prompt_engine = BioCypherPromptEngine(
    @@ -4022,23 +3716,19 @@ 

    56 57 58 -59 -60 -61

    def connect(self) -> None:
    -    """
    -    Connect to the database and authenticate.
    -    """
    -    db_name = self.connection_args.get("db_name")
    -    uri = f"{self.connection_args.get('host')}:{self.connection_args.get('port')}"
    -    uri = uri if uri.startswith("bolt://") else "bolt://" + uri
    -    user = self.connection_args.get("user")
    -    password = self.connection_args.get("password")
    -    self.driver = nu.Driver(
    -        db_name=db_name or "neo4j",
    -        db_uri=uri,
    -        user=user,
    -        password=password,
    -    )
    +59
    def connect(self) -> None:
    +    """Connect to the database and authenticate."""
    +    db_name = self.connection_args.get("db_name")
    +    uri = f"{self.connection_args.get('host')}:{self.connection_args.get('port')}"
    +    uri = uri if uri.startswith("bolt://") else "bolt://" + uri
    +    user = self.connection_args.get("user")
    +    password = self.connection_args.get("password")
    +    self.driver = nu.Driver(
    +        db_name=db_name or "neo4j",
    +        db_uri=uri,
    +        user=user,
    +        password=password,
    +    )
     
    @@ -4059,83 +3749,27 @@