Skip to content

Commit

Permalink
removed previous test
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbergeron committed May 6, 2024
1 parent f540014 commit d1bef13
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 88 deletions.
96 changes: 96 additions & 0 deletions NumGI/ParallelEquationGenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import multiprocessing as mp
import os

import sympy as sp
import torch

from NumGI.ConstantDictionaries import DIFFERENTIAL_FUNCTIONS
from NumGI.ConstantDictionaries import OPERATIONS
from NumGI.DatasetTokenizer import DatasetTokenizer
from NumGI.EquationTokenizer import EquationTokenizer
from NumGI.SolutionGenerator import SolutionGenerator


def worker(args):
sols = args[0].generate_solution_dataset(*args[1:-1])
generate_tokenized_lists(sols, args[-1])
return [args[-1]]


def generate_tokenized_lists(sols, num):
"""Generates tokenized lists of equations and solutions. Saves them to disk."""
x = []
y = []
for i in sols:
if not isinstance(i[1], sp.logic.boolalg.BooleanTrue) and not isinstance(
i[1], sp.logic.boolalg.BooleanFalse
):
x.append(i[0].doit())
y.append(i[1])

tok = EquationTokenizer()
y_list = [tok.sympy_to_list(i) for i in y]
x_list = [tok.sympy_to_list(i) for i in x]
x_nozoo = []
y_nozoo = []
for idx, i in enumerate(y_list):
if len(i) < 200 and len(i) > 10:
if "zoo" not in [str(j) for j in i]:
try:
if len(x_list[idx]) < 100:
x_nozoo.append(x[idx])
y_nozoo.append(y[idx])
except Exception as e:
print(e)
continue
try:
dataset = DatasetTokenizer(x_nozoo, y_nozoo, useDefaultTokenizer=True)
dataset.device = "cpu"
torch.save(dataset.x_tokenized.to("cpu"), f"data/x_var_6/x_{num}.pt")
torch.save(dataset.y_tokenized.to("cpu"), f"data/x_var_6/y_{num}.pt")
except KeyError as e:
print(f"nan in dataset: {e}")
except ValueError:
print(len(x_nozoo))


def generate_eq_parallel(gen_args: list, path: str, num_thousands: int):
"""Generates equations in parallel.
Note some equations will be discarded because they are too long.
This won't create the exact number of expected equations.
Args:
path (str): path to save the equations to
num_thousands (int): number of thousands of equations to generate
"""
pool = mp.Pool(mp.cpu_count() - 1)
shift = 0
solgen = SolutionGenerator()

for i in os.listdir(path):
new_i = (i.split("_")[1]).split(".")[0]
shift = max(int(new_i), shift)

shift += 1
# Define the parameters for each call to generate_solution_dataset
parameters = [([solgen] + gen_args + [shift + _]) for _ in range(num_thousands)]

pool.map(worker, parameters)


if __name__ == "__main__":
diff_func = DIFFERENTIAL_FUNCTIONS
ops = OPERATIONS
vars = ["x"]
gen_args = [
(3, 4),
(3, 5),
1_000,
vars,
diff_func,
ops,
]
generate_eq_parallel(gen_args, "data/x_var_6", 10000)
88 changes: 0 additions & 88 deletions test/EquationTests/test_numpy_sympy_torch.py

This file was deleted.

0 comments on commit d1bef13

Please sign in to comment.