forked from PGelss/scikit_tt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_ulam.py
66 lines (47 loc) · 2.18 KB
/
test_ulam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# -*- coding: utf-8 -*-
import unittest as ut
from unittest import TestCase
import scikit_tt.data_driven.ulam as ulam
import numpy as np
import os
class TestPF(TestCase):
def setUp(self):
"""Consider triple- and quadruple well potentials for testing"""
# set tolerance for errors
self.tol = 1e-10
# load transition lists
directory = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
self.transitions_2d = np.load(directory + '/examples/data/triple_well_transitions.npz')['transitions']
self.transitions_3d = np.load(directory + '/examples/data/quadruple_well_transitions.npz')['transitions']
# coarse-grain 3d data
self.transitions_3d = np.int64(np.ceil(np.true_divide(self.transitions_3d, 5)))
def test_ulam_2d(self):
"""test for perron_frobenius_2d"""
# construct transition operator in TT format
operator = ulam.ulam_2d(self.transitions_2d, [50, 50], 500)
# construct full operator
operator_full = np.zeros([50, 50, 50, 50])
for i in range(self.transitions_2d.shape[1]):
[x_1, y_1, x_2, y_2] = self.transitions_2d[:, i] - 1
operator_full[x_2, y_2, x_1, y_1] += 1
operator_full *= np.true_divide(1, 500)
# compute error
error = np.abs(operator.full() - operator_full).sum()
# check if error is smaller than tolerance
self.assertLess(error, self.tol)
def test_ulam_3d(self):
"""test for perron_frobenius_3d"""
# construct transition operator in TT format
operator = ulam.ulam_3d(self.transitions_3d, [5, 5, 5], 12500)
# construct full operator
operator_full = np.zeros([5, 5, 5, 5, 5, 5])
for i in range(self.transitions_3d.shape[1]):
[x_1, y_1, z_1, x_2, y_2, z_2] = self.transitions_3d[:, i] - 1
operator_full[x_2, y_2, z_2, x_1, y_1, z_1] += 1
operator_full *= np.true_divide(1, 12500)
# compute error
error = np.abs(operator.full() - operator_full).sum()
# check if error is smaller than tolerance
self.assertLess(error, self.tol)
if __name__ == '__main__':
ut.main()