3
3
from __future__ import print_function
4
4
from __future__ import unicode_literals
5
5
6
+ import unittest
7
+
6
8
from caffe2 .python import core
7
9
from hypothesis import given
8
10
import hypothesis .strategies as st
@@ -44,11 +46,11 @@ def _tensor_splits(draw, add_axis=False):
44
46
45
47
46
48
class TestFunctional (hu .HypothesisTestCase ):
47
- @given (X = hu .tensor (), engine = st .sampled_from (["" , "CUDNN" ]))
48
- def test_relu (self , X , engine ):
49
+ @given (X = hu .tensor (), engine = st .sampled_from (["" , "CUDNN" ]), ** hu . gcs )
50
+ def test_relu (self , X , engine , gc , dc ):
49
51
X += 0.02 * np .sign (X )
50
52
X [X == 0.0 ] += 0.02
51
- output = Functional .Relu (X )
53
+ output = Functional .Relu (X , device_option = gc )
52
54
Y_l = output [0 ]
53
55
Y_d = output ["output_0" ]
54
56
@@ -66,11 +68,11 @@ def test_relu(self, X, engine):
66
68
Y_d , Y_ref , err_msg = 'Functional Relu result mismatch'
67
69
)
68
70
69
- @given (tensor_splits = _tensor_splits ())
70
- def test_concat (self , tensor_splits ):
71
+ @given (tensor_splits = _tensor_splits (), ** hu . gcs )
72
+ def test_concat (self , tensor_splits , gc , dc ):
71
73
# Input Size: 1 -> inf
72
74
axis , _ , splits = tensor_splits
73
- concat_result , split_info = Functional .Concat (* splits , axis = axis )
75
+ concat_result , split_info = Functional .Concat (* splits , axis = axis , device_option = gc )
74
76
75
77
concat_result_ref = np .concatenate (splits , axis = axis )
76
78
split_info_ref = np .array ([a .shape [axis ] for a in splits ])
@@ -87,8 +89,8 @@ def test_concat(self, tensor_splits):
87
89
err_msg = 'Functional Concat split info mismatch'
88
90
)
89
91
90
- @given (tensor_splits = _tensor_splits (), split_as_arg = st .booleans ())
91
- def test_split (self , tensor_splits , split_as_arg ):
92
+ @given (tensor_splits = _tensor_splits (), split_as_arg = st .booleans (), ** hu . gcs )
93
+ def test_split (self , tensor_splits , split_as_arg , gc , dc ):
92
94
# Output Size: 1 - inf
93
95
axis , split_info , splits = tensor_splits
94
96
@@ -100,7 +102,7 @@ def test_split(self, tensor_splits, split_as_arg):
100
102
else :
101
103
input_tensors = [np .concatenate (splits , axis = axis ), split_info ]
102
104
kwargs = dict (axis = axis , num_output = len (splits ))
103
- result = Functional .Split (* input_tensors , ** kwargs )
105
+ result = Functional .Split (* input_tensors , device_option = gc , ** kwargs )
104
106
105
107
def split_ref (input , split = split_info ):
106
108
s = np .cumsum ([0 ] + list (split ))
@@ -114,3 +116,7 @@ def split_ref(input, split=split_info):
114
116
np .testing .assert_array_equal (
115
117
result [i ], ref , err_msg = 'Functional Relu result mismatch'
116
118
)
119
+
120
+
121
+ if __name__ == '__main__' :
122
+ unittest .main ()
0 commit comments