Skip to content

Commit

Permalink
fix: add checker_support_prompt for MyToTChecker
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 28, 2024
1 parent 4f43dc8 commit 16a30a6
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 65 deletions.
12 changes: 9 additions & 3 deletions src/codeinterpreterapi/thoughts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ class MyToTChain(ToTChain):
tot_controller: ToTController = ToTController()
tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy
verbose_llm: bool = False
thought_generator: BaseThoughtGenerationStrategy = None

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

def initialize_thought_generator(self):
self.thought_generator = self.tot_strategy_class(llm=self.llm, c=self.c, verbose=self.verbose_llm)

@classmethod
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
"""
Expand Down Expand Up @@ -96,16 +100,18 @@ def _call(
if run_manager:
run_manager.on_text(text="Starting the ToT solve procedure.\n")

if self.thought_generator is None:
self.initialize_thought_generator()

problem_description = inputs["problem_description"]
checker_inputs = {"problem_description": problem_description}
thoughts_path: tuple[str, ...] = ()
thought_generator = self.tot_strategy_class(llm=self.llm, c=self.c, verbose=self.verbose_llm)

level = 0
for _ in range(self.k):
level = self.tot_memory.level
thought_text = thought_generator.next_thought(
problem_description, thoughts_path, callbacks=_run_manager.get_child()
thought_text = self.thought_generator.next_thought(
problem_description, thoughts_path, callbacks=_run_manager.get_child(), tot_checker=self.checker
)
checker_inputs["thoughts"] = thoughts_path + (thought_text,)
thought_validity = self.checker(checker_inputs, callbacks=_run_manager.get_child())["validity"]
Expand Down
248 changes: 194 additions & 54 deletions src/codeinterpreterapi/thoughts/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,201 @@
from typing import Optional, Tuple

import spacy
from codeinterpreterapi.thoughts.thought_generation import (
MyProposePromptStrategy,
MyProposePromptStrategyJa,
MySampleCoTStrategy,
MySampleCoTStrategyJa,
)
from langchain.prompts import PromptTemplate
from langchain.schema import BaseMemory
from langchain_core.output_parsers import StrOutputParser
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.thought import ThoughtValidity
from spacy import Language

sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
from codeinterpreterapi.thoughts.base import MyToTChain
from codeinterpreterapi.thoughts.thought_generation import (
MyProposePromptStrategy,
MyProposePromptStrategyJa,
MySampleCoTStrategy,
MySampleCoTStrategyJa,
)

sudoku_puzzle_sample = "3,x,x,x|1,x,3,x|x,1,x,3|4,x,x,1"
sudoku_puzzle = "3,x,x,x|1,x,3,x|x,1,x,3|4,x,x,1"
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
sudoku_problem_description = f"""
### problem
次の数独パズルを解いてください。
sudoku_puzzle:
{sudoku_puzzle}
- This is a 4x4 Sudoku puzzle.
- The * represents a cell to be filled.
- The | character separates rows.
- At each step, replace one or more * with digits 1-4.
- |で区切られた4つの数字は、first row, second row, third row, 4th row をそれぞれ表します。
- 各rowとcolumnの4つの数字は重複していけません。例えば1,2,3,4や1,2,4,3はOKですが、1,2,2,3はNGです。
- さらに、2x2 のサブグリッド(全4個)でも4つの数字は重複してはいけません。
- Keep the known digits from previous valid thoughts in place.
- Each thought can be a partial or the final solution.
### description
これは有効な4x4の数独パズルを表す文字列です。この初期パズルは、数独のルールに完全に従っています。
パズルには1から4までの数字を使用します。
xは、空欄を表す特別な記号です。xを1から4のいずれかの数字で埋めてください。
パイプ文字(|)は行の区切りを表しています。
例えば、
1,2,3,4|2,3,4,1|3,4,1,2|4,1,2,3
は、以下の4x4の数独パズルを表しています。
1,2,3,4
2,3,4,1
3,4,1,2
4,1,2,3
パズルでは各行、各列、各2x2サブグリッドで数字が重複してはいけません。
ルールの詳細です。
a) 各行の|で区切られた4つの数字に、1から4の数字がそれぞれ一度だけ出現しなければなりません。
b) 各列でも、縦に見たときの4つの数字に、1から4の数字がそれぞれ一度だけ出現しなければなりません。
c) 4つの2x2サブグリッドのそれぞれにおいても、1から4の数字がそれぞれ一度だけ出現しなければなりません。
d) xは数字ではないため、重複してはいけない数字に含まれません。
e) xが1つも残ってない状態になればクリアです。
解答を進める際は、前の有効な解答で確定した数字の位置は変えないでください。
""".strip()


#######
# The following code implements an LLM-based checker
#######

checker_prompt = """
## Order
You are a Three-of-Thought (ToT) strategy correctness checker.
As an expert in this genre, you can make accurate judgments without missing even the slightest information.
class MyToTChecker(ToTChecker):
llm: Optional[BaseMemory] = None
prompt: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts"],
template="""
次の Problem Description に示す問題を解決するための thoughts について、
[VALID_FINAL|VALID_INTERMEDIATE|INVALID]のどれか1つだけを選んで1行目に回答してください。
そして、選択した理由を2行目以降でできるだけ詳しく説明してください。
Based on the problem and thoughts, please judge the correctness of the latest thought (only one).
## Problem Description
### problem
{problem_description}
## Solution Attempt
The following is the Thought generated by the answering LLM.
{thoughts}
## Output
Evaluate whether this answer is correct or not and output in the following format.
<judgement>: this should be in first line. like <judgement>VALID_FINAL</judgement>
<explanation>
Choose one of the following for <judgement>:
- VALID_FINAL: The answer is completely correct, and the puzzle is solved.
- VALID_INTERMEDIATE: The answer is partially correct but not yet complete.
- INVALID: The answer violates the rules or has obvious errors.
In <explanation>, explain in detail the reason for choosing the judgement. Especially in the case of INVALID, point out specifically which rule is violated or where the error is.
Problem Description: 解決するべき問題です。
When evaluating the answer, strictly follow the rules shown in the Problem Description.
Please again, pay deep attention for details. Think step by step. Answer correctly.
"""

