Skip to content

Commit

Permalink
fix: update planners.py for PydanticOutputParser
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Jul 29, 2024
1 parent 59052ea commit 4c7c543
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
55 changes: 45 additions & 10 deletions src/codeinterpreterapi/planners/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from langchain.agents import AgentExecutor
from langchain.agents.agent import RunnableAgent
from langchain.output_parsers import OutputFixingParser
from langchain.schema import AIMessage, Generation
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -32,6 +35,43 @@ class CodeInterpreterPlanList(BaseModel):
)


class Metadata(BaseModel):
id: str = Field(description="The ID of the content")
name: str = Field(description="The name of the content")
type: str = Field(description="The type of the content")


class ResponseMetadata(BaseModel):
id: str = Field(description="The ID of the response")
model: str = Field(description="The model used for the response")
stop_reason: str = Field(description="The reason for stopping")
stop_sequence: str | None = Field(description="The stop sequence, if any")
usage: dict = Field(description="The token usage information")


class CodeInterpreterPlanOutput(BaseModel):
agent_task_list: List[CodeInterpreterPlan] = Field(description="List of agent tasks")
metadata: Metadata = Field(description="Metadata of the content")
response_metadata: ResponseMetadata = Field(description="Metadata of the response")


class CustomPydanticOutputParser(PydanticOutputParser):
def parse(self, text) -> CodeInterpreterPlanList:
print("parse text=", text)
input_data = self.preprocess_input(text)
return super().parse(input_data)

def preprocess_input(self, input_data) -> str:
if isinstance(input_data, AIMessage):
return input_data.content
elif isinstance(input_data, Generation):
return input_data.text
elif isinstance(input_data, str):
return input_data
else:
raise ValueError(f"Unexpected input type: {type(input_data)}")


class CodeInterpreterPlanner:
@staticmethod
def choose_planner(ci_params: CodeInterpreterParams) -> Union[Runnable, AgentExecutor]:
Expand All @@ -52,19 +92,14 @@ def choose_planner(ci_params: CodeInterpreterParams) -> Union[Runnable, AgentExe
# structured_llm
structured_llm = ci_params.llm.bind_tools(tools=[CodeInterpreterPlanList])

# parser
parser = CustomPydanticOutputParser(pydantic_object=CodeInterpreterPlanList)
new_parser = OutputFixingParser.from_llm(parser=parser, llm=ci_params.llm)
# runnable
runnable = (
create_complement_input(prompt) | prompt | structured_llm
# | StrOutputParser()
# | PlannerSingleOutputParser()
)
# runnable = assign_runnable_history(runnable, ci_params.runnable_config)
runnable = create_complement_input(prompt) | prompt | structured_llm | new_parser

# agent
# planner_agent = create_react_agent(ci_params.llm_fast, ci_params.tools, prompt)
print("choose_planner prompt.input_variables=", prompt.input_variables)
remapped_inputs = create_complement_input(prompt).invoke({})
agent = RunnableAgent(runnable=runnable, input_keys=list(remapped_inputs.keys()))
agent = RunnableAgent(runnable=runnable, input_keys=list(prompt.input_variables))

# return executor or runnable
return_as_executor = False
Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/supervisors/supervisors.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def route(next_action: str) -> str:
llm_with_structured_output = ci_params.llm.with_structured_output(RouteSchema)
# supervisor_agent = prompt | ci_params.llm | output_parser
# supervisor_agent_for_executor = prompt | ci_params.llm
supervisor_agent_structured_output = prompt | llm_with_structured_output
supervisor_agent_structured_output = planner | prompt | llm_with_structured_output

ci_params.supervisor_agent = supervisor_agent_structured_output

Expand Down

0 comments on commit 4c7c543

Please sign in to comment.