Skip to content

Commit

Permalink
fix: add checker_support_prompt for MyToTChecker
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 28, 2024
1 parent 4f43dc8 commit ff96db8
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 59 deletions.
12 changes: 9 additions & 3 deletions src/codeinterpreterapi/thoughts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ class MyToTChain(ToTChain):
tot_controller: ToTController = ToTController()
tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy
verbose_llm: bool = False
thought_generator: BaseThoughtGenerationStrategy = None

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def initialize_thought_generator(self):
self.thought_generator = self.tot_strategy_class(llm=self.llm, c=self.c, verbose=self.verbose_llm)

@classmethod
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
"""
Expand Down Expand Up @@ -96,16 +100,18 @@ def _call(
if run_manager:
run_manager.on_text(text="Starting the ToT solve procedure.\n")

if self.thought_generator is None:
self.initialize_thought_generator()

problem_description = inputs["problem_description"]
checker_inputs = {"problem_description": problem_description}
thoughts_path: tuple[str, ...] = ()
thought_generator = self.tot_strategy_class(llm=self.llm, c=self.c, verbose=self.verbose_llm)

level = 0
for _ in range(self.k):
level = self.tot_memory.level
thought_text = thought_generator.next_thought(
problem_description, thoughts_path, callbacks=_run_manager.get_child()
thought_text = self.thought_generator.next_thought(
problem_description, thoughts_path, callbacks=_run_manager.get_child(), tot_checker=self.checker
)
checker_inputs["thoughts"] = thoughts_path + (thought_text,)
thought_validity = self.checker(checker_inputs, callbacks=_run_manager.get_child())["validity"]
Expand Down
237 changes: 189 additions & 48 deletions src/codeinterpreterapi/thoughts/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,201 @@
from typing import Optional, Tuple

import spacy
from codeinterpreterapi.thoughts.thought_generation import (
MyProposePromptStrategy,
MyProposePromptStrategyJa,
MySampleCoTStrategy,
MySampleCoTStrategyJa,
)
from langchain.prompts import PromptTemplate
from langchain.schema import BaseMemory
from langchain_core.output_parsers import StrOutputParser
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.thought import ThoughtValidity
from spacy import Language

sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
from codeinterpreterapi.thoughts.base import MyToTChain
from codeinterpreterapi.thoughts.thought_generation import (
MyProposePromptStrategy,
MyProposePromptStrategyJa,
MySampleCoTStrategy,
MySampleCoTStrategyJa,
)

sudoku_puzzle_sample = "3,x,x,x|1,x,3,x|x,1,x,3|4,x,x,1"
sudoku_puzzle = "3,x,x,x|1,x,3,x|x,1,x,3|4,x,x,1"
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
sudoku_problem_description = f"""
### problem
次の数独パズルを解いてください。
sudoku_puzzle:
{sudoku_puzzle}
- This is a 4x4 Sudoku puzzle.
- The * represents a cell to be filled.
- The | character separates rows.
- At each step, replace one or more * with digits 1-4.
- |で区切られた4つの数字は、first row, second row, third row, 4th row をそれぞれ表します。
- 各rowとcolumnの4つの数字は重複していけません。例えば1,2,3,4や1,2,4,3はOKですが、1,2,2,3はNGです。
- さらに、2x2 のサブグリッド(全4個)でも4つの数字は重複してはいけません。
- Keep the known digits from previous valid thoughts in place.
- Each thought can be a partial or the final solution.
### description
これは有効な4x4の数独パズルを表す文字列です。この初期パズルは、数独のルールに完全に従っています。
パズルには1から4までの数字を使用します。
xは、空欄を表す特別な記号です。xを1から4のいずれかの数字で埋めてください。
パイプ文字(|)は行の区切りを表しています。
例えば、
1,2,3,4|2,3,4,1|3,4,1,2|4,1,2,3
は、以下の4x4の数独パズルを表しています。
1,2,3,4
2,3,4,1
3,4,1,2
4,1,2,3
パズルでは各行、各列、各2x2サブグリッドで数字が重複してはいけません。
ルールの詳細です。
a) 各行の|で区切られた4つの数字に、1から4の数字がそれぞれ一度だけ出現しなければなりません。
b) 各列でも、縦に見たときの4つの数字に、1から4の数字がそれぞれ一度だけ出現しなければなりません。
c) 4つの2x2サブグリッドのそれぞれにおいても、1から4の数字がそれぞれ一度だけ出現しなければなりません。
d) xは数字ではないため、重複してはいけない数字に含まれません。
e) xが1つも残ってない状態になればクリアです。
解答を進める際は、前の有効な解答で確定した数字の位置は変えないでください。
""".strip()


#######
# The following code implements an LLM-based checker
#######

checker_prompt = """
## Order
You are a Three-of-Thought (ToT) strategy correctness checker.
As an expert in this genre, you can make accurate judgments without missing even the slightest information.
class MyToTChecker(ToTChecker):
llm: Optional[BaseMemory] = None
prompt: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts"],
template="""
次の Problem Description に示す問題を解決するための thoughts について、
[VALID_FINAL|VALID_INTERMEDIATE|INVALID]のどれか1つだけを選んで1行目に回答してください。
そして、選択した理由を2行目以降でできるだけ詳しく説明してください。
Based on the problem and thoughts, please judge the correctness of the latest thought (only one).
## Problem Description
### problem
{problem_description}
## Solution Attempt
The following is the Thought generated by the answering LLM.
{thoughts}
## Output
Evaluate whether this answer is correct or not and output in the following format.
<judgement>: this should be in first line. like <judgement>VALID_FINAL</judgement>
<explanation>
Choose one of the following for <judgement>:
- VALID_FINAL: The answer is completely correct, and the puzzle is solved.
- VALID_INTERMEDIATE: The answer is partially correct but not yet complete.
- INVALID: The answer violates the rules or has obvious errors.
In <explanation>, explain in detail the reason for choosing the judgement. Especially in the case of INVALID, point out specifically which rule is violated or where the error is.
When evaluating the answer, strictly follow the rules shown in the Problem Description.
Problem Description: 解決するべき問題です。
Please again, pay deep attention for details. Think step by step. Answer correctly.
"""

