diff --git a/src/codeinterpreterapi/planners/planners.py b/src/codeinterpreterapi/planners/planners.py index 7433845a..ec334c5d 100644 --- a/src/codeinterpreterapi/planners/planners.py +++ b/src/codeinterpreterapi/planners/planners.py @@ -2,6 +2,9 @@ from langchain.agents import AgentExecutor from langchain.agents.agent import RunnableAgent +from langchain.output_parsers import OutputFixingParser +from langchain.schema import AIMessage, Generation +from langchain_core.output_parsers import PydanticOutputParser from langchain_core.runnables import Runnable from pydantic import BaseModel, Field @@ -32,6 +35,43 @@ class CodeInterpreterPlanList(BaseModel): ) +class Metadata(BaseModel): + id: str = Field(description="The ID of the content") + name: str = Field(description="The name of the content") + type: str = Field(description="The type of the content") + + +class ResponseMetadata(BaseModel): + id: str = Field(description="The ID of the response") + model: str = Field(description="The model used for the response") + stop_reason: str = Field(description="The reason for stopping") + stop_sequence: str | None = Field(description="The stop sequence, if any") + usage: dict = Field(description="The token usage information") + + +class CodeInterpreterPlanOutput(BaseModel): + agent_task_list: List[CodeInterpreterPlan] = Field(description="List of agent tasks") + metadata: Metadata = Field(description="Metadata of the content") + response_metadata: ResponseMetadata = Field(description="Metadata of the response") + + +class CustomPydanticOutputParser(PydanticOutputParser): + def parse(self, text) -> CodeInterpreterPlanList: + print("parse text=", text) + input_data = self.preprocess_input(text) + return super().parse(input_data) + + def preprocess_input(self, input_data) -> str: + if isinstance(input_data, AIMessage): + return input_data.content + elif isinstance(input_data, Generation): + return input_data.text + elif isinstance(input_data, str): + return input_data + else: + raise ValueError(f"Unexpected input type: {type(input_data)}") + + class CodeInterpreterPlanner: @staticmethod def choose_planner(ci_params: CodeInterpreterParams) -> Union[Runnable, AgentExecutor]: @@ -52,19 +92,14 @@ def choose_planner(ci_params: CodeInterpreterParams) -> Union[Runnable, AgentExe # structured_llm structured_llm = ci_params.llm.bind_tools(tools=[CodeInterpreterPlanList]) + # parser + parser = CustomPydanticOutputParser(pydantic_object=CodeInterpreterPlanList) + new_parser = OutputFixingParser.from_llm(parser=parser, llm=ci_params.llm) # runnable - runnable = ( - create_complement_input(prompt) | prompt | structured_llm - # | StrOutputParser() - # | PlannerSingleOutputParser() - ) - # runnable = assign_runnable_history(runnable, ci_params.runnable_config) + runnable = create_complement_input(prompt) | prompt | structured_llm | new_parser # agent - # planner_agent = create_react_agent(ci_params.llm_fast, ci_params.tools, prompt) - print("choose_planner prompt.input_variables=", prompt.input_variables) - remapped_inputs = create_complement_input(prompt).invoke({}) - agent = RunnableAgent(runnable=runnable, input_keys=list(remapped_inputs.keys())) + agent = RunnableAgent(runnable=runnable, input_keys=list(prompt.input_variables)) # return executor or runnable return_as_executor = False diff --git a/src/codeinterpreterapi/supervisors/supervisors.py b/src/codeinterpreterapi/supervisors/supervisors.py index 56734250..0e80cdd6 100644 --- a/src/codeinterpreterapi/supervisors/supervisors.py +++ b/src/codeinterpreterapi/supervisors/supervisors.py @@ -140,7 +140,7 @@ def route(next_action: str) -> str: llm_with_structured_output = ci_params.llm.with_structured_output(RouteSchema) # supervisor_agent = prompt | ci_params.llm | output_parser # supervisor_agent_for_executor = prompt | ci_params.llm - supervisor_agent_structured_output = prompt | llm_with_structured_output + supervisor_agent_structured_output = planner | prompt | llm_with_structured_output ci_params.supervisor_agent = supervisor_agent_structured_output