Skip to content

[MLIR][NVVM] Support generating all the ldmatrix intrinsics from NVVM ops #148783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Pecco-314
Copy link

Previously, the NVVM dialect's ldmatrix operation could only generate a limited subset of the available NVVM ldmatrix intrinsics. The intrinsics generating new ops introduced in BlackWell are not accessible through the NVVM ops. This commit extends the ldmatrix operation to support all available ldmatrix intrinsics.

@Pecco-314 Pecco-314 requested a review from grypp as a code owner July 15, 2025 06:30
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Pecco (Pecco-314)

Changes

Previously, the NVVM dialect's ldmatrix operation could only generate a limited subset of the available NVVM ldmatrix intrinsics. The intrinsics generating new ops introduced in BlackWell are not accessible through the NVVM ops. This commit extends the ldmatrix operation to support all available ldmatrix intrinsics.


Patch is 22.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148783.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+34-2)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+2-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+8-4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+79-22)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+2-2)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+13-4)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (-11)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+37-7)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..cfb21e8331d05 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,6 +1990,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
+def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
+def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
+def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
+def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
+
+def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
+  [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
+def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
+def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
+def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;
+
+def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
+  [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
+   LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elttype"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
   Arguments<(ins LLVM_PointerShared:$ptr, 
                  Variadic<I32>:$sources, 
@@ -2021,13 +2050,16 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
 
 def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   Results<(outs AnyType:$res)>,
-  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
+  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num,
+                 MMALayoutAttr: $layout,
+                 LdStMatrixShapeAttr: $shape,
+                 LdStMatrixEltTypeAttr: $elttype)> {
 
   let summary = "cooperative matrix load";
 
   string llvmBuilder = [{
       auto operands = moduleTranslation.lookupValues(opInst.getOperands());
-      auto intId = getLdMatrixIntrinsicId($layout, $num);
+      auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype);
       $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
   }];
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 80b3d85488495..470dc2512a9ad 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -289,7 +289,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
         ldMatrixResultType, srcPtr,
         /*num=*/op.getNumTiles(),
         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
-                                     : NVVM::MMALayout::row);
+                                     : NVVM::MMALayout::row,
+        NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16);
 
     // The ldmatrix operation returns either a single i32 value or a struct of
     // i32 values. Here we unpack those values and cast them back to their
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6e29b129e8835..93c155b67fb5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -806,14 +806,18 @@ LogicalResult NVVM::LdMatrixOp::verify() {
     return emitOpError("expected num attribute to be 1, 2 or 4");
 
   Type i32 = IntegerType::get(getContext(), 32);
-  if (getNum() == 1 && getType() != i32)
+  uint32_t num = getNum();
+  if (getShape() == LdStMatrixShape::M16N16) {
+    num *= 2;
+  }
+  if (num == 1 && getType() != i32)
     return emitOpError("expected destination type is i32");
-  if (getNum() == 2 || getNum() == 4) {
+  if (num == 2 || num == 4) {
     Type dstType = LLVM::LLVMStructType::getLiteral(
-        getContext(), SmallVector<Type>(getNum(), i32));
+        getContext(), SmallVector<Type>(num, i32));
     if (getType() != dstType)
       return emitOpError("expected destination type is a structure of ")
-             << getNum() << " elements of type i32";
+             << num << " elements of type i32";
   }
   return success();
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eecca64c4bf81..5d13933519c54 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -134,33 +134,90 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
   llvm_unreachable("unsupported vote kind");
 }
 
-/// Return the intrinsic ID associated with ldmatrix for the given paramters.
-static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
-                                                  int32_t num) {
+static llvm::Intrinsic::ID
+getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+                       NVVM::LdStMatrixShape shape,
+                       NVVM::LdStMatrixEltType elttype) {
   if (layout == NVVM::MMALayout::row) {
-    switch (num) {
-    case 1:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
-    case 2:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
-    case 4:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
-    default:
-      llvm_unreachable("unsupported number of matrix");
+    if (shape == NVVM::LdStMatrixShape::M8N8 &&
+        elttype == NVVM::LdStMatrixEltType::B16) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+      case 4:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
+      case 4:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
+      case 4:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
+      }
     }
-
   } else {
-    switch (num) {
-    case 1:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
-    case 2:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
-    case 4:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
-    default:
-      llvm_unreachable("unsupported number of matrix");
+    if (shape == NVVM::LdStMatrixShape::M8N8 &&
+        elttype == NVVM::LdStMatrixEltType::B16) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
+      case 4:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
+      }
     }
   }
+  llvm_unreachable("unsupported matrix configuration");
 }
 
 /// Return the intrinsic ID associated with st.bulk for the given address type.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index d0bc806e0aa8c..75a556f471373 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
 // CHECK-LABEL: @ldmatrix_x4
 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
   // CHECK: llvm.extractvalue
   // CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
 // CHECK-LABEL: @ldmatrix_x1
 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> i32
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
   // CHECK: llvm.bitcast
   // CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index bd1106e304c60..f9def0877d71a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
   // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
   llvm.return
 }
 
@@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
   llvm.return
 }
 
@@ -1140,10 +1140,19 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   llvm.return
 }
 
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+
 // -----
 
 llvm.func @caller() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c98ac92..6a4edd0d22a08 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
   llvm.return
 }
 
-// CHECK-LABEL: llvm.func @ld_matrix
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
-  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-  llvm.return
-}
-
 // CHECK-LABEL: llvm.func @redux_sync
 llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
   // CHECK: nvvm.redux.sync  add %{{.*}}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..89429a762db92 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> i32
+  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
+  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr add...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants