diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..a5881bc --- /dev/null +++ b/config.yaml @@ -0,0 +1,23 @@ +tools: + - name: "constitution_tool" + description: "Answers questions about the U.S. Constitution." + url: "https://my.app/v1/completions" + config: + method: 'POST' + headers: + 'Content-Type': 'application/json' + 'Authorization': 'Basic 12345' + body: + prompt: '{{prompt}}' + responseParser: 'json.answer' + responseMetadata: + - name: 'sources' + loc: 'json.sources' + responseFormat: + agent: '{{response}}' + json: + - "response" + - "sources" + examples: + - "What is the definition of a citizen in the U.S. Constitution?" + - "What article describes the power of the judiciary branch?" \ No newline at end of file diff --git a/react_agent/__init__.py b/react_agent/__init__.py new file mode 100644 index 0000000..8e90d03 --- /dev/null +++ b/react_agent/__init__.py @@ -0,0 +1,3 @@ +"""Init.""" + +__version__ = "0.1.0" \ No newline at end of file diff --git a/react_agent/agent/__init__.py b/react_agent/agent/__init__.py new file mode 100644 index 0000000..d7826b4 --- /dev/null +++ b/react_agent/agent/__init__.py @@ -0,0 +1,17 @@ +"""Define agents.""" + +from langchain.agents import AgentExecutor, create_react_agent +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate + +from react_agent.constants import REACT_PROMPT +from react_agent.common.llm import chat_llm, completion_llm +from react_agent.tools import import_tools + +def react_agent(): + """Create a ReAct agent.""" + response_format = "agent" + tools = import_tools(common_tools_kwargs={"response_format": response_format}) + prompt = PromptTemplate.from_template(REACT_PROMPT) + agent = create_react_agent(completion_llm, tools, prompt) + agent_executor = AgentExecutor(name="ReActAgent", agent=agent, tools=tools, handle_parsing_errors=True) + return agent_executor \ No newline at end of file diff --git a/react_agent/api.py b/react_agent/api.py new file mode 100644 index 0000000..93b198c --- /dev/null +++ b/react_agent/api.py @@ -0,0 +1,56 @@ +from dotenv import load_dotenv + +load_dotenv() + +import logging +from contextlib import asynccontextmanager + +import mlflow +import uvicorn +from fastapi import FastAPI + +from react_agent.agent import react_agent +from react_agent.apispec import ReActRequest, ReActResponse +from react_agent.constants import APP_HOST, APP_PORT + +logger = logging.getLogger(__name__) + +# Start logging +mlflow.langchain.autolog(log_traces=True) + +agents = {} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Run startup sequence.""" + # Add them to the application components + agents["react"] = react_agent() + + logger.info("Startup sequence successful") + yield + agents.clear() + + +app = FastAPI(lifespan=lifespan) + + +@app.post("/react", response_model=ReActResponse) +async def react(request: ReActRequest): + """Interact with the ReAct agent.""" + agent = agents["react"] + + # Send it to the agent + agent_response = agent.invoke({"input": request.prompt, "chat_history": []}) + answer = agent_response["output"] + response = ReActResponse(answer=answer) + return response + +@app.route("/health") +def health(): + """Perform a service health check.""" + return {"status": "ok"} + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s") + uvicorn.run(app, port=APP_PORT, host=APP_HOST) \ No newline at end of file diff --git a/react_agent/apispec.py b/react_agent/apispec.py new file mode 100644 index 0000000..553eee2 --- /dev/null +++ b/react_agent/apispec.py @@ -0,0 +1,17 @@ +from typing import Optional + +from pydantic import BaseModel + + +class ReActRequest(BaseModel): + """Request for ReAct endpoint.""" + + prompt: str + tools: list[str] = [] + + +class ReActResponse(BaseModel): + """Response for ReAct endpoint.""" + + answer: str | dict + tools_used: list[str] = [] \ No newline at end of file diff --git a/react_agent/common/llm.py b/react_agent/common/llm.py new file mode 100644 index 0000000..ba1a34a --- /dev/null +++ b/react_agent/common/llm.py @@ -0,0 +1,50 @@ +from typing import Any, List, Mapping, Optional + +import httpx +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain_core.messages import HumanMessage +from langchain_openai.chat_models import ChatOpenAI + +from tool_based_agent.constants import OPENAI_IGNORE_SSL, OPENAI_MODEL, OPENAI_URI + + +class CustomOpenAI(LLM): + """Class to define interaction with the hosted OpenAI instance at a specified URI without SSL verification.""" + + base_url: str + model: str + api_key: str + http_client: httpx.Client = None + temperature: float = 0.8 + + @property + def _llm_type(self) -> str: + return "custom" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ): + # Create the request + llm = ChatOpenAI( + base_url=self.base_url, model=self.model, api_key=self.api_key, http_client=self.http_client, temperature=self.temperature + ) + request = [HumanMessage(content=prompt)] + response = llm.invoke(request, stop=stop, **kwargs).content + return response + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {"base_url": self.base_url, "model": self.model} + + +verify = False if OPENAI_IGNORE_SSL else True +http_client = httpx.Client(verify=verify) + +chat_llm = ChatOpenAI(base_url=OPENAI_URI, model=OPENAI_MODEL, http_client=http_client, temperature=0, api_key="NONE") +completion_llm = CustomOpenAI(base_url=OPENAI_URI, model=OPENAI_MODEL, http_client=http_client, temperature=0, api_key="NONE") \ No newline at end of file diff --git a/react_agent/constants.py b/react_agent/constants.py new file mode 100644 index 0000000..0d62c46 --- /dev/null +++ b/react_agent/constants.py @@ -0,0 +1,56 @@ +import os +import pathlib +import yaml + +DIRECTORY_PATH = pathlib.Path(os.path.dirname(__file__)).parent + +with open(DIRECTORY_PATH / "config.yaml") as f: + CONFIG = yaml.load(f, yaml.SafeLoader) + +APP_HOST = os.environ.get("APP_HOST", "0.0.0.0") +APP_PORT = int(os.environ.get("APP_PORT", "2113")) + +OPENAI_URI = os.environ.get("OPENAI_URI", "http://localhost:11434/v1") +OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "mistral") +OPENAI_IGNORE_SSL = os.environ.get("OPENAI_IGNORE_SSL", False) + +ERROR_MESSAGE = "Unable to process request, please try again later." + +REACT_PROMPT = """Assistant is a large language model trained by OpenAI. + +Assistant is designed to be able to assist with a wide range of tasks. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. Assistant does not speak in character and uses character tools when asked to speak in character. + +Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics. + +Overall, Assistant is a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist. + +TOOLS: +------ + +Assistant has access to the following tools: + +{tools} + +To use a tool, please use the following format: + +``` +Thought: Do I need to use a tool? Yes +Action: the action to take, should be one of [{tool_names}] +Action Input: the input to the action +Observation: the result of the action +``` + +When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format: + +``` +Thought: Do I need to use a tool? No +Final Answer: [your response here] +``` + +Begin! + +Previous conversation history: +{chat_history} + +New input: {input} +{agent_scratchpad}""" \ No newline at end of file diff --git a/react_agent/tools/__init__.py b/react_agent/tools/__init__.py new file mode 100644 index 0000000..669b6a9 --- /dev/null +++ b/react_agent/tools/__init__.py @@ -0,0 +1,32 @@ +"""Define and import tools.""" + +import logging + +from react_agent.constants import CONFIG +from react_agent.tools.common import CommonTool +from react_agent.tools.math import ComputeSquareTool + +logger = logging.getLogger(__name__) + + +def create_common_tools(**kwargs): + """Create common tools.""" + logger.info("Creating Common Tools from config.yaml") + tool_configs = CONFIG.get("tools", []) + tools = [] + for tool_config in tool_configs: + tool = CommonTool(**tool_config, **kwargs) + tools.append(tool) + logger.info(f"Created {tool_config['name']} tool") + return tools + + +def import_tools(all_return_direct: bool = False, common_tools_kwargs: dict = {}): + """Gather tools.""" + base_tools = [ComputeSquareTool()] + common_tools = create_common_tools(**common_tools_kwargs) + tools = base_tools + common_tools + if all_return_direct: + for tool in tools: + tool.return_direct = True + return tools \ No newline at end of file diff --git a/react_agent/tools/common.py b/react_agent/tools/common.py new file mode 100644 index 0000000..745ddc0 --- /dev/null +++ b/react_agent/tools/common.py @@ -0,0 +1,109 @@ +import json +import logging +import re +from typing import Optional, Type + +import requests +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field + +from react_agent.constants import ERROR_MESSAGE + +logger = logging.getLogger(__name__) + + +# From https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary +class DotDict(dict): + """Use dot notation to access dictionary attributes.""" + + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + +class CommonToolInput(BaseModel): + """Tool input structure.""" + + prompt: str = Field(description="should be a prompt for an AI") + +class CommonTool(BaseTool): + """Tool for interacting with external API instances.""" + + name: str = Field(description="Tool name") + description: str = Field(description="Tool description") + args_schema: Type[BaseModel] = CommonToolInput + response_format: str + url: str + config: dict + + # TODO: Implement validation + + def request(self, prompt: str): + """Send a request to the API endpoint.""" + method = self.config["method"] + headers = self.config["headers"] + body = self.config["body"].copy() + for key, value in body.items(): + if isinstance(value, str) and re.match(r"{{prompt}}", value, flags=re.MULTILINE): + body[key] = re.sub(r"{{prompt}}", prompt, value, flags=re.MULTILINE) + url_response = requests.request(method=method, url=self.url, headers=headers, json=body) + if str(url_response.status_code)[0] != "2": + url_response.raise_for_status() + # Parse the response + url_response_json = url_response.json() + url_response_dotdict = DotDict(url_response_json) + + # Get the answer + response_loc = self.config["responseParser"] + response_loc = ".".join(response_loc.split(".")[1:]) if response_loc.startswith("json") else response_loc + response = url_response_dotdict.__getattr__(response_loc).strip() + + # Create metadata + metadata = {} + metadata_fields = self.config.get("responseMetadata", None) + if metadata_fields: + for field in metadata_fields: + name = field["name"] + loc = field["loc"] + loc = ".".join(loc.split(".")[1:]) if loc.startswith("json") else loc + metadata[name] = url_response_dotdict.__getattr__(loc) + return response, metadata + + def format_response(self, response: str, metadata: dict): + """Format the response.""" + response_format = self.config["responseFormat"][self.response_format] + if self.response_format == "json": + formatted_response = {} + available_fields = {"response": response} | metadata + for key in response_format: + value = available_fields.get(key, None) + formatted_response[key] = value + formatted_response = json.dumps(formatted_response) + else: + formatted_response = response_format + available_fields = {"response": response} | metadata + for key, value in available_fields.items(): + if isinstance(value, dict): + value = json.dumps(value) + else: + value = str(value) + formatted_response = formatted_response.replace("{{" + key + "}}", value) + + return formatted_response + + def _run(self, prompt: str, run_manager: Optional[CallbackManagerForToolRun] = None): + """Use the tool.""" + try: + response, metadata = self.request(prompt=prompt) + return self.format_response(response, metadata) + except Exception as e: + logger.error(e) + raise e # TODO: remove + return ERROR_MESSAGE + + async def _arun(self, prompt: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str: + """Use the tool asynchronously.""" + raise NotImplementedError("Tool does not support async") \ No newline at end of file diff --git a/react_agent/tools/math.py b/react_agent/tools/math.py new file mode 100644 index 0000000..cc29843 --- /dev/null +++ b/react_agent/tools/math.py @@ -0,0 +1,35 @@ +import logging +import time +from typing import Optional, Type + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class ComputeSquareInput(BaseModel): + """Compute square tool input structure.""" + + number: int = Field(description="number to square") + + +class ComputeSquareTool(BaseTool): + """Tool for computing the square of a number.""" + + name = "compute_square_tool" + description = "Compute the square of a number" + args_schema: Type[BaseModel] = ComputeSquareInput + + def _run(self, number: str, run_manager: Optional[CallbackManagerForToolRun] = None): + """Use the tool.""" + time.sleep(1) # TODO: Remove this + return float(number) ** 2 + + async def _arun(self, number: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str: + """Use the tool asynchronously.""" + raise NotImplementedError("Tool does not support async") \ No newline at end of file diff --git a/sample.env b/sample.env new file mode 100644 index 0000000..250ba47 --- /dev/null +++ b/sample.env @@ -0,0 +1,7 @@ +APP_HOST="0.0.0.0" +APP_PORT=2113 +OPENAI_URI="https://localhost:11434/v1" +OPENAI_MODEL="mistral" +OPENAI_IGNORE_SSL="True" +MLFLOW_TRACKING_URI="http://localhost:5000" +MLFLOW_EXPERIMENT_NAME="Tool Based Agent" diff --git a/streamlit/.streamlit/config.toml b/streamlit/.streamlit/config.toml new file mode 100644 index 0000000..216cdf0 --- /dev/null +++ b/streamlit/.streamlit/config.toml @@ -0,0 +1,2 @@ +[theme] +base="light" \ No newline at end of file diff --git a/streamlit/intro.py b/streamlit/intro.py new file mode 100644 index 0000000..159ad7a --- /dev/null +++ b/streamlit/intro.py @@ -0,0 +1,8 @@ +import streamlit as st +from util import HIDE_MENU, read_markdown_file + +st.set_page_config(page_title="Introduction") + +st.markdown(HIDE_MENU, unsafe_allow_html=True) +st.image("webapp/static/images/red_hat_banner_dark.png") +st.markdown(read_markdown_file("webapp/static/markdown/introduction.md")) \ No newline at end of file diff --git a/streamlit/pages/app.py b/streamlit/pages/app.py new file mode 100644 index 0000000..d5f0e6a --- /dev/null +++ b/streamlit/pages/app.py @@ -0,0 +1,30 @@ +import json +import os +import requests +import streamlit as st +from util import HIDE_MENU + +st.set_page_config(page_title="App") + +st.markdown(HIDE_MENU, unsafe_allow_html=True) +st.image("webapp/static/images/red_hat_banner_dark.png") + +host = os.environ.get("APP_HOST", "0.0.0.0") +port = os.environ.get("APP_PORT", "2113") +api_root = f"http://{host}:{port}" + +prompt = st.text_area("Prompt:") +url = api_root + "/react" +ask = st.button("Ask") + +if ask: + st.subheader("Answer") + with st.spinner("Generating answer..."): + request_json = {"prompt": prompt} + response = requests.post(url=url, json=request_json) + if response.status_code != 200: + st.error("Request failed. Is the API running?") + st.stop() + response_json = response.json() + answer = response_json["answer"] + st.write(answer) \ No newline at end of file diff --git a/streamlit/static/images/red_hat_banner_dark.png b/streamlit/static/images/red_hat_banner_dark.png new file mode 100644 index 0000000..3340884 Binary files /dev/null and b/streamlit/static/images/red_hat_banner_dark.png differ diff --git a/streamlit/static/markdown/introduction.md b/streamlit/static/markdown/introduction.md new file mode 100644 index 0000000..3e86fee --- /dev/null +++ b/streamlit/static/markdown/introduction.md @@ -0,0 +1,5 @@ +## Introduction + +## Contact + +Please report any issues, comments, or concerns with the application to the team. \ No newline at end of file diff --git a/streamlit/util.py b/streamlit/util.py new file mode 100644 index 0000000..188e08a --- /dev/null +++ b/streamlit/util.py @@ -0,0 +1,14 @@ +from pathlib import Path + +HIDE_MENU = """ + +""" + + +def read_markdown_file(md_file) -> str: + """Given a path to a markdown file, it reads the file.""" + return Path(md_file).read_text() \ No newline at end of file