checker_prompt_ja = """
## Order
あなたはThree-of-Thought (ToT) 戦略の正誤チェッカーです。
問題と思考から、最新の思考(1つ)について正誤判断をしてください。
## Problem Description
{problem_description}
## Solution Attempt
以下は、回答LLMが生成したThoughts(思考列)です。
### IMPORTANT! please check this section ###
{thoughts}
### IMPORTANT! please check this section ###
## Support analysis result
以下は、サポートLLMが生成した解析結果です。
{support_result}
## Output
この解答が正しいかどうかを評価し、以下のフォーマットで出力してください。
<judgement>
<explanation>
<judgement>は以下のいずれかを選択してください。
- VALID_FINAL: 思考が完全に正しい場合、現時点で必要な思考がなされた場合。
- VALID_INTERMEDIATE: 思考は部分的に正しいが、まだ完全ではない場合。
- INVALID: 思考がルールに違反しているか、明らかな誤りがある場合。ループが発生した場合。
<explanation>では、judgementを選択した理由を詳しく説明してください。特に、INVALIDの場合は、具体的にどのルールに違反しているか、どこに誤りがあるかを指摘してください。
解答の評価を行う際は、Problem Descriptionで示されたルールに厳密に従ってください。
"""

checker_support_prompt_ja = """
## Order
あなたはThree-of-Thought (ToT) 戦略の正誤チェッカーのサポートツールです。
問題と思考から、最新の思考(1つ)を正誤判断するための明確な判断基準を最大10個の箇条書きで整理してください。
判断基準は「~であること。」や「~でないこと。」というOK/NGを明示する形式で書いてください。
「~を確認する」や「~を判断する」という形式では書かないでください。
## Problem Description
{problem_description}
Thoughts: 解決の手続きについての思考です。
## Solution Attempt
以下は、回答LLMが生成したThoughts(思考列)です。
{thoughts}
正しい回答を選ぶために下記のガイドラインに従ってください。
- VALID_FINAL: 最終的な thoughts が問題の解決に最適であると確認したときに選びます。
- VALID_INTERMEDIATE: 最終的な thoughts が問題の解決に適しているが、解決には思考をさらに進める必要があるときに選びます。
- INVALID: 最終的な thoughts が問題を解決できないとき、内容がルールに違反していたとき、明らな間違いがあり思考をやり直す必要があるときに選びます。
## Output
以下に判断基準を示します。
"""


