@@ -140,26 +140,32 @@ def floordiv(a, b):
140
140
return aten .div .Tensor_mode (a , b , rounding_mode = "floor" )
141
141
142
142
143
- def get_padded_length (x ):
144
- if x % config .alignment_size == 0 :
143
+ def get_alignment_size (x ):
144
+ if x .dtype == torch .float16 or x .dtype == torch .half or x .dtype == torch .bfloat16 :
145
+ return 8
146
+ elif x .dtype == torch .float32 or x .dtype == torch .float :
147
+ return 4
148
+ else :
149
+ return 0
150
+
151
+
152
+ def check_device (a : Tensor , b : Tensor ):
153
+ return a .is_cuda and b .is_cuda
154
+
155
+
156
+ def get_padded_length (x , alignment_size ):
157
+ if alignment_size == 0 or x % alignment_size == 0 :
145
158
return 0
146
- return int ((x // config . alignment_size + 1 ) * config . alignment_size ) - x
159
+ return int ((x // alignment_size + 1 ) * alignment_size ) - x
147
160
148
161
149
162
def pad_dim (x , padded_length , dim ):
163
+ if padded_length == 0 :
164
+ return x
150
165
pad = x .new_zeros (* x .shape [:dim ], padded_length , * x .shape [dim + 1 :])
151
166
return torch .cat ([x , pad ], dim = dim )
152
167
153
168
154
- def check_device_dtype (a : Tensor , b : Tensor ):
155
- return (
156
- a .is_cuda
157
- and b .is_cuda
158
- and a .dtype in (torch .float32 , torch .float16 , torch .bfloat16 )
159
- and b .dtype in (torch .float32 , torch .float16 , torch .bfloat16 )
160
- )
161
-
162
-
163
169
@register_decomposition ([aten .addmm ])
164
170
def addmm (input , mat1 , mat2 , * , beta = 1 , alpha = 1 ):
165
171
if config .triton .mm != "aten" :
@@ -172,57 +178,59 @@ def addmm(input, mat1, mat2, *, beta=1, alpha=1):
172
178
173
179
if (
174
180
config .shape_padding
175
- and check_device_dtype (mat1 , mat2 )
181
+ and check_device (mat1 , mat2 )
176
182
and should_pad_bench (mat1 , mat2 , torch .ops .aten .addmm , input = input )
177
183
):
178
- m_padded_length = get_padded_length (mat1 .shape [0 ])
179
- k_padded_length = get_padded_length (mat1 .shape [1 ])
180
- n_padded_length = get_padded_length (mat2 .shape [1 ])
181
-
182
- if k_padded_length != 0 :
183
- mat1 = pad_dim (mat1 , k_padded_length , 1 )
184
- mat2 = pad_dim (mat2 , k_padded_length , 0 )
185
- elif m_padded_length != 0 :
186
- mat1 = pad_dim (mat1 , m_padded_length , 0 )
187
- elif n_padded_length != 0 :
188
- mat2 = pad_dim (mat2 , n_padded_length , 1 )
189
-
190
- if input is not None and k_padded_length == 0 :
191
- if m_padded_length != 0 and input .dim () == 2 :
192
- input = pad_dim (input , m_padded_length , 0 )
193
- elif n_padded_length != 0 :
194
- if input .dim () == 2 :
195
- input = pad_dim (input , n_padded_length , 1 )
196
- elif input .dim () == 1 :
197
- input = pad_dim (input , n_padded_length , 0 )
198
-
199
- if k_padded_length != 0 :
200
- return torch .ops .aten .addmm (input , mat1 , mat2 , beta = beta , alpha = alpha )
201
- elif m_padded_length != 0 :
202
- return torch .ops .aten .addmm (input , mat1 , mat2 , beta = beta , alpha = alpha )[
203
- :- m_padded_length , :
204
- ]
205
- elif n_padded_length != 0 :
206
- return torch .ops .aten .addmm (input , mat1 , mat2 , beta = beta , alpha = alpha )[
207
- :, :- n_padded_length
208
- ]
184
+ m_padded_length = get_padded_length (mat1 .shape [0 ], get_alignment_size (mat1 ))
185
+ k_padded_length = get_padded_length (mat1 .shape [1 ], get_alignment_size (mat1 ))
186
+ n_padded_length = get_padded_length (mat2 .shape [1 ], get_alignment_size (mat2 ))
187
+ if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0 :
188
+ return pad_addmm (
189
+ input , mat1 , mat2 , m_padded_length , n_padded_length , k_padded_length
190
+ )
209
191
210
192
return NotImplemented # go directly to lowering
211
193
212
194
195
+ def pad_addmm (input , mat1 , mat2 , m_padded_length , k_padded_length , n_padded_length ):
196
+ if k_padded_length != 0 :
197
+ mat1 = pad_dim (mat1 , k_padded_length , 1 )
198
+ mat2 = pad_dim (mat2 , k_padded_length , 0 )
199
+ elif n_padded_length != 0 :
200
+ mat2 = pad_dim (mat2 , n_padded_length , 1 )
201
+ elif m_padded_length != 0 :
202
+ mat1 = pad_dim (mat1 , m_padded_length , 0 )
203
+
204
+ if input is not None and k_padded_length == 0 :
205
+ if n_padded_length != 0 :
206
+ if input .dim () == 2 :
207
+ input = pad_dim (input , n_padded_length , 1 )
208
+ elif input .dim () == 1 :
209
+ input = pad_dim (input , n_padded_length , 0 )
210
+ elif m_padded_length != 0 and input .dim () == 2 :
211
+ input = pad_dim (input , m_padded_length , 0 )
212
+
213
+ if k_padded_length != 0 :
214
+ return torch .ops .aten .addmm (input , mat1 , mat2 )
215
+ elif n_padded_length != 0 :
216
+ return torch .ops .aten .addmm (input , mat1 , mat2 )[:, :- n_padded_length ]
217
+ else :
218
+ return torch .ops .aten .addmm (input , mat1 , mat2 )[:- m_padded_length , :]
219
+
220
+
213
221
def should_pad_bench (mat1 , mat2 , op , input = None ):
214
222
assert utils .has_triton ()
215
223
from triton .testing import do_bench
216
224
217
225
with no_dispatch ():
218
226
if op is torch .ops .aten .mm or op is torch .ops .aten .addmm :
219
- m_padded_length = get_padded_length (mat1 .shape [0 ])
220
- k_padded_length = get_padded_length (mat1 .shape [1 ])
221
- n_padded_length = get_padded_length (mat2 .shape [1 ])
227
+ m_padded_length = get_padded_length (mat1 .shape [0 ], get_alignment_size ( mat1 ) )
228
+ k_padded_length = get_padded_length (mat1 .shape [1 ], get_alignment_size ( mat1 ) )
229
+ n_padded_length = get_padded_length (mat2 .shape [1 ], get_alignment_size ( mat2 ) )
222
230
elif op is torch .ops .aten .bmm :
223
- m_padded_length = get_padded_length (mat1 .shape [1 ])
224
- k_padded_length = get_padded_length (mat1 .shape [2 ])
225
- n_padded_length = get_padded_length (mat2 .shape [2 ])
231
+ m_padded_length = get_padded_length (mat1 .shape [1 ], get_alignment_size ( mat1 ) )
232
+ k_padded_length = get_padded_length (mat1 .shape [2 ], get_alignment_size ( mat1 ) )
233
+ n_padded_length = get_padded_length (mat2 .shape [2 ], get_alignment_size ( mat2 ) )
226
234
else :
227
235
return False
228
236
@@ -244,85 +252,123 @@ def should_pad_bench(mat1, mat2, op, input=None):
244
252
lambda : op (input , mat1 , mat2 ), warmup = warmup , rep = rep , fast_flush = True
245
253
)[0 ]
246
254
247
- mat1_pad = mat1 .new_empty ([get_padded_length (i ) + i for i in mat1 .shape ])
248
- mat2_pad = mat2 .new_empty ([get_padded_length (i ) + i for i in mat2 .shape ])
255
+ mat1_pad = torch .randn_like (mat1 )
256
+ mat2_pad = torch .randn_like (mat2 )
257
+
249
258
if op is torch .ops .aten .addmm :
250
259
input_pad = None
251
- if input is not None and input .is_cuda and input .dtype == torch .float32 :
252
- input_pad = input .new_empty (
253
- [get_padded_length (i ) + i for i in input .shape ]
254
- )
260
+ if input is not None and input .is_cuda :
261
+ input_pad = torch .randn_like (input )
262
+ pad_time = do_bench (
263
+ lambda : pad_addmm (
264
+ input_pad ,
265
+ mat1_pad ,
266
+ mat2_pad ,
267
+ m_padded_length ,
268
+ k_padded_length ,
269
+ n_padded_length ,
270
+ ),
271
+ warmup = warmup ,
272
+ rep = rep ,
273
+ fast_flush = True ,
274
+ )[0 ]
275
+ elif op is torch .ops .aten .mm :
255
276
pad_time = do_bench (
256
- lambda : op (input_pad , mat1_pad , mat2_pad ),
277
+ lambda : pad_mm (
278
+ mat1_pad ,
279
+ mat2_pad ,
280
+ m_padded_length ,
281
+ k_padded_length ,
282
+ n_padded_length ,
283
+ ),
257
284
warmup = warmup ,
258
285
rep = rep ,
259
286
fast_flush = True ,
260
287
)[0 ]
261
288
else :
289
+ if k_padded_length == 0 and not config .shape_padding_bmm :
290
+ return False
262
291
pad_time = do_bench (
263
- lambda : op (mat1_pad , mat2_pad ), warmup = warmup , rep = rep , fast_flush = True
292
+ lambda : pad_bmm (
293
+ mat1_pad ,
294
+ mat2_pad ,
295
+ m_padded_length ,
296
+ k_padded_length ,
297
+ n_padded_length ,
298
+ ),
299
+ warmup = warmup ,
300
+ rep = rep ,
301
+ fast_flush = True ,
264
302
)[0 ]
265
303
266
- # Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.3x for
267
- # aten.mm and aten.addmm and 2x for aten.bmm represent a reasonable tradeoff between
268
- # performance improvement from shape padding and overhead from addtional memory ops
304
+ # Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
305
+ # tradeoff between performance improvement from shape padding and overhead from addtional memory ops
269
306
# TODO: Build a learned model which would be better than this heuristic
270
- if op is torch .ops .aten .mm or op is torch .ops .aten .addmm :
271
- return ori_time > pad_time * 1.3
272
- else :
273
- return ori_time > pad_time * 2
307
+ return ori_time > pad_time * 1.1
274
308
275
309
276
310
@register_decomposition ([aten .mm ])
277
311
def mm_decomp (mat1 , mat2 ):
278
312
if (
279
313
config .shape_padding
280
- and check_device_dtype (mat1 , mat2 )
314
+ and check_device (mat1 , mat2 )
281
315
and should_pad_bench (mat1 , mat2 , torch .ops .aten .mm )
282
316
):
283
- m_padded_length = get_padded_length (mat1 .shape [0 ])
284
- k_padded_length = get_padded_length (mat1 .shape [1 ])
285
- n_padded_length = get_padded_length (mat2 .shape [1 ])
286
-
287
- if k_padded_length != 0 :
288
- mat1 = pad_dim (mat1 , k_padded_length , 1 )
289
- mat2 = pad_dim (mat2 , k_padded_length , 0 )
290
- return torch .ops .aten .mm (mat1 , mat2 )
291
- elif m_padded_length != 0 :
292
- mat1 = pad_dim (mat1 , m_padded_length , 0 )
293
- return torch .ops .aten .mm (mat1 , mat2 )[:- m_padded_length , :]
294
- elif n_padded_length != 0 :
295
- mat2 = pad_dim (mat2 , n_padded_length , 1 )
296
- return torch .ops .aten .mm (mat1 , mat2 )[:, :- n_padded_length ]
317
+ m_padded_length = get_padded_length (mat1 .shape [0 ], get_alignment_size (mat1 ))
318
+ k_padded_length = get_padded_length (mat1 .shape [1 ], get_alignment_size (mat1 ))
319
+ n_padded_length = get_padded_length (mat2 .shape [1 ], get_alignment_size (mat2 ))
320
+
321
+ if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0 :
322
+ return pad_mm (mat1 , mat2 , m_padded_length , k_padded_length , n_padded_length )
297
323
298
324
return NotImplemented # go directly to lowering
299
325
300
326
327
+ def pad_mm (mat1 , mat2 , m_padded_length , k_padded_length , n_padded_length ):
328
+ if k_padded_length != 0 :
329
+ mat1 = pad_dim (mat1 , k_padded_length , 1 )
330
+ mat2 = pad_dim (mat2 , k_padded_length , 0 )
331
+ return torch .ops .aten .mm (mat1 , mat2 )
332
+ elif n_padded_length != 0 :
333
+ mat2 = pad_dim (mat2 , n_padded_length , 1 )
334
+ return torch .ops .aten .mm (mat1 , mat2 )[:, :- n_padded_length ]
335
+ else :
336
+ mat1 = pad_dim (mat1 , m_padded_length , 0 )
337
+ return torch .ops .aten .mm (mat1 , mat2 )[:- m_padded_length , :]
338
+
339
+
301
340
@register_decomposition ([aten .bmm ])
302
341
def bmm_decomp (mat1 , mat2 ):
303
342
if (
304
343
config .shape_padding
305
- and check_device_dtype (mat1 , mat2 )
344
+ and check_device (mat1 , mat2 )
306
345
and should_pad_bench (mat1 , mat2 , torch .ops .aten .bmm )
307
346
):
308
- m_padded_length = get_padded_length (mat1 .shape [1 ])
309
- k_padded_length = get_padded_length (mat1 .shape [2 ])
310
- n_padded_length = get_padded_length (mat2 .shape [2 ])
311
-
312
- if k_padded_length != 0 :
313
- mat1 = pad_dim (mat1 , k_padded_length , 2 )
314
- mat2 = pad_dim (mat2 , k_padded_length , 1 )
315
- return torch .ops .aten .bmm (mat1 , mat2 )
316
- elif m_padded_length != 0 :
317
- mat1 = pad_dim (mat1 , m_padded_length , 1 )
318
- return torch .ops .aten .bmm (mat1 , mat2 )[:, :- m_padded_length , :].contiguous ()
319
- elif n_padded_length != 0 :
320
- mat2 = pad_dim (mat2 , n_padded_length , 2 )
321
- return torch .ops .aten .bmm (mat1 , mat2 )[:, :, :- n_padded_length ].contiguous ()
347
+ m_padded_length = get_padded_length (mat1 .shape [1 ], get_alignment_size (mat1 ))
348
+ k_padded_length = get_padded_length (mat1 .shape [2 ], get_alignment_size (mat1 ))
349
+ n_padded_length = get_padded_length (mat2 .shape [2 ], get_alignment_size (mat2 ))
350
+
351
+ if k_padded_length != 0 or (
352
+ config .shape_padding_bmm and (n_padded_length != 0 or m_padded_length != 0 )
353
+ ):
354
+ pad_bmm (mat1 , mat2 , m_padded_length , k_padded_length , n_padded_length )
322
355
323
356
return NotImplemented # go directly to lowering
324
357
325
358
359
+ def pad_bmm (mat1 , mat2 , m_padded_length , k_padded_length , n_padded_length ):
360
+ if k_padded_length != 0 :
361
+ mat1 = pad_dim (mat1 , k_padded_length , 2 )
362
+ mat2 = pad_dim (mat2 , k_padded_length , 1 )
363
+ return torch .ops .aten .bmm (mat1 , mat2 )
364
+ elif config .shape_padding_bmm and n_padded_length != 0 :
365
+ mat2 = pad_dim (mat2 , n_padded_length , 2 )
366
+ return torch .ops .aten .bmm (mat1 , mat2 )[:, :, :- n_padded_length ].contiguous ()
367
+ else :
368
+ mat1 = pad_dim (mat1 , m_padded_length , 1 )
369
+ return torch .ops .aten .bmm (mat1 , mat2 )[:, :- m_padded_length , :].contiguous ()
370
+
371
+
326
372
@register_decomposition ([aten .convolution_backward ])
327
373
def convolution_backward (
328
374
grad_output ,
0 commit comments