Skip to content

Commit

Permalink
fix: update thoughts adding controller.py and thought_generation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 26, 2024
1 parent 77c5be5 commit 4f43dc8
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 10 deletions.
131 changes: 131 additions & 0 deletions src/codeinterpreterapi/thoughts/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations

from textwrap import indent
from typing import Any, Dict, List, Optional, Type

from langchain.base_language import BaseLanguageModel
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
from langchain_experimental.pydantic_v1 import Extra
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import Thought, ThoughtValidity
from langchain_experimental.tot.thought_generation import BaseThoughtGenerationStrategy, ProposePromptStrategy


class MyToTChain(ToTChain):
"""
Chain implementing the Tree of Thought (ToT).
"""

llm: BaseLanguageModel
"""
Language model to use. It must be set to produce different variations for
the same prompt.
"""
checker: ToTChecker
"""ToT Checker to use."""
output_key: str = "response" #: :meta private:
k: int = 10
"""The maximum number of conversation rounds"""
c: int = 3
"""The number of children to explore at each node"""
tot_memory: ToTDFSMemory = ToTDFSMemory()
tot_controller: ToTController = ToTController()
tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy
verbose_llm: bool = False

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

@classmethod
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
"""
Create a ToTChain from a language model.
:param llm: The language model to use.
:param kwargs: Additional arguments to pass to the ToTChain constructor.
"""
return cls(llm=llm, **kwargs)

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.tot_controller.c = self.c

@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return ["problem_description"]

@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]

def log_thought(
self,
thought: Thought,
level: int,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> None:
if run_manager:
colors = {
ThoughtValidity.VALID_FINAL: "green",
ThoughtValidity.VALID_INTERMEDIATE: "yellow",
ThoughtValidity.INVALID: "red",
}
text = indent(f"Thought: {thought.text}\n", prefix=" " * level)
run_manager.on_text(text=text, color=colors[thought.validity], verbose=self.verbose)

def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
if run_manager:
run_manager.on_text(text="Starting the ToT solve procedure.\n")

problem_description = inputs["problem_description"]
checker_inputs = {"problem_description": problem_description}
thoughts_path: tuple[str, ...] = ()
thought_generator = self.tot_strategy_class(llm=self.llm, c=self.c, verbose=self.verbose_llm)

level = 0
for _ in range(self.k):
level = self.tot_memory.level
thought_text = thought_generator.next_thought(
problem_description, thoughts_path, callbacks=_run_manager.get_child()
)
checker_inputs["thoughts"] = thoughts_path + (thought_text,)
thought_validity = self.checker(checker_inputs, callbacks=_run_manager.get_child())["validity"]
thought = Thought(text=thought_text, validity=thought_validity)
if thought.validity == ThoughtValidity.VALID_FINAL:
self.log_thought(thought, level, run_manager)
return {self.output_key: thought.text}
self.tot_memory.store(thought)
self.log_thought(thought, level, run_manager)
thoughts_path = self.tot_controller(self.tot_memory)

return {self.output_key: "No solution found"}

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
raise NotImplementedError("Async not implemented yet")

@property
def _chain_type(self) -> str:
return "tot"
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ def prepare_test_llm():
model=model,
temperature=0,
google_api_key=os.environ.get("GEMINI_API_KEY"),
max_output_tokens=1024,
max_output_tokens=1024 * 4,
)
return llm


def test_create():
tot_chain = create_tot_chain_from_llm(prepare_test_llm())
tot_chain = create_tot_chain_from_llm(llm=prepare_test_llm(), is_simple=True)
tot_chain.run(problem_description=sudoku_problem_description)


Expand Down
49 changes: 49 additions & 0 deletions src/codeinterpreterapi/thoughts/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Tuple

from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import ThoughtValidity


class MyToTController(ToTController):
"""
Tree of Thought (ToT) controller.
This is a version of a ToT controller, dubbed in the paper as a "Simple
Controller".
It has one parameter `c` which is the number of children to explore for each
thought.
"""

def __init__(self, c: int = 3):
"""
Initialize the controller.
Args:
c: The number of children to explore at each node.
"""
self.c = c

def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]:
next_thought = memory.top()
parent_thought = memory.top_parent()
validity = ThoughtValidity.VALID_INTERMEDIATE if next_thought is None else next_thought.validity

# 1 if the current partial solution is invalid, backtrack to the parent
# thought.
if validity == ThoughtValidity.INVALID:
memory.pop()
next_thought = memory.top()
if next_thought and len(next_thought.children) >= self.c:
memory.pop()

# 2 if the current partial solution is valid but C children were
# explored and yet failed to find a final solution, backtrack to the
# parent thought.
elif (
validity == ThoughtValidity.VALID_INTERMEDIATE and parent_thought and len(parent_thought.children) >= self.c
):
memory.pop(2)

