forked from shroominic/codeinterpreter-api
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: update thoughts adding controller.py and thought_generation.py
- Loading branch information
Showing
6 changed files
with
386 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from __future__ import annotations | ||
|
||
from textwrap import indent | ||
from typing import Any, Dict, List, Optional, Type | ||
|
||
from langchain.base_language import BaseLanguageModel | ||
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun | ||
from langchain_experimental.pydantic_v1 import Extra | ||
from langchain_experimental.tot.base import ToTChain | ||
from langchain_experimental.tot.checker import ToTChecker | ||
from langchain_experimental.tot.controller import ToTController | ||
from langchain_experimental.tot.memory import ToTDFSMemory | ||
from langchain_experimental.tot.thought import Thought, ThoughtValidity | ||
from langchain_experimental.tot.thought_generation import BaseThoughtGenerationStrategy, ProposePromptStrategy | ||
|
||
|
||
class MyToTChain(ToTChain): | ||
""" | ||
Chain implementing the Tree of Thought (ToT). | ||
""" | ||
|
||
llm: BaseLanguageModel | ||
""" | ||
Language model to use. It must be set to produce different variations for | ||
the same prompt. | ||
""" | ||
checker: ToTChecker | ||
"""ToT Checker to use.""" | ||
output_key: str = "response" #: :meta private: | ||
k: int = 10 | ||
"""The maximum number of conversation rounds""" | ||
c: int = 3 | ||
"""The number of children to explore at each node""" | ||
tot_memory: ToTDFSMemory = ToTDFSMemory() | ||
tot_controller: ToTController = ToTController() | ||
tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy | ||
verbose_llm: bool = False | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
|
||
@classmethod | ||
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain: | ||
""" | ||
Create a ToTChain from a language model. | ||
:param llm: The language model to use. | ||
:param kwargs: Additional arguments to pass to the ToTChain constructor. | ||
""" | ||
return cls(llm=llm, **kwargs) | ||
|
||
def __init__(self, **kwargs: Any): | ||
super().__init__(**kwargs) | ||
self.tot_controller.c = self.c | ||
|
||
@property | ||
def input_keys(self) -> List[str]: | ||
"""Will be whatever keys the prompt expects. | ||
:meta private: | ||
""" | ||
return ["problem_description"] | ||
|
||
@property | ||
def output_keys(self) -> List[str]: | ||
"""Will always return text key. | ||
:meta private: | ||
""" | ||
return [self.output_key] | ||
|
||
def log_thought( | ||
self, | ||
thought: Thought, | ||
level: int, | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> None: | ||
if run_manager: | ||
colors = { | ||
ThoughtValidity.VALID_FINAL: "green", | ||
ThoughtValidity.VALID_INTERMEDIATE: "yellow", | ||
ThoughtValidity.INVALID: "red", | ||
} | ||
text = indent(f"Thought: {thought.text}\n", prefix=" " * level) | ||
run_manager.on_text(text=text, color=colors[thought.validity], verbose=self.verbose) | ||
|
||
def _call( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> Dict[str, str]: | ||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | ||
if run_manager: | ||
run_manager.on_text(text="Starting the ToT solve procedure.\n") | ||
|
||
problem_description = inputs["problem_description"] | ||
checker_inputs = {"problem_description": problem_description} | ||
thoughts_path: tuple[str, ...] = () | ||
thought_generator = self.tot_strategy_class(llm=self.llm, c=self.c, verbose=self.verbose_llm) | ||
|
||
level = 0 | ||
for _ in range(self.k): | ||
level = self.tot_memory.level | ||
thought_text = thought_generator.next_thought( | ||
problem_description, thoughts_path, callbacks=_run_manager.get_child() | ||
) | ||
checker_inputs["thoughts"] = thoughts_path + (thought_text,) | ||
thought_validity = self.checker(checker_inputs, callbacks=_run_manager.get_child())["validity"] | ||
thought = Thought(text=thought_text, validity=thought_validity) | ||
if thought.validity == ThoughtValidity.VALID_FINAL: | ||
self.log_thought(thought, level, run_manager) | ||
return {self.output_key: thought.text} | ||
self.tot_memory.store(thought) | ||
self.log_thought(thought, level, run_manager) | ||
thoughts_path = self.tot_controller(self.tot_memory) | ||
|
||
return {self.output_key: "No solution found"} | ||
|
||
async def _acall( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | ||
) -> Dict[str, str]: | ||
raise NotImplementedError("Async not implemented yet") | ||
|
||
@property | ||
def _chain_type(self) -> str: | ||
return "tot" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from typing import Tuple | ||
|
||
from langchain_experimental.tot.controller import ToTController | ||
from langchain_experimental.tot.memory import ToTDFSMemory | ||
from langchain_experimental.tot.thought import ThoughtValidity | ||
|
||
|
||
class MyToTController(ToTController): | ||
""" | ||
Tree of Thought (ToT) controller. | ||
This is a version of a ToT controller, dubbed in the paper as a "Simple | ||
Controller". | ||
It has one parameter `c` which is the number of children to explore for each | ||
thought. | ||
""" | ||
|
||
def __init__(self, c: int = 3): | ||
""" | ||
Initialize the controller. | ||
Args: | ||
c: The number of children to explore at each node. | ||
""" | ||
self.c = c | ||
|
||
def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]: | ||
next_thought = memory.top() | ||
parent_thought = memory.top_parent() | ||
validity = ThoughtValidity.VALID_INTERMEDIATE if next_thought is None else next_thought.validity | ||
|
||
# 1 if the current partial solution is invalid, backtrack to the parent | ||
# thought. | ||
if validity == ThoughtValidity.INVALID: | ||
memory.pop() | ||
next_thought = memory.top() | ||
if next_thought and len(next_thought.children) >= self.c: | ||
memory.pop() | ||
|
||
# 2 if the current partial solution is valid but C children were | ||
# explored and yet failed to find a final solution, backtrack to the | ||
# parent thought. | ||
elif ( | ||
validity == ThoughtValidity.VALID_INTERMEDIATE and parent_thought and len(parent_thought.children) >= self.c | ||
): | ||
memory.pop(2) | ||
|
||
return tuple(thought.text for thought in memory.current_path()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# LCEL version of. | ||
# https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/tot/thought_generation.py | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from langchain.prompts.base import BasePromptTemplate | ||
from langchain_core.pydantic_v1 import Field | ||
from langchain_experimental.tot.thought_generation import ProposePromptStrategy, SampleCoTStrategy | ||
|
||
from codeinterpreterapi.thoughts.prompts import ( | ||
get_cot_prompt, | ||
get_cot_prompt_ja, | ||
get_propose_prompt, | ||
get_propose_prompt_ja, | ||
) | ||
|
||
|
||
class MySampleCoTStrategy(SampleCoTStrategy): | ||
""" | ||
Sample strategy from a Chain-of-Thought (CoT) prompt. | ||
This strategy works better when the thought space is rich, such as when each | ||
thought is a paragraph. Independent and identically distributed samples | ||
lead to diversity, which helps to avoid repetition. | ||
""" | ||
|
||
prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt) | ||
|
||
def next_thought( | ||
self, | ||
problem_description: str, | ||
thoughts_path: Tuple[str, ...] = (), | ||
**kwargs: Any, | ||
) -> str: | ||
response_text = self.predict_and_parse( | ||
problem_description=problem_description, thoughts=thoughts_path, **kwargs | ||
) | ||
return response_text if isinstance(response_text, str) else "" | ||
|
||
|
||
class MySampleCoTStrategyJa(SampleCoTStrategy): | ||
""" | ||
Sample strategy from a Chain-of-Thought (CoT) prompt. | ||
This strategy works better when the thought space is rich, such as when each | ||
thought is a paragraph. Independent and identically distributed samples | ||
lead to diversity, which helps to avoid repetition. | ||
""" | ||
|
||
prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt_ja) | ||
|
||
def next_thought( | ||
self, | ||
problem_description: str, | ||
thoughts_path: Tuple[str, ...] = (), | ||
**kwargs: Any, | ||
) -> str: | ||
response_text = self.predict_and_parse( | ||
problem_description=problem_description, thoughts=thoughts_path, **kwargs | ||
) | ||
return response_text if isinstance(response_text, str) else "" | ||
|
||
|
||
class MyProposePromptStrategy(ProposePromptStrategy): | ||
""" | ||
Strategy that is sequentially using a "propose prompt". | ||
This strategy works better when the thought space is more constrained, such | ||
as when each thought is just a word or a line. Proposing different thoughts | ||
in the same prompt completion helps to avoid duplication. | ||
""" | ||
|
||
prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt) | ||
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict) | ||
|
||
def next_thought( | ||
self, | ||
problem_description: str, | ||
thoughts_path: Tuple[str, ...] = (), | ||
**kwargs: Any, | ||
) -> str: | ||
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]: | ||
new_thoughts = self.predict_and_parse( | ||
problem_description=problem_description, | ||
thoughts=thoughts_path, | ||
n=self.c, | ||
**kwargs, | ||
) | ||
print("new_thoughts=", new_thoughts) | ||
if not new_thoughts: | ||
return "" | ||
if isinstance(new_thoughts, list): | ||
self.tot_memory[thoughts_path] = new_thoughts[::-1] | ||
else: | ||
return "" | ||
return self.tot_memory[thoughts_path].pop() | ||
|
||
|
||
class MyProposePromptStrategyJa(ProposePromptStrategy): | ||
""" | ||
Strategy that is sequentially using a "propose prompt". | ||
This strategy works better when the thought space is more constrained, such | ||
as when each thought is just a word or a line. Proposing different thoughts | ||
in the same prompt completion helps to avoid duplication. | ||
""" | ||
|
||
prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt_ja) | ||
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict) | ||
|
||
def next_thought( | ||
self, | ||
problem_description: str, | ||
thoughts_path: Tuple[str, ...] = (), | ||
**kwargs: Any, | ||
) -> str: | ||
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]: | ||
new_thoughts = self.predict_and_parse( | ||
problem_description=problem_description, | ||
thoughts=thoughts_path, | ||
n=self.c, | ||
**kwargs, | ||
) | ||
print("new_thoughts=", new_thoughts) | ||
if not new_thoughts: | ||
return "" | ||
if isinstance(new_thoughts, list): | ||
self.tot_memory[thoughts_path] = new_thoughts[::-1] | ||
else: | ||
return "" | ||
return self.tot_memory[thoughts_path].pop() |
Oops, something went wrong.