Skip to content

Commit

Permalink
fix: generate_response_stream accept BaseMessageContent
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Aug 21, 2024
1 parent 3d72a8e commit 3a94afa
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 45 deletions.
15 changes: 10 additions & 5 deletions src/codeinterpreterapi/brain/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ def prepare_input(self, input: Input):
del input['intermediate_steps']

# set agent_results
input['agent_executor_result'] = self.agent_executor_result
input['plan_list'] = self.plan_list
input['supervisor_result'] = self.supervisor_result
input['thought_result'] = self.thought_result
input['crew_result'] = self.crew_result
if isinstance(input, list):
last_input = input[-1]
else:
last_input = input

last_input['agent_executor_result'] = self.agent_executor_result
last_input['plan_list'] = self.plan_list
last_input['supervisor_result'] = self.supervisor_result
last_input['thought_result'] = self.thought_result
last_input['crew_result'] = self.crew_result
return input

def run(self, input: Input, runnable_config: Optional[RunnableConfig] = None) -> Output:
Expand Down
23 changes: 17 additions & 6 deletions src/codeinterpreterapi/crew/crew_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Union

from crewai import Agent, Crew, Task

Expand Down Expand Up @@ -69,15 +69,26 @@ def create_task(self, final_goal: str, plan: CodeInterpreterPlan) -> Task:
print("WARN: no task found plan.agent_name=", plan.agent_name)
return None

def run(self, inputs: Dict, plan_list: CodeInterpreterPlanList):
def run(self, inputs: Union[Dict, List[Dict]], plan_list: CodeInterpreterPlanList):
# update task description
if plan_list is None:
return {}
tasks = self.create_tasks(final_goal=inputs["input"], plan_list=plan_list)
crew_inputs = {"input": inputs.get("input", "")}

if isinstance(inputs, list):
last_input = inputs[-1]
else:
last_input = inputs
if "input" in inputs:
final_goal = last_input["input"]
elif "content" in inputs:
final_goal = last_input["content"]
else:
final_goal = "ユーザの指示に従って最終的な回答をしてください"

tasks = self.create_tasks(final_goal=final_goal, plan_list=plan_list)
my_crew = Crew(agents=self.agents, tasks=tasks)
print("CodeInterpreterCrew.kickoff() crew_inputs=", crew_inputs)
result = my_crew.kickoff(inputs=crew_inputs)
print("CodeInterpreterCrew.kickoff() crew_inputs=", last_input)
result = my_crew.kickoff(inputs=last_input)
return result


Expand Down
94 changes: 60 additions & 34 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from codeinterpreterapi.brain.brain import CodeInterpreterBrain
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.callbacks.markdown.callbacks import MarkdownFileCallbackHandler
from codeinterpreterapi.chains import aremove_download_link
from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory
from codeinterpreterapi.config import settings
from codeinterpreterapi.llm.llm import CodeInterpreterLlm
from codeinterpreterapi.schema import CodeInterpreterResponse, File, SessionStatus, UserRequest
from codeinterpreterapi.utils.multi_converter import MultiConverter

BaseMessageContent = Union[str, List[Union[str, Dict]]]


def _handle_deprecated_kwargs(kwargs: dict) -> None:
settings.MODEL = kwargs.get("model", settings.MODEL)
Expand Down Expand Up @@ -105,7 +106,8 @@ def __init__(
configurable = {"session_id": init_session_id} # TODO: set session_id
runnable_config = RunnableConfig(
configurable=configurable,
callbacks=[AgentCallbackHandler(self._output_handler), MarkdownFileCallbackHandler("langchain_log.md")],
callbacks=[],
# callbacks=[AgentCallbackHandler(self._output_handler), MarkdownFileCallbackHandler("langchain_log.md")],
)

# ci_params = {}
Expand Down Expand Up @@ -199,34 +201,61 @@ def _history_backend(self) -> BaseChatMessageHistory:
)
)

def _input_handler(self, request: UserRequest) -> None:
def _input_message_prepare(self, request: UserRequest) -> None:
# set return input_message
if isinstance(request.content, str):
input_message = {"input": request.content, "agent_scratchpad": ""}
else:
input_message = request.content
return input_message

def _input_handler_common(self, request: UserRequest, add_content_str: str) -> None:
"""Callback function to handle user input."""
# TODO: variables as context to the agent
# TODO: current files as context to the agent
if not request.files:
return
return self._input_message_prepare(request)
if not request.content:
request.content = "I uploaded, just text me back and confirm that you got the file(s)."
assert isinstance(request.content, str), "TODO: implement image support"
request.content += "\n**The user uploaded the following files: **\n"
add_content_str = "I uploaded, just text me back and confirm that you got the file(s).\n" + add_content_str
assert isinstance(request.content, BaseMessageContent) # "TODO: implement image support"