return tuple(thought.text for thought in memory.current_path())
18 changes: 16 additions & 2 deletions src/codeinterpreterapi/thoughts/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ def get_propose_prompt_ja() -> PromptTemplate:
output_parser=JSONListOutputParser(),
template=dedent(
"""
あなたは、思考ツリーの設定で思考を生成するインテリジェントなエージェントです
あなたは、Three-of-Thought (ToT) で思考を生成するインテリジェントなエージェントです
与えられた PROBLEM から思考を生成して出力してください。
出力は、先頭に "```json"、末尾に "```" を含む、JSON形式の文字列リストとしてマークダウンコードスニペットにフォーマットしてください。
出力例を示します。
```json
[
"<thought-1>",
Expand All @@ -126,6 +128,7 @@ def get_propose_prompt_ja() -> PromptTemplate:
]
```
問題と作業指示は以下の通りです。
PROBLEM
{{ problem_description }}
Expand All @@ -144,7 +147,18 @@ def get_propose_prompt_ja() -> PromptTemplate:
{%- endif -%}
ガイドライン:
- 簡単な問題には最終的な回答1回で思考を完結してください。
- 複雑な問題には一度に {{ n }} 個のステップバイステップで連続した思考をしてください。
- 思考を生成するために、問題と前回の思考を注意深く分析してください。
- 思考は、問題解決に向けた明確なステップや洞察を提供するものでなければなりません。
- 思考がループしている場合は、直ちに気づいて修正してください。あなたは頻繁に思考ループに陥る傾向があります。
- 各思考は簡潔にまとめ、問題に直接関連するようにしてください。
- 思考の質を向上させるために、必要に応じて問題をさらに分析し、追加の情報を検討してください。
- より少ない思考で問題を解決できるように最善を尽くしてください。
- 生成された思考のリストを、指定されたJSON形式で出力してください。
それでは開始してください。
"""
).strip(),
)
130 changes: 130 additions & 0 deletions src/codeinterpreterapi/thoughts/thought_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# LCEL version of.
# https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/tot/thought_generation.py
from typing import Any, Dict, List, Tuple

from langchain.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_experimental.tot.thought_generation import ProposePromptStrategy, SampleCoTStrategy

from codeinterpreterapi.thoughts.prompts import (
get_cot_prompt,
get_cot_prompt_ja,
get_propose_prompt,
get_propose_prompt_ja,
)


class MySampleCoTStrategy(SampleCoTStrategy):
"""
Sample strategy from a Chain-of-Thought (CoT) prompt.
This strategy works better when the thought space is rich, such as when each
thought is a paragraph. Independent and identically distributed samples
lead to diversity, which helps to avoid repetition.
"""

prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt)

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
response_text = self.predict_and_parse(
problem_description=problem_description, thoughts=thoughts_path, **kwargs
)
return response_text if isinstance(response_text, str) else ""


class MySampleCoTStrategyJa(SampleCoTStrategy):
"""
Sample strategy from a Chain-of-Thought (CoT) prompt.
This strategy works better when the thought space is rich, such as when each
thought is a paragraph. Independent and identically distributed samples
lead to diversity, which helps to avoid repetition.
"""

prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt_ja)

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
response_text = self.predict_and_parse(
problem_description=problem_description, thoughts=thoughts_path, **kwargs
)
return response_text if isinstance(response_text, str) else ""


class MyProposePromptStrategy(ProposePromptStrategy):
"""
Strategy that is sequentially using a "propose prompt".
This strategy works better when the thought space is more constrained, such
as when each thought is just a word or a line. Proposing different thoughts
in the same prompt completion helps to avoid duplication.
"""

prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt)
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict)

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]:
new_thoughts = self.predict_and_parse(
problem_description=problem_description,
thoughts=thoughts_path,
n=self.c,
**kwargs,
)
print("new_thoughts=", new_thoughts)
if not new_thoughts:
return ""
if isinstance(new_thoughts, list):
self.tot_memory[thoughts_path] = new_thoughts[::-1]
else:
return ""
return self.tot_memory[thoughts_path].pop()


class MyProposePromptStrategyJa(ProposePromptStrategy):
"""
Strategy that is sequentially using a "propose prompt".
This strategy works better when the thought space is more constrained, such
as when each thought is just a word or a line. Proposing different thoughts
in the same prompt completion helps to avoid duplication.
"""

prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt_ja)
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict)

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]:
new_thoughts = self.predict_and_parse(
problem_description=problem_description,
thoughts=thoughts_path,
n=self.c,
**kwargs,
)
print("new_thoughts=", new_thoughts)
if not new_thoughts:
return ""
if isinstance(new_thoughts, list):
self.tot_memory[thoughts_path] = new_thoughts[::-1]
else:
return ""
return self.tot_memory[thoughts_path].pop()
Loading

0 comments on commit 4f43dc8

Please sign in to comment.