-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunit_test.py
111 lines (76 loc) · 3.08 KB
/
unit_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse, os, sys
import tensorflow as tf
import numpy as np
def parse_arg():
parser = argparse.ArgumentParser()
targets = ["rgb_to_lab", "rgb_to_xyz"]
parser.add_argument("--target", default=targets[0], help="Test targets",
choices=targets)
args = parser.parse_args()
print(args)
return args
def test_rgb_to_lab():
import skimage
iterations = 10000
batch_size, h, w, c = 16, 64, 64, 3
from tf_utils.image_utils import tf_rgb_to_lab
rgb_placeholder = tf.placeholder(tf.float32, shape=[batch_size, h, w, c],
name="rgb_placeholder")
tf_xyz_tensor = tf_rgb_to_lab(rgb_placeholder, srgb=True)
with tf.Session() as sess:
for i in range(iterations):
rgb_images = np.random.random((batch_size, h, w, c)).astype(np.float32)
ski_xyz = np.asarray([ skimage.color.rgb2lab(rgb_images[_i]) for _i in range(batch_size) ])
tf_xyz = sess.run(tf_xyz_tensor, feed_dict={rgb_placeholder:rgb_images})
mse = np.mean(np.square(ski_xyz - tf_xyz))
if mse < 1e-5: # results equal
sys.stdout.write("\r>> Checked %d/%d " % (i + 1, iterations))
sys.stdout.flush()
else:
print("\nError : %f \nToo bad." % (mse))
print(ski_xyz[0,:3,:3,:])
print(tf_xyz[0,:3,:3,:])
quit()
print("\nDone.")
return
def test_rgb_to_xyz():
import skimage
iterations = 10000
batch_size, h, w, c = 16, 64, 64, 3
from tf_utils.image_utils import tf_rgb_to_xyz
rgb_placeholder = tf.placeholder(tf.float32, shape=[batch_size, h, w, c],
name="rgb_placeholder")
tf_xyz_tensor = tf_rgb_to_xyz(rgb_placeholder)
with tf.Session() as sess:
for i in range(iterations):
rgb_images = np.random.random((batch_size, h, w, c)).astype(np.float32)
ski_xyz = np.asarray([ skimage.color.rgb2xyz(rgb_images[_i]) for _i in range(batch_size) ])
tf_xyz = sess.run(tf_xyz_tensor, feed_dict={rgb_placeholder:rgb_images})
mse = np.mean(np.square(ski_xyz - tf_xyz))
if mse < 1e-5: # results equal
sys.stdout.write("\r>> Checked %d/%d " % (i + 1, iterations))
sys.stdout.flush()
else:
print("\nError : %f \nToo bad." % (mse))
quit()
print("\nDone.")
return
def show_on_tb(dir_name="show_graph_logs"):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
writer = tf.summary.FileWriter(dir_name, sess.graph)
writer.close()
print("python3 -m tensorboard.main --logdir=" + dir_name + " --port=32424")
return
if __name__ == "__main__":
args = parse_arg()
print("args.target =", args.target)
target_func = "test_" + args.target
if target_func in globals():
globals()[ target_func ]()
else:
print("Function %s NOT FOUND" %(target_func))
quit()