# set request.content
if isinstance(request.content, str):
request.content += add_content_str
elif isinstance(request.content, list):
last_content = request.content[-1]
if isinstance(last_content, str):
last_content += add_content_str
else:
pass
# TODO impl it.
# last_content["input"] += add_content_str
else:
# Dict
pass
# TODO impl it.
# request.content["input"] += add_content_str
return self._input_message_prepare(request)

def _input_handler(self, request: UserRequest) -> None:
"""Callback function to handle user input."""
add_content_str = "\n**The user uploaded the following files: **\n"
for file in request.files:
self.input_files.append(file)
request.content += f"[Attachment: {file.name}]\n"
add_content_str += f"[Attachment: {file.name}]\n"
self.ci_params.codebox.upload(file.name, file.content)
request.content += "**File(s) are now available in the cwd. **\n"
add_content_str += "**File(s) are now available in the cwd. **\n"
return self._input_handler_common(request, add_content_str)

async def _ainput_handler(self, request: UserRequest) -> None:
# TODO: variables as context to the agent
# TODO: current files as context to the agent
if not request.files:
return
if not request.content:
request.content = "I uploaded, just text me back and confirm that you got the file(s)."
assert isinstance(request.content, str), "TODO: implement image support"
request.content += "\n**The user uploaded the following files: **\n"
"""Callback function to handle user input."""
add_content_str = "\n**The user uploaded the following files: **\n"
for file in request.files:
self.input_files.append(file)
request.content += f"[Attachment: {file.name}]\n"
add_content_str += f"[Attachment: {file.name}]\n"
await self.ci_params.codebox.aupload(file.name, file.content)
request.content += "**File(s) are now available in the cwd. **\n"
add_content_str += "**File(s) are now available in the cwd. **\n"
return self._input_handler_common(request, add_content_str)

def _output_handler_pre(self, response: Any) -> str:
print("_output_handler_pre response(type)=", type(response))
Expand Down Expand Up @@ -304,7 +333,7 @@ async def _aoutput_handler(self, response: str) -> CodeInterpreterResponse:

def generate_response_sync(
self,
user_msg: str,
user_msg: BaseMessageContent,
files: list[File] = [],
) -> CodeInterpreterResponse:
print("DEPRECATION WARNING: Use generate_response for sync generation.\n")
Expand All @@ -315,16 +344,16 @@ def generate_response_sync(

def generate_response(
self,
user_msg: str,
files: list[File] = [],
user_msg: BaseMessageContent,
files: list[File] = None,
) -> CodeInterpreterResponse:
"""Generate a Code Interpreter response based on the user's input."""
if files is None:
files = []
user_request = UserRequest(content=user_msg, files=files)
try:
self._input_handler(user_request)
input_message = self._input_handler(user_request)
print("generate_response type(user_request.content)=", type(user_request.content))
agent_scratchpad = ""
input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad}
# ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #=======
response = self.brain.invoke(input=input_message)
# ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #=======
Expand All @@ -348,15 +377,15 @@ def generate_response(

async def agenerate_response(
self,
user_msg: str,
user_msg: BaseMessageContent,
files: list[File] = None,
) -> CodeInterpreterResponse:
"""Generate a Code Interpreter response based on the user's input."""
if files is None:
files = None
files = []
user_request = UserRequest(content=user_msg, files=files)
try:
await self._ainput_handler(user_request)
input_message = await self._ainput_handler(user_request)
agent_scratchpad = ""
input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad}
# ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #=======
Expand All @@ -380,18 +409,15 @@ async def agenerate_response(

def generate_response_stream(
self,
user_msg: str,
user_msg: BaseMessageContent,
files: list[File] = None,
) -> Iterator[str]:
"""Generate a Code Interpreter response based on the user's input."""
if files is None:
files = []
user_request = UserRequest(content=user_msg, files=files)
try:
self._input_handler(user_request)
agent_scratchpad = ""
input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad}
print("generate_response_stream type(user_request.content)=", user_request.content)
input_message = self._input_handler(user_request)
# ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #=======
response_stream = self.brain.stream(input=input_message)
# ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #=======
Expand All @@ -417,7 +443,7 @@ def generate_response_stream(

async def agenerate_response_stream(
self,
user_msg: str,
user_msg: BaseMessageContent,
files: list[File] = None,
) -> AsyncGenerator[CodeInterpreterResponse, None]:
"""Generate a Code Interpreter response based on the user's input."""
Expand Down

0 comments on commit 3a94afa

Please sign in to comment.