checker_prompt_ja = """
## Order
あなたはThree-of-Thought (ToT) 戦略の正誤チェッカーです。
問題と思考から、最新の思考(1つ)について正誤判断をしてください。
## Problem Description
{problem_description}
Thoughts: 解決の手続きについての思考です。
## Solution Attempt
以下は、回答LLMが生成したThoughts(思考列)です。
### IMPORTANT! please check this section ###
{thoughts}
### IMPORTANT! please check this section ###
## Support analysis result
以下は、サポートLLMが生成した解析結果です。
{support_result}
## Output
この解答が正しいかどうかを評価し、以下のフォーマットで出力してください。
<judgement>
<explanation>
<judgement>は以下のいずれかを選択してください。
- VALID_FINAL: 思考が完全に正しい場合、現時点で必要な思考がなされた場合。
- VALID_INTERMEDIATE: 思考は部分的に正しいが、まだ完全ではない場合。
- INVALID: 思考がルールに違反しているか、明らかな誤りがある場合。ループが発生した場合。
<explanation>では、judgementを選択した理由を詳しく説明してください。特に、INVALIDの場合は、具体的にどのルールに違反しているか、どこに誤りがあるかを指摘してください。
解答の評価を行う際は、Problem Descriptionで示されたルールに厳密に従ってください。
"""

