forked from lim271/MultiOutputIHGP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample.py
65 lines (61 loc) · 2.07 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
from moihgp import MOIHGPOnlineLearning
from time import time
import matplotlib.pyplot as plt
dt = 0.1
gamma = 0.9
windowsize = 2
if __name__=='__main__':
v1 = np.array([1.1, 0.9])
v2 = np.array([-0.9, -1.1])
p11 = [np.array([-1.1, -0.9])]
p12 = [np.array([-0.9, -1.1])]
p21 = [np.array([1.1, 0.9])]
p22 = [np.array([0.9, 1.1])]
v11 = []
v12 = []
v21 = []
v22 = []
for t in range(20):
v11.append(v1 + 0.3 * np.sin(t) + 0.1 * np.random.randn(2))
v12.append(v1 + 0.3 * np.cos(t) + 0.1 * np.random.randn(2))
v21.append(v2 + 0.3 * np.sin(0.3*t) + 0.1 * np.random.randn(2))
v22.append(v2 + 0.3 * np.cos(0.3*t) + 0.1 * np.random.randn(2))
p11.append(p11[-1] + v11[-1] * dt)
p12.append(p12[-1] + v12[-1] * dt)
p21.append(p21[-1] + v21[-1] * dt)
p22.append(p22[-1] + v22[-1] * dt)
data = np.hstack([p11, p12, p21, p22])
_, num_output = data.shape
num_latent = num_output // 2
gp = MOIHGPOnlineLearning(
dt, num_output, num_latent,
gamma=gamma, windowsize=windowsize, threading=False
)
yhat = []
for y in data:
tic = time()
yhat.append(gp.step(y))
toc = time()
print("elapsed time per step:", toc - tic)
Cov = gp.covariance
Corr = np.eye(num_output//2)
for i in range(num_output//2):
ii = slice(i*2, (i+1)*2)
for j in range(num_output//2):
jj = slice(j*2, (j+1)*2)
Corr[i, j] = np.sum(np.linalg.eigvals(Cov[ii, jj])) / np.sqrt(
np.sum(np.linalg.eigvals(Cov[ii, ii])) * np.sum(np.linalg.eigvals(Cov[jj, jj]))
)
print(Corr)
yhat = np.array(yhat)
plt.figure(1)
plt.scatter(data[:, 0], data[:, 1])
plt.scatter(data[:, 2], data[:, 3])
plt.scatter(data[:, 4], data[:, 5])
plt.scatter(data[:, 6], data[:, 7])
plt.plot(yhat[:, 0] , yhat[:, 1])
plt.plot(yhat[:, 2] , yhat[:, 3])
plt.plot(yhat[:, 4] , yhat[:, 5])
plt.plot(yhat[:, 6] , yhat[:, 7])
plt.show()