From 2c2b8d4de0d8516a6c7a5776501a10b56f1ec387 Mon Sep 17 00:00:00 2001 From: jinno Date: Sat, 25 May 2024 15:53:02 +0900 Subject: [PATCH] fix: update thoughts/tot.py for llm ToTChecker --- src/codeinterpreterapi/thoughts/tot.py | 67 +++++++++++++++++++------- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/src/codeinterpreterapi/thoughts/tot.py b/src/codeinterpreterapi/thoughts/tot.py index 47dd2b0f..c9be1594 100644 --- a/src/codeinterpreterapi/thoughts/tot.py +++ b/src/codeinterpreterapi/thoughts/tot.py @@ -1,7 +1,8 @@ import os -import re -from typing import Tuple +from typing import Optional, Tuple +from langchain.prompts import PromptTemplate +from langchain.schema import BaseMemory, HumanMessage from langchain_experimental.tot.base import ToTChain from langchain_experimental.tot.checker import ToTChecker from langchain_experimental.tot.thought import ThoughtValidity @@ -22,19 +23,42 @@ print(problem_description) ####### -# The following code implement a simple rule based checker for -# a specific 4x4 sudoku puzzle. +# The following code implements an LLM-based checker ####### class MyChecker(ToTChecker): + llm: Optional[BaseMemory] = None + prompt: PromptTemplate = PromptTemplate( + input_variables=["problem_description", "thoughts"], + template=""" + Given the following problem description and a series of thoughts, evaluate the validity of the last thought. + + Problem Description: + {problem_description} + + Thoughts: + {thoughts} + + Evaluate the last thought and return one of the following: + - VALID_FINAL if the last thought is a valid final solution to the problem. + - VALID_INTERMEDIATE if the last thought is a valid intermediate step towards the solution. + - INVALID if the last thought is invalid or contradicts the problem description. + + Evaluation: + """, + ) + def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity: - last_thought = thoughts[-1] - clean_solution = last_thought.replace(" ", "").replace('"', "") - regex_solution = clean_solution.replace("*", ".").replace("|", "\\|") - if sudoku_solution in clean_solution: + prompt = self.prompt.format(problem_description=problem_description, thoughts="\n".join(thoughts)) + message = HumanMessage(content=prompt) + evaluation = self.llm([message]).content.strip().upper() + + print("evaluation=", evaluation) + + if evaluation == "VALID_FINAL": return ThoughtValidity.VALID_FINAL - elif re.search(regex_solution, sudoku_solution): + elif evaluation == "VALID_INTERMEDIATE": return ThoughtValidity.VALID_INTERMEDIATE else: return ThoughtValidity.INVALID @@ -43,23 +67,31 @@ def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ####### # Testing the MyChecker class above: ####### -def test_checker(): - checker = MyChecker() - assert checker.evaluate("", ("3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE - assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",)) == ThoughtValidity.VALID_FINAL - assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE - assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1",)) == ThoughtValidity.INVALID +def test_checker(tot_chain): + checker = tot_chain.checker + assert ( + checker.evaluate(problem_description, ("3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1",)) + == ThoughtValidity.VALID_INTERMEDIATE + ) + assert checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",)) == ThoughtValidity.VALID_FINAL + assert ( + checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",)) + == ThoughtValidity.VALID_INTERMEDIATE + ) + assert checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1",)) == ThoughtValidity.INVALID ####### # Initialize and run the ToT chain, # with maximum number of interactions k set to 30 and -# the maximum number child thoughts c set to 8. +# the maximum number of child thoughts c set to 8. ####### def create(llm): - tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False) + checker = MyChecker() + checker.llm = llm + tot_chain = ToTChain.from_llm(llm=llm, checker=checker, k=30, c=5, verbose=True, verbose_llm=False) tot_chain.run(problem_description=problem_description) return tot_chain @@ -78,5 +110,4 @@ def test_create(): if __name__ == "__main__": - test_checker() test_create()