Skip to content

Commit

Permalink
fix: update files for pydantic_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Sep 21, 2024
1 parent 45b27b4 commit d8a210a
Show file tree
Hide file tree
Showing 18 changed files with 84 additions and 41 deletions.
6 changes: 3 additions & 3 deletions src/codeinterpreterapi/brain/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from codeboxapi import CodeBox # type: ignore
from gui_agent_loop_core.schema.agent.schema import AgentDefinition
from langchain.base_language import BaseLanguageModel
from langchain_core.language_models import BaseLanguageModel
from langchain.callbacks.base import Callbacks
from langchain.tools import BaseTool
from langchain_core.tools import BaseTool
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import tool

Expand Down
8 changes: 7 additions & 1 deletion src/codeinterpreterapi/callbacks/markdown/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from uuid import UUID

from langchain.callbacks import FileCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult


from langchain_core.agents import AgentFinish

from langchain_core.agents import AgentAction

from langchain_core.outputs import LLMResult


class MarkdownFileCallbackHandler(FileCallbackHandler):
Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/callbacks/stdout/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional
from uuid import UUID

from langchain.callbacks import StdOutCallbackHandler
from langchain_core.callbacks import StdOutCallbackHandler
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage

Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

from langchain_core.messages import SystemMessage
from langchain_core.pydantic_v1 import BaseSettings, SecretStr

from pydantic import SecretStr
from pydantic_settings import BaseSettings
from codeinterpreterapi.prompts import code_interpreter_system_message, code_interpreter_system_message_ja


Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/graphs/agent_wrapper_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Optional, Type

from langchain.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import BaseTool

Expand Down
6 changes: 4 additions & 2 deletions src/codeinterpreterapi/graphs/tool_node/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Sequence

from langchain.agents import AgentExecutor
from langchain.chat_models.base import BaseChatModel
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.tools import BaseTool, tool
from langchain_core.tools import tool
from pydantic import BaseTool

from langgraph.graph import MessageGraph, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from google.generativeai.types.content_types import FunctionDeclarationType # type: ignore[import]
from google.generativeai.types.content_types import ToolDict # type: ignore[import]
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
from langchain.chat_models.base import BaseChatModel
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore
Expand Down
6 changes: 5 additions & 1 deletion src/codeinterpreterapi/planners/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from langchain.agents import AgentExecutor
from langchain.agents.agent import RunnableAgent
from langchain.schema import AIMessage, Generation


from langchain_core.messages import AIMessage

from langchain_core.outputs import Generation
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from codeboxapi.schema import CodeBoxStatus
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool

ToolsRenderer = Callable[[List[BaseTool]], str]
Expand Down
4 changes: 3 additions & 1 deletion src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from codeboxapi import CodeBox # type: ignore
from codeboxapi.schema import CodeBoxStatus # type: ignore
from gui_agent_loop_core.schema.message.schema import BaseMessageContent
from langchain.callbacks.base import BaseCallbackHandler, Callbacks
from langchain.callbacks.base import Callbacks

from langchain_core.callbacks import BaseCallbackHandler
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain_community.chat_message_histories.redis import RedisChatMessageHistory
Expand Down
21 changes: 18 additions & 3 deletions src/codeinterpreterapi/supervisors/supervisors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import Any, Dict

from langchain.agents import AgentExecutor
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
from langchain_core.runnables import Runnable
from langchain_core.runnables.utils import Input
from langchain_core.prompts import PromptTemplate

from codeinterpreterapi.agents.agents import CodeInterpreterAgent
from codeinterpreterapi.brain.params import CodeInterpreterParams
Expand All @@ -18,11 +19,13 @@
from codeinterpreterapi.test_prompts.test_prompt import TestPrompt
from codeinterpreterapi.utils.multi_converter import MultiConverter

from codeinterpreterapi.tools.tools import CodeInterpreterTools


class CodeInterpreterSupervisor:
def __init__(self, planner: Runnable, ci_params: CodeInterpreterParams):
self.planner = ci_params.planner_agent
self.ci_params = ci_params
self.ci_params: CodeInterpreterParams = ci_params
self.supervisor_chain = None
self.supervisor_chain_no_agent = None
self.initialize()
Expand Down Expand Up @@ -82,8 +85,20 @@ def get_executor(self) -> AgentExecutor:
# TODO: impl
return self.supervisor_chain

def zoltraak_pre_process(self, input: Input) -> str:
prompt_template = PromptTemplate(
template="zoltraakによる前処理でinputを一般的な汎用言語表現に翻訳してください。: {input}"
)
tools = CodeInterpreterTools.get_zoltraak_tools(self.ci_params)
llm_with_tools = self.ci_params.llm.bind_tools(tools)
chain = prompt_template | llm_with_tools
pre_processed_input = chain.invoke(input)
print("zoltraak_pre_process pre_processed_input=", pre_processed_input)
return pre_processed_input

def invoke(self, input: Input) -> CodeInterpreterIntermediateResult:
planner_result = self.planner.invoke(input, config=self.ci_params.runnable_config)
pre_processed_input = self.zoltraak_pre_process(input)
planner_result = self.planner.invoke(pre_processed_input, config=self.ci_params.runnable_config)
print("supervisor.invoke type(planner_result)=", type(planner_result))
if isinstance(planner_result, CodeInterpreterPlanList):
plan_list: CodeInterpreterPlanList = planner_result
Expand Down
25 changes: 8 additions & 17 deletions src/codeinterpreterapi/thoughts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from textwrap import indent
from typing import Any, Dict, List, Optional, Type

