@@ -324,70 +324,158 @@ def quantized_matmul_int8(
324
324
# - out_block_size
325
325
# - in_block_size
326
326
TUNED_BLOCK_SIZES = {
327
- (6 , 128 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
328
- (6 , 128 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
329
- (6 , 2048 , 6144 , 4096 , 'bfloat16' , True ): (2048 , 512 , 4096 ),
330
- (6 , 2048 , 4096 , 4096 , 'bfloat16' , True ): (2048 , 512 , 4096 ),
331
- (6 , 2048 , 4096 , 14336 , 'bfloat16' , True ): (2048 , 4096 , 512 ),
332
- (6 , 128 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
333
- (6 , 128 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
334
- (6 , 2048 , 28672 , 4096 , 'bfloat16' , True ): (2048 , 1024 , 4096 ),
335
- (6 , 16 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
336
- (6 , 16 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
337
- (6 , 64 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
338
- (6 , 64 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
339
- (6 , 256 , 6144 , 4096 , 'bfloat16' , True ): (256 , 512 , 4096 ),
340
- (6 , 256 , 4096 , 4096 , 'bfloat16' , True ): (256 , 512 , 4096 ),
341
- (6 , 256 , 28672 , 4096 , 'bfloat16' , True ): (256 , 2048 , 4096 ),
342
- (6 , 256 , 4096 , 14336 , 'bfloat16' , True ): (256 , 4096 , 512 ),
343
- (6 , 16 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
344
- (6 , 512 , 6144 , 4096 , 'bfloat16' , True ): (512 , 1024 , 4096 ),
345
- (6 , 512 , 4096 , 4096 , 'bfloat16' , True ): (512 , 1024 , 4096 ),
346
- (6 , 512 , 28672 , 4096 , 'bfloat16' , True ): (512 , 2048 , 4096 ),
347
- (6 , 512 , 4096 , 14336 , 'bfloat16' , True ): (512 , 256 , 14336 ),
348
- (6 , 1024 , 6144 , 4096 , 'bfloat16' , True ): (1024 , 768 , 4096 ),
349
- (6 , 1024 , 4096 , 4096 , 'bfloat16' , True ): (1024 , 512 , 4096 ),
327
+ (6 , 1024 , 1280 , 8192 , 'bfloat16' , True ): (1024 , 256 , 8192 ),
328
+ (6 , 1024 , 13824 , 5120 , 'bfloat16' , True ): (1024 , 768 , 5120 ),
329
+ (6 , 1024 , 1792 , 5120 , 'bfloat16' , True ): (1024 , 256 , 5120 ),
350
330
(6 , 1024 , 28672 , 4096 , 'bfloat16' , True ): (1024 , 2048 , 4096 ),
351
331
(6 , 1024 , 4096 , 14336 , 'bfloat16' , True ): (1024 , 256 , 14336 ),
352
- (6 , 16 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
353
- (6 , 32 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
354
- (6 , 32 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
355
- (6 , 32 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
356
- (6 , 32 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
357
- (6 , 64 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
358
- (6 , 64 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
359
- (6 , 16 , 1280 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
360
- (6 , 16 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
361
- (6 , 64 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
362
- (6 , 64 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
332
+ (6 , 1024 , 4096 , 4096 , 'bfloat16' , True ): (1024 , 512 , 4096 ),
333
+ (6 , 1024 , 5120 , 1280 , 'bfloat16' , True ): (1024 , 1280 , 1280 ),
334
+ (6 , 1024 , 5120 , 3456 , 'bfloat16' , True ): (1024 , 1024 , 3456 ),
335
+ (6 , 1024 , 5120 , 640 , 'bfloat16' , True ): (256 , 5120 , 640 ),
336
+ (6 , 1024 , 5120 , 6912 , 'bfloat16' , True ): (1024 , 512 , 6912 ),
337
+ (6 , 1024 , 6144 , 4096 , 'bfloat16' , True ): (1024 , 768 , 4096 ),
338
+ (6 , 1024 , 6912 , 5120 , 'bfloat16' , True ): (1024 , 768 , 5120 ),
339
+ (6 , 1024 , 7168 , 8192 , 'bfloat16' , True ): (1024 , 512 , 8192 ),
340
+ (6 , 1024 , 8192 , 1024 , 'bfloat16' , True ): (1024 , 4096 , 1024 ),
341
+ (6 , 1024 , 8192 , 3584 , 'bfloat16' , True ): (1024 , 1024 , 3584 ),
342
+ (6 , 1024 , 896 , 5120 , 'bfloat16' , True ): (1024 , 896 , 2560 ),
363
343
(6 , 128 , 1280 , 8192 , 'bfloat16' , True ): (128 , 1280 , 2048 ),
364
- (6 , 128 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
344
+ (6 , 128 , 13824 , 5120 , 'bfloat16' , True ): (128 , 512 , 5120 ),
345
+ (6 , 128 , 1792 , 5120 , 'bfloat16' , True ): (128 , 1792 , 1280 ),
346
+ (6 , 128 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
347
+ (6 , 128 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
348
+ (6 , 128 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
349
+ (6 , 128 , 5120 , 1280 , 'bfloat16' , True ): (128 , 1280 , 1280 ),
350
+ (6 , 128 , 5120 , 3456 , 'bfloat16' , True ): (128 , 640 , 3456 ),
351
+ (6 , 128 , 5120 , 640 , 'bfloat16' , True ): (128 , 2560 , 640 ),
352
+ (6 , 128 , 5120 , 6912 , 'bfloat16' , True ): (128 , 2560 , 1152 ),
353
+ (6 , 128 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
354
+ (6 , 128 , 6912 , 5120 , 'bfloat16' , True ): (128 , 1152 , 2560 ),
365
355
(6 , 128 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
356
+ (6 , 128 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
366
357
(6 , 128 , 8192 , 3584 , 'bfloat16' , True ): (128 , 8192 , 512 ),
367
- (6 , 256 , 1280 , 8192 , 'bfloat16' , True ): (256 , 256 , 8192 ),
368
- (6 , 256 , 8192 , 1024 , 'bfloat16' , True ): (256 , 2048 , 1024 ),
369
- (6 , 256 , 7168 , 8192 , 'bfloat16' , True ): (256 , 512 , 8192 ),
370
- (6 , 256 , 8192 , 3584 , 'bfloat16' , True ): (256 , 8192 , 512 ),
358
+ (6 , 128 , 896 , 5120 , 'bfloat16' , True ): (128 , 896 , 2560 ),
359
+ (6 , 16 , 1280 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
360
+ (6 , 16 , 13824 , 5120 , 'bfloat16' , True ): (128 , 512 , 5120 ),
361
+ (6 , 16 , 1792 , 5120 , 'bfloat16' , True ): (128 , 896 , 2560 ),
362
+ (6 , 16 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
363
+ (6 , 16 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
364
+ (6 , 16 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
365
+ (6 , 16 , 5120 , 1280 , 'bfloat16' , True ): (128 , 1280 , 1280 ),
366
+ (6 , 16 , 5120 , 3456 , 'bfloat16' , True ): (128 , 640 , 3456 ),
367
+ (6 , 16 , 5120 , 640 , 'bfloat16' , True ): (128 , 2560 , 640 ),
368
+ (6 , 16 , 5120 , 6912 , 'bfloat16' , True ): (128 , 1280 , 2304 ),
369
+ (6 , 16 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
370
+ (6 , 16 , 6912 , 5120 , 'bfloat16' , True ): (128 , 1152 , 2560 ),
371
371
(6 , 16 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
372
- (6 , 512 , 1280 , 8192 , 'bfloat16' , True ): (512 , 256 , 8192 ),
373
- (6 , 512 , 8192 , 1024 , 'bfloat16' , True ): (512 , 4096 , 1024 ),
374
- (6 , 512 , 7168 , 8192 , 'bfloat16' , True ): (512 , 512 , 8192 ),
375
- (6 , 512 , 8192 , 3584 , 'bfloat16' , True ): (512 , 2048 , 3584 ),
376
- (6 , 1024 , 1280 , 8192 , 'bfloat16' , True ): (1024 , 256 , 8192 ),
377
- (6 , 1024 , 8192 , 1024 , 'bfloat16' , True ): (1024 , 4096 , 1024 ),
378
- (6 , 1024 , 7168 , 8192 , 'bfloat16' , True ): (1024 , 512 , 8192 ),
379
- (6 , 1024 , 8192 , 3584 , 'bfloat16' , True ): (1024 , 1024 , 3584 ),
380
- (6 , 2048 , 1280 , 8192 , 'bfloat16' , True ): (2048 , 256 , 8192 ),
381
- (6 , 2048 , 8192 , 1024 , 'bfloat16' , True ): (256 , 8192 , 1024 ),
372
+ (6 , 16 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
382
373
(6 , 16 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
374
+ (6 , 16 , 896 , 5120 , 'bfloat16' , True ): (128 , 896 , 2560 ),
375
+ (6 , 16384 , 13824 , 5120 , 'bfloat16' , True ): (2048 , 1536 , 5120 ),
376
+ (6 , 16384 , 1792 , 5120 , 'bfloat16' , True ): (1024 , 1792 , 5120 ),
377
+ (6 , 16384 , 5120 , 1280 , 'bfloat16' , True ): (512 , 5120 , 1280 ),
378
+ (6 , 16384 , 5120 , 3456 , 'bfloat16' , True ): (512 , 5120 , 3456 ),
379
+ (6 , 16384 , 5120 , 640 , 'bfloat16' , True ): (512 , 5120 , 640 ),
380
+ (6 , 16384 , 5120 , 6912 , 'bfloat16' , True ): (512 , 5120 , 6912 ),
381
+ (6 , 16384 , 6912 , 5120 , 'bfloat16' , True ): (512 , 6912 , 5120 ),
382
+ (6 , 16384 , 896 , 5120 , 'bfloat16' , True ): (1024 , 896 , 5120 ),
383
+ (6 , 2048 , 1280 , 8192 , 'bfloat16' , True ): (2048 , 256 , 8192 ),
384
+ (6 , 2048 , 13824 , 5120 , 'bfloat16' , True ): (2048 , 768 , 5120 ),
385
+ (6 , 2048 , 1792 , 5120 , 'bfloat16' , True ): (2048 , 256 , 5120 ),
386
+ (6 , 2048 , 28672 , 4096 , 'bfloat16' , True ): (2048 , 1024 , 4096 ),
387
+ (6 , 2048 , 4096 , 14336 , 'bfloat16' , True ): (2048 , 4096 , 512 ),
388
+ (6 , 2048 , 4096 , 4096 , 'bfloat16' , True ): (2048 , 512 , 4096 ),
389
+ (6 , 2048 , 5120 , 1280 , 'bfloat16' , True ): (256 , 5120 , 1280 ),
390
+ (6 , 2048 , 5120 , 3456 , 'bfloat16' , True ): (2048 , 512 , 3456 ),
391
+ (6 , 2048 , 5120 , 640 , 'bfloat16' , True ): (256 , 5120 , 640 ),
392
+ (6 , 2048 , 5120 , 6912 , 'bfloat16' , True ): (2048 , 512 , 6912 ),
393
+ (6 , 2048 , 6144 , 4096 , 'bfloat16' , True ): (2048 , 512 , 4096 ),
394
+ (6 , 2048 , 6912 , 5120 , 'bfloat16' , True ): (2048 , 768 , 5120 ),
383
395
(6 , 2048 , 7168 , 8192 , 'bfloat16' , True ): (2048 , 256 , 8192 ),
396
+ (6 , 2048 , 8192 , 1024 , 'bfloat16' , True ): (256 , 8192 , 1024 ),
384
397
(6 , 2048 , 8192 , 3584 , 'bfloat16' , True ): (2048 , 512 , 3584 ),
398
+ (6 , 2048 , 896 , 5120 , 'bfloat16' , True ): (1024 , 896 , 5120 ),
399
+ (6 , 256 , 1280 , 8192 , 'bfloat16' , True ): (256 , 256 , 8192 ),
400
+ (6 , 256 , 13824 , 5120 , 'bfloat16' , True ): (256 , 512 , 5120 ),
401
+ (6 , 256 , 1792 , 5120 , 'bfloat16' , True ): (256 , 1792 , 1280 ),
402
+ (6 , 256 , 28672 , 4096 , 'bfloat16' , True ): (256 , 2048 , 4096 ),
403
+ (6 , 256 , 4096 , 14336 , 'bfloat16' , True ): (256 , 4096 , 512 ),
404
+ (6 , 256 , 4096 , 4096 , 'bfloat16' , True ): (256 , 512 , 4096 ),
405
+ (6 , 256 , 5120 , 1280 , 'bfloat16' , True ): (256 , 2560 , 1280 ),
406
+ (6 , 256 , 5120 , 3456 , 'bfloat16' , True ): (256 , 1024 , 3456 ),
407
+ (6 , 256 , 5120 , 640 , 'bfloat16' , True ): (256 , 2560 , 640 ),
408
+ (6 , 256 , 5120 , 6912 , 'bfloat16' , True ): (256 , 5120 , 768 ),
409
+ (6 , 256 , 6144 , 4096 , 'bfloat16' , True ): (256 , 512 , 4096 ),
410
+ (6 , 256 , 6912 , 5120 , 'bfloat16' , True ): (256 , 6912 , 512 ),
411
+ (6 , 256 , 7168 , 8192 , 'bfloat16' , True ): (256 , 512 , 8192 ),
412
+ (6 , 256 , 8192 , 1024 , 'bfloat16' , True ): (256 , 2048 , 1024 ),
413
+ (6 , 256 , 8192 , 3584 , 'bfloat16' , True ): (256 , 8192 , 512 ),
414
+ (6 , 256 , 896 , 5120 , 'bfloat16' , True ): (256 , 896 , 2560 ),
385
415
(6 , 32 , 1280 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
386
- (6 , 32 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
416
+ (6 , 32 , 13824 , 5120 , 'bfloat16' , True ): (128 , 512 , 5120 ),
417
+ (6 , 32 , 1792 , 5120 , 'bfloat16' , True ): (128 , 896 , 2560 ),
418
+ (6 , 32 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
419
+ (6 , 32 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
420
+ (6 , 32 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
421
+ (6 , 32 , 5120 , 1280 , 'bfloat16' , True ): (128 , 1280 , 1280 ),
422
+ (6 , 32 , 5120 , 3456 , 'bfloat16' , True ): (128 , 640 , 3456 ),
423
+ (6 , 32 , 5120 , 640 , 'bfloat16' , True ): (128 , 2560 , 640 ),
424
+ (6 , 32 , 5120 , 6912 , 'bfloat16' , True ): (128 , 1280 , 2304 ),
425
+ (6 , 32 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
426
+ (6 , 32 , 6912 , 5120 , 'bfloat16' , True ): (128 , 2304 , 1280 ),
387
427
(6 , 32 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
428
+ (6 , 32 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
388
429
(6 , 32 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
430
+ (6 , 32 , 896 , 5120 , 'bfloat16' , True ): (128 , 896 , 2560 ),
431
+ (6 , 4096 , 13824 , 5120 , 'bfloat16' , True ): (2048 , 1536 , 5120 ),
432
+ (6 , 4096 , 1792 , 5120 , 'bfloat16' , True ): (512 , 1792 , 5120 ),
433
+ (6 , 4096 , 5120 , 1280 , 'bfloat16' , True ): (256 , 5120 , 1280 ),
434
+ (6 , 4096 , 5120 , 3456 , 'bfloat16' , True ): (4096 , 512 , 3456 ),
435
+ (6 , 4096 , 5120 , 640 , 'bfloat16' , True ): (256 , 5120 , 640 ),
436
+ (6 , 4096 , 5120 , 6912 , 'bfloat16' , True ): (256 , 5120 , 6912 ),
437
+ (6 , 4096 , 6912 , 5120 , 'bfloat16' , True ): (256 , 6912 , 5120 ),
438
+ (6 , 4096 , 896 , 5120 , 'bfloat16' , True ): (256 , 896 , 5120 ),
439
+ (6 , 512 , 1280 , 8192 , 'bfloat16' , True ): (512 , 256 , 8192 ),
440
+ (6 , 512 , 13824 , 5120 , 'bfloat16' , True ): (512 , 13824 , 512 ),
441
+ (6 , 512 , 1792 , 5120 , 'bfloat16' , True ): (512 , 1792 , 1280 ),
442
+ (6 , 512 , 28672 , 4096 , 'bfloat16' , True ): (512 , 2048 , 4096 ),
443
+ (6 , 512 , 4096 , 14336 , 'bfloat16' , True ): (512 , 256 , 14336 ),
444
+ (6 , 512 , 4096 , 4096 , 'bfloat16' , True ): (512 , 1024 , 4096 ),
445
+ (6 , 512 , 5120 , 1280 , 'bfloat16' , True ): (512 , 2560 , 1280 ),
446
+ (6 , 512 , 5120 , 3456 , 'bfloat16' , True ): (512 , 1280 , 3456 ),
447
+ (6 , 512 , 5120 , 640 , 'bfloat16' , True ): (512 , 2560 , 640 ),
448
+ (6 , 512 , 5120 , 6912 , 'bfloat16' , True ): (512 , 512 , 6912 ),
449
+ (6 , 512 , 6144 , 4096 , 'bfloat16' , True ): (512 , 1024 , 4096 ),
450
+ (6 , 512 , 6912 , 5120 , 'bfloat16' , True ): (512 , 768 , 5120 ),
451
+ (6 , 512 , 7168 , 8192 , 'bfloat16' , True ): (512 , 512 , 8192 ),
452
+ (6 , 512 , 8192 , 1024 , 'bfloat16' , True ): (512 , 4096 , 1024 ),
453
+ (6 , 512 , 8192 , 3584 , 'bfloat16' , True ): (512 , 2048 , 3584 ),
454
+ (6 , 512 , 896 , 5120 , 'bfloat16' , True ): (512 , 896 , 2560 ),
389
455
(6 , 64 , 1280 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
456
+ (6 , 64 , 13824 , 5120 , 'bfloat16' , True ): (128 , 512 , 5120 ),
457
+ (6 , 64 , 1792 , 5120 , 'bfloat16' , True ): (128 , 896 , 2560 ),
458
+ (6 , 64 , 28672 , 4096 , 'bfloat16' , True ): (128 , 28672 , 256 ),
459
+ (6 , 64 , 4096 , 14336 , 'bfloat16' , True ): (128 , 4096 , 896 ),
460
+ (6 , 64 , 4096 , 4096 , 'bfloat16' , True ): (128 , 512 , 4096 ),
461
+ (6 , 64 , 5120 , 1280 , 'bfloat16' , True ): (128 , 1280 , 1280 ),
462
+ (6 , 64 , 5120 , 3456 , 'bfloat16' , True ): (128 , 1024 , 3456 ),
463
+ (6 , 64 , 5120 , 640 , 'bfloat16' , True ): (128 , 2560 , 640 ),
464
+ (6 , 64 , 5120 , 6912 , 'bfloat16' , True ): (128 , 1280 , 2304 ),
465
+ (6 , 64 , 6144 , 4096 , 'bfloat16' , True ): (128 , 768 , 4096 ),
466
+ (6 , 64 , 6912 , 5120 , 'bfloat16' , True ): (128 , 2304 , 1280 ),
467
+ (6 , 64 , 7168 , 8192 , 'bfloat16' , True ): (128 , 256 , 8192 ),
390
468
(6 , 64 , 8192 , 1024 , 'bfloat16' , True ): (128 , 2048 , 1024 ),
469
+ (6 , 64 , 8192 , 3584 , 'bfloat16' , True ): (128 , 1024 , 3584 ),
470
+ (6 , 64 , 896 , 5120 , 'bfloat16' , True ): (128 , 896 , 2560 ),
471
+ (6 , 8192 , 13824 , 5120 , 'bfloat16' , True ): (2048 , 1536 , 5120 ),
472
+ (6 , 8192 , 1792 , 5120 , 'bfloat16' , True ): (512 , 1792 , 5120 ),
473
+ (6 , 8192 , 5120 , 1280 , 'bfloat16' , True ): (256 , 5120 , 1280 ),
474
+ (6 , 8192 , 5120 , 3456 , 'bfloat16' , True ): (512 , 5120 , 3456 ),
475
+ (6 , 8192 , 5120 , 640 , 'bfloat16' , True ): (512 , 5120 , 640 ),
476
+ (6 , 8192 , 5120 , 6912 , 'bfloat16' , True ): (512 , 5120 , 6912 ),
477
+ (6 , 8192 , 6912 , 5120 , 'bfloat16' , True ): (512 , 6912 , 5120 ),
478
+ (6 , 8192 , 896 , 5120 , 'bfloat16' , True ): (512 , 896 , 5120 ),
391
479
}
392
480
393
481
0 commit comments