Skip to content

Conversation

kvederni
Copy link
Contributor

This change adds more MMA intrinsics for F8F6F4 and FP64 types. The implementation is based on PTX ISA version 9.0.
New restrictions were added for dtype/ctype combinations for MMA and sparse MMA intrinsics. MLIR restrictions for dtype/ctype MMA intrinsics were aligned with NVVM IR.

[NVPTX] Added restrictions for dtype/ctype combinations.
[MLIR] Aligned MMA restrictions with NVVM IR.

MMA description in PTX ISA 9.0 is at https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-backend-nvptx
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-llvm-ir

Author: Kirill Vedernikov (kvederni)

Changes

This change adds more MMA intrinsics for F8F6F4 and FP64 types. The implementation is based on PTX ISA version 9.0.
New restrictions were added for dtype/ctype combinations for MMA and sparse MMA intrinsics. MLIR restrictions for dtype/ctype MMA intrinsics were aligned with NVVM IR.


Patch is 22.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156040.diff

5 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+90-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+21-9)
  • (modified) llvm/test/CodeGen/NVPTX/wmma.py (+104-11)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+3-2)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (-26)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7b40841e45d0d..9015245f99983 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -272,6 +272,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
       !eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
       !eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
+      !eq(gft,"m16n8k32:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+      !eq(gft,"m16n8k32:c:f32") : !listsplat(llvm_float_ty, 4),
+      !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+      !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
 
       // wmma fp16 -> fp16/fp32 @  m16n16k16/m8n32k16/m32n8k16
       // All other supported geometries use the same fragment format for f32 and
@@ -298,6 +302,21 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
       !eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
 
+      !eq(gft,"m16n8k4:a:f64") : !listsplat(llvm_double_ty, 2),
+      !eq(gft,"m16n8k4:b:f64") : [llvm_double_ty],
+      !eq(gft,"m16n8k4:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k4:d:f64") : !listsplat(llvm_double_ty, 4),
+
+      !eq(gft,"m16n8k8:a:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k8:b:f64") : !listsplat(llvm_double_ty, 2),
+      !eq(gft,"m16n8k8:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k8:d:f64") : !listsplat(llvm_double_ty, 4),
+
+      !eq(gft,"m16n8k16:a:f64") : !listsplat(llvm_double_ty, 8),
+      !eq(gft,"m16n8k16:b:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k16:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k16:d:f64") : !listsplat(llvm_double_ty, 4),
+
       // wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
       !eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
       !eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
@@ -378,6 +397,26 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
       !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
 
+      // mma e4m3/e5m2 -> f16/f32 @ m16n8k16
+      !eq(gft,"m16n8k16:a:e4m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k16:a:e5m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k16:b:e4m3") : [llvm_i32_ty],
+      !eq(gft,"m16n8k16:b:e5m2") : [llvm_i32_ty],
+      // mma e4m3/e5m2/e3m2/e2m3/e2m1 -> f32 @ m16n8k32
+      !eq(gft,"m16n8k32:a:e4m3") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e5m2") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e3m2") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e2m3") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:b:e4m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e5m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e3m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e2m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e2m1") : !listsplat(llvm_i32_ty, 2),
+      // mma e2m1 -> f32 @m16n8k64
+      !eq(gft,"m16n8k64:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k64:b:e2m1") : !listsplat(llvm_i32_ty, 2),
+
       // wmma/mma b1 -> s32 @ m8n8k128(b1)
       !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
       !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
@@ -468,7 +507,7 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, strin
                   # !if(Satfinite, "_satfinite", "");
 }
 
-class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
+class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, string Kind,
                WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
   string signature = MMA_SIGNATURE<A, B, C, D>.ret;
   string record = "int_nvvm_mma"
@@ -476,6 +515,7 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
                   # "_" # A.geom
                   # "_" # ALayout
                   # "_" # BLayout
+                  # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
                   # !if(Satfinite, "_satfinite", "")
                   # signature;
 }
