Skip to content

Commit

Permalink
new dataset gen test
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbergeron committed May 6, 2024
1 parent d1bef13 commit 3c5f8ab
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/system_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os: [ubuntu-latest, windows-latest] # macos-latest]
python-version: ['3.11']

runs-on: ${{ matrix.os }}
Expand Down
86 changes: 60 additions & 26 deletions NumGI/SolutionGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,34 +27,71 @@ def __init__(
self.USED_VARS = []

def generate_solution_dataset(
self, ops_sol: tuple, ops_eq: tuple, num_eqs: int, vars: list, funcs: list, ops: list
self,
num_ops_range_sol: tuple,
num_ops_range_eq: tuple,
num_eqs: int,
vars: list,
funcs: list,
ops: list,
verbose=False,
) -> list:
"""Call to generate dataset of equations."""
"""Call to generate dataset of equations.
Args:
num_ops_range_sol (tuple): Range of number of operations for solution.
num_ops_range_eq (tuple): Range of number of operations for equation.
num_eqs (int): Number of equations to generate.
vars (list): List of variables to use.
funcs (list): List of functions to use.
ops (list): List of operations to use.
verbose (bool, optional): Whether to print progress. Defaults to False.
Returns:
list: List of tuples of format (solution, differential equations).
"""
dataset = []
for _ in range(num_eqs):
# if _ % 1_0 == 0:
# print(f"Generating equation {_} of {num_eqs}")
num_ops_sol = random.randint(ops_sol[0], ops_sol[1])
if verbose and (_ % 1_0 == 0):
print(f"Generating equation {_} of {num_eqs}")

# pick a random number of operations for the solution and generate it
num_ops_sol = random.randint(num_ops_range_sol[0], num_ops_range_sol[1])
sol, used_vars = self.generate_solution(num_ops_sol, vars, funcs, ops)
equation = self.generate_equation(used_vars, ops_eq, ops, sol)
sol = sol.doit() # might be problematic

actually_used_vars = [_symbol.name for _symbol in sol.free_symbols]
if len(actually_used_vars) == 0:
continue

func_sol = sp.Function("f")(*[sp.Symbol(var) for var in ["x"]])
# format the solution as a sympy equation and add it to the dataset
func_sol = sp.Function("f")(*[sp.Symbol(var) for var in actually_used_vars])
sol_eq = sp.Eq(func_sol, sol)

# generate an equation from the created solution
equation = self.generate_equation(actually_used_vars, num_ops_range_eq, ops, sol_eq)

dataset.append((sol_eq, equation))

return dataset

def generate_equation(self, used_vars, ops_eqs, ops, sol):
def generate_equation(self, used_vars, num_ops_range_eqs, ops, sol):
"""Generate an equation from a solution."""
tree = self.generate_equation_tree(random.randint(ops_eqs[0], ops_eqs[1]), ops)
tree = self.generate_equation_tree(
random.randint(num_ops_range_eqs[0], num_ops_range_eqs[1]), ops
)
return self.tree_to_equation(tree, sol, used_vars)

def generate_solution(self, num_ops: int, vars: list, funcs: list, ops: list):
"""Generate a list of solution equations with a specific number of operations."""
"""Generate a list of solution equations with a specific number of operations.
Do we want to use trees for this too?
"""
used_vars = []
new_vars = vars.copy()
for i in range(num_ops):
for ops_number in range(num_ops):
op = self.choose_operation(ops)
if i == 0:
if ops_number == 0:
var = self.choose_variable(new_vars, used_vars)
f = self.choose_function(funcs)
f1 = f(var)
Expand Down Expand Up @@ -99,14 +136,14 @@ def choose_variable(self, new_vars: list | None, used_vars: list | None):
Probabilities for when to choose which need to be initialized somewhere.
"""
# if used_vars is None or len(used_vars) <= 0 or random.random() < self.PROB_NEW_SYMBOL:
# var = self.pop_random(new_vars)
# used_vars.append(var)
return sp.symbols("x")
# return sp.symbols(random.choice(used_vars))
if used_vars is None or len(used_vars) <= 0 or random.random() < self.PROB_NEW_SYMBOL:
var = self.pop_random(new_vars)
used_vars.append(var)

return sp.Symbol(random.choice(used_vars))

def choose_used_variable(self, used_vars: list):
return sp.symbols("x") # sp.symbols(random.choice(used_vars))
return sp.Symbol(random.choice(used_vars))

def choose_operation(
self,
Expand All @@ -132,14 +169,11 @@ def tree_to_equation(
):
"""Converts a tree to a sympy equation."""
root = tree.root

vars = [sp.Symbol(var) for var in ["x"]]
func = sp.Function("f")(*vars)
try:
expression = self.tree_to_eq_helper(root, sol, used_vars)
rhs = expression.doit()
expression = self.tree_to_eq_helper(root, sol.lhs, used_vars)
replaced_exp = expression.subs(sol.lhs, sol.rhs)
rhs = replaced_exp.doit()
equation = sp.Eq(expression, rhs)
equation = equation.replace(sol, func)
except ValueError:
print(expression)
return equation
Expand Down Expand Up @@ -312,8 +346,8 @@ def choose_op_noarithmetic(self, ops: list):
if __name__ == "__main__":
sg = SolutionGenerator()
eqs = sg.generate_solution_dataset(
ops_sol=(3, 5),
ops_eq=(2, 5),
num_ops_range_sol=(3, 5),
num_ops_range_eq=(2, 5),
num_eqs=1000,
vars=VARIABLES,
funcs=DIFFERENTIAL_FUNCTIONS,
Expand Down
37 changes: 37 additions & 0 deletions test/EquationTests/test_EquationGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import pytest

from NumGI.ConstantDictionaries import DIFFERENTIAL_FUNCTIONS
from NumGI.ConstantDictionaries import OPERATIONS
from NumGI.ConstantDictionaries import VARIABLES
from NumGI.SolutionGenerator import SolutionGenerator


Expand Down Expand Up @@ -34,3 +37,37 @@ def test_insert_tree(
assert tree.level == 2
assert tree.root.right.right.level == 2
assert tree.root.right.left.level == 2


def test_not_always_true():
sg = SolutionGenerator()
num_eqs = 200
dataset = sg.generate_solution_dataset(
num_ops_range_sol=(3, 5),
num_ops_range_eq=(4, 6),
num_eqs=num_eqs,
vars=VARIABLES,
funcs=DIFFERENTIAL_FUNCTIONS,
ops=OPERATIONS,
)

always_true_cnt = 0
is_sol_cnt = 0

for eq in dataset:
replaced_eq = eq[1].replace(eq[0].lhs, eq[0].rhs)

try:
if (replaced_eq).doit():
pass
except Exception:
is_sol_cnt += 1

try:
if eq[1].doit():
always_true_cnt += 1
except Exception:
pass

assert always_true_cnt / num_eqs < 0.2
assert is_sol_cnt / num_eqs < 0.1

0 comments on commit 3c5f8ab

Please sign in to comment.