Skip to content

Commit

Permalink
fix: add sample thoughts/tot.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed May 25, 2024
1 parent f8c5fff commit 2207ead
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions src/codeinterpreterapi/thoughts/tot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import re
from typing import Tuple

from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.thought import ThoughtValidity

sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
problem_description = f"""
{sudoku_puzzle}
- This is a 4x4 Sudoku puzzle.
- The * represents a cell to be filled.
- The | character separates rows.
- At each step, replace one or more * with digits 1-4.
- There must be no duplicate digits in any row, column or 2x2 subgrid.
- Keep the known digits from previous valid thoughts in place.
- Each thought can be a partial or the final solution.
""".strip()
print(problem_description)

#######
# The following code implement a simple rule based checker for
# a specific 4x4 sudoku puzzle.
#######


class MyChecker(ToTChecker):
def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:
last_thought = thoughts[-1]
clean_solution = last_thought.replace(" ", "").replace('"', "")
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
if sudoku_solution in clean_solution:
return ThoughtValidity.VALID_FINAL
elif re.search(regex_solution, sudoku_solution):
return ThoughtValidity.VALID_INTERMEDIATE
else:
return ThoughtValidity.INVALID


#######
# Testing the MyChecker class above:
#######
def test_checker():
checker = MyChecker()
assert checker.evaluate("", ("3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",)) == ThoughtValidity.VALID_FINAL
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1",)) == ThoughtValidity.INVALID


#######
# Initialize and run the ToT chain,
# with maximum number of interactions k set to 30 and
# the maximum number child thoughts c set to 8.
#######


def create(llm):
tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False)
tot_chain.run(problem_description=problem_description)
return tot_chain


def test_create():
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore

model = "gemini-1.5-flash-latest"
llm = ChatGoogleGenerativeAI(
model=model,
temperature=0,
google_api_key=os.environ.get("GEMINI_API_KEY"),
max_output_tokens=1024,
)
create(llm)


if __name__ == "__main__":
test_checker()
test_create()

0 comments on commit 2207ead

Please sign in to comment.