@@ -601,7 +641,7 @@ class NVVM_MMA_OPS {
             ["m16n8k16", "m16n8k8"],
             ["bf16"], [], ["f32"], []>.ret;
   list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
-            ["m8n8k4"],
+            ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"],
             ["f64"], [], ["f64"], []>.ret;
   list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
             ["m8n8k4", "m16n8k8", "m16n8k16"],
@@ -609,6 +649,18 @@ class NVVM_MMA_OPS {
   list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
             ["m8n8k16", "m16n8k16", "m16n8k32"],
             ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+  // m16n8k32 fp8 variants are intersected with f8f6f4 variants
+  // and processed there
+  list<list<WMMA_REGS>> fp8_mma_ops = MMA_OPS<
+            ["m16n8k16"],
+            ["e4m3", "e5m2"], ["e4m3", "e5m2"],
+            ["f16", "f32"], ["f16", "f32"]>.ret;
+  // it also contains e4m3/e5m2 from fp8 variants
+  list<list<WMMA_REGS>> f8f6f4_mma_ops = MMA_OPS<
+            ["m16n8k32"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f16", "f32"], ["f16", "f32"]>.ret;
   list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
             ["m8n8k32", "m16n8k32", "m16n8k64"],
             ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
@@ -617,7 +669,8 @@ class NVVM_MMA_OPS {
             ["b1"], [], ["s32"], []>.ret;
   list<list<WMMA_REGS>> all_mma_ops = !listconcat(
             tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
-            fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+            fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
+            int_mma_ops, subint_mma_ops, bit_mma_ops);
 
   list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
             ["m16n8k16", "m16n8k32"],
@@ -770,7 +823,8 @@ class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
 // if NVVM_MMA_SUPPORTED<...>.ret then
 //   def : FOO<>; // The record will only be defined for supported ops.
 //
-class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b,
+                         string kind, int satf> {
   // MMA ops check both layouts.
   string layout = layout_a # ":" # layout_b;
   string a_type = frags[0].ptx_elt_type;
@@ -805,10 +859,31 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
          !or(!ne(a_type, b_type),
              !ne(c_type, d_type))): false,
 
-    // m16n8k8 requires C and D to be the same type.
-    !and(!eq(geom, "m16n8k8"),
+    // m16n8k16/m16n8k32 requires C and D to be the same type
+    !and(!or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32")),
          !ne(c_type, d_type)): false,
 
+    // Limit kind to valid types and geometries
+    !and(!ne(kind, ""),
+         !or(!ne(geom, "m16n8k32"),
+             !and(!ne(a_type, "e4m3"),
+                  !ne(a_type, "e5m2"),
+                  !ne(a_type, "e3m2"),
+                  !ne(a_type, "e2m3"),
+                  !ne(a_type, "e2m1")))): false,
+
+    // Limit m16n8k16/m16n8k32 with no kind to valid types
+    !and(!eq(kind, ""),
+         !or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32")),
+             !or(!eq(a_type, "e3m2"),
+                 !eq(a_type, "e2m3"),
+                 !eq(a_type, "e2m1"),
+                 !eq(b_type, "e3m2"),
+                 !eq(b_type, "e2m3"),
+                 !eq(b_type, "e2m1"))): false,
+
     // All other are OK.
     true: true
   );
@@ -882,9 +957,10 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
              !eq(a_type, "tf32")),
          !ne(a_type, b_type)): false,
 
-    // m16n8k16 and m16n8k32 requires C and D to be the same type.
+    // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
     !and(!or(!eq(geom, "m16n8k16"),
-             !eq(geom, "m16n8k32")),
+             !eq(geom, "m16n8k32"),
+             !eq(geom, "m16n8k64")),
          !ne(c_type, d_type)): false,
 
     !and(!eq(kind, ""),
@@ -2143,10 +2219,12 @@ foreach layout_a = ["row", "col"] in {
     foreach satf = [0, 1] in {
       foreach op = NVVM_MMA_OPS.all_mma_ops in {
         foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
-          if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-            def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
-              : NVVM_MMA<op[0], op[1], op[2], op[3]>;
-          }
+          foreach kind = ["", "kind::f8f6f4"] in {
+            if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+                def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record
+                : NVVM_MMA<op[0], op[1], op[2], op[3]>;
+            }
+          } // kind
         } // b1op
       } // op
     } // satf
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c544911bdf1e3..8f58c31d7e1c7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4461,6 +4461,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
         !eq(ptx_elt_type, "e2m1"),
         !ne(kind, "")) : [hasSM120a, hasPTX<87>],
 
