This repository has been archived by the owner on Aug 31, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 187
/
test_zero_even_op.py
120 lines (98 loc) · 4.2 KB
/
test_zero_even_op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
##############################################################
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import unittest
from caffe2.proto import caffe2_pb2
from caffe2.python import core
from caffe2.python import workspace
import utils.c2
class ZeroEvenOpTest(unittest.TestCase):
def _run_zero_even_op(self, X):
op = core.CreateOperator('ZeroEven', ['X'], ['Y'])
workspace.FeedBlob('X', X)
workspace.RunOperatorOnce(op)
Y = workspace.FetchBlob('Y')
return Y
def _run_zero_even_op_gpu(self, X):
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, 0)):
op = core.CreateOperator('ZeroEven', ['X'], ['Y'])
workspace.FeedBlob('X', X)
workspace.RunOperatorOnce(op)
Y = workspace.FetchBlob('Y')
return Y
def test_throws_on_non_1D_arrays(self):
X = np.zeros((2, 2), dtype=np.float32)
with self.assertRaisesRegexp(RuntimeError, 'X\.ndim\(\) == 1'):
self._run_zero_even_op(X)
def test_handles_empty_arrays(self):
X = np.array([], dtype=np.float32)
Y_exp = np.copy(X)
Y_act = self._run_zero_even_op(X)
np.testing.assert_allclose(Y_act, Y_exp)
def test_sets_vals_at_even_inds_to_zero(self):
X = np.array([0, 1, 2, 3, 4], dtype=np.float32)
Y_exp = np.array([0, 1, 0, 3, 0], dtype=np.float32)
Y_act = self._run_zero_even_op(X)
np.testing.assert_allclose(Y_act[0::2], Y_exp[0::2])
def test_preserves_vals_at_odd_inds(self):
X = np.array([0, 1, 2, 3, 4], dtype=np.float32)
Y_exp = np.array([0, 1, 0, 3, 0], dtype=np.float32)
Y_act = self._run_zero_even_op(X)
np.testing.assert_allclose(Y_act[1::2], Y_exp[1::2])
def test_handles_even_length_arrays(self):
X = np.random.rand(64).astype(np.float32)
Y_exp = np.copy(X)
Y_exp[0::2] = 0.0
Y_act = self._run_zero_even_op(X)
np.testing.assert_allclose(Y_act, Y_exp)
def test_handles_odd_length_arrays(self):
X = np.random.randn(77).astype(np.float32)
Y_exp = np.copy(X)
Y_exp[0::2] = 0.0
Y_act = self._run_zero_even_op(X)
np.testing.assert_allclose(Y_act, Y_exp)
def test_gpu_throws_on_non_1D_arrays(self):
X = np.zeros((2, 2), dtype=np.float32)
with self.assertRaisesRegexp(RuntimeError, 'X\.ndim\(\) == 1'):
self._run_zero_even_op_gpu(X)
def test_gpu_handles_empty_arrays(self):
X = np.array([], dtype=np.float32)
Y_exp = np.copy(X)
Y_act = self._run_zero_even_op_gpu(X)
np.testing.assert_allclose(Y_act, Y_exp)
def test_gpu_sets_vals_at_even_inds_to_zero(self):
X = np.array([0, 1, 2, 3, 4], dtype=np.float32)
Y_exp = np.array([0, 1, 0, 3, 0], dtype=np.float32)
Y_act = self._run_zero_even_op_gpu(X)
np.testing.assert_allclose(Y_act[0::2], Y_exp[0::2])
def test_gpu_preserves_vals_at_odd_inds(self):
X = np.array([0, 1, 2, 3, 4], dtype=np.float32)
Y_exp = np.array([0, 1, 0, 3, 0], dtype=np.float32)
Y_act = self._run_zero_even_op_gpu(X)
np.testing.assert_allclose(Y_act[1::2], Y_exp[1::2])
def test_gpu_handles_even_length_arrays(self):
X = np.random.rand(64).astype(np.float32)
Y_exp = np.copy(X)
Y_exp[0::2] = 0.0
Y_act = self._run_zero_even_op_gpu(X)
np.testing.assert_allclose(Y_act, Y_exp)
def test_gpu_handles_odd_length_arrays(self):
X = np.random.randn(77).astype(np.float32)
Y_exp = np.copy(X)
Y_exp[0::2] = 0.0
Y_act = self._run_zero_even_op_gpu(X)
np.testing.assert_allclose(Y_act, Y_exp)
if __name__ == '__main__':
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
utils.c2.import_custom_ops()
assert 'ZeroEven' in workspace.RegisteredOperators()
unittest.main()