Skip to content

Commit

Permalink
Added decaying exp scaling to step size in integral to remove oom errors
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbergeron committed Jan 12, 2024
1 parent ecb3687 commit fbe0039
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions NumGI/Loss/LossDataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
import random

import sympy as sp
Expand All @@ -15,12 +16,11 @@ class LossDataset:
DatasetTokenizer (DatasetTokenizer): DatasetTokenizer to create loss dataset from.
"""

def __init__(self, eq_dataset: DatasetTokenizer, N: int, ell_norm: int = 1):
def __init__(self, eq_dataset: DatasetTokenizer):
self.eq_dataset = eq_dataset
self.grid_size = (100, 100, 1000)
self.grid_size = (-100, 100, 1000)
self.max_integral_value = 10e10 # we can play with this value
self.var_dict = self.create_var_dict()
self.loss = self.calculate_n_pairwise_loss(N, ell_norm)

def create_var_dict(self):
"""Creates a dictionary of different variables and their corresponding equations.
Expand Down Expand Up @@ -50,7 +50,6 @@ def calculate_n_pairwise_loss(self, N, ell_norm):
second_batch = N - first_batch
for i in range(first_batch):
chosen_symbols = random.choice(list(possible_symbols))

possible_equations = {i[1] for i in self.var_dict[chosen_symbols]}

idx_sympy_1, idx_sympy_2 = random.sample(possible_equations, 2)
Expand All @@ -76,7 +75,7 @@ def calculate_n_pairwise_loss(self, N, ell_norm):
loss[1, i] = sol_sympy_2[1]
loss[2, i] = torch.inf

return loss
self.loss = loss

def compute_integral(self, sympy_eq):
func, symbols = self.eq_dataset.sympy_to_torch(sympy_eq)
Expand All @@ -87,7 +86,10 @@ def compute_integral(self, sympy_eq):
return result

def create_discrete_grids(self, symbols):
grid = torch.linspace(*self.grid_size, device=self.eq_dataset.device)
grid_low, grid_high, num_grid = self.grid_size
# scale grid down with dimesion
num_grid = int(num_grid * math.exp(-len(symbols)))
grid = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
grids = [grid for i in symbols]
mesh = torch.meshgrid(grids)
return mesh

0 comments on commit fbe0039

Please sign in to comment.