Skip to content

Commit

Permalink
added integration scope to complex nbs
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbergeron committed Jan 31, 2024
1 parent fbe0039 commit e5aa1f8
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions NumGI/Loss/LossDataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import math
import random

import numpy as np
import sympy as sp
import torch

Expand All @@ -18,7 +18,7 @@ class LossDataset:

def __init__(self, eq_dataset: DatasetTokenizer):
self.eq_dataset = eq_dataset
self.grid_size = (-100, 100, 1000)
self.grid_size = (-1, 1, 1000)
self.max_integral_value = 10e10 # we can play with this value
self.var_dict = self.create_var_dict()

Expand Down Expand Up @@ -61,10 +61,11 @@ def calculate_n_pairwise_loss(self, N, ell_norm):

loss[0, i] = sol_sympy_1[1]
loss[1, i] = sol_sympy_2[1]
if integral < self.max_integral_value:
integral_val = integral.item()
if np.abs(integral_val) < self.max_integral_value:
loss[2, i] = integral.item()
else:
loss[2, i] = torch.inf
loss[2, i] = np.sign(integral_val) * self.max_integral_value

for i in range(second_batch):
chosen_symbols = random.sample(possible_symbols, 2)
Expand All @@ -81,15 +82,19 @@ def compute_integral(self, sympy_eq):
func, symbols = self.eq_dataset.sympy_to_torch(sympy_eq)
grids = self.create_discrete_grids(symbols)
_arg = {sym: _grid for sym, _grid in zip(symbols, grids)}
result = torch.mean(func(**_arg))
complex_result = func(**_arg)
result = (complex_result * complex_result.conj()) ** 0.5
result = torch.nanmean(result.real)
del grids
return result

def create_discrete_grids(self, symbols):
grid_low, grid_high, num_grid = self.grid_size
# scale grid down with dimesion
num_grid = int(num_grid * math.exp(-len(symbols)))
grid = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
num_grid = int(num_grid * np.exp(-len(symbols)))
grid_real = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
grid_im = torch.linspace(grid_low, grid_high, num_grid, device=self.eq_dataset.device)
grid = torch.complex(grid_real, grid_im)
grids = [grid for i in symbols]
mesh = torch.meshgrid(grids)
return mesh

0 comments on commit e5aa1f8

Please sign in to comment.