Skip to content

Commit

Permalink
fix: pdate CodeInterpreterToT and call from session.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 25, 2024
1 parent 01b35d2 commit 8b776d9
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 13 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ langchain_community
langchain_experimental
langchainhub
langsmith
https://huggingface.co/spacy/en_core_web_md/resolve/main/en_core_web_md-any-py3-none-any.whl
13 changes: 10 additions & 3 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool

from codeinterpreterapi.agents.agents import CodeInterpreterAgent
from codeinterpreterapi.chains import (
aget_file_modifications,
aremove_download_link,
Expand All @@ -29,10 +30,10 @@
)
from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory
from codeinterpreterapi.config import settings
from codeinterpreterapi.llm.llm import CodeInterpreterLlm
from codeinterpreterapi.schema import CodeInterpreterResponse, File, SessionStatus, UserRequest
from codeinterpreterapi.thoughts.thoughts import CodeInterpreterToT

from .agents.agents import CodeInterpreterAgent
from .llm.llm import CodeInterpreterLlm
from .planners.planners import CodeInterpreterPlanner
from .supervisors.supervisors import CodeInterpreterSupervisor
from .tools.tools import CodeInterpreterTools
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
self.agent_executor: Optional[Runnable] = None
self.llm_planner: Optional[Runnable] = None
self.supervisor: Optional[AgentExecutor] = None
self.thought: Optional[Runnable] = None
self.input_files: list[File] = []
self.output_files: list[File] = []
self.code_log: list[tuple[str, str]] = []
Expand All @@ -96,6 +98,7 @@ def initialize(self):
self.initialize_agent_executor()
self.initialize_llm_planner()
self.initialize_supervisor()
self.initialize_thought()

def initialize_agent_executor(self):
is_experimental = False
Expand Down Expand Up @@ -127,6 +130,9 @@ def initialize_supervisor(self):
verbose=self.verbose,
)

def initialize_thought(self):
self.thought = CodeInterpreterToT.get_runnable_tot_chain(llm=self.llm)

def start(self) -> SessionStatus:
print("start")
status = SessionStatus.from_codebox_status(self.codebox.start())
Expand Down Expand Up @@ -404,7 +410,8 @@ def generate_response(
# ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #=======
# response = self.agent_executor.invoke(input=input_message)
# response = self.llm_planner.invoke(input=input_message)
response = self.supervisor.invoke(input=input_message)
# response = self.supervisor.invoke(input=input_message)
response = self.thought.invoke(input=input_message)
# ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #=======
print("response(type)=", type(response))
print("response=", response)
Expand Down
145 changes: 138 additions & 7 deletions src/codeinterpreterapi/thoughts/thoughts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,145 @@
from gui_agent_loop_core.thoughts.prompts import get_propose_prompt, get_propose_prompt_ja
# LCEL version of.
# https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/tot/thought_generation.py
from typing import Any, Dict, List, Optional, Tuple, Union

from langchain.prompts.base import BasePromptTemplate
from langchain_experimental.pydantic_v1 import Field
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import RunnableSequence, RunnableSerializable
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import Input, Output
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.prompts import get_cot_prompt, get_propose_prompt

from codeinterpreterapi.thoughts.tot import create_tot_chain_from_llm


class CodeInterpreterToT(RunnableSerializable):
tot_chain: ToTChain = None

def __init__(self, llm=None):
super().__init__()
self.tot_chain = create_tot_chain_from_llm(llm=llm)

def run(self, input: Input):
problem_description = input["input"]
return self.tot_chain.run(problem_description=problem_description)

def __call__(self, input: Input) -> Dict[str, str]:
return self.run(input)

def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
return self.run(input)

def batch(self, inputs: List[Dict[str, str]]) -> List[Dict[str, str]]:
return [self.run(input_item) for input_item in inputs]

async def ainvoke(self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Output:
raise NotImplementedError("Async not implemented yet")

async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
raise NotImplementedError("Async not implemented yet")

@classmethod
def get_runnable_tot_chain(
cls,
llm=None,
):
# ToTChainのインスタンスを作成
tot_chain = cls(
llm=llm,
)
return tot_chain


# from langchain_experimental.tot.prompts import get_cot_prompt, get_propose_prompt
from langchain_experimental.tot.thought_generation import ProposePromptStrategy
class BaseThoughtGenerationStrategyRunnableSequence(RunnableSequence):
"""
Base class for a thought generation strategy.
"""

c: int = 3
"""The number of children thoughts to propose at each step."""

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
"""
Generate the next thought given the problem description and the thoughts
generated so far.
"""
return ""


class SampleCoTStrategyRunnableSequence(BaseThoughtGenerationStrategyRunnableSequence):
"""
Sample strategy from a Chain-of-Thought (CoT) prompt.
This strategy works better when the thought space is rich, such as when each
thought is a paragraph. Independent and identically distributed samples
lead to diversity, which helps to avoid repetition.
"""

prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt)

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
response_text = self.predict_and_parse(
problem_description=problem_description, thoughts=thoughts_path, **kwargs
)
return response_text if isinstance(response_text, str) else ""


class ProposePromptStrategyRunnableSequence(SampleCoTStrategyRunnableSequence):
"""
Strategy that is sequentially using a "propose prompt".
This strategy works better when the thought space is more constrained, such
as when each thought is just a word or a line. Proposing different thoughts
in the same prompt completion helps to avoid duplication.
"""

class MyProposePromptStrategy(ProposePromptStrategy):
prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt)
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict)

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]:
new_thoughts = self.invoke(
problem_description=problem_description,
thoughts=thoughts_path,
n=self.c,
**kwargs,
)
if not new_thoughts:
return ""
if isinstance(new_thoughts, list):
self.tot_memory[thoughts_path] = new_thoughts[::-1]
else:
return ""
return self.tot_memory[thoughts_path].pop()


def test():
tot_chain = CodeInterpreterToT.get_runnable_tot_chain()
tot_chain.invoke({"input": "pythonで円周率を表示するプログラムを実行してください。"})


class MyProposePromptStrategyJa(ProposePromptStrategy):
prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt_ja)
if __name__ == "__main__":
test()
8 changes: 5 additions & 3 deletions src/codeinterpreterapi/thoughts/tot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#######


class MyChecker(ToTChecker):
class MyToTChecker(ToTChecker):
llm: Optional[BaseMemory] = None
prompt: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts"],
Expand Down Expand Up @@ -121,8 +121,10 @@ def test_checker():
#######


def create_tot_chain_from_llm(llm):
checker = MyChecker()
def create_tot_chain_from_llm(llm=None):
checker = MyToTChecker()
if llm is None:
llm = prepare_test_llm()
checker.llm = llm
tot_chain = ToTChain.from_llm(llm=llm, checker=checker, k=30, c=5, verbose=True, verbose_llm=False)
return tot_chain
Expand Down

0 comments on commit 8b776d9

Please sign in to comment.