checker_support_prompt_ja = """
## Order
あなたはThree-of-Thought (ToT) 戦略の正誤チェッカーのサポートツールです。
問題と思考から、最新の思考(1つ)を正誤判断するための明確な判断基準を最大10個の箇条書きで整理してください。
判断基準は「~であること。」や「~でないこと。」というOK/NGを明示する形式で書いてください。
「~を確認する」や「~を判断する」という形式では書かないでください。
## Problem Description
{problem_description}
## Solution Attempt
以下は、回答LLMが生成したThoughts(思考列)です。
{thoughts}
正しい回答を選ぶために下記のガイドラインに従ってください。
- VALID_FINAL: 最終的な thoughts が問題の解決に最適であると確認したときに選びます。
- VALID_INTERMEDIATE: 最終的な thoughts が問題の解決に適しているが、解決には思考をさらに進める必要があるときに選びます。
- INVALID: 最終的な thoughts が問題を解決できないとき、内容がルールに違反していたとき、明らな間違いがあり思考をやり直す必要があるときに選びます。
## Output
以下に判断基準を示します。
Evaluation:
""",
"""


class MyToTChecker(ToTChecker):
llm: Optional[BaseMemory] = None
prompt: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts", "support_result"],
template=checker_prompt_ja,
)
prompt_support: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts"],
template=checker_support_prompt_ja,
)
nlp: Language = spacy.load("en_core_web_md")

def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:
thoughts = self.pre_evaluate(thoughts)
print("thoughts=", thoughts)
evaluation = self.prompt | self.llm | StrOutputParser()
llm_output = evaluation.invoke({"problem_description": problem_description, "thoughts": thoughts})

support_result = self.prompt_support | self.llm | StrOutputParser()
support_result = support_result.invoke({"problem_description": problem_description, "thoughts": thoughts})
print("support_result=", support_result)

evaluation = self.prompt | self.llm | StrOutputParser()
print(
"prompt_item=",
self.prompt.invoke(
{"problem_description": problem_description, "thoughts": thoughts, "support_result": support_result}
),
)
llm_output = evaluation.invoke(
{"problem_description": problem_description, "thoughts": thoughts, "support_result": support_result}
)
print("llm_output=", llm_output)
final_judge = self.judge_llm_output(llm_output)
print("final_judge=", final_judge)
return final_judge

def pre_evaluate(self, thoughts: Tuple[str, ...]) -> str:
thoughts_str = ""
for i, thought in enumerate(thoughts):
thoughts_str += f"thought_{i}: "
thoughts_str += thought
thoughts_str += "\n"
return thoughts_str

def judge_llm_output(self, llm_output) -> ThoughtValidity:
llm_output_1st_line = llm_output.split("\n")[0]
thought_validity_candidates = ["VALID_FINAL", "VALID_INTERMEDIATE", "INVALID"]
Expand Down Expand Up @@ -106,15 +231,15 @@ def test_checker():
tot_chain = create_tot_chain_from_llm(prepare_test_llm())
checker = tot_chain.checker
assert (
checker.evaluate(sudoku_problem_description, ("3,*,1,2|1,*,3,*|*,1,*,3|4,*,*,1",))
checker.evaluate(sudoku_problem_description, ("3,x,1,2|1,x,3,x|x,1,x,3|4,x,x,1",))
== ThoughtValidity.VALID_INTERMEDIATE
)
assert (
checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",))
== ThoughtValidity.VALID_FINAL
)
assert (
checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",))
checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,x,1",))
== ThoughtValidity.VALID_INTERMEDIATE
)
assert checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,2,3,1",)) == ThoughtValidity.INVALID
Expand Down Expand Up @@ -143,7 +268,7 @@ def create_tot_chain_from_llm(llm=None, is_ja=True, is_simple=False):
else:
tot_strategy_class = MyProposePromptStrategy

tot_chain = ToTChain.from_llm(
tot_chain = MyToTChain.from_llm(
llm=llm,
checker=checker,
k=20,
Expand All @@ -156,23 +281,39 @@ def create_tot_chain_from_llm(llm=None, is_ja=True, is_simple=False):


def prepare_test_llm():
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore

model = "gemini-1.5-flash-latest"
llm = ChatGoogleGenerativeAI(
model=model,
temperature=0,
google_api_key=os.environ.get("GEMINI_API_KEY"),
max_output_tokens=1024 * 4,
)
# model = "gemini-1.5-flash-latest"
# model = "gemini-1.5-pro-latest"
model = "gemini-1.0-pro"
# model = "claude-3-haiku-20240307"
# model = "claude-3-sonnet-20240229"
if "gemini" in model:
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore

llm = ChatGoogleGenerativeAI(
model=model,
temperature=0.1,
google_api_key=os.environ.get("GEMINI_API_KEY"),
max_output_tokens=1024 * 4,
)
else:
from langchain_anthropic import ChatAnthropic # type: ignore

llm = ChatAnthropic(
model_name=model,
temperature=0.1,
anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY"),
max_tokens=1024 * 4,
)

return llm


def test_create():
tot_chain = create_tot_chain_from_llm(llm=prepare_test_llm(), is_simple=True)
tot_chain = create_tot_chain_from_llm(llm=prepare_test_llm(), is_simple=False)
tot_chain.run(problem_description=sudoku_problem_description)


if __name__ == "__main__":
test_checker()
test_create()
test_create()
12 changes: 6 additions & 6 deletions src/codeinterpreterapi/thoughts/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def get_propose_prompt_ja() -> PromptTemplate:
"""
あなたは、Three-of-Thought (ToT) で思考を生成するインテリジェントなエージェントです。
与えられた PROBLEM から思考を生成して出力してください。
問題と、解決のための思考プロセス(VALID THOUGHTS)が確定しています。
VALID THOUGHTS に続く直近の思考候補を {{ n }} 種類 出力してください。
出力は、先頭に "```json"、末尾に "```" を含む、JSON形式の文字列リストとしてマークダウンコードスニペットにフォーマットしてください。
出力例を示します。
Expand All @@ -128,7 +129,7 @@ def get_propose_prompt_ja() -> PromptTemplate:
]
```
問題と作業指示は以下の通りです
問題は以下の通りです
PROBLEM
{{ problem_description }}
Expand All @@ -140,16 +141,15 @@ def get_propose_prompt_ja() -> PromptTemplate:
{{ thought }}
{% endfor %}
上記の思考を参考にして、次の {{ n }} 個の思考を出力してください
上記の思考を参考にして、次の {{ n }} 種類の思考を出力してください
{% else %}
上記の PROBLEM を参考にして、次の {{ n }} 個の思考を出力してください
上記の PROBLEM を参考にして、次の {{ n }} 種類の思考を出力してください
{%- endif -%}
ガイドライン:
- 簡単な問題には最終的な回答1回で思考を完結してください。
- 複雑な問題には一度に {{ n }} 個のステップバイステップで連続した思考をしてください。
- 簡単な問題などで候補が完全に同じになってもかまいません。
- 思考を生成するために、問題と前回の思考を注意深く分析してください。
- 思考は、問題解決に向けた明確なステップや洞察を提供するものでなければなりません。
- 思考がループしている場合は、直ちに気づいて修正してください。あなたは頻繁に思考ループに陥る傾向があります。
Expand Down
Loading

0 comments on commit ff96db8

Please sign in to comment.