From e599adde922cd75ef745d18697a57c583b7ff41c Mon Sep 17 00:00:00 2001 From: jinno Date: Fri, 20 Sep 2024 09:47:05 +0900 Subject: [PATCH] fix: add zoltraak_design tool --- src/codeinterpreterapi/schema.py | 9 ++++- src/codeinterpreterapi/tools/zoltraak.py | 44 ++++++++++++++++++------ 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/codeinterpreterapi/schema.py b/src/codeinterpreterapi/schema.py index 64c8a07..cf8ed5c 100644 --- a/src/codeinterpreterapi/schema.py +++ b/src/codeinterpreterapi/schema.py @@ -108,7 +108,14 @@ class BashCommand(BaseModel): class ZoltraakInput(BaseModel): - request: str + prompt: str = Field( + default="このシステムを改善してください。", + description="やりたいこと。曖昧な目標でも動作するし、具体的に指定すればピンポイントで編集や改善もできる。", + ) + name: str = Field( + default="codeinterpreter", + description="処理対象の名前。対象がシステムの場合はディレクトリ名に使われる。対象がpythonファイルの場合はpythonファイル名に使われる。", + ) class UserRequest(HumanMessage): diff --git a/src/codeinterpreterapi/tools/zoltraak.py b/src/codeinterpreterapi/tools/zoltraak.py index 1650b36..45f4028 100644 --- a/src/codeinterpreterapi/tools/zoltraak.py +++ b/src/codeinterpreterapi/tools/zoltraak.py @@ -7,11 +7,17 @@ from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.llm.llm import prepare_test_llm from codeinterpreterapi.schema import ZoltraakInput +from enum import Enum CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) INVOKE_TASKS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../invoke_tasks")) +class ZoltraakCompilerEnum(Enum): + PYTHON_CODE = "dev_obj" + DESIGN = "general_def" + + class ZoltraakTools: def __init__(self, ci_params: CodeInterpreterParams): self.ci_params = ci_params @@ -23,33 +29,49 @@ def get_tools_zoltraak(cls, ci_params: CodeInterpreterParams) -> None: tools_instance = cls(ci_params=ci_params) tools = [ StructuredTool( - name="zoltraak", - description="あいまいなリクエストを元にプログラムの設計をしてからプロトタイプのソースコード群を作成します。\n" + name="zoltraak_python", + description="あいまいなリクエストを元にプロトタイプのソースコード群を作成します。\n" "このツールは基本的な要件を満たす最低限度の品質を持ったコードを生成できます。\n" "プログラミングを開始するときは、このツールを最初に実行してください。", func=tools_instance.run, coroutine=tools_instance.arun, args_schema=ZoltraakInput, ), + StructuredTool( + name="zoltraak_design", + description="あいまいなリクエストから設計文書を作成します。\n" + "このツールは基本的な要件を満たすための具体的な設計を定義できます。\n" + "設計作業を進めるときは、このツール実行してください。", + func=tools_instance.run_design, + coroutine=tools_instance.arun_design, + args_schema=ZoltraakInput, + ), ] return tools - def run(self, request: str) -> str: - return self._common_run(request) + def run(self, prompt: str, name: str) -> str: + return self._common_run(prompt, name, ZoltraakCompilerEnum.PYTHON_CODE.value) + + async def arun(self, prompt: str, name: str) -> str: + return self._common_run(prompt, name, ZoltraakCompilerEnum.PYTHON_CODE.value) + + def run_design(self, prompt: str, name: str) -> str: + return self._common_run(prompt, name, ZoltraakCompilerEnum.DESIGN.value) - async def arun(self, request: str) -> str: - return self._common_run(request) + async def arun_design(self, prompt: str, name: str) -> str: + return self._common_run(prompt, name, ZoltraakCompilerEnum.DESIGN.value) - def _common_run(self, request: str): + def _common_run(self, prompt: str, name: str, compiler: str): try: # シェルインジェクションを防ぐためにshlexを使用 args = [] args.append('/home/jinno/.pyenv/shims/zoltraak') - args.append(f"\"{request}\"") + args.append(f"\"requirements/{name}.md\"") + args.append('-p') + args.append(f"\"{prompt}\"") args.append('-c') - args.append('dev_func') - # args.append(f"eval \"$(pyenv init -)\";zoltraak \"{request}\" -c dev_func") + args.append(f"\"{compiler}\"") output_content = subprocess.check_output( args, stderr=subprocess.STDOUT, universal_newlines=True, cwd=settings.WORK_DIR ) @@ -71,7 +93,7 @@ def test(): ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) tools_instance = ZoltraakTools(ci_params=ci_params) test_request = "シンプルなpythonのサンプルプログラムを書いてください。テーマはなんでもいいです。" - result = tools_instance.run(test_request) + result = tools_instance.run(test_request, "sample") print("result=", result) # assert "test output" in result