diff --git a/README.md b/README.md index 98479ff..d7e8b58 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,6 @@ export OPENAI_API_KEY="..." export QDRANT_KEY="..." export QDRANT_URL="..." export PG_URL="..." -export ASSISTANT_ID="..." (used with legacy impl) ``` For the slack bot, you also need these: @@ -76,10 +75,17 @@ Once the process completes, you should have a meta data and vector embeddings in ## Using the assistant +Web based: + ``` streamlit run agent-ui.py ``` +CLI based: + +``` +python agent-cli.py [-f QUESTION.txt] +``` [...] #### Results diff --git a/agent-cli.py b/agent-cli.py new file mode 100644 index 0000000..defc90b --- /dev/null +++ b/agent-cli.py @@ -0,0 +1,148 @@ +import sys + +from prompt_toolkit import PromptSession +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.styles import Style + +from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( + AgentTokenBufferMemory, +) + +from langchain.callbacks.base import AsyncCallbackHandler +from langchain.schema import LLMResult +from langchain_core.messages import BaseMessage +from uuid import UUID +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union + +from core.costs import TokenCostProcess, CostCalcAsyncHandler + +from core.agent import agent_executor, agent_llm + +import argparse + +# --- + +style = Style.from_dict( + { + "completion-menu.completion": "bg:#008888 #ffffff", + "completion-menu.completion.current": "bg:#00aaaa #000000", + "scrollbar.background": "bg:#88aaaa", + "scrollbar.button": "bg:#222222", + } +) + +class CLIAsyncHandler(AsyncCallbackHandler): + + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: + pass + + async def on_llm_new_token(self, token: str, **kwargs) -> None: + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + pass + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Run when tool starts running.""" + tool_name = serialized["name"] + print(f"{tool_name} : {input_str}") + + async def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + pass + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + print("Thinking ...") + + +def main(args): + + memory = AgentTokenBufferMemory(llm=agent_llm) + + session = PromptSession( + lexer=None, completer=None, style=style + ) + + if(args.filename == None): + + # enter QA loop + + print("How can I help you?") + + while True: + try: + prompt_text = session.prompt("> ") + except KeyboardInterrupt: + continue # Control-C pressed. Try again. + except EOFError: + break # Control-D pressed. + + # request chat completion + try: + response_handle = agent_executor( + {"input": prompt_text, "history": memory.buffer}, + callbacks=[CLIAsyncHandler()], + include_run_info=True, + ) + + memory.save_context({"input": prompt_text}, response_handle) + + print(f"\n{response_handle['output']}") + + except Exception as e: + print("Failed to call Openai API: ", str(e)) + + print("GoodBye!") + + else: + + # enter one-shot prompting from file + + with open(args.filename) as f: + prompt = f.read() + prompt_text = prompt.replace('\n', ' ').replace('\r', '') + + response_handle = agent_executor( + {"input": prompt_text, "history": []}, + callbacks=[CLIAsyncHandler()], + include_run_info=True, + ) + + print(f"\n{response_handle['output']}") + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Camel Quickstart Assistant') + parser.add_argument('-f', '--filename', help='The input file that will be taken as a prompt', required=False) + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/assistant.py b/assistant.py deleted file mode 100644 index 9e56cf2..0000000 --- a/assistant.py +++ /dev/null @@ -1,470 +0,0 @@ -from openai import OpenAI - -from typing import Optional, Type, Any - -from util.utils import show_json, as_json -import time -import jmespath -import json -import httpx - -from qdrant_client import QdrantClient -from langchain.prompts import PromptTemplate -from langchain.docstore.document import Document -from langchain_community.vectorstores.utils import maximal_marginal_relevance - -from conf.constants import * - -from tenacity import ( - retry, - stop_after_attempt, - wait_random_exponential, -) - -from statemachine import State -from statemachine import StateMachine - -import argparse -import cohere -import numpy as np - -import streamlit as st -from abc import ABC, abstractmethod - -# --- - -def get_response(client, thread): - return client.beta.threads.messages.list(thread_id=thread.id, order="asc") - -def pretty_print(messages): - print("# Messages") - for m in messages: - print(f"{m.role}: {m.content[0].text.value}") - print() - -# primitive wait condition for API requests, needs improvement -def wait_on_run(client, run, thread): - while run.status == "queued" or run.status == "in_progress": - run = client.beta.threads.runs.retrieve( - thread_id=thread.id, - run_id=run.id, - ) - print("Thinking ... ", run.status) - #show_json(run) - time.sleep(0.5) - return run - -# fetch the call arguments from an assistant callback -def get_call_arguments(run): - #show_json(run) - tool_calls = jmespath.search( - "required_action.submit_tool_outputs.tool_calls", - as_json(run) - ) - - call_arguments = [] - for call in tool_calls: - id = jmespath.search("id", call) - arguments = jmespath.search("function.arguments", call) - call_arguments.append( - { - "call_id": id, - "call_arguments":json.loads(arguments) - } - ) - return call_arguments - -# search local storage for documentation related to componment -def fetch_docs(entities): - print("Fetching docs for query: ", entities) - - query_results = query_qdrant(entities, collection_name="camel_docs") - num_matches = len(query_results) - - # print("First glance matches:") - # for i, article in enumerate(query_results): - # print(f'{i + 1}. {article.payload["filename"]} (Score: {round(article.score, 3)})') - - if num_matches > 0: - - docs = [] - for _, article in enumerate(query_results): - with open(article.payload["filename"]) as f: - docs.append(f.read()) - - # apply reranking - co = cohere.Client(os.environ['COHERE_KEY']) - rerank_hits = co.rerank( - model = 'rerank-english-v2.0', - query = entities, - documents = docs, - top_n = 3 - ) - - print("Reranked matches: ") - for hit in rerank_hits: - orig_result = query_results[hit.index] - print(f'{orig_result.payload["filename"]} (Score: {round(hit.relevance_score, 3)})') - - # TODO: This is wrong and needs to be fixed. it must consider the rerank - doc = query_results[0] - with open(doc.payload["filename"]) as f: - contents = f.read() - return contents - else: - return "No matching file found for "+entities - -def document_from_scored_point( - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, - ) -> Document: - return Document( - page_content=scored_point.payload.get(content_payload_key), - metadata=scored_point.payload.get(metadata_payload_key) or {}, - ) - -def fetch_and_rerank(entities, collections, feedback): - - # compute the query embedding - embedding = get_embedding(text=entities) - - # lookup across multiple vector store - results = [] - for name in collections: - intermittent_results = query_qdrant(embedding=embedding, collection_name=name, top_k=15) - results.extend(intermittent_results) - - ## The MMR impl used with retriever(search_type='mmr') - embeddings = [result.vector for result in results] - - mmr_selected = maximal_marginal_relevance( - np.array(embedding), embeddings, k=5, lambda_mult=0.8 - ) - - mmr_results = [ - ( - document_from_scored_point( - scored_point=results[i], - content_payload_key="page_content", - metadata_payload_key="metadata" - ), - results[i].score, - ) - for i in mmr_selected - ] - - response_documents = [] - for i, article in enumerate(mmr_results): - doc = article[0] - score = article[1] - feedback.print(str(round(score, 3))+ ": "+ doc.metadata["page_number"]) - response_documents.append(doc) - - return response_documents - -def query_qdrant(query, collection_name, top_k=5): - - embedded_query = get_embedding(text=query) - - qdrant_client = create_qdrant_client() - query_results = qdrant_client.search( - collection_name=collection_name, - query_vector=(embedded_query), - limit=top_k, - ) - - return query_results - -def query_qdrant(embedding, collection_name, top_k=5): - - qdrant_client = create_qdrant_client() - - results = create_qdrant_client().search( - collection_name=collection_name, - query_vector=(embedding), - with_payload=True, - with_vectors=True, - limit=top_k, - ) - - return results - -@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) -def get_embedding(text, model="text-embedding-ada-002"): - start = time.time() - text = text.replace("\n", " ") - resp = create_openai_client().embeddings.create(input = [text], model=model) - print("Embedding ms: ", time.time() - start) - return resp.data[0].embedding - -def create_openai_client(): - client = OpenAI( - timeout=httpx.Timeout( - 10.0, read=8.0, write=3.0, connect=3.0 - ) - ) - return client - -def create_qdrant_client(): - client = QdrantClient( - QDRANT_URL, - api_key=QDRANT_KEY, - ) - return client - - -# rewrite a question using the chat API -@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(1)) -def rewrite_question(openai_client, text): - - template = PromptTemplate.from_template( - """ - Rephrase the following text: - - Text: \"\"\"{text}\"\"\" - - """ - ) - - response = openai_client.chat.completions.create( - model="gpt-3.5-turbo-1106", - messages=[ - {"role": "system", "content": "You are a service used to rewrite text"}, - {"role": "user", "content": template.format(text=text)} - ] - ) - - return response.choices[0].message.content -# --- - -class StatusStrategy(ABC): - @abstractmethod - def print(self, message) -> str: - pass - - @abstractmethod - def set_visible(self, is_visible): - pass - - @abstractmethod - def set_tagline(self, tagline): - pass - -class LoggingStatus(StatusStrategy): - def print(self, message): - print(message) - - def set_visible(self, is_visible): - pass - - def set_tagline(self, tagline): - pass - -class StreamlitStatus(StatusStrategy): - - def __init__(self, st_callback): - self.st_callback = st_callback - self.st_status = None - - def print(self, message): - if(self.st_status is not None): - self.st_status.write(message) - - def set_visible(self, is_visible): - if (is_visible): - self.st_status = self.st_callback.status("Thinking ...") - else: - self.st_status.update(label="Completed!", state="complete", expanded=False) - - def set_tagline(self, tagline): - if(self.st_status is not None): - self.st_status.update(label=tagline) - -class Assistant(StateMachine): - "Assistant state machine" - - feedback = LoggingStatus() - status = None - - prompt = State(initial=True) - running = State() - lookup = State() - answered = State(final=True) - - kickoff = prompt.to(running) - request_docs = running.to(lookup) - docs_supplied = lookup.to(running) - resolved = running.to(answered) - - def __init__(self, st_callback=None): - - # streamlit callback, if present - self.st_callback = st_callback - - if(self.st_callback is not None): - self.feedback = StreamlitStatus(self.st_callback) - - # internal states - self.prompt_text = None - self.thread = None - self.run = None - - self.openai_client = create_openai_client() - self.lookups_total = 0 - - super().__init__() - - def display_message(self, message): - if self.st_callback is not None: - st.session_state.messages.append({"role": "assistant", "content": message}) - with st.chat_message("assistant"): - st.markdown(message, unsafe_allow_html=True) - - def display_status(self, message): - if(self.status is None): - print(message) - else: - self.status.write('Lookup additional information ...') - - def on_exit_prompt(self, text): - self.lookups_total = 0 - self.prompt_text = text - - # clear screen - if(self.st_callback is not None): - self.st_callback.empty() - - # display status widget - self.feedback.set_visible(True) - - # start a new thread and delete old ones if exist - if(self.thread is not None): - self.openai_client.beta.threads.delete(self.thread.id) - - self.thread = self.openai_client.beta.threads.create() - self.feedback.print("New Thread: " + str(self.thread.id)) - - # Add initial message - #improved_question = rewrite_question(openai_client = self.openai_client, text=text) - #print("Improved question: \n", improved_question) - - message = self.openai_client.beta.threads.messages.create( - thread_id=self.thread.id, - role="user", - content=text, - ) - - # create a run - self.run = self.openai_client.beta.threads.runs.create( - thread_id=self.thread.id, - assistant_id=ASSISTANT_ID, - ) - - def on_enter_lookup(self): - - self.feedback.set_tagline("Working ...") - self.feedback.print("Lookup additional information ...") - - self.lookups_total = self.lookups_total +1 - - # take call arguments and invoke lookup - args = get_call_arguments(self.run) - outputs=[] - - for a in args: - entity_args = a["call_arguments"]["entities"] - self.feedback.print("Keywords: " + ' | '.join(entity_args) ) - - keywords = ' '.join(entity_args) - - # we may end up with no keywrods at all - if(len(keywords)==0 or keywords.isspace()): - outputs.append( - { - "tool_call_id": a["call_id"], - "output": "'No additional information found.'" - } - ) - continue - - - #doc = fetch_pdf_pages(entities=keywords, feedback=self.feedback) - docs = fetch_and_rerank( - entities=keywords, - collections=["rhaetor.github.io_2", "rhaetor.github.io_components_2"], - feedback=self.feedback - ) - - response_content = [] - response_content = [str(d.page_content) for d in docs] - - outputs.append( - { - "tool_call_id": a["call_id"], - "output": "'"+(' '.join(response_content))+"'" - } - ) - - # submit lookup results (aka tool outputs) - self.run = self.openai_client.beta.threads.runs.submit_tool_outputs( - thread_id=self.thread.id, - run_id=self.run.id, - tool_outputs=outputs - ) - - self.docs_supplied() - self.feedback.print("Processing new information ...") - - # starting a thinking loop - def on_enter_running(self): - - self.feedback.set_tagline("Thinking ...") - - self.run = wait_on_run(self.openai_client, self.run, self.thread) - - if(self.run.status == "requires_action"): - self.request_docs() - elif(self.run.status == "completed"): - self.resolved() - else: - self.feedback.print("Illegal state: " + self.run.status) - print(self.run.last_error) - if (self.st_callback is not None): - self.st_callback.error(self.run.last_error) - - # the assistant has resolved the question - def on_enter_answered(self): - - # thread complete, show answer - assistant_response = get_response(self.openai_client, self.thread) - for m in assistant_response: - if(m.role == "assistant"): - self.display_message(m.content[0].text.value) - - pretty_print(assistant_response) - - # delete the thread - self.openai_client.beta.threads.delete(self.thread.id) - self.feedback.print("Deleted Thread: " + str(self.thread.id)) - self.feedback.set_visible(False) - self.thread = None - -# -- - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Camel Support Assistant') - parser.add_argument('-f', '--filename', help='The input file that will be taken as a prompt', required=False) - args = parser.parse_args() - - if(args.filename == None): - prompt = input("Prompt: ") - else: - with open(args.filename) as f: - prompt = f.read() - prompt = prompt.replace('\n', ' ').replace('\r', '') - - - sm = Assistant() - sm.kickoff(prompt) - diff --git a/conf/constants.py b/conf/constants.py index a226dd6..da51804 100644 --- a/conf/constants.py +++ b/conf/constants.py @@ -6,13 +6,6 @@ TEXT_DIR = "./data/text/" PROCESSED_DIR = "./data/processed/" -ASSISTANT_ID = None -try: - ASSISTANT_ID = os.environ['ASSISTANT_ID'] -except KeyError: - print('ASSISTANT_ID is missing!') - sys.exit() - QDRANT_KEY = None try: QDRANT_KEY = os.environ['QDRANT_KEY'] diff --git a/core/slack.py b/core/slack.py index 962e863..fd23aa1 100644 --- a/core/slack.py +++ b/core/slack.py @@ -241,8 +241,6 @@ def on_enter_retired(self): class SlackAsyncHandler(AsyncCallbackHandler): - websocketaction: str = "appendtext" - def __init__( self, feedback): self.feedback = feedback diff --git a/requirements.txt b/requirements.txt index 56802ba..c021167 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,4 +32,5 @@ aiohttp httpx apscheduler lancedb -python-dotenv \ No newline at end of file +python-dotenv +prompt_toolkit \ No newline at end of file