From 8b776d97ef06c0bc7af7e04bd89c0539013fd9c4 Mon Sep 17 00:00:00 2001 From: jinno Date: Sun, 26 May 2024 01:44:48 +0900 Subject: [PATCH] fix: pdate CodeInterpreterToT and call from session.py --- requirements.txt | 1 + src/codeinterpreterapi/session.py | 13 +- src/codeinterpreterapi/thoughts/thoughts.py | 145 +++++++++++++++++++- src/codeinterpreterapi/thoughts/tot.py | 8 +- 4 files changed, 154 insertions(+), 13 deletions(-) diff --git a/requirements.txt b/requirements.txt index 673dd5bb..c0e8f060 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ langchain_community langchain_experimental langchainhub langsmith +https://huggingface.co/spacy/en_core_web_md/resolve/main/en_core_web_md-any-py3-none-any.whl diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index 187ef52f..3424b429 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -21,6 +21,7 @@ from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool +from codeinterpreterapi.agents.agents import CodeInterpreterAgent from codeinterpreterapi.chains import ( aget_file_modifications, aremove_download_link, @@ -29,10 +30,10 @@ ) from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory from codeinterpreterapi.config import settings +from codeinterpreterapi.llm.llm import CodeInterpreterLlm from codeinterpreterapi.schema import CodeInterpreterResponse, File, SessionStatus, UserRequest +from codeinterpreterapi.thoughts.thoughts import CodeInterpreterToT -from .agents.agents import CodeInterpreterAgent -from .llm.llm import CodeInterpreterLlm from .planners.planners import CodeInterpreterPlanner from .supervisors.supervisors import CodeInterpreterSupervisor from .tools.tools import CodeInterpreterTools @@ -77,6 +78,7 @@ def __init__( self.agent_executor: Optional[Runnable] = None self.llm_planner: Optional[Runnable] = None self.supervisor: Optional[AgentExecutor] = None + self.thought: Optional[Runnable] = None self.input_files: list[File] = [] self.output_files: list[File] = [] self.code_log: list[tuple[str, str]] = [] @@ -96,6 +98,7 @@ def initialize(self): self.initialize_agent_executor() self.initialize_llm_planner() self.initialize_supervisor() + self.initialize_thought() def initialize_agent_executor(self): is_experimental = False @@ -127,6 +130,9 @@ def initialize_supervisor(self): verbose=self.verbose, ) + def initialize_thought(self): + self.thought = CodeInterpreterToT.get_runnable_tot_chain(llm=self.llm) + def start(self) -> SessionStatus: print("start") status = SessionStatus.from_codebox_status(self.codebox.start()) @@ -404,7 +410,8 @@ def generate_response( # ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #======= # response = self.agent_executor.invoke(input=input_message) # response = self.llm_planner.invoke(input=input_message) - response = self.supervisor.invoke(input=input_message) + # response = self.supervisor.invoke(input=input_message) + response = self.thought.invoke(input=input_message) # ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #======= print("response(type)=", type(response)) print("response=", response) diff --git a/src/codeinterpreterapi/thoughts/thoughts.py b/src/codeinterpreterapi/thoughts/thoughts.py index 40e1dd8d..1f7f43cf 100644 --- a/src/codeinterpreterapi/thoughts/thoughts.py +++ b/src/codeinterpreterapi/thoughts/thoughts.py @@ -1,14 +1,145 @@ -from gui_agent_loop_core.thoughts.prompts import get_propose_prompt, get_propose_prompt_ja +# LCEL version of. +# https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/tot/thought_generation.py +from typing import Any, Dict, List, Optional, Tuple, Union + from langchain.prompts.base import BasePromptTemplate -from langchain_experimental.pydantic_v1 import Field +from langchain_core.pydantic_v1 import Field +from langchain_core.runnables import RunnableSequence, RunnableSerializable +from langchain_core.runnables.config import RunnableConfig +from langchain_core.runnables.utils import Input, Output +from langchain_experimental.tot.base import ToTChain +from langchain_experimental.tot.prompts import get_cot_prompt, get_propose_prompt + +from codeinterpreterapi.thoughts.tot import create_tot_chain_from_llm + + +class CodeInterpreterToT(RunnableSerializable): + tot_chain: ToTChain = None + + def __init__(self, llm=None): + super().__init__() + self.tot_chain = create_tot_chain_from_llm(llm=llm) + + def run(self, input: Input): + problem_description = input["input"] + return self.tot_chain.run(problem_description=problem_description) + + def __call__(self, input: Input) -> Dict[str, str]: + return self.run(input) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + return self.run(input) + + def batch(self, inputs: List[Dict[str, str]]) -> List[Dict[str, str]]: + return [self.run(input_item) for input_item in inputs] + + async def ainvoke(self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Output: + raise NotImplementedError("Async not implemented yet") + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + raise NotImplementedError("Async not implemented yet") + + @classmethod + def get_runnable_tot_chain( + cls, + llm=None, + ): + # ToTChainのインスタンスを作成 + tot_chain = cls( + llm=llm, + ) + return tot_chain + -# from langchain_experimental.tot.prompts import get_cot_prompt, get_propose_prompt -from langchain_experimental.tot.thought_generation import ProposePromptStrategy +class BaseThoughtGenerationStrategyRunnableSequence(RunnableSequence): + """ + Base class for a thought generation strategy. + """ + c: int = 3 + """The number of children thoughts to propose at each step.""" + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any, + ) -> str: + """ + Generate the next thought given the problem description and the thoughts + generated so far. + """ + return "" + + +class SampleCoTStrategyRunnableSequence(BaseThoughtGenerationStrategyRunnableSequence): + """ + Sample strategy from a Chain-of-Thought (CoT) prompt. + + This strategy works better when the thought space is rich, such as when each + thought is a paragraph. Independent and identically distributed samples + lead to diversity, which helps to avoid repetition. + """ + + prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt) + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any, + ) -> str: + response_text = self.predict_and_parse( + problem_description=problem_description, thoughts=thoughts_path, **kwargs + ) + return response_text if isinstance(response_text, str) else "" + + +class ProposePromptStrategyRunnableSequence(SampleCoTStrategyRunnableSequence): + """ + Strategy that is sequentially using a "propose prompt". + + This strategy works better when the thought space is more constrained, such + as when each thought is just a word or a line. Proposing different thoughts + in the same prompt completion helps to avoid duplication. + """ -class MyProposePromptStrategy(ProposePromptStrategy): prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt) + tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict) + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any, + ) -> str: + if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]: + new_thoughts = self.invoke( + problem_description=problem_description, + thoughts=thoughts_path, + n=self.c, + **kwargs, + ) + if not new_thoughts: + return "" + if isinstance(new_thoughts, list): + self.tot_memory[thoughts_path] = new_thoughts[::-1] + else: + return "" + return self.tot_memory[thoughts_path].pop() + + +def test(): + tot_chain = CodeInterpreterToT.get_runnable_tot_chain() + tot_chain.invoke({"input": "pythonで円周率を表示するプログラムを実行してください。"}) -class MyProposePromptStrategyJa(ProposePromptStrategy): - prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt_ja) +if __name__ == "__main__": + test() diff --git a/src/codeinterpreterapi/thoughts/tot.py b/src/codeinterpreterapi/thoughts/tot.py index 5880a07c..1fedaad5 100644 --- a/src/codeinterpreterapi/thoughts/tot.py +++ b/src/codeinterpreterapi/thoughts/tot.py @@ -31,7 +31,7 @@ ####### -class MyChecker(ToTChecker): +class MyToTChecker(ToTChecker): llm: Optional[BaseMemory] = None prompt: PromptTemplate = PromptTemplate( input_variables=["problem_description", "thoughts"], @@ -121,8 +121,10 @@ def test_checker(): ####### -def create_tot_chain_from_llm(llm): - checker = MyChecker() +def create_tot_chain_from_llm(llm=None): + checker = MyToTChecker() + if llm is None: + llm = prepare_test_llm() checker.llm = llm tot_chain = ToTChain.from_llm(llm=llm, checker=checker, k=30, c=5, verbose=True, verbose_llm=False) return tot_chain