Skip to content

Commit 4a1633c

Browse files
jiawenliu64pytorchmergebot
authored andcommitted
[Inductor] GEMM Shape Padding Optimization (pytorch#90425)
Summary: Optimize the shape padding in the following perspectives: - Add BFloat16 support for AMP training and Float16 support for inference - Optimize microbenchmark to avoid peak memory issue, and include profiling memory ops to make more accurate decision - Add a flag to turn off/on padding dims N and M in `torch.bmm` due to expensive memory copy of `.contiguous` to avoid peak memory issues in internal models Test Plan: CI Differential Revision: D41724868 Pull Request resolved: pytorch#90425 Approved by: https://github.com/jianyuh
1 parent b7dfbf8 commit 4a1633c

File tree

2 files changed

+144
-96
lines changed

2 files changed

+144
-96
lines changed

torch/_inductor/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def is_fbcode():
9292

9393
# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
9494
shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1"
95-
alignment_size = 4
95+
96+
# Pad input tensors in dimension N and M of bmm to leverage Tensor Cores in NVIDIA GPUs
97+
shape_padding_bmm = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING_BMM", "1") == "1"
9698

9799
# Fx-based linear/matmul/bmm + permute/transpose vertical fusion
98100
permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"

torch/_inductor/decomposition.py

Lines changed: 141 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -140,26 +140,32 @@ def floordiv(a, b):
140140
return aten.div.Tensor_mode(a, b, rounding_mode="floor")
141141

142142

143-
def get_padded_length(x):
144-
if x % config.alignment_size == 0:
143+
def get_alignment_size(x):
144+
if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16:
145+
return 8
146+
elif x.dtype == torch.float32 or x.dtype == torch.float:
147+
return 4
148+
else:
149+
return 0
150+
151+
152+
def check_device(a: Tensor, b: Tensor):
153+
return a.is_cuda and b.is_cuda
154+
155+
156+
def get_padded_length(x, alignment_size):
157+
if alignment_size == 0 or x % alignment_size == 0:
145158
return 0
146-
return int((x // config.alignment_size + 1) * config.alignment_size) - x
159+
return int((x // alignment_size + 1) * alignment_size) - x
147160

148161

149162
def pad_dim(x, padded_length, dim):
163+
if padded_length == 0:
164+
return x
150165
pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
151166
return torch.cat([x, pad], dim=dim)
152167

153168

154-
def check_device_dtype(a: Tensor, b: Tensor):
155-
return (
156-
a.is_cuda
157-
and b.is_cuda
158-
and a.dtype in (torch.float32, torch.float16, torch.bfloat16)
159-
and b.dtype in (torch.float32, torch.float16, torch.bfloat16)
160-
)
161-
162-
163169
@register_decomposition([aten.addmm])
164170
def addmm(input, mat1, mat2, *, beta=1, alpha=1):
165171
if config.triton.mm != "aten":
@@ -172,57 +178,59 @@ def addmm(input, mat1, mat2, *, beta=1, alpha=1):
172178

173179
if (
174180
config.shape_padding
175-
and check_device_dtype(mat1, mat2)
181+
and check_device(mat1, mat2)
176182
and should_pad_bench(mat1, mat2, torch.ops.aten.addmm, input=input)
177183
):
178-
m_padded_length = get_padded_length(mat1.shape[0])
179-
k_padded_length = get_padded_length(mat1.shape[1])
180-
n_padded_length = get_padded_length(mat2.shape[1])
181-
182-
if k_padded_length != 0:
183-
mat1 = pad_dim(mat1, k_padded_length, 1)
184-
mat2 = pad_dim(mat2, k_padded_length, 0)
185-
elif m_padded_length != 0:
186-
mat1 = pad_dim(mat1, m_padded_length, 0)
187-
elif n_padded_length != 0:
188-
mat2 = pad_dim(mat2, n_padded_length, 1)
189-
190-
if input is not None and k_padded_length == 0:
191-
if m_padded_length != 0 and input.dim() == 2:
192-
input = pad_dim(input, m_padded_length, 0)
193-
elif n_padded_length != 0:
194-
if input.dim() == 2:
195-
input = pad_dim(input, n_padded_length, 1)
196-
elif input.dim() == 1:
197-
input = pad_dim(input, n_padded_length, 0)
198-
199-
if k_padded_length != 0:
200-
return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
201-
elif m_padded_length != 0:
202-
return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)[
203-
:-m_padded_length, :
204-
]
205-
elif n_padded_length != 0:
206-
return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)[
207-
:, :-n_padded_length
208-
]
184+
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
185+
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
186+
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
187+
if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
188+
return pad_addmm(
189+
input, mat1, mat2, m_padded_length, n_padded_length, k_padded_length
190+
)
209191

210192
return NotImplemented # go directly to lowering
211193

212194

195+
def pad_addmm(input, mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
196+
if k_padded_length != 0:
197+
mat1 = pad_dim(mat1, k_padded_length, 1)
198+
mat2 = pad_dim(mat2, k_padded_length, 0)
199+
elif n_padded_length != 0:
200+
mat2 = pad_dim(mat2, n_padded_length, 1)
201+
elif m_padded_length != 0:
202+
mat1 = pad_dim(mat1, m_padded_length, 0)
203+
204+
if input is not None and k_padded_length == 0:
205+
if n_padded_length != 0:
206+
if input.dim() == 2:
207+
input = pad_dim(input, n_padded_length, 1)
208+
elif input.dim() == 1:
209+
input = pad_dim(input, n_padded_length, 0)
210+
elif m_padded_length != 0 and input.dim() == 2:
211+
input = pad_dim(input, m_padded_length, 0)
212+
213+
if k_padded_length != 0:
214+
return torch.ops.aten.addmm(input, mat1, mat2)
215+
elif n_padded_length != 0:
216+
return torch.ops.aten.addmm(input, mat1, mat2)[:, :-n_padded_length]
217+
else:
218+
return torch.ops.aten.addmm(input, mat1, mat2)[:-m_padded_length, :]
219+
220+
213221
def should_pad_bench(mat1, mat2, op, input=None):
214222
assert utils.has_triton()
215223
from triton.testing import do_bench
216224

217225
with no_dispatch():
218226
if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
219-
m_padded_length = get_padded_length(mat1.shape[0])
220-
k_padded_length = get_padded_length(mat1.shape[1])
221-
n_padded_length = get_padded_length(mat2.shape[1])
227+
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
228+
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
229+
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
222230
elif op is torch.ops.aten.bmm:
223-
m_padded_length = get_padded_length(mat1.shape[1])
224-
k_padded_length = get_padded_length(mat1.shape[2])
225-
n_padded_length = get_padded_length(mat2.shape[2])
231+
m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
232+
k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
233+
n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
226234
else:
227235
return False
228236

@@ -244,85 +252,123 @@ def should_pad_bench(mat1, mat2, op, input=None):
244252
lambda: op(input, mat1, mat2), warmup=warmup, rep=rep, fast_flush=True
245253
)[0]
246254

247-
mat1_pad = mat1.new_empty([get_padded_length(i) + i for i in mat1.shape])
248-
mat2_pad = mat2.new_empty([get_padded_length(i) + i for i in mat2.shape])
255+
mat1_pad = torch.randn_like(mat1)
256+
mat2_pad = torch.randn_like(mat2)
257+
249258
if op is torch.ops.aten.addmm:
250259
input_pad = None
251-
if input is not None and input.is_cuda and input.dtype == torch.float32:
252-
input_pad = input.new_empty(
253-
[get_padded_length(i) + i for i in input.shape]
254-
)
260+
if input is not None and input.is_cuda:
261+
input_pad = torch.randn_like(input)
262+
pad_time = do_bench(
263+
lambda: pad_addmm(
264+
input_pad,
265+
mat1_pad,
266+
mat2_pad,
267+
m_padded_length,
268+
k_padded_length,
269+
n_padded_length,
270+
),
271+
warmup=warmup,
272+
rep=rep,
273+
fast_flush=True,
274+
)[0]
275+
elif op is torch.ops.aten.mm:
255276
pad_time = do_bench(
256-
lambda: op(input_pad, mat1_pad, mat2_pad),
277+
lambda: pad_mm(
278+
mat1_pad,
279+
mat2_pad,
280+
m_padded_length,
281+
k_padded_length,
282+
n_padded_length,
283+
),
257284
warmup=warmup,
258285
rep=rep,
259286
fast_flush=True,
260287
)[0]
261288
else:
289+
if k_padded_length == 0 and not config.shape_padding_bmm:
290+
return False
262291
pad_time = do_bench(
263-
lambda: op(mat1_pad, mat2_pad), warmup=warmup, rep=rep, fast_flush=True
292+
lambda: pad_bmm(
293+
mat1_pad,
294+
mat2_pad,
295+
m_padded_length,
296+
k_padded_length,
297+
n_padded_length,
298+
),
299+
warmup=warmup,
300+
rep=rep,
301+
fast_flush=True,
264302
)[0]
265303

266-
# Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.3x for
267-
# aten.mm and aten.addmm and 2x for aten.bmm represent a reasonable tradeoff between
268-
# performance improvement from shape padding and overhead from addtional memory ops
304+
# Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
305+
# tradeoff between performance improvement from shape padding and overhead from addtional memory ops
269306
# TODO: Build a learned model which would be better than this heuristic
270-
if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
271-
return ori_time > pad_time * 1.3
272-
else:
273-
return ori_time > pad_time * 2
307+
return ori_time > pad_time * 1.1
274308

275309

276310
@register_decomposition([aten.mm])
277311
def mm_decomp(mat1, mat2):
278312
if (
279313
config.shape_padding
280-
and check_device_dtype(mat1, mat2)
314+
and check_device(mat1, mat2)
281315
and should_pad_bench(mat1, mat2, torch.ops.aten.mm)
282316
):
283-
m_padded_length = get_padded_length(mat1.shape[0])
284-
k_padded_length = get_padded_length(mat1.shape[1])
285-
n_padded_length = get_padded_length(mat2.shape[1])
286-
287-
if k_padded_length != 0:
288-
mat1 = pad_dim(mat1, k_padded_length, 1)
289-
mat2 = pad_dim(mat2, k_padded_length, 0)
290-
return torch.ops.aten.mm(mat1, mat2)
291-
elif m_padded_length != 0:
292-
mat1 = pad_dim(mat1, m_padded_length, 0)
293-
return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :]
294-
elif n_padded_length != 0:
295-
mat2 = pad_dim(mat2, n_padded_length, 1)
296-
return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length]
317+
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
318+
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
319+
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
320+
321+
if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
322+
return pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
297323

298324
return NotImplemented # go directly to lowering
299325

300326

327+
def pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
328+
if k_padded_length != 0:
329+
mat1 = pad_dim(mat1, k_padded_length, 1)
330+
mat2 = pad_dim(mat2, k_padded_length, 0)
331+
return torch.ops.aten.mm(mat1, mat2)
332+
elif n_padded_length != 0:
333+
mat2 = pad_dim(mat2, n_padded_length, 1)
334+
return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length]
335+
else:
336+
mat1 = pad_dim(mat1, m_padded_length, 0)
337+
return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :]
338+
339+
301340
@register_decomposition([aten.bmm])
302341
def bmm_decomp(mat1, mat2):
303342
if (
304343
config.shape_padding
305-
and check_device_dtype(mat1, mat2)
344+
and check_device(mat1, mat2)
306345
and should_pad_bench(mat1, mat2, torch.ops.aten.bmm)
307346
):
308-
m_padded_length = get_padded_length(mat1.shape[1])
309-
k_padded_length = get_padded_length(mat1.shape[2])
310-
n_padded_length = get_padded_length(mat2.shape[2])
311-
312-
if k_padded_length != 0:
313-
mat1 = pad_dim(mat1, k_padded_length, 2)
314-
mat2 = pad_dim(mat2, k_padded_length, 1)
315-
return torch.ops.aten.bmm(mat1, mat2)
316-
elif m_padded_length != 0:
317-
mat1 = pad_dim(mat1, m_padded_length, 1)
318-
return torch.ops.aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous()
319-
elif n_padded_length != 0:
320-
mat2 = pad_dim(mat2, n_padded_length, 2)
321-
return torch.ops.aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous()
347+
m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
348+
k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
349+
n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
350+
351+
if k_padded_length != 0 or (
352+
config.shape_padding_bmm and (n_padded_length != 0 or m_padded_length != 0)
353+
):
354+
pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
322355

323356
return NotImplemented # go directly to lowering
324357

325358

359+
def pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
360+
if k_padded_length != 0:
361+
mat1 = pad_dim(mat1, k_padded_length, 2)
362+
mat2 = pad_dim(mat2, k_padded_length, 1)
363+
return torch.ops.aten.bmm(mat1, mat2)
364+
elif config.shape_padding_bmm and n_padded_length != 0:
365+
mat2 = pad_dim(mat2, n_padded_length, 2)
366+
return torch.ops.aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous()
367+
else:
368+
mat1 = pad_dim(mat1, m_padded_length, 1)
369+
return torch.ops.aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous()
370+
371+
326372
@register_decomposition([aten.convolution_backward])
327373
def convolution_backward(
328374
grad_output,

0 commit comments

Comments
 (0)