From 77c5be5e53531fb2f091b6ebbe730f8a16c830a1 Mon Sep 17 00:00:00 2001 From: jinno Date: Sun, 26 May 2024 13:53:06 +0900 Subject: [PATCH] fix: update thoughts prompt for is_ja=True --- src/codeinterpreterapi/session.py | 2 +- src/codeinterpreterapi/thoughts/prompts.py | 13 +-- src/codeinterpreterapi/thoughts/thoughts.py | 100 ++------------------ src/codeinterpreterapi/thoughts/tot.py | 29 +++++- 4 files changed, 39 insertions(+), 105 deletions(-) diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index 3424b429..d9757807 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -131,7 +131,7 @@ def initialize_supervisor(self): ) def initialize_thought(self): - self.thought = CodeInterpreterToT.get_runnable_tot_chain(llm=self.llm) + self.thought = CodeInterpreterToT.get_runnable_tot_chain(llm=self.llm, is_ja=self.is_ja, is_simple=False) def start(self) -> SessionStatus: print("start") diff --git a/src/codeinterpreterapi/thoughts/prompts.py b/src/codeinterpreterapi/thoughts/prompts.py index cfbcbce0..1b4a4d43 100644 --- a/src/codeinterpreterapi/thoughts/prompts.py +++ b/src/codeinterpreterapi/thoughts/prompts.py @@ -131,23 +131,20 @@ def get_propose_prompt_ja() -> PromptTemplate: {{ problem_description }} {% if thoughts %} - VALID THOUGHTS + VALID THOUGHTS(思考) {% for thought in thoughts %} {{ thought }} {% endfor %} - Possible next {{ n }} valid thoughts based on the last valid thought: + 上記の思考を参考にして、次の {{ n }} 個の思考を出力してください。 {% else %} - Possible next {{ n }} valid thoughts based on the PROBLEM: + 上記の PROBLEM を参考にして、次の {{ n }} 個の思考を出力してください。 + {%- endif -%} -次の思考を生成するために、問題と有効な思考を注意深く分析してください。 -思考は、問題解決に向けた明確なステップや洞察を提供するものでなければなりません。 -各思考は簡潔にまとめ、問題に直接関連するようにしてください。 -思考の質を向上させるために、必要に応じて問題をさらに分析し、追加の情報を検討してください。 -生成された思考のリストを、指定されたJSON形式で出力してください。 + """ ).strip(), ) diff --git a/src/codeinterpreterapi/thoughts/thoughts.py b/src/codeinterpreterapi/thoughts/thoughts.py index 1f7f43cf..4d1ed901 100644 --- a/src/codeinterpreterapi/thoughts/thoughts.py +++ b/src/codeinterpreterapi/thoughts/thoughts.py @@ -1,14 +1,9 @@ -# LCEL version of. -# https://github.com/langchain-ai/langchain/blob/master/libs/experimental/langchain_experimental/tot/thought_generation.py -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union -from langchain.prompts.base import BasePromptTemplate -from langchain_core.pydantic_v1 import Field -from langchain_core.runnables import RunnableSequence, RunnableSerializable +from langchain_core.runnables import RunnableSerializable from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.utils import Input, Output from langchain_experimental.tot.base import ToTChain -from langchain_experimental.tot.prompts import get_cot_prompt, get_propose_prompt from codeinterpreterapi.thoughts.tot import create_tot_chain_from_llm @@ -16,9 +11,9 @@ class CodeInterpreterToT(RunnableSerializable): tot_chain: ToTChain = None - def __init__(self, llm=None): + def __init__(self, llm=None, is_ja=True, is_simple=False): super().__init__() - self.tot_chain = create_tot_chain_from_llm(llm=llm) + self.tot_chain = create_tot_chain_from_llm(llm=llm, is_ja=is_ja, is_simple=is_simple) def run(self, input: Input): problem_description = input["input"] @@ -47,95 +42,12 @@ async def abatch( raise NotImplementedError("Async not implemented yet") @classmethod - def get_runnable_tot_chain( - cls, - llm=None, - ): + def get_runnable_tot_chain(cls, llm=None, is_ja=True, is_simple=False): # ToTChainのインスタンスを作成 - tot_chain = cls( - llm=llm, - ) + tot_chain = cls(llm=llm, is_ja=is_ja, is_simple=is_simple) return tot_chain -class BaseThoughtGenerationStrategyRunnableSequence(RunnableSequence): - """ - Base class for a thought generation strategy. - """ - - c: int = 3 - """The number of children thoughts to propose at each step.""" - - def next_thought( - self, - problem_description: str, - thoughts_path: Tuple[str, ...] = (), - **kwargs: Any, - ) -> str: - """ - Generate the next thought given the problem description and the thoughts - generated so far. - """ - return "" - - -class SampleCoTStrategyRunnableSequence(BaseThoughtGenerationStrategyRunnableSequence): - """ - Sample strategy from a Chain-of-Thought (CoT) prompt. - - This strategy works better when the thought space is rich, such as when each - thought is a paragraph. Independent and identically distributed samples - lead to diversity, which helps to avoid repetition. - """ - - prompt: BasePromptTemplate = Field(default_factory=get_cot_prompt) - - def next_thought( - self, - problem_description: str, - thoughts_path: Tuple[str, ...] = (), - **kwargs: Any, - ) -> str: - response_text = self.predict_and_parse( - problem_description=problem_description, thoughts=thoughts_path, **kwargs - ) - return response_text if isinstance(response_text, str) else "" - - -class ProposePromptStrategyRunnableSequence(SampleCoTStrategyRunnableSequence): - """ - Strategy that is sequentially using a "propose prompt". - - This strategy works better when the thought space is more constrained, such - as when each thought is just a word or a line. Proposing different thoughts - in the same prompt completion helps to avoid duplication. - """ - - prompt: BasePromptTemplate = Field(default_factory=get_propose_prompt) - tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict) - - def next_thought( - self, - problem_description: str, - thoughts_path: Tuple[str, ...] = (), - **kwargs: Any, - ) -> str: - if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]: - new_thoughts = self.invoke( - problem_description=problem_description, - thoughts=thoughts_path, - n=self.c, - **kwargs, - ) - if not new_thoughts: - return "" - if isinstance(new_thoughts, list): - self.tot_memory[thoughts_path] = new_thoughts[::-1] - else: - return "" - return self.tot_memory[thoughts_path].pop() - - def test(): tot_chain = CodeInterpreterToT.get_runnable_tot_chain() tot_chain.invoke({"input": "pythonで円周率を表示するプログラムを実行してください。"}) diff --git a/src/codeinterpreterapi/thoughts/tot.py b/src/codeinterpreterapi/thoughts/tot.py index 1fedaad5..f1777e59 100644 --- a/src/codeinterpreterapi/thoughts/tot.py +++ b/src/codeinterpreterapi/thoughts/tot.py @@ -2,6 +2,12 @@ from typing import Optional, Tuple import spacy +from codeinterpreterapi.thoughts.thought_generation import ( + MyProposePromptStrategy, + MyProposePromptStrategyJa, + MySampleCoTStrategy, + MySampleCoTStrategyJa, +) from langchain.prompts import PromptTemplate from langchain.schema import BaseMemory from langchain_core.output_parsers import StrOutputParser @@ -121,12 +127,31 @@ def test_checker(): ####### -def create_tot_chain_from_llm(llm=None): +def create_tot_chain_from_llm(llm=None, is_ja=True, is_simple=False): checker = MyToTChecker() if llm is None: llm = prepare_test_llm() checker.llm = llm - tot_chain = ToTChain.from_llm(llm=llm, checker=checker, k=30, c=5, verbose=True, verbose_llm=False) + if is_ja: + if is_simple: + tot_strategy_class = MySampleCoTStrategyJa + else: + tot_strategy_class = MyProposePromptStrategyJa + else: + if is_simple: + tot_strategy_class = MySampleCoTStrategy + else: + tot_strategy_class = MyProposePromptStrategy + + tot_chain = ToTChain.from_llm( + llm=llm, + checker=checker, + k=20, + c=3, + verbose=True, + tot_strategy_class=tot_strategy_class, + verbose_llm=False, + ) return tot_chain