From 4f43dc838b2fa4e1311a41499a3f76441d598b3c Mon Sep 17 00:00:00 2001 From: jinno Date: Sun, 26 May 2024 15:49:11 +0900 Subject: [PATCH] fix: update thoughts adding controller.py and thought_generation.py --- src/codeinterpreterapi/thoughts/base.py | 131 ++++++++++++++++++ .../thoughts/{tot.py => checker.py} | 4 +- src/codeinterpreterapi/thoughts/controller.py | 49 +++++++ src/codeinterpreterapi/thoughts/prompts.py | 18 ++- .../thoughts/thought_generation.py | 130 +++++++++++++++++ src/codeinterpreterapi/thoughts/thoughts.py | 64 ++++++++- 6 files changed, 386 insertions(+), 10 deletions(-) create mode 100644 src/codeinterpreterapi/thoughts/base.py rename src/codeinterpreterapi/thoughts/{tot.py => checker.py} (98%) create mode 100644 src/codeinterpreterapi/thoughts/controller.py create mode 100644 src/codeinterpreterapi/thoughts/thought_generation.py diff --git a/src/codeinterpreterapi/thoughts/base.py b/src/codeinterpreterapi/thoughts/base.py new file mode 100644 index 00000000..5385dd59 --- /dev/null +++ b/src/codeinterpreterapi/thoughts/base.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from textwrap import indent +from typing import Any, Dict, List, Optional, Type + +from langchain.base_language import BaseLanguageModel +from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun +from langchain_experimental.pydantic_v1 import Extra +from langchain_experimental.tot.base import ToTChain +from langchain_experimental.tot.checker import ToTChecker +from langchain_experimental.tot.controller import ToTController +from langchain_experimental.tot.memory import ToTDFSMemory +from langchain_experimental.tot.thought import Thought, ThoughtValidity +from langchain_experimental.tot.thought_generation import BaseThoughtGenerationStrategy, ProposePromptStrategy + + +class MyToTChain(ToTChain): + """ + Chain implementing the Tree of Thought (ToT). + """ + + llm: BaseLanguageModel + """ + Language model to use. It must be set to produce different variations for + the same prompt. + """ + checker: ToTChecker + """ToT Checker to use.""" + output_key: str = "response" #: :meta private: + k: int = 10 + """The maximum number of conversation rounds""" + c: int = 3 + """The number of children to explore at each node""" + tot_memory: ToTDFSMemory = ToTDFSMemory() + tot_controller: ToTController = ToTController() + tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy + verbose_llm: bool = False + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @classmethod + def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain: + """ + Create a ToTChain from a language model. + + :param llm: The language model to use. + :param kwargs: Additional arguments to pass to the ToTChain constructor. + """ + return cls(llm=llm, **kwargs) + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.tot_controller.c = self.c + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + return ["problem_description"] + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] + + def log_thought( + self, + thought: Thought, + level: int, + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> None: + if run_manager: + colors = { + ThoughtValidity.VALID_FINAL: "green", + ThoughtValidity.VALID_INTERMEDIATE: "yellow", + ThoughtValidity.INVALID: "red", + } + text = indent(f"Thought: {thought.text}\n", prefix=" " * level) + run_manager.on_text(text=text, color=colors[thought.validity], verbose=self.verbose) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + if run_manager: + run_manager.on_text(text="Starting the ToT solve procedure.\n") + + problem_description = inputs["problem_description"] + checker_inputs = {"problem_description": problem_description} + thoughts_path: tuple[str, ...] = () + thought_generator = self.tot_strategy_class(llm=self.llm, c=self.c, verbose=self.verbose_llm) + + level = 0 + for _ in range(self.k): + level = self.tot_memory.level + thought_text = thought_generator.next_thought( + problem_description, thoughts_path, callbacks=_run_manager.get_child() + ) + checker_inputs["thoughts"] = thoughts_path + (thought_text,) + thought_validity = self.checker(checker_inputs, callbacks=_run_manager.get_child())["validity"] + thought = Thought(text=thought_text, validity=thought_validity) + if thought.validity == ThoughtValidity.VALID_FINAL: + self.log_thought(thought, level, run_manager) + return {self.output_key: thought.text} + self.tot_memory.store(thought) + self.log_thought(thought, level, run_manager) + thoughts_path = self.tot_controller(self.tot_memory) + + return {self.output_key: "No solution found"} + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + raise NotImplementedError("Async not implemented yet") + + @property + def _chain_type(self) -> str: + return "tot" diff --git a/src/codeinterpreterapi/thoughts/tot.py b/src/codeinterpreterapi/thoughts/checker.py similarity index 98% rename from src/codeinterpreterapi/thoughts/tot.py rename to src/codeinterpreterapi/thoughts/checker.py index f1777e59..ec3a90b1 100644 --- a/src/codeinterpreterapi/thoughts/tot.py +++ b/src/codeinterpreterapi/thoughts/checker.py @@ -163,13 +163,13 @@ def prepare_test_llm(): model=model, temperature=0, google_api_key=os.environ.get("GEMINI_API_KEY"), - max_output_tokens=1024, + max_output_tokens=1024 * 4, ) return llm def test_create(): - tot_chain = create_tot_chain_from_llm(prepare_test_llm()) + tot_chain = create_tot_chain_from_llm(llm=prepare_test_llm(), is_simple=True) tot_chain.run(problem_description=sudoku_problem_description) diff --git a/src/codeinterpreterapi/thoughts/controller.py b/src/codeinterpreterapi/thoughts/controller.py new file mode 100644 index 00000000..ca5f71bd --- /dev/null +++ b/src/codeinterpreterapi/thoughts/controller.py @@ -0,0 +1,49 @@ +from typing import Tuple + +from langchain_experimental.tot.controller import ToTController +from langchain_experimental.tot.memory import ToTDFSMemory +from langchain_experimental.tot.thought import ThoughtValidity + + +class MyToTController(ToTController): + """ + Tree of Thought (ToT) controller. + + This is a version of a ToT controller, dubbed in the paper as a "Simple + Controller". + + It has one parameter `c` which is the number of children to explore for each + thought. + """ + + def __init__(self, c: int = 3): + """ + Initialize the controller. + + Args: + c: The number of children to explore at each node. + """ + self.c = c + + def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]: + next_thought = memory.top() + parent_thought = memory.top_parent() + validity = ThoughtValidity.VALID_INTERMEDIATE if next_thought is None else next_thought.validity + + # 1 if the current partial solution is invalid, backtrack to the parent + # thought. + if validity == ThoughtValidity.INVALID: + memory.pop() + next_thought = memory.top() + if next_thought and len(next_thought.children) >= self.c: + memory.pop() + + # 2 if the current partial solution is valid but C children were + # explored and yet failed to find a final solution, backtrack to the + # parent thought. + elif ( + validity == ThoughtValidity.VALID_INTERMEDIATE and parent_thought and len(parent_thought.children) >= self.c + ): + memory.pop(2) + + return tuple(thought.text for thought in memory.current_path()) diff --git a/src/codeinterpreterapi/thoughts/prompts.py b/src/codeinterpreterapi/thoughts/prompts.py index 1b4a4d43..fa43c39c 100644 --- a/src/codeinterpreterapi/thoughts/prompts.py +++ b/src/codeinterpreterapi/thoughts/prompts.py @@ -114,10 +114,12 @@ def get_propose_prompt_ja() -> PromptTemplate: output_parser=JSONListOutputParser(), template=dedent( """ -あなたは、思考ツリーの設定で思考を生成するインテリジェントなエージェントです。 +あなたは、Three-of-Thought (ToT) で思考を生成するインテリジェントなエージェントです。 +与えられた PROBLEM から思考を生成して出力してください。 出力は、先頭に "```json"、末尾に "```" を含む、JSON形式の文字列リストとしてマークダウンコードスニペットにフォーマットしてください。 +出力例を示します。 ```json [ "", @@ -126,6 +128,7 @@ def get_propose_prompt_ja() -> PromptTemplate: ] ``` +問題と作業指示は以下の通りです。 PROBLEM {{ problem_description }} @@ -144,7 +147,18 @@ def get_propose_prompt_ja() -> PromptTemplate: {%- endif -%} - +ガイドライン: +- 簡単な問題には最終的な回答1回で思考を完結してください。 +- 複雑な問題には一度に {{ n }} 個のステップバイステップで連続した思考をしてください。 +- 思考を生成するために、問題と前回の思考を注意深く分析してください。 +- 思考は、問題解決に向けた明確なステップや洞察を提供するものでなければなりません。 +- 思考がループしている場合は、直ちに気づいて修正してください。あなたは頻繁に思考ループに陥る傾向があります。 +- 各思考は簡潔にまとめ、問題に直接関連するようにしてください。 +- 思考の質を向上させるために、必要に応じて問題をさらに分析し、追加の情報を検討してください。 +- より少ない思考で問題を解決できるように最善を尽くしてください。 +- 生成された思考のリストを、指定されたJSON形式で出力してください。 + +それでは開始してください。 """ ).strip(), ) diff --git a/src/codeinterpreterapi/thoughts/thought_generation.py b/src/codeinterpreterapi/thoughts/thought_generation.py new file mode 100644 index 00000000..edce5888 --- /dev/null +++ b/src/codeinterpreterapi/thoughts/thought_generation.py @@ -0,0 +1,130 @@ +# LCEL version of. +# https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/tot/thought_generation.py +from typing import Any, Dict, List, Tuple + +from langchain.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import Field +from langchain_experimental.tot.thought_generation import ProposePromptStrategy, SampleCoTStrategy + +from codeinterpreterapi.thoughts.prompts import ( + get_cot_prompt, + get_cot_prompt_ja, + get_propose_prompt, + get_propose_prompt_ja, +) + + +class MySampleCoTStrategy(SampleCoTStrategy): + """ + Sample strategy from a Chain-of-Thought (CoT) prompt. + + This strategy works better when the thought space is rich, such as when each + thought is a paragraph. Independent and identically distributed samples + lead to diversity, which helps to avoid repetition. + """ + + prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt) + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any, + ) -> str: + response_text = self.predict_and_parse( + problem_description=problem_description, thoughts=thoughts_path, **kwargs + ) + return response_text if isinstance(response_text, str) else "" + + +class MySampleCoTStrategyJa(SampleCoTStrategy): + """ + Sample strategy from a Chain-of-Thought (CoT) prompt. + + This strategy works better when the thought space is rich, such as when each + thought is a paragraph. Independent and identically distributed samples + lead to diversity, which helps to avoid repetition. + """ + + prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt_ja) + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any, + ) -> str: + response_text = self.predict_and_parse( + problem_description=problem_description, thoughts=thoughts_path, **kwargs + ) + return response_text if isinstance(response_text, str) else "" + + +class MyProposePromptStrategy(ProposePromptStrategy): + """ + Strategy that is sequentially using a "propose prompt". + + This strategy works better when the thought space is more constrained, such + as when each thought is just a word or a line. Proposing different thoughts + in the same prompt completion helps to avoid duplication. + """ + + prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt) + tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict) + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any, + ) -> str: + if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]: + new_thoughts = self.predict_and_parse( + problem_description=problem_description, + thoughts=thoughts_path, + n=self.c, + **kwargs, + ) + print("new_thoughts=", new_thoughts) + if not new_thoughts: + return "" + if isinstance(new_thoughts, list): + self.tot_memory[thoughts_path] = new_thoughts[::-1] + else: + return "" + return self.tot_memory[thoughts_path].pop() + + +class MyProposePromptStrategyJa(ProposePromptStrategy): + """ + Strategy that is sequentially using a "propose prompt". + + This strategy works better when the thought space is more constrained, such + as when each thought is just a word or a line. Proposing different thoughts + in the same prompt completion helps to avoid duplication. + """ + + prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt_ja) + tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict) + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any, + ) -> str: + if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]: + new_thoughts = self.predict_and_parse( + problem_description=problem_description, + thoughts=thoughts_path, + n=self.c, + **kwargs, + ) + print("new_thoughts=", new_thoughts) + if not new_thoughts: + return "" + if isinstance(new_thoughts, list): + self.tot_memory[thoughts_path] = new_thoughts[::-1] + else: + return "" + return self.tot_memory[thoughts_path].pop() diff --git a/src/codeinterpreterapi/thoughts/thoughts.py b/src/codeinterpreterapi/thoughts/thoughts.py index 4d1ed901..fccbd1e7 100644 --- a/src/codeinterpreterapi/thoughts/thoughts.py +++ b/src/codeinterpreterapi/thoughts/thoughts.py @@ -1,15 +1,14 @@ from typing import Any, Dict, List, Optional, Union +from codeinterpreterapi.thoughts.base import MyToTChain +from codeinterpreterapi.thoughts.checker import create_tot_chain_from_llm from langchain_core.runnables import RunnableSerializable from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import Input, Output -from langchain_experimental.tot.base import ToTChain - -from codeinterpreterapi.thoughts.tot import create_tot_chain_from_llm class CodeInterpreterToT(RunnableSerializable): - tot_chain: ToTChain = None + tot_chain: MyToTChain = None def __init__(self, llm=None, is_ja=True, is_simple=False): super().__init__() @@ -49,9 +48,62 @@ def get_runnable_tot_chain(cls, llm=None, is_ja=True, is_simple=False): def test(): - tot_chain = CodeInterpreterToT.get_runnable_tot_chain() - tot_chain.invoke({"input": "pythonで円周率を表示するプログラムを実行してください。"}) + tot_chain = CodeInterpreterToT.get_runnable_tot_chain(is_simple=True) + tot_chain.invoke({"input": sample2}) + + +sample1 = "pythonで円周率を表示するプログラムを実行してください。" +sample2 = """SVG画像を自動生成するプログラムの要件を以下のように定義します。 + +目的: + +電子書籍のヘッダ画像を自動生成すること +別のコンテンツ生成プログラムが出力したSVGファイルを入力として受け取ること +入力SVGファイルを指定の要件に従って加工し、新たなSVGファイルとして出力すること +機能要件: + +グリッドレイアウト機能の実装 + +指定したグリッドサイズ(行数、列数)に基づいて要素を配置できるようにする +グリッドの各セルに対して要素を割り当てられるようにする +グリッドのサイズや間隔を柔軟に設定できるようにする +SVG要素の配置と編集 + +グリッド上の指定した位置にSVG要素(テキスト、図形、画像など)を配置できるようにする +配置する要素の属性(サイズ、色、フォントなど)を柔軟に設定できるようにする +既存のSVG要素を削除、移動、変更できるようにする +外部画像ファイルの読み込みと配置 + +PNGやJPEGなどの外部画像ファイルを読み込んでSVGファイルに埋め込めるようにする +読み込んだ画像をグリッド上の指定した位置に配置できるようにする +画像のサイズを変更できるようにする +SVGファイルの入出力 + +SVGファイルを入力として読み込み、加工後のSVGファイルを出力できるようにする +出力ファイルのファイル名やパスを指定できるようにする +非機能要件: + +Python3とsvgwriteライブラリを使用して実装すること +コードはモジュール化し、再利用性と保守性を高めること +エラーハンドリングを適切に行い、ログ出力を行うこと +コードにはコメントを付けて可読性を高めること +実装の進め方: + +svgwriteを使ったSVGファイルの基本的な読み込み、編集、出力の機能を実装する +グリッドレイアウト機能を実装し、要素を配置できるようにする +外部画像ファイルの読み込みと配置機能を実装する +入力SVGファイルを読み込んで、指定の要件に従って加工し、新たなSVGファイルを出力する一連の処理を実装する +細かい仕様について検討し、機能を拡張する +テストを行い、不具合を修正する +ドキュメントを整備し、コードをリファクタリングする +まずはこの要件定義に基づいて、各機能の実装に着手してください。実装方法や詳細な手順は、要件に合わせて適宜ご判断ください。 + + + +作業フォルダは/app/workを使ってください。 +全ての処理は自動で実施して結果とプログラムだけ報告してください。 +""" if __name__ == "__main__": test()