from langchain.base_language import BaseLanguageModel
from langchain_core.language_models import BaseLanguageModel
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
from langchain_core.runnables import Runnable
from langchain_experimental.pydantic_v1 import Extra
from langchain_experimental.pydantic_v2 import Extra
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
Expand All @@ -21,13 +21,10 @@ class MyToTChain(ToTChain):
"""

llm: Runnable
"""
Language model to use. It must be set to produce different variations for
the same prompt.
"""
"""Language model to use. It must be set to produce different variations for the same prompt."""
checker: ToTChecker
"""ToT Checker to use."""
output_key: str = "response" #: :meta private:
output_key: str = "response"
k: int = 10
"""The maximum number of conversation rounds"""
c: int = 3
Expand All @@ -36,7 +33,7 @@ class MyToTChain(ToTChain):
tot_controller: ToTController = ToTController()
tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy
verbose_llm: bool = False
thought_generator: BaseThoughtGenerationStrategy = None
thought_generator: Optional[BaseThoughtGenerationStrategy] = None

class Config:
"""Configuration for this pydantic object."""
Expand All @@ -50,7 +47,7 @@ def initialize_thought_generator(self):
print("initialize_thought_generator prompt.input_variables=", input_variables)

@classmethod
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> MyToTChain:
"""
Create a ToTChain from a language model.
Expand All @@ -65,18 +62,12 @@ def __init__(self, **kwargs: Any):

@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
"""Will be whatever keys the prompt expects."""
return ["problem_description"]

@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
"""Will always return text key."""
return [self.output_key]

def log_thought(
Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/thoughts/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Tuple

import spacy
from langchain.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_core.runnables import Runnable
from langchain_experimental.tot.checker import ToTChecker
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/thoughts/thought_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/tot/thought_generation.py
from typing import Any, Dict, List, Tuple

from langchain.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.prompts import BasePromptTemplate
from pydantic import Field
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.thought import ThoughtValidity
from langchain_experimental.tot.thought_generation import ProposePromptStrategy, SampleCoTStrategy
Expand Down
15 changes: 14 additions & 1 deletion src/codeinterpreterapi/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from langchain_community.tools.shell.tool import BaseTool, ShellTool
from langchain_community.tools.shell.tool import ShellTool

from langchain_core.tools import BaseTool
from langchain_community.tools.tavily_search import TavilySearchResults

from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.tools.bash import BashTools
from codeinterpreterapi.tools.code_checker import CodeChecker
from codeinterpreterapi.tools.python import PythonTools
from codeinterpreterapi.tools.zoltraak import ZoltraakTools
from typing import List


Expand All @@ -21,6 +24,7 @@ def get_all_tools(self) -> list[BaseTool]:
self._additional_tools.extend(CodeChecker.get_tools_code_checker(self._ci_params))
self.add_tools_shell()
self.add_tools_web_search()
self.add_tools_zoltraak()
return self._additional_tools

def add_tools_shell(self) -> None:
Expand All @@ -42,6 +46,15 @@ def add_tools_web_search(self) -> None:
tools = [TavilySearchResults(max_results=1)]
self._additional_tools += tools

def add_tools_zoltraak(self) -> None:
tools = ZoltraakTools.get_tools_zoltraak(self._ci_params)
self._additional_tools += tools

@staticmethod
def get_zoltraak_tools(ci_params: CodeInterpreterParams) -> None:
tools = ZoltraakTools.get_tools_zoltraak(ci_params)
return tools

@staticmethod
def get_agent_tools(agent_tools: str, all_tools: List[BaseTool]) -> None:
selected_tools = []
Expand Down
12 changes: 11 additions & 1 deletion src/codeinterpreterapi/tools/zoltraak.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,14 @@ async def arun_design(self, prompt: str, name: str) -> str:
return self._common_run(prompt, name, ZoltraakCompilerEnum.DESIGN.value)

def _common_run(self, prompt: str, name: str, compiler: str):
# mdファイルを生成して内容をreturnする
inout_md_path = f"requirements/{name}.md"

try:
# シェルインジェクションを防ぐためにshlexを使用
args = []
args.append('/home/jinno/.pyenv/shims/zoltraak')
args.append(f"\"requirements/{name}.md\"")
args.append(f"\"{inout_md_path}\"")
args.append('-p')
args.append(f"\"{prompt}\"")
args.append('-c')
Expand All @@ -77,7 +80,14 @@ def _common_run(self, prompt: str, name: str, compiler: str):
)
print("_common_run output_content=", type(output_content))
print("_common_run output_content=", output_content)

if os.path.isfile(inout_md_path):
with open(inout_md_path, "r") as file:
output_content = file.read()
print("_common_run output_content(zoltraak)=", output_content)
return output_content
self.command_log.append((args, output_content))
print("WARN: no inout_m inout_md_path=", inout_md_path)
return output_content

except subprocess.CalledProcessError as e:
Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/utils/runnable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Union

from langchain.prompts.base import BasePromptTemplate
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import RunnableLambda, RunnableSequence


Expand Down
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/utils/runnable_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain_community.chat_message_histories import FileChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory

Expand Down

0 comments on commit d8a210a

Please sign in to comment.