Skip to content

Commit 531c724

Browse files
authored
Add block sizes for Qwen/Qwen2.5-32B-Instruct (#9516)
1 parent 2820f7c commit 531c724

File tree

2 files changed

+142
-50
lines changed

2 files changed

+142
-50
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import math
44
import warnings
5+
import logging
56

67
import torch
78
from torch.library import impl, custom_op
@@ -1097,6 +1098,9 @@ def quantized_matmul_int8(
10971098
"out_block_size": out_block_size,
10981099
"in_block_size": in_block_size,
10991100
})
1101+
logging.warning(
1102+
f"Couldn't find w8a8 quantized matmul kernel block sizes for {bs=}, {n_out_features=}, {n_in_features=}, {jnp.dtype(jax_dtype).name=}, {quantize_activation=}, falling back to XLA quantized matmul kernel."
1103+
)
11001104
from torch_xla.experimental.xla_quantized_matmul import quantized_matmul_xla
11011105
return quantized_matmul_xla(
11021106
x, w, scalar, quantize_activation=quantize_activation).to(x.dtype)

torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py

Lines changed: 138 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -324,70 +324,158 @@ def quantized_matmul_int8(
324324
# - out_block_size
325325
# - in_block_size
326326
TUNED_BLOCK_SIZES = {
327-
(6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
328-
(6, 128, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
329-
(6, 2048, 6144, 4096, 'bfloat16', True): (2048, 512, 4096),
330-
(6, 2048, 4096, 4096, 'bfloat16', True): (2048, 512, 4096),
331-
(6, 2048, 4096, 14336, 'bfloat16', True): (2048, 4096, 512),
332-
(6, 128, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
333-
(6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
334-
(6, 2048, 28672, 4096, 'bfloat16', True): (2048, 1024, 4096),
335-
(6, 16, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
336-
(6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
337-
(6, 64, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
338-
(6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
339-
(6, 256, 6144, 4096, 'bfloat16', True): (256, 512, 4096),
340-
(6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096),
341-
(6, 256, 28672, 4096, 'bfloat16', True): (256, 2048, 4096),
342-
(6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 512),
343-
(6, 16, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
344-
(6, 512, 6144, 4096, 'bfloat16', True): (512, 1024, 4096),
345-
(6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096),
346-
(6, 512, 28672, 4096, 'bfloat16', True): (512, 2048, 4096),
347-
(6, 512, 4096, 14336, 'bfloat16', True): (512, 256, 14336),
348-
(6, 1024, 6144, 4096, 'bfloat16', True): (1024, 768, 4096),
349-
(6, 1024, 4096, 4096, 'bfloat16', True): (1024, 512, 4096),
327+
(6, 1024, 1280, 8192, 'bfloat16', True): (1024, 256, 8192),
328+
(6, 1024, 13824, 5120, 'bfloat16', True): (1024, 768, 5120),
329+
(6, 1024, 1792, 5120, 'bfloat16', True): (1024, 256, 5120),
350330
(6, 1024, 28672, 4096, 'bfloat16', True): (1024, 2048, 4096),
351331
(6, 1024, 4096, 14336, 'bfloat16', True): (1024, 256, 14336),
352-
(6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
353-
(6, 32, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
354-
(6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
355-
(6, 32, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
356-
(6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
357-
(6, 64, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
358-
(6, 64, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
359-
(6, 16, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
360-
(6, 16, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
361-
(6, 64, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
362-
(6, 64, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
332+
(6, 1024, 4096, 4096, 'bfloat16', True): (1024, 512, 4096),
333+
(6, 1024, 5120, 1280, 'bfloat16', True): (1024, 1280, 1280),
334+
(6, 1024, 5120, 3456, 'bfloat16', True): (1024, 1024, 3456),
335+
(6, 1024, 5120, 640, 'bfloat16', True): (256, 5120, 640),
336+
(6, 1024, 5120, 6912, 'bfloat16', True): (1024, 512, 6912),
337+
(6, 1024, 6144, 4096, 'bfloat16', True): (1024, 768, 4096),
338+
(6, 1024, 6912, 5120, 'bfloat16', True): (1024, 768, 5120),
339+
(6, 1024, 7168, 8192, 'bfloat16', True): (1024, 512, 8192),
340+
(6, 1024, 8192, 1024, 'bfloat16', True): (1024, 4096, 1024),
341+
(6, 1024, 8192, 3584, 'bfloat16', True): (1024, 1024, 3584),
342+
(6, 1024, 896, 5120, 'bfloat16', True): (1024, 896, 2560),
363343
(6, 128, 1280, 8192, 'bfloat16', True): (128, 1280, 2048),
364-
(6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
344+
(6, 128, 13824, 5120, 'bfloat16', True): (128, 512, 5120),
345+
(6, 128, 1792, 5120, 'bfloat16', True): (128, 1792, 1280),
346+
(6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
347+
(6, 128, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
348+
(6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
349+
(6, 128, 5120, 1280, 'bfloat16', True): (128, 1280, 1280),
350+
(6, 128, 5120, 3456, 'bfloat16', True): (128, 640, 3456),
351+
(6, 128, 5120, 640, 'bfloat16', True): (128, 2560, 640),
352+
(6, 128, 5120, 6912, 'bfloat16', True): (128, 2560, 1152),
353+
(6, 128, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
354+
(6, 128, 6912, 5120, 'bfloat16', True): (128, 1152, 2560),
365355
(6, 128, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
356+
(6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
366357
(6, 128, 8192, 3584, 'bfloat16', True): (128, 8192, 512),
367-
(6, 256, 1280, 8192, 'bfloat16', True): (256, 256, 8192),
368-
(6, 256, 8192, 1024, 'bfloat16', True): (256, 2048, 1024),
369-
(6, 256, 7168, 8192, 'bfloat16', True): (256, 512, 8192),
370-
(6, 256, 8192, 3584, 'bfloat16', True): (256, 8192, 512),
358+
(6, 128, 896, 5120, 'bfloat16', True): (128, 896, 2560),
359+
(6, 16, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
360+
(6, 16, 13824, 5120, 'bfloat16', True): (128, 512, 5120),
361+
(6, 16, 1792, 5120, 'bfloat16', True): (128, 896, 2560),
362+
(6, 16, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
363+
(6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
364+
(6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
365+
(6, 16, 5120, 1280, 'bfloat16', True): (128, 1280, 1280),
366+
(6, 16, 5120, 3456, 'bfloat16', True): (128, 640, 3456),
367+
(6, 16, 5120, 640, 'bfloat16', True): (128, 2560, 640),
368+
(6, 16, 5120, 6912, 'bfloat16', True): (128, 1280, 2304),
369+
(6, 16, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
370+
(6, 16, 6912, 5120, 'bfloat16', True): (128, 1152, 2560),
371371
(6, 16, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
372-
(6, 512, 1280, 8192, 'bfloat16', True): (512, 256, 8192),
373-
(6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024),
374-
(6, 512, 7168, 8192, 'bfloat16', True): (512, 512, 8192),
375-
(6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584),
376-
(6, 1024, 1280, 8192, 'bfloat16', True): (1024, 256, 8192),
377-
(6, 1024, 8192, 1024, 'bfloat16', True): (1024, 4096, 1024),
378-
(6, 1024, 7168, 8192, 'bfloat16', True): (1024, 512, 8192),
379-
(6, 1024, 8192, 3584, 'bfloat16', True): (1024, 1024, 3584),
380-
(6, 2048, 1280, 8192, 'bfloat16', True): (2048, 256, 8192),
381-
(6, 2048, 8192, 1024, 'bfloat16', True): (256, 8192, 1024),
372+
(6, 16, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
382373
(6, 16, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
374+
(6, 16, 896, 5120, 'bfloat16', True): (128, 896, 2560),
375+
(6, 16384, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120),
376+
(6, 16384, 1792, 5120, 'bfloat16', True): (1024, 1792, 5120),
377+
(6, 16384, 5120, 1280, 'bfloat16', True): (512, 5120, 1280),
378+
(6, 16384, 5120, 3456, 'bfloat16', True): (512, 5120, 3456),
379+
(6, 16384, 5120, 640, 'bfloat16', True): (512, 5120, 640),
380+
(6, 16384, 5120, 6912, 'bfloat16', True): (512, 5120, 6912),
381+
(6, 16384, 6912, 5120, 'bfloat16', True): (512, 6912, 5120),
382+
(6, 16384, 896, 5120, 'bfloat16', True): (1024, 896, 5120),
383+
(6, 2048, 1280, 8192, 'bfloat16', True): (2048, 256, 8192),
384+
(6, 2048, 13824, 5120, 'bfloat16', True): (2048, 768, 5120),
385+
(6, 2048, 1792, 5120, 'bfloat16', True): (2048, 256, 5120),
386+
(6, 2048, 28672, 4096, 'bfloat16', True): (2048, 1024, 4096),
387+
(6, 2048, 4096, 14336, 'bfloat16', True): (2048, 4096, 512),
388+
(6, 2048, 4096, 4096, 'bfloat16', True): (2048, 512, 4096),
389+
(6, 2048, 5120, 1280, 'bfloat16', True): (256, 5120, 1280),
390+
(6, 2048, 5120, 3456, 'bfloat16', True): (2048, 512, 3456),
391+
(6, 2048, 5120, 640, 'bfloat16', True): (256, 5120, 640),
392+
(6, 2048, 5120, 6912, 'bfloat16', True): (2048, 512, 6912),
393+
(6, 2048, 6144, 4096, 'bfloat16', True): (2048, 512, 4096),
394+
(6, 2048, 6912, 5120, 'bfloat16', True): (2048, 768, 5120),
383395
(6, 2048, 7168, 8192, 'bfloat16', True): (2048, 256, 8192),
396+
(6, 2048, 8192, 1024, 'bfloat16', True): (256, 8192, 1024),
384397
(6, 2048, 8192, 3584, 'bfloat16', True): (2048, 512, 3584),
398+
(6, 2048, 896, 5120, 'bfloat16', True): (1024, 896, 5120),
399+
(6, 256, 1280, 8192, 'bfloat16', True): (256, 256, 8192),
400+
(6, 256, 13824, 5120, 'bfloat16', True): (256, 512, 5120),
401+
(6, 256, 1792, 5120, 'bfloat16', True): (256, 1792, 1280),
402+
(6, 256, 28672, 4096, 'bfloat16', True): (256, 2048, 4096),
403+
(6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 512),
404+
(6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096),
405+
(6, 256, 5120, 1280, 'bfloat16', True): (256, 2560, 1280),
406+
(6, 256, 5120, 3456, 'bfloat16', True): (256, 1024, 3456),
407+
(6, 256, 5120, 640, 'bfloat16', True): (256, 2560, 640),
408+
(6, 256, 5120, 6912, 'bfloat16', True): (256, 5120, 768),
409+
(6, 256, 6144, 4096, 'bfloat16', True): (256, 512, 4096),
410+
(6, 256, 6912, 5120, 'bfloat16', True): (256, 6912, 512),
411+
(6, 256, 7168, 8192, 'bfloat16', True): (256, 512, 8192),
412+
(6, 256, 8192, 1024, 'bfloat16', True): (256, 2048, 1024),
413+
(6, 256, 8192, 3584, 'bfloat16', True): (256, 8192, 512),
414+
(6, 256, 896, 5120, 'bfloat16', True): (256, 896, 2560),
385415
(6, 32, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
386-
(6, 32, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
416+
(6, 32, 13824, 5120, 'bfloat16', True): (128, 512, 5120),
417+
(6, 32, 1792, 5120, 'bfloat16', True): (128, 896, 2560),
418+
(6, 32, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
419+
(6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
420+
(6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
421+
(6, 32, 5120, 1280, 'bfloat16', True): (128, 1280, 1280),
422+
(6, 32, 5120, 3456, 'bfloat16', True): (128, 640, 3456),
423+
(6, 32, 5120, 640, 'bfloat16', True): (128, 2560, 640),
424+
(6, 32, 5120, 6912, 'bfloat16', True): (128, 1280, 2304),
425+
(6, 32, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
426+
(6, 32, 6912, 5120, 'bfloat16', True): (128, 2304, 1280),
387427
(6, 32, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
428+
(6, 32, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
388429
(6, 32, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
430+
(6, 32, 896, 5120, 'bfloat16', True): (128, 896, 2560),
431+
(6, 4096, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120),
432+
(6, 4096, 1792, 5120, 'bfloat16', True): (512, 1792, 5120),
433+
(6, 4096, 5120, 1280, 'bfloat16', True): (256, 5120, 1280),
434+
(6, 4096, 5120, 3456, 'bfloat16', True): (4096, 512, 3456),
435+
(6, 4096, 5120, 640, 'bfloat16', True): (256, 5120, 640),
436+
(6, 4096, 5120, 6912, 'bfloat16', True): (256, 5120, 6912),
437+
(6, 4096, 6912, 5120, 'bfloat16', True): (256, 6912, 5120),
438+
(6, 4096, 896, 5120, 'bfloat16', True): (256, 896, 5120),
439+
(6, 512, 1280, 8192, 'bfloat16', True): (512, 256, 8192),
440+
(6, 512, 13824, 5120, 'bfloat16', True): (512, 13824, 512),
441+
(6, 512, 1792, 5120, 'bfloat16', True): (512, 1792, 1280),
442+
(6, 512, 28672, 4096, 'bfloat16', True): (512, 2048, 4096),
443+
(6, 512, 4096, 14336, 'bfloat16', True): (512, 256, 14336),
444+
(6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096),
445+
(6, 512, 5120, 1280, 'bfloat16', True): (512, 2560, 1280),
446+
(6, 512, 5120, 3456, 'bfloat16', True): (512, 1280, 3456),
447+
(6, 512, 5120, 640, 'bfloat16', True): (512, 2560, 640),
448+
(6, 512, 5120, 6912, 'bfloat16', True): (512, 512, 6912),
449+
(6, 512, 6144, 4096, 'bfloat16', True): (512, 1024, 4096),
450+
(6, 512, 6912, 5120, 'bfloat16', True): (512, 768, 5120),
451+
(6, 512, 7168, 8192, 'bfloat16', True): (512, 512, 8192),
452+
(6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024),
453+
(6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584),
454+
(6, 512, 896, 5120, 'bfloat16', True): (512, 896, 2560),
389455
(6, 64, 1280, 8192, 'bfloat16', True): (128, 256, 8192),
456+
(6, 64, 13824, 5120, 'bfloat16', True): (128, 512, 5120),
457+
(6, 64, 1792, 5120, 'bfloat16', True): (128, 896, 2560),
458+
(6, 64, 28672, 4096, 'bfloat16', True): (128, 28672, 256),
459+
(6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 896),
460+
(6, 64, 4096, 4096, 'bfloat16', True): (128, 512, 4096),
461+
(6, 64, 5120, 1280, 'bfloat16', True): (128, 1280, 1280),
462+
(6, 64, 5120, 3456, 'bfloat16', True): (128, 1024, 3456),
463+
(6, 64, 5120, 640, 'bfloat16', True): (128, 2560, 640),
464+
(6, 64, 5120, 6912, 'bfloat16', True): (128, 1280, 2304),
465+
(6, 64, 6144, 4096, 'bfloat16', True): (128, 768, 4096),
466+
(6, 64, 6912, 5120, 'bfloat16', True): (128, 2304, 1280),
467+
(6, 64, 7168, 8192, 'bfloat16', True): (128, 256, 8192),
390468
(6, 64, 8192, 1024, 'bfloat16', True): (128, 2048, 1024),
469+
(6, 64, 8192, 3584, 'bfloat16', True): (128, 1024, 3584),
470+
(6, 64, 896, 5120, 'bfloat16', True): (128, 896, 2560),
471+
(6, 8192, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120),
472+
(6, 8192, 1792, 5120, 'bfloat16', True): (512, 1792, 5120),
473+
(6, 8192, 5120, 1280, 'bfloat16', True): (256, 5120, 1280),
474+
(6, 8192, 5120, 3456, 'bfloat16', True): (512, 5120, 3456),
475+
(6, 8192, 5120, 640, 'bfloat16', True): (512, 5120, 640),
476+
(6, 8192, 5120, 6912, 'bfloat16', True): (512, 5120, 6912),
477+
(6, 8192, 6912, 5120, 'bfloat16', True): (512, 6912, 5120),
478+
(6, 8192, 896, 5120, 'bfloat16', True): (512, 896, 5120),
391479
}
392480

393481

0 commit comments

Comments
 (0)