forked from shroominic/codeinterpreter-api
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
82 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
import re | ||
from typing import Tuple | ||
|
||
from langchain_experimental.tot.base import ToTChain | ||
from langchain_experimental.tot.checker import ToTChecker | ||
from langchain_experimental.tot.thought import ThoughtValidity | ||
|
||
sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1" | ||
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1" | ||
problem_description = f""" | ||
{sudoku_puzzle} | ||
- This is a 4x4 Sudoku puzzle. | ||
- The * represents a cell to be filled. | ||
- The | character separates rows. | ||
- At each step, replace one or more * with digits 1-4. | ||
- There must be no duplicate digits in any row, column or 2x2 subgrid. | ||
- Keep the known digits from previous valid thoughts in place. | ||
- Each thought can be a partial or the final solution. | ||
""".strip() | ||
print(problem_description) | ||
|
||
####### | ||
# The following code implement a simple rule based checker for | ||
# a specific 4x4 sudoku puzzle. | ||
####### | ||
|
||
|
||
class MyChecker(ToTChecker): | ||
def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity: | ||
last_thought = thoughts[-1] | ||
clean_solution = last_thought.replace(" ", "").replace('"', "") | ||
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|") | ||
if sudoku_solution in clean_solution: | ||
return ThoughtValidity.VALID_FINAL | ||
elif re.search(regex_solution, sudoku_solution): | ||
return ThoughtValidity.VALID_INTERMEDIATE | ||
else: | ||
return ThoughtValidity.INVALID | ||
|
||
|
||
####### | ||
# Testing the MyChecker class above: | ||
####### | ||
def test_checker(): | ||
checker = MyChecker() | ||
assert checker.evaluate("", ("3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE | ||
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1",)) == ThoughtValidity.VALID_FINAL | ||
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1",)) == ThoughtValidity.VALID_INTERMEDIATE | ||
assert checker.evaluate("", ("3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1",)) == ThoughtValidity.INVALID | ||
|
||
|
||
####### | ||
# Initialize and run the ToT chain, | ||
# with maximum number of interactions k set to 30 and | ||
# the maximum number child thoughts c set to 8. | ||
####### | ||
|
||
|
||
def create(llm): | ||
tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False) | ||
tot_chain.run(problem_description=problem_description) | ||
return tot_chain | ||
|
||
|
||
def test_create(): | ||
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore | ||
|
||
model = "gemini-1.5-flash-latest" | ||
llm = ChatGoogleGenerativeAI( | ||
model=model, | ||
temperature=0, | ||
google_api_key=os.environ.get("GEMINI_API_KEY"), | ||
max_output_tokens=1024, | ||
) | ||
create(llm) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_checker() | ||
test_create() |