Skip to content

Commit

Permalink
feat: introduce tool selector agent (#183)
Browse files Browse the repository at this point in the history
* introduce DeciderAgent

* introduce dicider agent

* add test for decider agent

* remove test code

* add comments

* phrasing and formatting

* more intuitive naming of classes ("selector")

* remove ChatGSE mentions

* get last tool_result from message list and introduce rag agent selector to conversation

* fix test code

* fix minor errors

* conditional to not run coverage badge and push in PR CI

* formatting, phrasing

* formatting

* pre-commit

* add missed type'List'

* fix merge error

* restore method get_description

* use `List` from typing for type hints consistently

* use `Tuple` class for type hinting

* use `Dict` From typing for type hints

* pre-commit

* add missing type `List`

* missed some `Dict`s

---------

Co-authored-by: fengsh <[email protected]>
Co-authored-by: slobentanzer <[email protected]>
  • Loading branch information
3 people authored Aug 1, 2024
1 parent 793ca55 commit f058565
Show file tree
Hide file tree
Showing 41 changed files with 1,230 additions and 400 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ jobs:
shell: bash

- name: Generate coverage badge
if: (github.event_name == 'push' && github.ref == 'refs/heads/main') || (github.event_name == 'pull_request' && github.event.action == 'closed' && github.base_ref == 'refs/heads/main' && github.event.pull_request.merged == true)
run: poetry run coverage-badge -f -o docs/coverage/coverage.svg
shell: bash

- name: Commit changes
if: (github.event_name == 'push' && github.ref == 'refs/heads/main') || (github.event_name == 'pull_request' && github.event.action == 'closed' && github.base_ref == 'refs/heads/main' && github.event.pull_request.merged == true)
uses: s0/git-publish-subdir-action@develop
env:
REPO: self
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ benchmark/data/*
serve.sh
.blast/*
.api_results/*
*.coverage
*.coverage
1 change: 1 addition & 0 deletions benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
import re

from nltk.corpus import wordnet
import pytest
import importlib_metadata
Expand Down
15 changes: 8 additions & 7 deletions benchmark/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import hashlib
import itertools
from typing import Dict

from cryptography.fernet import Fernet
import rsa
Expand All @@ -13,7 +14,7 @@
import pandas as pd


def get_benchmark_dataset() -> dict[str, pd.DataFrame | dict[str, str]]:
def get_benchmark_dataset() -> Dict[str, pd.DataFrame | Dict[str, str]]:
"""
Get benchmark dataset:
Expand All @@ -34,7 +35,7 @@ def get_benchmark_dataset() -> dict[str, pd.DataFrame | dict[str, str]]:
return test_data


def _load_hold_out_test_dataset() -> dict[str, pd.DataFrame | dict[str, str]]:
def _load_hold_out_test_dataset() -> Dict[str, pd.DataFrame | Dict[str, str]]:
"""Load hold out test dataset.
Returns:
Expand Down Expand Up @@ -208,7 +209,7 @@ def _get_private_key_from_env_variable() -> rsa.PrivateKey:
return private_key


def _get_encrypted_test_data() -> dict[str, dict[str, str]]:
def _get_encrypted_test_data() -> Dict[str, Dict[str, str]]:
"""Get encrypted test data.
currently from manually copied file benchmark/encrypted_llm_test_data.json
TODO: automatically load test dataset (from github releases)?
Expand Down Expand Up @@ -255,11 +256,11 @@ def _decrypt_data(
return decrypted_test_data


def _decrypt(payload: dict[str, str], private_key: rsa.PrivateKey) -> str:
def _decrypt(payload: Dict[str, str], private_key: rsa.PrivateKey) -> str:
"""Decrypt a payload.
Args:
payload (dict[str, str]): Payload with key and data to decrypt.
payload (Dict[str, str]): Payload with key and data to decrypt.
private_key (rsa.PrivateKey): Private key to decrypt the payload.
Returns:
Expand All @@ -280,7 +281,7 @@ def _apply_literal_eval(df: pd.DataFrame, columns: list[str]):
Args:
df (pd.DataFrame): Dataframe.
columns (list[str]): Columns to apply literal_eval to.
columns (List[str]): Columns to apply literal_eval to.
"""
for col_name in columns:
if col_name in df.columns:
Expand All @@ -296,7 +297,7 @@ def _get_all_files(directory: str) -> list[str]:
directory (str): Path to directory.
Returns:
list[str]: List of file paths.
List[str]: List of file paths.
"""
all_files = []
for root, dirs, files in os.walk(directory):
Expand Down
42 changes: 21 additions & 21 deletions benchmark/results/medical_exam_failure_modes.csv

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion benchmark/test_api_calling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from urllib.parse import urlencode
import inspect

import pytest

Expand Down
11 changes: 7 additions & 4 deletions benchmark/test_text_image_multimodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@

# Load the data
import os
import pytest
import inspect
import hashlib
import inspect

import pytest

import numpy as np

from benchmark.conftest import calculate_bool_vector_score
from .benchmark_utils import (
get_confidence_file_path,
skip_if_already_run,
get_result_file_path,
write_confidence_to_file,
write_results_to_file,
get_confidence_file_path,
write_confidence_to_file,
)


Expand Down
14 changes: 7 additions & 7 deletions benchmark/test_user_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from biochatter._misc import ensure_iterable
from .conftest import calculate_bool_vector_score
from .benchmark_utils import (
categorize_failure_modes,
skip_if_already_run,
get_result_file_path,
write_results_to_file,
categorize_failure_modes,
get_failure_mode_file_path,
write_failure_modes_to_file,
)
Expand All @@ -23,13 +23,13 @@ def test_medical_exam(
multiple_testing,
):
"""Test medical exam data by the model.
The user input is a medical question with answer options. The system prompt
has the guidelines to answer the question, and the expected answer is the
information that the model should reply from the given question. If the case
contains the word 'regex', the test is successful if the extracted information
occures in the words in response. If it is a different question, the test is
The user input is a medical question with answer options. The system prompt
has the guidelines to answer the question, and the expected answer is the
information that the model should reply from the given question. If the case
contains the word 'regex', the test is successful if the extracted information
occures in the words in response. If it is a different question, the test is
successful if the extracted information matches the expected answer exactly.
For all false answers also calculate the failure mode of the answer.
For all false answers also calculate the failure mode of the answer.
"""
# Downloads the naturale language synonym toolkit, just need to be done once per device
# nltk.download()
Expand Down
11 changes: 5 additions & 6 deletions biochatter/_image.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Functions for image encoding

import subprocess
import base64
import io
import os
import pdf2image
from PIL import Image
import base64
import tempfile # needed for test
import subprocess

from PIL import Image
import pdf2image


def convert_and_resize_image(image: Image, max_size: int = 1024) -> Image:
Expand Down
10 changes: 5 additions & 5 deletions biochatter/api_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .abc import BaseQueryBuilder, BaseFetcher, BaseInterpreter
from .api_agent import APIAgent
from .abc import BaseFetcher, BaseInterpreter, BaseQueryBuilder
from .blast import (
BlastQueryParameters,
BlastQueryBuilder,
BlastFetcher,
BlastInterpreter,
BlastQueryBuilder,
BlastQueryParameters,
)
from .oncokb import OncoKBQueryBuilder, OncoKBFetcher, OncoKBInterpreter
from .oncokb import OncoKBFetcher, OncoKBInterpreter, OncoKBQueryBuilder
from .api_agent import APIAgent
3 changes: 2 additions & 1 deletion biochatter/api_agent/abc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Optional
from collections.abc import Callable

from pydantic import BaseModel
Expand Down Expand Up @@ -94,6 +94,7 @@ class BaseFetcher(ABC):
def fetch_results(
self,
query_model,
retries: Optional[int] = 3,
):
"""
Fetches results by submitting a query. Can implement a multi-step
Expand Down
6 changes: 6 additions & 0 deletions biochatter/api_agent/api_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ def execute(self, question: str) -> Optional[str]:

self.final_answer = final_answer
return final_answer

def get_description(self, tool_name: str, tool_desc: str):
return (
f"This API agent interacts with {tool_name}'s API for querying and "
f"fetching results. {tool_desc}"
)
2 changes: 1 addition & 1 deletion biochatter/api_agent/blast.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _fetch_results(
)

def fetch_results(
self, query_model: BlastQueryParameters, retries: int = 10000
self, query_model: BlastQueryParameters, retries: int = 20
) -> str:
"""
Submit request and fetch results from BLAST API. Wraps individual
Expand Down
4 changes: 3 additions & 1 deletion biochatter/api_agent/oncokb.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def __init__(self, api_token="demo"):
}
self.base_url = "https://demo.oncokb.org/api/v1"

def fetch_results(self, request_data: OncoKBQueryParameters) -> str:
def fetch_results(
self, request_data: OncoKBQueryParameters, retries: Optional[int] = 3
) -> str:
"""Function to submit the OncoKB query and fetch the results directly.
No multi-step procedure, thus no wrapping of submission and retrieval in
this case.
Expand Down
1 change: 1 addition & 0 deletions biochatter/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MAX_AGENT_DESC_LENGTH = 1000
41 changes: 36 additions & 5 deletions biochatter/database_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import neo4j_utils as nu

from .prompts import BioCypherPromptEngine
from .constants import MAX_AGENT_DESC_LENGTH
from .kg_langgraph_agent import KGQueryReflexionAgent


Expand Down Expand Up @@ -62,8 +63,8 @@ def _generate_query(self, query: str):
self.connection_args,
)
query_prompt = self.prompt_engine.generate_query_prompt(query)
cypher_query = agent.execute(query, query_prompt)
return cypher_query
agent_result = agent.execute(query, query_prompt)
return agent_result.answer, agent_result.tool_result

def get_query_results(self, query: str, k: int = 3) -> list[Document]:
"""
Expand All @@ -77,16 +78,21 @@ def get_query_results(self, query: str, k: int = 3) -> list[Document]:
k (int): The number of results to return.
Returns:
list[Document]: A list of Document objects. The page content values
List[Document]: A list of Document objects. The page content values
are the literal dictionaries returned by the query, the metadata
values are the cypher query used to generate the results, for
now.
"""
cypher_query = self._generate_query(
(cypher_query, tool_result) = self._generate_query(
query
) # self.prompt_engine.generate_query(query)
# TODO some logic if it fails?
results = self.driver.query(query=cypher_query)
if tool_result is not None:
# If _generate_query() already returned tool_result, we won't connect
# to graph database to query result any more
results = [tool_result]
else:
results = self.driver.query(query=cypher_query)

documents = []
# return first k results
Expand All @@ -107,3 +113,28 @@ def get_query_results(self, query: str, k: int = 3) -> list[Document]:
break

return documents

def get_description(self):
result = self.driver.query("MATCH (n:Schema_info) RETURN n LIMIT 1")

if result[0]:
schema_info_node = result[0][0]["n"]
schema_dict_content = schema_info_node["schema_info"][
:MAX_AGENT_DESC_LENGTH
] # limit to 1000 characters
return (
f"the graph database contains the following nodes and edges: \n\n"
f"{schema_dict_content}"
)

# schema_info is not found in database
nodes_query = "MATCH (n) RETURN DISTINCT labels(n) LIMIT 300"
node_results = self.driver.query(query=nodes_query)
edges_query = "MATCH (n) RETURN DISTINCT type(n) LIMIT 300"
edge_results = self.driver.query(query=edges_query)
desc = (
f"The graph database contains the following nodes and edges: \n"
f"nodes: \n{node_results}"
f"edges: \n{edge_results}"
)
return desc[:MAX_AGENT_DESC_LENGTH]
Loading

0 comments on commit f058565

Please sign in to comment.