Skip to content

Commit

Permalink
fix: update thoughts prompt for is_ja=True
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 26, 2024
1 parent 8b776d9 commit 77c5be5
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 105 deletions.
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def initialize_supervisor(self):
)

def initialize_thought(self):
self.thought = CodeInterpreterToT.get_runnable_tot_chain(llm=self.llm)
self.thought = CodeInterpreterToT.get_runnable_tot_chain(llm=self.llm, is_ja=self.is_ja, is_simple=False)

def start(self) -> SessionStatus:
print("start")
Expand Down
13 changes: 5 additions & 8 deletions src/codeinterpreterapi/thoughts/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,20 @@ def get_propose_prompt_ja() -> PromptTemplate:
{{ problem_description }}
{% if thoughts %}
VALID THOUGHTS
VALID THOUGHTS(思考)
{% for thought in thoughts %}
{{ thought }}
{% endfor %}
Possible next {{ n }} valid thoughts based on the last valid thought:
上記の思考を参考にして、次の {{ n }} 個の思考を出力してください。
{% else %}
Possible next {{ n }} valid thoughts based on the PROBLEM:
上記の PROBLEM を参考にして、次の {{ n }} 個の思考を出力してください。
{%- endif -%}
次の思考を生成するために、問題と有効な思考を注意深く分析してください。
思考は、問題解決に向けた明確なステップや洞察を提供するものでなければなりません。
各思考は簡潔にまとめ、問題に直接関連するようにしてください。
思考の質を向上させるために、必要に応じて問題をさらに分析し、追加の情報を検討してください。
生成された思考のリストを、指定されたJSON形式で出力してください。
"""
).strip(),
)
100 changes: 6 additions & 94 deletions src/codeinterpreterapi/thoughts/thoughts.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
# LCEL version of.
# https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/tot/thought_generation.py
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union

from langchain.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import RunnableSequence, RunnableSerializable
from langchain_core.runnables import RunnableSerializable
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import Input, Output
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.prompts import get_cot_prompt, get_propose_prompt

from codeinterpreterapi.thoughts.tot import create_tot_chain_from_llm


class CodeInterpreterToT(RunnableSerializable):
tot_chain: ToTChain = None

def __init__(self, llm=None):
def __init__(self, llm=None, is_ja=True, is_simple=False):
super().__init__()
self.tot_chain = create_tot_chain_from_llm(llm=llm)
self.tot_chain = create_tot_chain_from_llm(llm=llm, is_ja=is_ja, is_simple=is_simple)

def run(self, input: Input):
problem_description = input["input"]
Expand Down Expand Up @@ -47,95 +42,12 @@ async def abatch(
raise NotImplementedError("Async not implemented yet")

@classmethod
def get_runnable_tot_chain(
cls,
llm=None,
):
def get_runnable_tot_chain(cls, llm=None, is_ja=True, is_simple=False):
# ToTChainのインスタンスを作成
tot_chain = cls(
llm=llm,
)
tot_chain = cls(llm=llm, is_ja=is_ja, is_simple=is_simple)
return tot_chain


class BaseThoughtGenerationStrategyRunnableSequence(RunnableSequence):
"""
Base class for a thought generation strategy.
"""

c: int = 3
"""The number of children thoughts to propose at each step."""

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
"""
Generate the next thought given the problem description and the thoughts
generated so far.
"""
return ""


class SampleCoTStrategyRunnableSequence(BaseThoughtGenerationStrategyRunnableSequence):
"""
Sample strategy from a Chain-of-Thought (CoT) prompt.
This strategy works better when the thought space is rich, such as when each
thought is a paragraph. Independent and identically distributed samples
lead to diversity, which helps to avoid repetition.
"""

prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt)

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
response_text = self.predict_and_parse(
problem_description=problem_description, thoughts=thoughts_path, **kwargs
)
return response_text if isinstance(response_text, str) else ""


class ProposePromptStrategyRunnableSequence(SampleCoTStrategyRunnableSequence):
"""
Strategy that is sequentially using a "propose prompt".
This strategy works better when the thought space is more constrained, such
as when each thought is just a word or a line. Proposing different thoughts
in the same prompt completion helps to avoid duplication.
"""

prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt)
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict)

def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any,
) -> str:
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]:
new_thoughts = self.invoke(
problem_description=problem_description,
thoughts=thoughts_path,
n=self.c,
**kwargs,
)
if not new_thoughts:
return ""
if isinstance(new_thoughts, list):
self.tot_memory[thoughts_path] = new_thoughts[::-1]
else:
return ""
return self.tot_memory[thoughts_path].pop()


def test():
tot_chain = CodeInterpreterToT.get_runnable_tot_chain()
tot_chain.invoke({"input": "pythonで円周率を表示するプログラムを実行してください。"})
Expand Down
29 changes: 27 additions & 2 deletions src/codeinterpreterapi/thoughts/tot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
from typing import Optional, Tuple

import spacy
from codeinterpreterapi.thoughts.thought_generation import (
MyProposePromptStrategy,
MyProposePromptStrategyJa,
MySampleCoTStrategy,
MySampleCoTStrategyJa,
)
from langchain.prompts import PromptTemplate
from langchain.schema import BaseMemory
from langchain_core.output_parsers import StrOutputParser
Expand Down Expand Up @@ -121,12 +127,31 @@ def test_checker():
#######


def create_tot_chain_from_llm(llm=None):
def create_tot_chain_from_llm(llm=None, is_ja=True, is_simple=False):
checker = MyToTChecker()
if llm is None:
llm = prepare_test_llm()
checker.llm = llm
tot_chain = ToTChain.from_llm(llm=llm, checker=checker, k=30, c=5, verbose=True, verbose_llm=False)
if is_ja:
if is_simple:
tot_strategy_class = MySampleCoTStrategyJa
else:
tot_strategy_class = MyProposePromptStrategyJa
else:
if is_simple:
tot_strategy_class = MySampleCoTStrategy
else:
tot_strategy_class = MyProposePromptStrategy

tot_chain = ToTChain.from_llm(
llm=llm,
checker=checker,
k=20,
c=3,
verbose=True,
tot_strategy_class=tot_strategy_class,
verbose_llm=False,
)
return tot_chain


Expand Down

0 comments on commit 77c5be5

Please sign in to comment.