From d94667e30ee46759bb1d3d5e4965099e32751b41 Mon Sep 17 00:00:00 2001 From: Arnaud Bergeron Date: Fri, 2 Feb 2024 14:35:48 -0500 Subject: [PATCH] fixed integration --- NumGI/Loss/LossDataset.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/NumGI/Loss/LossDataset.py b/NumGI/Loss/LossDataset.py index 910a9bf..4297281 100644 --- a/NumGI/Loss/LossDataset.py +++ b/NumGI/Loss/LossDataset.py @@ -19,7 +19,8 @@ class LossDataset: def __init__(self, eq_dataset: DatasetTokenizer): self.eq_dataset = eq_dataset self.grid_size = (-1, 1, 1000) - self.max_integral_value = 10e10 # we can play with this value + self.grid_length = self.grid_size[1] - self.grid_size[0] + self.max_integral_value = 1e2 # we can play with this value self.var_dict = self.create_var_dict() def create_var_dict(self): @@ -50,6 +51,8 @@ def calculate_n_pairwise_loss(self, N, ell_norm): second_batch = N - first_batch for i in range(first_batch): chosen_symbols = random.choice(list(possible_symbols)) + if len(self.var_dict[chosen_symbols]) < 1: + continue possible_equations = {i[1] for i in self.var_dict[chosen_symbols]} idx_sympy_1, idx_sympy_2 = random.sample(possible_equations, 2) @@ -58,6 +61,8 @@ def calculate_n_pairwise_loss(self, N, ell_norm): integrand = sp.Abs(sol_sympy_1[0].rhs - sol_sympy_2[0].rhs) ** ell_norm integral = self.compute_integral(integrand) + if integral is None: + continue loss[0, i] = sol_sympy_1[1] loss[1, i] = sol_sympy_2[1] @@ -80,21 +85,30 @@ def calculate_n_pairwise_loss(self, N, ell_norm): def compute_integral(self, sympy_eq): func, symbols = self.eq_dataset.sympy_to_torch(sympy_eq) + + if len(symbols) < 1: + return torch.tensor(torch.nan) grids = self.create_discrete_grids(symbols) _arg = {sym: _grid for sym, _grid in zip(symbols, grids)} - complex_result = func(**_arg) + try: + complex_result = func(**_arg) + except Exception as e: + print(f"Error in sympy_to_torch {e}") + return None result = (complex_result * complex_result.conj()) ** 0.5 - result = torch.nanmean(result.real) + grid_size = self.grid_length / (grids[0].shape[0] ** (0.5)) + reults_weighted = result.real * (grid_size ** (2 * len(symbols))) + result = torch.nansum(reults_weighted) del grids return result def create_discrete_grids(self, symbols): grid_low, grid_high, num_grid = self.grid_size # scale grid down with dimesion - num_grid = int(num_grid * np.exp(-len(symbols))) + num_grid = int(num_grid * np.exp(-(1 - len(symbols)))) grid_real = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device) - grid_im = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device) - grid = torch.complex(grid_real, grid_im) - grids = [grid for i in symbols] + grid_im = 1j * torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device) + grid = grid_real[:, None] + grid_im[None, :] + grids = [grid.flatten() for i in symbols] mesh = torch.meshgrid(grids) return mesh