Evaluation:
""",
class MyToTChecker(ToTChecker):
llm: Optional[BaseMemory] = None
prompt: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts", "support_result"],
template=checker_prompt_ja,
)
prompt_support: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts"],
template=checker_support_prompt_ja,
)
nlp: Language = spacy.load("en_core_web_md")

def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:
print("thoughts=", thoughts)
evaluation = self.prompt | self.llm | StrOutputParser()
llm_output = evaluation.invoke({"problem_description": problem_description, "thoughts": thoughts})
thoughts = self.pre_evaluate(thoughts)
# print("MyToTChecker thoughts=", thoughts)

print("llm_output=", llm_output)
final_judge = self.judge_llm_output(llm_output)
print("final_judge=", final_judge)
support_result = self.prompt_support | self.llm | StrOutputParser()
support_result = support_result.invoke({"problem_description": problem_description, "thoughts": thoughts})
# print("MyToTChecker support_result=", support_result)

evaluation = self.prompt | self.llm | StrOutputParser()
# print(
# "prompt_item=",
# self.prompt.invoke(
# {"problem_description": problem_description, "thoughts": thoughts, "support_result": support_result}
# ),
# )
evaluation_output = evaluation.invoke(
{"problem_description": problem_description, "thoughts": thoughts, "support_result": support_result}
)
print("MyToTChecker evaluation_output=", evaluation_output)
final_judge = self.judge_llm_output(evaluation_output)
print("MyToTChecker final_judge=", final_judge)
return final_judge

def pre_evaluate(self, thoughts: Tuple[str, ...]) -> str:
thoughts_str = ""
for i, thought in enumerate(thoughts):
thoughts_str += f"thought_{i}: "
thoughts_str += thought
thoughts_str += "\n"
return thoughts_str

def judge_llm_output(self, llm_output) -> ThoughtValidity:
llm_output_1st_line = llm_output.split("\n")[0]
thought_validity_candidates = ["VALID_FINAL", "VALID_INTERMEDIATE", "INVALID"]
Expand All @@ -83,11 +208,11 @@ def judge_llm_output(self, llm_output) -> ThoughtValidity:
actual = self.nlp(llm_output_1st_line)
options_nlp = ["FINAL", "INTERMEDIATE", "INVALID"]
similarities = [actual.similarity(self.nlp(option)) for option in options_nlp]
print("similarities=", similarities)
print("MyToTChecker similarities=", similarities)
best_match_index = similarities.index(max(similarities))
best_match = thought_validity_candidates[best_match_index]

print(f"Best match: {best_match} with similarity {similarities[best_match_index]}")
print(f"MyToTChecker Best match: {best_match} with similarity {similarities[best_match_index]}")
return self.get_thought_validity(best_match)

def get_thought_validity(self, thought_validity) -> ThoughtValidity:
Expand All @@ -106,15 +231,15 @@ def test_checker():
tot_chain = create_tot_chain_from_llm(prepare_test_llm())
checker = tot_chain.checker
assert (
checker.evaluate(sudoku_problem_description, ("3,*,1,2|1,*,3,*|*,1,*,3|4,*,*,1",))
checker.evaluate(sudoku_problem_description, ("3,x,1,2|1,x,3,x|x,1,x,3|4,x,x,1",))
== ThoughtValidity.VALID_INTERMEDIATE
)
assert (
checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",))
== ThoughtValidity.VALID_FINAL
)
assert (
checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",))
checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,x,1",))
== ThoughtValidity.VALID_INTERMEDIATE
)
assert checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,2,3,1",)) == ThoughtValidity.INVALID
Expand Down Expand Up @@ -143,7 +268,7 @@ def create_tot_chain_from_llm(llm=None, is_ja=True, is_simple=False):
else:
tot_strategy_class = MyProposePromptStrategy

tot_chain = ToTChain.from_llm(
tot_chain = MyToTChain.from_llm(
llm=llm,
checker=checker,
k=20,
Expand All @@ -156,20 +281,35 @@ def create_tot_chain_from_llm(llm=None, is_ja=True, is_simple=False):


def prepare_test_llm():
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore

model = "gemini-1.5-flash-latest"
llm = ChatGoogleGenerativeAI(
model=model,
temperature=0,
google_api_key=os.environ.get("GEMINI_API_KEY"),
max_output_tokens=1024 * 4,
)
# model = "gemini-1.5-flash-latest"
# model = "gemini-1.5-pro-latest"
model = "gemini-1.0-pro"
# model = "claude-3-haiku-20240307"
# model = "claude-3-sonnet-20240229"
if "gemini" in model:
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore

llm = ChatGoogleGenerativeAI(
model=model,
temperature=0.1,
google_api_key=os.environ.get("GEMINI_API_KEY"),
max_output_tokens=1024 * 4,
)
else:
from langchain_anthropic import ChatAnthropic # type: ignore

llm = ChatAnthropic(
model_name=model,
temperature=0.1,
anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY"),
max_tokens=1024 * 4,
)

return llm


def test_create():
tot_chain = create_tot_chain_from_llm(llm=prepare_test_llm(), is_simple=True)
tot_chain = create_tot_chain_from_llm(llm=prepare_test_llm(), is_simple=False)
tot_chain.run(problem_description=sudoku_problem_description)


Expand Down
Loading

0 comments on commit 16a30a6

Please sign in to comment.