From fbe0039a05676d53f489c50335568957b64712d0 Mon Sep 17 00:00:00 2001 From: Arnaud Bergeron Date: Fri, 12 Jan 2024 18:01:34 -0500 Subject: [PATCH] Added decaying exp scaling to step size in integral to remove oom errors --- NumGI/Loss/LossDataset.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/NumGI/Loss/LossDataset.py b/NumGI/Loss/LossDataset.py index 3ee44a6..e9af741 100644 --- a/NumGI/Loss/LossDataset.py +++ b/NumGI/Loss/LossDataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import random import sympy as sp @@ -15,12 +16,11 @@ class LossDataset: DatasetTokenizer (DatasetTokenizer): DatasetTokenizer to create loss dataset from. """ - def __init__(self, eq_dataset: DatasetTokenizer, N: int, ell_norm: int = 1): + def __init__(self, eq_dataset: DatasetTokenizer): self.eq_dataset = eq_dataset - self.grid_size = (100, 100, 1000) + self.grid_size = (-100, 100, 1000) self.max_integral_value = 10e10 # we can play with this value self.var_dict = self.create_var_dict() - self.loss = self.calculate_n_pairwise_loss(N, ell_norm) def create_var_dict(self): """Creates a dictionary of different variables and their corresponding equations. @@ -50,7 +50,6 @@ def calculate_n_pairwise_loss(self, N, ell_norm): second_batch = N - first_batch for i in range(first_batch): chosen_symbols = random.choice(list(possible_symbols)) - possible_equations = {i[1] for i in self.var_dict[chosen_symbols]} idx_sympy_1, idx_sympy_2 = random.sample(possible_equations, 2) @@ -76,7 +75,7 @@ def calculate_n_pairwise_loss(self, N, ell_norm): loss[1, i] = sol_sympy_2[1] loss[2, i] = torch.inf - return loss + self.loss = loss def compute_integral(self, sympy_eq): func, symbols = self.eq_dataset.sympy_to_torch(sympy_eq) @@ -87,7 +86,10 @@ def compute_integral(self, sympy_eq): return result def create_discrete_grids(self, symbols): - grid = torch.linspace(*self.grid_size, device=self.eq_dataset.device) + grid_low, grid_high, num_grid = self.grid_size + # scale grid down with dimesion + num_grid = int(num_grid * math.exp(-len(symbols))) + grid = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device) grids = [grid for i in symbols] mesh = torch.meshgrid(grids) return mesh