From 2207ead10332d11d1566c47ef93314f467519b45 Mon Sep 17 00:00:00 2001 From: jinno Date: Sat, 25 May 2024 14:51:26 +0900 Subject: [PATCH] fix: add sample thoughts/tot.py --- src/codeinterpreterapi/thoughts/tot.py | 82 ++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 src/codeinterpreterapi/thoughts/tot.py diff --git a/src/codeinterpreterapi/thoughts/tot.py b/src/codeinterpreterapi/thoughts/tot.py new file mode 100644 index 00000000..47dd2b0f --- /dev/null +++ b/src/codeinterpreterapi/thoughts/tot.py @@ -0,0 +1,82 @@ +import os +import re +from typing import Tuple + +from langchain_experimental.tot.base import ToTChain +from langchain_experimental.tot.checker import ToTChecker +from langchain_experimental.tot.thought import ThoughtValidity + +sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1" +sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1" +problem_description = f""" +{sudoku_puzzle} + +- This is a 4x4 Sudoku puzzle. +- The * represents a cell to be filled. +- The | character separates rows. +- At each step, replace one or more * with digits 1-4. +- There must be no duplicate digits in any row, column or 2x2 subgrid. +- Keep the known digits from previous valid thoughts in place. +- Each thought can be a partial or the final solution. +""".strip() +print(problem_description) + +####### +# The following code implement a simple rule based checker for +# a specific 4x4 sudoku puzzle. +####### + + +class MyChecker(ToTChecker): + def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity: + last_thought = thoughts[-1] + clean_solution = last_thought.replace(" ", "").replace('"', "") + regex_solution = clean_solution.replace("*", ".").replace("|", "\\|") + if sudoku_solution in clean_solution: + return ThoughtValidity.VALID_FINAL + elif re.search(regex_solution, sudoku_solution): + return ThoughtValidity.VALID_INTERMEDIATE + else: + return ThoughtValidity.INVALID + + +####### +# Testing the MyChecker class above: +####### +def test_checker(): + checker = MyChecker() + assert checker.evaluate("", ("3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE + assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",)) == ThoughtValidity.VALID_FINAL + assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE + assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1",)) == ThoughtValidity.INVALID + + +####### +# Initialize and run the ToT chain, +# with maximum number of interactions k set to 30 and +# the maximum number child thoughts c set to 8. +####### + + +def create(llm): + tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False) + tot_chain.run(problem_description=problem_description) + return tot_chain + + +def test_create(): + from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore + + model = "gemini-1.5-flash-latest" + llm = ChatGoogleGenerativeAI( + model=model, + temperature=0, + google_api_key=os.environ.get("GEMINI_API_KEY"), + max_output_tokens=1024, + ) + create(llm) + + +if __name__ == "__main__": + test_checker() + test_create()