Skip to content

Commit

Permalink
fix: update thoughts/tot.py for llm ToTChecker
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 25, 2024
1 parent 2207ead commit 2c2b8d4
Showing 1 changed file with 49 additions and 18 deletions.
67 changes: 49 additions & 18 deletions src/codeinterpreterapi/thoughts/tot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import re
from typing import Tuple
from typing import Optional, Tuple

from langchain.prompts import PromptTemplate
from langchain.schema import BaseMemory, HumanMessage
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.thought import ThoughtValidity
Expand All @@ -22,19 +23,42 @@
print(problem_description)

#######
# The following code implement a simple rule based checker for
# a specific 4x4 sudoku puzzle.
# The following code implements an LLM-based checker
#######


class MyChecker(ToTChecker):
llm: Optional[BaseMemory] = None
prompt: PromptTemplate = PromptTemplate(
input_variables=["problem_description", "thoughts"],
template="""
Given the following problem description and a series of thoughts, evaluate the validity of the last thought.
Problem Description:
{problem_description}
Thoughts:
{thoughts}
Evaluate the last thought and return one of the following:
- VALID_FINAL if the last thought is a valid final solution to the problem.
- VALID_INTERMEDIATE if the last thought is a valid intermediate step towards the solution.
- INVALID if the last thought is invalid or contradicts the problem description.
Evaluation:
""",
)

def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:
last_thought = thoughts[-1]
clean_solution = last_thought.replace(" ", "").replace('"', "")
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
if sudoku_solution in clean_solution:
prompt = self.prompt.format(problem_description=problem_description, thoughts="\n".join(thoughts))
message = HumanMessage(content=prompt)
evaluation = self.llm([message]).content.strip().upper()

print("evaluation=", evaluation)

if evaluation == "VALID_FINAL":
return ThoughtValidity.VALID_FINAL
elif re.search(regex_solution, sudoku_solution):
elif evaluation == "VALID_INTERMEDIATE":
return ThoughtValidity.VALID_INTERMEDIATE
else:
return ThoughtValidity.INVALID
Expand All @@ -43,23 +67,31 @@ def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) ->
#######
# Testing the MyChecker class above:
#######
def test_checker():
checker = MyChecker()
assert checker.evaluate("", ("3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",)) == ThoughtValidity.VALID_FINAL
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1",)) == ThoughtValidity.INVALID
def test_checker(tot_chain):
checker = tot_chain.checker
assert (
checker.evaluate(problem_description, ("3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1",))
== ThoughtValidity.VALID_INTERMEDIATE
)
assert checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",)) == ThoughtValidity.VALID_FINAL
assert (
checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",))
== ThoughtValidity.VALID_INTERMEDIATE
)
assert checker.evaluate(problem_description, ("3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1",)) == ThoughtValidity.INVALID


#######
# Initialize and run the ToT chain,
# with maximum number of interactions k set to 30 and
# the maximum number child thoughts c set to 8.
# the maximum number of child thoughts c set to 8.
#######


def create(llm):
tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False)
checker = MyChecker()
checker.llm = llm
tot_chain = ToTChain.from_llm(llm=llm, checker=checker, k=30, c=5, verbose=True, verbose_llm=False)
tot_chain.run(problem_description=problem_description)
return tot_chain

Expand All @@ -78,5 +110,4 @@ def test_create():


if __name__ == "__main__":
test_checker()
test_create()

0 comments on commit 2c2b8d4

Please sign in to comment.