-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_tests.py
70 lines (61 loc) · 1.46 KB
/
run_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from engine import Value
def test_sanity_check():
x = Value(-4.0)
z = 2 * x + 2 + x
q = z.tanh() + z * x
h = (z * z).tanh()
y = h + q + q * x
y.backward()
xmg, ymg = x, y
x = torch.Tensor([-4.0]).double()
x.requires_grad = True
z = 2 * x + 2 + x
q = z.tanh() + z * x
h = (z * z).tanh()
y = h + q + q * x
y.backward()
xpt, ypt = x, y
# forward pass went well
assert ymg.data == ypt.data.item()
# backward pass went well
assert xmg.grad == xpt.grad.item()
def test_more_ops():
a = Value(-4.0)
b = Value(2.0)
c = a + b
d = a * b + b**3
c += c + 1
c += 1 + c + (-a)
d += d * 2 + (b + a).tanh()
d += 3 * d + (b - a).tanh()
e = c - d
f = e**2
g = f / 2.0
g += 10.0 / f
g.backward()
amg, bmg, gmg = a, b, g
a = torch.Tensor([-4.0]).double()
b = torch.Tensor([2.0]).double()
a.requires_grad = True
b.requires_grad = True
c = a + b
d = a * b + b**3
c = c + c + 1
c = c + 1 + c + (-a)
d = d + d * 2 + (b + a).tanh()
d = d + 3 * d + (b - a).tanh()
e = c - d
f = e**2
g = f / 2.0
g = g + 10.0 / f
g.backward()
apt, bpt, gpt = a, b, g
tol = 1e-6
# forward pass went well
assert abs(gmg.data - gpt.data.item()) < tol
# backward pass went well
assert abs(amg.grad - apt.grad.item()) < tol
assert abs(bmg.grad - bpt.grad.item()) < tol
test_sanity_check()
test_more_ops()