Skip to content

Commit

Permalink
fix: update thoughts/tot.py for spacy judge logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 25, 2024
1 parent 2c2b8d4 commit 01b35d2
Showing 1 changed file with 66 additions and 28 deletions.
94 changes: 66 additions & 28 deletions src/codeinterpreterapi/thoughts/tot.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import os
from typing import Optional, Tuple

import spacy
from langchain.prompts import PromptTemplate
from langchain.schema import BaseMemory, HumanMessage
from langchain.schema import BaseMemory
from langchain_core.output_parsers import StrOutputParser
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.thought import ThoughtValidity
from spacy import Language

sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
problem_description = f"""
sudoku_problem_description = f"""
{sudoku_puzzle}
- This is a 4x4 Sudoku puzzle.
- The * represents a cell to be filled.
- The | character separates rows.
- At each step, replace one or more * with digits 1-4.
- There must be no duplicate digits in any row, column or 2x2 subgrid.
- |で区切られた4つの数字は、first row, second row, third row, 4th row をそれぞれ表します。
- 各rowとcolumnの4つの数字は重複していけません。例えば1,2,3,4や1,2,4,3はOKですが、1,2,2,3はNGです。
- さらに、2x2 のサブグリッド(全4個)でも4つの数字は重複してはいけません。
- Keep the known digits from previous valid thoughts in place.
- Each thought can be a partial or the final solution.
""".strip()
print(problem_description)

#######
# The following code implements an LLM-based checker
Expand All @@ -32,33 +36,58 @@ class MyChecker(ToTChecker):
prompt: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts"],
template="""
Given the following problem description and a series of thoughts, evaluate the validity of the last thought.
次の Problem Description に示す問題を解決するための thoughts について、
[VALID_FINAL|VALID_INTERMEDIATE|INVALID]のどれか1つだけを選んで1行目に回答してください。
そして、選択した理由を2行目以降でできるだけ詳しく説明してください。
Problem Description:
Problem Description: 解決するべき問題です。
{problem_description}
Thoughts:
Thoughts: 解決の手続きについての思考です。
{thoughts}
Evaluate the last thought and return one of the following:
- VALID_FINAL if the last thought is a valid final solution to the problem.
- VALID_INTERMEDIATE if the last thought is a valid intermediate step towards the solution.
- INVALID if the last thought is invalid or contradicts the problem description.
正しい回答を選ぶために下記のガイドラインに従ってください。
- VALID_FINAL: 最終的な thoughts が問題の解決に最適であると確認したときに選びます。
- VALID_INTERMEDIATE: 最終的な thoughts が問題の解決に適しているが、解決には思考をさらに進める必要があるときに選びます。
- INVALID: 最終的な thoughts が問題を解決できないとき、内容がルールに違反していたとき、明らな間違いがあり思考をやり直す必要があるときに選びます。
Evaluation:
""",
)
nlp: Language = spacy.load("en_core_web_md")

def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:
prompt = self.prompt.format(problem_description=problem_description, thoughts="\n".join(thoughts))
message = HumanMessage(content=prompt)
evaluation = self.llm([message]).content.strip().upper()

print("evaluation=", evaluation)

if evaluation == "VALID_FINAL":
print("thoughts=", thoughts)
evaluation = self.prompt | self.llm | StrOutputParser()
llm_output = evaluation.invoke({"problem_description": problem_description, "thoughts": thoughts})

print("llm_output=", llm_output)
final_judge = self.judge_llm_output(llm_output)
print("final_judge=", final_judge)
return final_judge

def judge_llm_output(self, llm_output) -> ThoughtValidity:
llm_output_1st_line = llm_output.split("\n")[0]
thought_validity_candidates = ["VALID_FINAL", "VALID_INTERMEDIATE", "INVALID"]
for thought_validity in thought_validity_candidates:
if thought_validity in llm_output_1st_line:
return self.get_thought_validity(thought_validity)

# nlp judge
actual = self.nlp(llm_output_1st_line)
options_nlp = ["FINAL", "INTERMEDIATE", "INVALID"]
similarities = [actual.similarity(self.nlp(option)) for option in options_nlp]
print("similarities=", similarities)
best_match_index = similarities.index(max(similarities))
best_match = thought_validity_candidates[best_match_index]

print(f"Best match: {best_match} with similarity {similarities[best_match_index]}")
return self.get_thought_validity(best_match)

def get_thought_validity(self, thought_validity) -> ThoughtValidity:
if thought_validity == "VALID_FINAL":
return ThoughtValidity.VALID_FINAL
elif evaluation == "VALID_INTERMEDIATE":
elif thought_validity == "VALID_INTERMEDIATE":
return ThoughtValidity.VALID_INTERMEDIATE
else:
return ThoughtValidity.INVALID
Expand All @@ -67,18 +96,22 @@ def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) ->
#######
# Testing the MyChecker class above:
#######
def test_checker(tot_chain):
def test_checker():
tot_chain = create_tot_chain_from_llm(prepare_test_llm())
checker = tot_chain.checker
assert (
checker.evaluate(problem_description, ("3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1",))
checker.evaluate(sudoku_problem_description, ("3,*,1,2|1,*,3,*|*,1,*,3|4,*,*,1",))
== ThoughtValidity.VALID_INTERMEDIATE
)
assert checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",)) == ThoughtValidity.VALID_FINAL
assert (
checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",))
checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",))
== ThoughtValidity.VALID_FINAL
)
assert (
checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",))
== ThoughtValidity.VALID_INTERMEDIATE
)
assert checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1",)) == ThoughtValidity.INVALID
assert checker.evaluate(sudoku_problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,2,3,1",)) == ThoughtValidity.INVALID


#######
Expand All @@ -88,15 +121,14 @@ def test_checker(tot_chain):
#######


def create(llm):
def create_tot_chain_from_llm(llm):
checker = MyChecker()
checker.llm = llm
tot_chain = ToTChain.from_llm(llm=llm, checker=checker, k=30, c=5, verbose=True, verbose_llm=False)
tot_chain.run(problem_description=problem_description)
return tot_chain


def test_create():
def prepare_test_llm():
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore

model = "gemini-1.5-flash-latest"
Expand All @@ -106,8 +138,14 @@ def test_create():
google_api_key=os.environ.get("GEMINI_API_KEY"),
max_output_tokens=1024,
)
create(llm)
return llm


def test_create():
tot_chain = create_tot_chain_from_llm(prepare_test_llm())
tot_chain.run(problem_description=sudoku_problem_description)


if __name__ == "__main__":
test_checker()
test_create()

0 comments on commit 01b35d2

Please sign in to comment.