Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbergeron committed May 6, 2024
1 parent 5df4ccf commit f540014
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 118 deletions.
Binary file modified NumGI/.DS_Store
Binary file not shown.
4 changes: 2 additions & 2 deletions NumGI/LoadTokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class LoadTokenizer(DatasetTokenizer):
"""The tokenizer used when loading data from files."""

def __init__(self, x_files, y_files):
def __init__(self, x_files, y_files, useDefaultTokenizer=True):
default_tokenized_x = []
default_tokenized_y = []

Expand Down Expand Up @@ -37,4 +37,4 @@ def __init__(self, x_files, y_files):
new_x = [tempTokenizer.tokens_to_list(i) for i in default_combined_x_torch.tolist()]
new_y = [tempTokenizer.tokens_to_list(i) for i in default_combined_y_torch.tolist()]

super().__init__(new_x, new_y, useDefaultTokenizer=False, isSympy=False)
super().__init__(new_x, new_y, useDefaultTokenizer=useDefaultTokenizer, isSympy=False)
48 changes: 30 additions & 18 deletions NumGI/Loss/LossDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,28 @@ def create_var_dict(self):
sol = self.eq_dataset.tokens_to_sympy(eq)
self.solutions.append(sol)
if frozenset(sol.free_symbols) not in var_dict:
var_dict[frozenset(sol.free_symbols)] = [[sol, i]]
if len(sol.free_symbols) < 5:
var_dict[frozenset(sol.free_symbols)] = [[sol, i]]
else:
var_dict[frozenset(sol.free_symbols)].append([sol, i])
return var_dict

def calculate_n_pairwise_loss(self, N, ell_norm):
loss = torch.zeros((3, N))
possible_symbols = self.var_dict.keys()
possible_symbols = [i for i in possible_symbols if len(i) >= 1]

possible_symbols = [i for i in possible_symbols if len(self.var_dict[i]) > 1]
max_len = 0
for i in possible_symbols:
if len(i) > max_len:
max_len = len(i)

first_batch = int(0.95 * N)
second_batch = N - first_batch
self.generate_grids(max_len)

first_batch = N # int(0.95 * N)
for i in range(first_batch):
chosen_symbols = random.choice(list(possible_symbols))
if len(self.var_dict[chosen_symbols]) < 1:
if len(self.var_dict[chosen_symbols]) <= 1:
continue
possible_equations = {i[1] for i in self.var_dict[chosen_symbols]}

Expand All @@ -64,22 +70,22 @@ def calculate_n_pairwise_loss(self, N, ell_norm):
if integral is None:
continue

loss[0, i] = sol_sympy_1[1]
loss[1, i] = sol_sympy_2[1]
integral_val = integral.item()
if np.abs(integral_val) < self.max_integral_value:
loss[0, i] = sol_sympy_1[1]
loss[1, i] = sol_sympy_2[1]
loss[2, i] = integral.item()
else:
loss[2, i] = np.sign(integral_val) * self.max_integral_value
continue

for i in range(second_batch):
chosen_symbols = random.sample(possible_symbols, 2)
sol_sympy_1 = random.choice(self.var_dict[chosen_symbols[0]])
sol_sympy_2 = random.choice(self.var_dict[chosen_symbols[1]])
# for i in range(second_batch):
# chosen_symbols = random.sample(possible_symbols, 2)
# sol_sympy_1 = random.choice(self.var_dict[chosen_symbols[0]])
# sol_sympy_2 = random.choice(self.var_dict[chosen_symbols[1]])

loss[0, i] = sol_sympy_1[1]
loss[1, i] = sol_sympy_2[1]
loss[2, i] = torch.inf
# loss[0, i] = sol_sympy_1[1]
# loss[1, i] = sol_sympy_2[1]
# loss[2, i] = torch.inf

self.loss = loss

Expand All @@ -88,7 +94,7 @@ def compute_integral(self, sympy_eq):

if len(symbols) < 1:
return torch.tensor(torch.nan)
grids = self.create_discrete_grids(symbols)
grids = self.grids[len(symbols) - 1]
_arg = {sym: _grid for sym, _grid in zip(symbols, grids)}
try:
complex_result = func(**_arg)
Expand All @@ -102,12 +108,18 @@ def compute_integral(self, sympy_eq):
del grids
return result

def generate_grids(self, N_grids):
self.grids = []
for i in range(N_grids):
symbols = [i for i in range(i + 1)]
self.grids.append(self.create_discrete_grids(symbols))

def create_discrete_grids(self, symbols):
grid_low, grid_high, num_grid = self.grid_size
# scale grid down with dimesion
num_grid = int(num_grid * np.exp(0.75 * (1 - len(symbols))))
num_grid = int(num_grid * np.exp(0.95 * (1 - len(symbols))))
grid_real = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
grid_im = 1j * torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
grid_im = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device) * 1j
grid = grid_real[:, None] + grid_im[None, :]
grids = [grid.flatten() for i in symbols]
mesh = torch.meshgrid(grids)
Expand Down
90 changes: 0 additions & 90 deletions NumGI/ParallelEquationGenerator.py

This file was deleted.

16 changes: 8 additions & 8 deletions NumGI/SolutionGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def generate_solution_dataset(
sol, used_vars = self.generate_solution(num_ops_sol, vars, funcs, ops)
equation = self.generate_equation(used_vars, ops_eq, ops, sol)

func_sol = sp.Function("f")(*[sp.Symbol(var) for var in used_vars])
func_sol = sp.Function("f")(*[sp.Symbol(var) for var in ["x"]])
sol_eq = sp.Eq(func_sol, sol)
dataset.append((sol_eq, equation))
return dataset
Expand Down Expand Up @@ -99,14 +99,14 @@ def choose_variable(self, new_vars: list | None, used_vars: list | None):
Probabilities for when to choose which need to be initialized somewhere.
"""
if used_vars is None or len(used_vars) <= 0 or random.random() < self.PROB_NEW_SYMBOL:
var = self.pop_random(new_vars)
used_vars.append(var)
return sp.symbols(var)
return sp.symbols(random.choice(used_vars))
# if used_vars is None or len(used_vars) <= 0 or random.random() < self.PROB_NEW_SYMBOL:
# var = self.pop_random(new_vars)
# used_vars.append(var)
return sp.symbols("x")
# return sp.symbols(random.choice(used_vars))

def choose_used_variable(self, used_vars: list):
return sp.symbols(random.choice(used_vars))
return sp.symbols("x") # sp.symbols(random.choice(used_vars))

def choose_operation(
self,
Expand All @@ -133,7 +133,7 @@ def tree_to_equation(
"""Converts a tree to a sympy equation."""
root = tree.root

vars = [sp.Symbol(var) for var in used_vars]
vars = [sp.Symbol(var) for var in ["x"]]
func = sp.Function("f")(*vars)
try:
expression = self.tree_to_eq_helper(root, sol, used_vars)
Expand Down

0 comments on commit f540014

Please sign in to comment.