+    !and(!or(!eq(ptx_elt_type,"e4m3"),
+             !eq(ptx_elt_type,"e5m2")),
+         !eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>],
+
     !or(!eq(ptx_elt_type, "e4m3"),
         !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
 
@@ -4476,6 +4480,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
     !and(!eq(geom, "m8n8k4"),
          !eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>],
 
+    !and(!or(!eq(geom, "m16n8k4"),
+             !eq(geom, "m16n8k8"),
+             !eq(geom, "m16n8k16")),
+         !eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>],
+
     // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
     !and(!or(!eq(geom, "m8n32k16"),
              !eq(geom, "m32n8k16")),
@@ -4760,8 +4769,8 @@ defset list<WMMA_INSTR> WMMAs  = {
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                WMMA_REGINFO FragC, WMMA_REGINFO FragD,
-               string ALayout, string BLayout, int Satfinite, string b1op>
-  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
+               string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
+  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
                         [FragA.Ins, FragB.Ins, FragC.Ins]>,
     // Requires does not seem to have effect on Instruction w/o Patterns.
     // We set it here anyways and propagate to the Pat<> we construct below.
@@ -4776,6 +4785,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragA.geom
                   # "." # ALayout
                   # "." # BLayout
+                  # !if(!ne(Kind, ""), "." # Kind, "")
                   # !if(Satfinite, ".satfinite", "")
                   # TypeList
                   # b1op # "\n\t\t"
@@ -4792,13 +4802,15 @@ defset list<WMMA_INSTR> MMAs  = {
       foreach satf = [0, 1] in {
         foreach op = NVVM_MMA_OPS.all_mma_ops in {
           foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
-            if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-              def : MMA<WMMA_REGINFO<op[0], "mma">,
-                        WMMA_REGINFO<op[1], "mma">,
-                        WMMA_REGINFO<op[2], "mma">,
-                        WMMA_REGINFO<op[3], "mma">,
-                        layout_a, layout_b, satf, b1op>;
-            }
+            foreach kind = ["", "kind::f8f6f4"] in {
+              if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+                def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>,
+                          WMMA_REGINFO<op[1], "mma", "", kind>,
+                          WMMA_REGINFO<op[2], "mma", "", kind>,
+                          WMMA_REGINFO<op[3], "mma", "", kind>,
+                          layout_a, layout_b, satf, b1op, kind>;
+              }
+            } // kind
           } // b1op
         } // op
       } // satf
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 6d73bce46da7c..1c32856c1ce20 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -90,6 +90,21 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
             "m16n8k32:b:s8": 2,
             "m16n8k32:c:s32": 4,
             "m16n8k32:d:s32": 4,
+            # e4m3/e5m2/e3m2/e2m3/e2m1 -> f16/f32 @ m16n8k16/m16n8k32
+            "m16n8k16:a:e4m3": 2,
+            "m16n8k16:a:e5m2": 2,
+            "m16n8k32:a:e4m3": 4,
+            "m16n8k32:a:e5m2": 4,
+            "m16n8k32:a:e3m2": 4,
+            "m16n8k32:a:e2m3": 4,
+            "m16n8k32:a:e2m1": 4,
+            "m16n8k16:b:e4m3": 1,
+            "m16n8k16:b:e5m2": 1,
+            "m16n8k32:b:e4m3": 2,
+            "m16n8k32:b:e5m2": 2,
+            "m16n8k32:b:e3m2": 2,
+            "m16n8k32:b:e2m3": 2,
+            "m16n8k32:b:e2m1": 2,
             # mma sp
             "m16n8k32:a:bf16": 4,
             "m16n8k32:a:f16": 4,
@@ -182,6 +197,18 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
             "m8n8k4:b:f64": 1,
             "m8n8k4:c:f64": 2,
             "m8n8k4:d:f64": 2,
+            "m16n8k4:a:f64": 2,
+            "m16n8k4:b:f64": 1,
+            "m16n8k4:c:f64": 4,
+            "m16n8k4:d:f64": 4,
+            "m16n8k8:a:f64": 4,
+            "m16n8k8:b:f64": 2,
+            "m16n8k8:c:f64": 4,
+            "m16n8k8:d:f64": 4,
+            "m16n8k16:a:f64": 8,
+            "m16n8k16:b:f64": 4,
+            "m16n8k16:c:f64": 4,
+            "m16n8k16:d:f64": 4,
             # tf32 -> s32 @ m16n16k8
             "m16n16k8:a:tf32": 4,
             "m16n16k8:b:tf32": 4,
@@ -324,7 +351,9 @@ def get_wmma_ops():
 
 def get_mma_ops():
     return (
-        make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
+        make_mma_ops(
+            ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"], ["f64"], [], ["f64"], []
+        )
         + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
         + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
         + make_mma_ops(
@@ -341,6 +370,20 @@ def get_mma_ops():
             ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
         )
         + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
+        + make_mma_ops(
+            ["m16n8k16"],
+            ["e4m3", "e5m2"],
+            ["e4m3", "e5m2"],
+            ["f16", "f32"],
+            ["f16", "f32"],
+        )
+        + make_mma_ops(
+            ["m16n8k32"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f16", "f32"],
+            ["f16", "f32"],
+        )
     )
 
 
@@ -492,7 +535,7 @@ def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
     return True
 
 
-def is_mma_variant_supported(op, layout_a, layout_b, satf):
+def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
     if not (
         is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
     ):
@@ -516,13 +559,49 @@ def is_mma_variant_supported(op, layout_a, layout_b, satf):
     ):
         return False
 
+    if (
+        op.a.geom != "m8n8k4"
+        and op.a.mma_type.ptx_type == "f64"
+        and (ptx_version < 78 or gpu_arch < 90)
+    ):
+        return False
+
     # C and D type must be the same
-    if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
+    if (
+        op.a.geom in ["m16n8k16", "m16n8k32"]
+        and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
+    ):
+        return False
+
+    if (
+        op.a.geom in ["m16n8k16", "m16n8k32"]
+        and any(x in ["e4m3", "e5m2"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+        and ptx_version < 87
+    ):
+        return False
+
+    if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
+        return False
+
+    if (
+        kind != ""
+        and (
+            op.a.geom != "m16n8k32"
+            or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
+        )
+    ):
+        return False
+
+    if (kind == ""
+        and op.a.geom in ["m16n8k16", "m16n8k32"]
+        and any(x in ["e3m2", "e2m3", "e2m1"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+    ):
         return False
 
     # Require row/col layout for all MMA except m8n8k4 on FP16
     if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
         return layout_a == "row" and layout_b == "col"
+
     return True
 
 
@@ -937,7 +1016,12 @@ def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
 """
 
     test_params = params
-    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["intrinsic"] = (
+        Template(intrinsic_template)
+        .substitute(params)
+        .replace("::", ".")
+        .replace("_", ".")
+    )
     test_params["function"] = test_params["intrinsic"].replace(".", "_")
     test_params["instruction"] = Template(instruction_template).substitute(params)
     test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
@@ -1002,16 +1086,24 @@ def gen_wmma_mma_tests():
 
 
 def gen_mma_tests():
-    mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
-    mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
+    mma_intrinsic_template = (
+        "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
+    )
+    mma_instruction_template = (
+        "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
+    )
 
     generated_items = []
 
-    for op, alayout, blayout, satf in product(
-        get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
+    for op, alayout, blayout, kind, satf in product(
+        get_mma_ops(),
+        ["row", "col"],
+        ["row", "col"],
+        ["", ".kind::f8f6f4"],
+        [".satfinite", ""],
     ):
 
-        if not is_mma_variant_supported(op, alayout, blayout, satf):
+        if not is_mma_variant_supported(op, alayout, blayout, kind, satf):
             continue
 
         for b1op in get_b1_ops(op.a.mma_type.ptx_type):
@@ -1024,6 +1116,7 @@ def gen_mma_tests():
                 "satf": satf,
                 "geom": op.a.geom,
                 "b1op": b1op,
+                "kind": kind,
             }
 
             intrinsic_template = mma_intrinsic_template
@@ -1105,9 +1198,9 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
     ):
         return False
 
-    # C and D type must be the same for m16n8k16/m16n8k32
+    # C and D type must be the same for m16n8k16/m16n8k32/m16n8k64
     if (
-        op.a.geom in ["m16n8k16", "m16n8k32"]
+        op.a.geom in ["m16n8k16", "m16n8k32", "m16n8k64"]
         and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
     ):
         return False
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..c1da1cf5d0c28 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1763,8 +1763,9 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
          !or(!ne(a_type, b_type),
              !ne(c_type, d_type))): false,
 
-    // m16n8k8 requires C and D to be the same type.
-    !and(!eq(geom, "m16n8k8"),
+    // m16n8k16/m16n8k32 requires C and D to be the same type
+    !and(!or(!eq(geom, "m16n8k16"),
+             !...
[truncated]

Copy link

github-actions bot commented Aug 29, 2025

✅ With the latest revision this PR passed the Python code formatter.

"m16n8k32:a:u8": 2 if is_mma_sparse else 4,
"m16n8k32:a:s8": 2 if is_mma_sparse else 4,
"m16n8k32:b:u8": 2,
"m16n8k32:b:s8": 2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Re: line +3]

You should generate those tests, include them into this PR, and make sure they pass with PTXAS testing enabled via LLVM_PTXAS_EXECUTABLE=/path/to/ptxas with ptxas from both cuda-13, and a few older versions (e.g. with 12.9, and 11.8) to make sure instructions are properly constrained and are accepted by ptxas.

See this comment inline on Graphite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Artem, thanks for feedback.
The tests are in mma.fp8.ptx87.sm120a.ll.txt. Those tests have been verified with ptxas from cuda-13.0 and cuda-12.8.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is that the tests for the code changes in the patch should be included with the patch, and that you need to test not only the new stuff, but also the existing tests generated by wmma.py, to make sure they still work.

In this case, the actual test file is supposed to be generated during the test run. E.g. see

# Check all variants of instructions supported by PTX86 on SM120a
# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_120a.ll
# RUN: %if ptxas-sm_120a && ptxas-isa-8.6 %{ \
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
# RUN: | %ptxas-verify -arch=sm_120a \
# RUN: %}
import wmma
wmma.main()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, I haven't tested the new tests only, I run all the tests using cmake --build $REPO/build --target check-llvm -- -j 64 with ptxas from cuda-13.0 and cuda-12.8. The tests passed. WMMA tests are generated from llvm-project/llvm/test/CodeGen/NVPTX/wmma-ptxX-smY.py. Should I include all generated build/test/CodeGen/NVPTX/Output/wmma-ptxX-smY.py.tmp-ptxX-sm_Y.ll files up to 30K lines each to this review?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good news and bad news.

Good news is that there's already a wmma-ptx87-sm120a.py test variant that will generate test cases for sm_87 and sm_120a. I've missed it when I wrote the previous comment.

The bad news is that the test has not been updated to use new way lit specifies ptxas features, so the ptxas tests we not actually run. You do need to update that test with %if ptxas-sm_120a && ptxas-isa-8.7 and rerun that test. No need to include the generated files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. It has been updated in 2b34832. The tests passed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Artem-B do you have any objections/comments about this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants