-
Notifications
You must be signed in to change notification settings - Fork 2
/
ebs_main.py
34 lines (24 loc) · 1.08 KB
/
ebs_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import dill as pickle
import pandas as pd
import numpy as np
import torch
import os
from func import load_data
from epde_general import epde_equations
from bamt_general import bs_experiment
from solver_general import solver_equations
from func import confidence_region as conf_plt
if __name__ == '__main__':
tasks = {
'wave_equation': load_data.wave_equation,
'burgers_equation': load_data.burgers_equation,
'burgers_equation_small_grid': load_data.burgers_equation_small_grid,
'KdV_equation': load_data.KdV_equation
}
title = list(tasks.keys())[0] # name of the problem/equation
u, grid_u, derivatives, cfg, params, b_conds = tasks[title]()
for variance in cfg.params["global_config"]["variance_arr"]:
df_main, epde_search_obj = epde_equations(u, grid_u, derivatives, cfg, variance, title)
equations = bs_experiment(df_main, cfg, title)
u_main, grid_main = solver_equations(cfg, params, b_conds, equations, epde_search_obj, title)
conf_plt.confidence_region_print(u, cfg, params, u_main, grid_main, variance)