@@ -38,7 +38,7 @@ class ConvDescriptor(Structure):
38
38
infiniopConvDescriptor_t = POINTER (ConvDescriptor )
39
39
40
40
41
- def conv (x , w , stride , padding , dilation ):
41
+ def conv (x , w , b , stride , padding , dilation ):
42
42
ndim = len (x .shape ) - 2
43
43
conv_func_map = {
44
44
1 : F .conv1d ,
@@ -54,10 +54,10 @@ def conv(x, w, stride, padding, dilation):
54
54
conv_func = conv_func_map [ndim ]
55
55
56
56
if PROFILE :
57
- ans = conv_func (x , w , stride = stride , padding = padding , dilation = dilation )
57
+ ans = conv_func (x , w , b , stride = stride , padding = padding , dilation = dilation )
58
58
torch .cuda .synchronize ()
59
59
return ans
60
- return conv_func (x , w , stride = stride , padding = padding , dilation = dilation )
60
+ return conv_func (x , w , b , stride = stride , padding = padding , dilation = dilation )
61
61
62
62
63
63
# infer the shape of the output given the inputs for a N-ary convolution
@@ -98,31 +98,34 @@ def test(
98
98
pads ,
99
99
strides ,
100
100
dilations ,
101
- tensor_stride = None ,
101
+ add_bias ,
102
102
tensor_dtype = torch .float16 ,
103
103
):
104
104
assert len (pads ) == len (strides ) == len (dilations )
105
105
print (
106
- f"Testing Conv on { torch_device } with x_shape: { x_shape } , w_shape: { w_shape } , b_shape: { w_shape [0 ]} , pads: { pads } , strides: { strides } , dilations: { dilations } , x_stride: { tensor_stride } dtype:{ tensor_dtype } "
106
+ f"Testing Conv on { torch_device } with x_shape: { x_shape } , w_shape: { w_shape } , add_bias: { add_bias } , "
107
+ f"b_shape: { w_shape [0 ]} , pads: { pads } , strides: { strides } , dilations: { dilations } , dtype:{ tensor_dtype } "
107
108
)
108
109
x = torch .rand (x_shape , dtype = tensor_dtype ).to (torch_device )
109
110
w = torch .rand (w_shape , dtype = tensor_dtype ).to (torch_device )
111
+ b = torch .round ((torch .rand (w_shape [0 ], dtype = tensor_dtype ).to (torch_device ) * 2 - 1 ) * 1000 ) / 1000 if add_bias else None
110
112
y = torch .zeros (
111
113
inferShape (x .shape , w .shape , pads , strides , dilations ), dtype = tensor_dtype
112
114
).to (torch_device )
113
115
114
116
for i in range (NUM_PRERUN if PROFILE else 1 ):
115
- ans = conv (x , w , strides , pads , dilations )
117
+ ans = conv (x , w , b , strides , pads , dilations )
116
118
if PROFILE :
117
119
start_time = time .time ()
118
120
for i in range (NUM_ITERATIONS ):
119
- _ = conv (x , w , strides , pads , dilations )
121
+ _ = conv (x , w , b , strides , pads , dilations )
120
122
elapsed = (time .time () - start_time ) / NUM_ITERATIONS
121
123
print (f"pytorch time: { elapsed :6f} " )
122
124
123
125
124
126
x_tensor = to_tensor (x , lib )
125
127
w_tensor = to_tensor (w , lib )
128
+ b_tensor = to_tensor (b , lib ) if b is not None else None
126
129
y_tensor = to_tensor (y , lib )
127
130
descriptor = infiniopConvDescriptor_t ()
128
131
@@ -133,6 +136,7 @@ def test(
133
136
y_tensor .descriptor ,
134
137
x_tensor .descriptor ,
135
138
w_tensor .descriptor ,
139
+ b_tensor .descriptor if b_tensor else None ,
136
140
tuple_to_void_p (pads ),
137
141
tuple_to_void_p (strides ),
138
142
tuple_to_void_p (dilations ),
@@ -147,27 +151,33 @@ def test(
147
151
workspace_ptr = ctypes .cast (workspace .data_ptr (), ctypes .POINTER (ctypes .c_uint8 ))
148
152
149
153
for i in range (NUM_PRERUN if PROFILE else 1 ):
150
- lib .infiniopConv (
151
- descriptor ,
152
- workspace_ptr ,
153
- workspaceSize ,
154
- y_tensor .data ,
155
- x_tensor .data ,
156
- w_tensor .data ,
157
- None ,
158
- )
159
- if PROFILE :
160
- start_time = time .time ()
161
- for i in range (NUM_ITERATIONS ):
154
+ check_error (
162
155
lib .infiniopConv (
163
156
descriptor ,
164
157
workspace_ptr ,
165
158
workspaceSize ,
166
159
y_tensor .data ,
167
160
x_tensor .data ,
168
161
w_tensor .data ,
162
+ b_tensor .data if b_tensor else None ,
169
163
None ,
170
164
)
165
+ )
166
+ if PROFILE :
167
+ start_time = time .time ()
168
+ for i in range (NUM_ITERATIONS ):
169
+ check_error (
170
+ lib .infiniopConv (
171
+ descriptor ,
172
+ workspace_ptr ,
173
+ workspaceSize ,
174
+ y_tensor .data ,
175
+ x_tensor .data ,
176
+ w_tensor .data ,
177
+ b_tensor .data if b_tensor else None ,
178
+ None ,
179
+ )
180
+ )
171
181
elapsed = (time .time () - start_time ) / NUM_ITERATIONS
172
182
print (f" lib time: { elapsed :6f} " )
173
183
@@ -181,18 +191,18 @@ def test(
181
191
def test_cpu (lib , test_cases ):
182
192
device = DeviceEnum .DEVICE_CPU
183
193
handle = create_handle (lib , device )
184
- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
185
- test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
186
- test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
194
+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
195
+ test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
196
+ test (lib , handle , "cpu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
187
197
destroy_handle (lib , handle )
188
198
189
199
190
200
def test_cuda (lib , test_cases ):
191
201
device = DeviceEnum .DEVICE_CUDA
192
202
handle = create_handle (lib , device )
193
- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
194
- test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
195
- test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
203
+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
204
+ test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
205
+ test (lib , handle , "cuda" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
196
206
destroy_handle (lib , handle )
197
207
198
208
@@ -201,54 +211,62 @@ def test_bang(lib, test_cases):
201
211
202
212
device = DeviceEnum .DEVICE_BANG
203
213
handle = create_handle (lib , device )
204
- for x_shape , w_shape , pads , strides , dilations , x_strides in test_cases :
205
- test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float16 )
206
- test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , x_strides , tensor_dtype = torch .float32 )
214
+ for x_shape , w_shape , pads , strides , dilations , add_bias in test_cases :
215
+ test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float16 )
216
+ test (lib , handle , "mlu" , x_shape , w_shape , pads , strides , dilations , add_bias , tensor_dtype = torch .float32 )
207
217
destroy_handle (lib , handle )
208
218
209
219
210
220
if __name__ == "__main__" :
211
221
test_cases = [
212
- # x_shape, w_shape, pads, strides, dilations, x_strides
222
+ # x_shape, w_shape, pads, strides, dilations, add_bias
213
223
(
214
224
(32 , 3 , 4 ),
215
225
(32 , 3 , 5 ),
216
226
(1 ,),
217
227
(1 ,),
218
228
(1 ,),
219
- None ,
229
+ False ,
230
+ ),
231
+ (
232
+ (3 , 7 , 4 ),
233
+ (3 , 7 , 5 ),
234
+ (1 ,),
235
+ (1 ,),
236
+ (1 ,),
237
+ True ,
220
238
),
221
239
(
222
240
(1 , 3 , 4 , 4 ),
223
241
(2 , 3 , 3 , 3 ),
224
242
(1 , 1 ),
225
243
(1 , 2 ),
226
244
(2 , 1 ),
227
- None ,
245
+ True ,
228
246
),
229
247
(
230
248
(32 , 3 , 128 , 128 ),
231
249
(64 , 3 , 5 , 5 ),
232
250
(2 , 2 ),
233
251
(2 , 2 ),
234
252
(1 , 1 ),
235
- None ,
253
+ False ,
236
254
),
237
255
(
238
256
(1 , 1 , 4 , 4 , 4 ),
239
257
(1 , 1 , 5 , 5 , 5 ),
240
258
(1 , 1 , 1 ),
241
259
(1 , 1 , 1 ),
242
260
(1 , 1 , 1 ),
243
- None ,
261
+ True ,
244
262
),
245
263
(
246
264
(32 , 3 , 32 , 32 , 32 ),
247
265
(64 , 3 , 5 , 5 , 5 ),
248
266
(3 , 2 , 2 ),
249
267
(4 , 3 , 3 ),
250
268
(2 , 2 , 1 ),
251
- None ,
269
+ False ,
252
270
),
253
271
]
254
272
args = get_args ()
@@ -260,6 +278,7 @@ def test_bang(lib, test_cases):
260
278
infiniopTensorDescriptor_t ,
261
279
infiniopTensorDescriptor_t ,
262
280
infiniopTensorDescriptor_t ,
281
+ infiniopTensorDescriptor_t ,
263
282
c_void_p ,
264
283
c_void_p ,
265
284
c_void_p ,
@@ -274,6 +293,7 @@ def test_bang(lib, test_cases):
274
293
c_void_p ,
275
294
c_void_p ,
276
295
c_void_p ,
296
+ c_void_p ,
277
297
]
278
298
lib .infiniopDestroyConvDescriptor .restype = c_int32
279
299
lib .infiniopDestroyConvDescriptor .argtypes = [
0 commit comments