Skip to content

Commit

Permalink
fixed integration
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbergeron committed Feb 2, 2024
1 parent 01b6ab0 commit d94667e
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions NumGI/Loss/LossDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class LossDataset:
def __init__(self, eq_dataset: DatasetTokenizer):
self.eq_dataset = eq_dataset
self.grid_size = (-1, 1, 1000)
self.max_integral_value = 10e10 # we can play with this value
self.grid_length = self.grid_size[1] - self.grid_size[0]
self.max_integral_value = 1e2 # we can play with this value
self.var_dict = self.create_var_dict()

def create_var_dict(self):
Expand Down Expand Up @@ -50,6 +51,8 @@ def calculate_n_pairwise_loss(self, N, ell_norm):
second_batch = N - first_batch
for i in range(first_batch):
chosen_symbols = random.choice(list(possible_symbols))
if len(self.var_dict[chosen_symbols]) < 1:
continue
possible_equations = {i[1] for i in self.var_dict[chosen_symbols]}

idx_sympy_1, idx_sympy_2 = random.sample(possible_equations, 2)
Expand All @@ -58,6 +61,8 @@ def calculate_n_pairwise_loss(self, N, ell_norm):

integrand = sp.Abs(sol_sympy_1[0].rhs - sol_sympy_2[0].rhs) ** ell_norm
integral = self.compute_integral(integrand)
if integral is None:
continue

loss[0, i] = sol_sympy_1[1]
loss[1, i] = sol_sympy_2[1]
Expand All @@ -80,21 +85,30 @@ def calculate_n_pairwise_loss(self, N, ell_norm):

def compute_integral(self, sympy_eq):
func, symbols = self.eq_dataset.sympy_to_torch(sympy_eq)

if len(symbols) < 1:
return torch.tensor(torch.nan)
grids = self.create_discrete_grids(symbols)
_arg = {sym: _grid for sym, _grid in zip(symbols, grids)}
complex_result = func(**_arg)
try:
complex_result = func(**_arg)
except Exception as e:
print(f"Error in sympy_to_torch {e}")
return None
result = (complex_result * complex_result.conj()) ** 0.5
result = torch.nanmean(result.real)
grid_size = self.grid_length / (grids[0].shape[0] ** (0.5))
reults_weighted = result.real * (grid_size ** (2 * len(symbols)))
result = torch.nansum(reults_weighted)
del grids
return result

def create_discrete_grids(self, symbols):
grid_low, grid_high, num_grid = self.grid_size
# scale grid down with dimesion
num_grid = int(num_grid * np.exp(-len(symbols)))
num_grid = int(num_grid * np.exp(-(1 - len(symbols))))
grid_real = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
grid_im = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
grid = torch.complex(grid_real, grid_im)
grids = [grid for i in symbols]
grid_im = 1j * torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
grid = grid_real[:, None] + grid_im[None, :]
grids = [grid.flatten() for i in symbols]
mesh = torch.meshgrid(grids)
return mesh

0 comments on commit d94667e

Please sign in to comment.