Skip to content

[mlir][xegpu] cleanup the print format for TensorDesc #149182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

chencha3
Copy link
Contributor

PR #145916 changed all of parameters of BlockTensorDescAttr to DefaultValuedParameters. Therefore, when all parameters are default values, BlockTensorDescAttr will be printed as empty, e.g., #xegpu.block_tdesc_attr<>, which is not necessary. This PR cleans up the print of BlockTensorDescAttr. If all of its fields are default values, it will be not printed in TensorDescType.

@llvmbot
Copy link
Member

llvmbot commented Jul 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Chao Chen (chencha3)

Changes

PR #145916 changed all of parameters of BlockTensorDescAttr to DefaultValuedParameters. Therefore, when all parameters are default values, BlockTensorDescAttr will be printed as empty, e.g., #xegpu.block_tdesc_attr&lt;&gt;, which is not necessary. This PR cleans up the print of BlockTensorDescAttr. If all of its fields are default values, it will be not printed in TensorDescType.


Full diff: https://github.com/llvm/llvm-project/pull/149182.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+6)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+21-32)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+12-4)
  • (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+1-1)
  • (modified) mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir (+3-3)
  • (modified) mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir (+3-3)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir (+1-1)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 42b5b7a0d4e3f..d022361d1e376 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -64,6 +64,12 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
     )>
   ];
 
+  let extraClassDeclaration = [{
+    // return true if all fields of the BlockTensorDescAttr are set with
+    // default values.
+    bool hasDefaultsOnly();
+  }];
+
 }
 
 def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 277158ac85409..c3ab8c9a1b73a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -131,12 +131,12 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
     }
 
-    BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
-      return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
-    }
-
-    ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr() const {
-      return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+    template <typename T,
+              typename = std::enable_if_t<
+                            std::is_same_v<T, BlockTensorDescAttr> ||
+                            std::is_same_v<T, ScatterTensorDescAttr>>>
+    T getEncodingOfType() const {
+      return llvm::dyn_cast_if_present<T>(getEncoding());
     }
 
     LayoutAttr getLayoutAttr() const {
@@ -144,49 +144,38 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     }
 
     xegpu::MemorySpace getMemorySpace() const {
-      auto block_attr = getEncodingAsBlockTensorDescAttr();
-      if (block_attr && block_attr.getMemorySpace())
-        return block_attr.getMemorySpace().getValue();
+      if (auto attr = getEncodingOfType<BlockTensorDescAttr>())
+        return attr.getMemorySpace().getValue();
 
-      auto scatter_attr = getEncodingAsScatterTensorDescAttr();
-      if (scatter_attr && scatter_attr.getMemorySpace())
-        return scatter_attr.getMemorySpace().getValue();
+      if (auto attr = getEncodingOfType<ScatterTensorDescAttr>())
+        return attr.getMemorySpace().getValue();
 
-      // return default value
+      llvm_unreachable("invalid encoding");
       return MemorySpace::Global;
     }
 
     // get the ArrayLength for blocked TensorDesc
     int getArrayLength() {
-      auto attr = getEncoding();
-      auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
-      assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
-      if (block_attr && block_attr.getArrayLength())
-        return block_attr.getArrayLength().getInt();
-      // return default value
-      return 1;
+      auto attr = getEncodingOfType<BlockTensorDescAttr>();
+      assert(attr && "invalid on non BlockTensorDescAttr.");
+      return attr.getArrayLength().getInt();
     }
 
     bool getBoundaryCheck() {
-      auto attr = getEncoding();
-      auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
-      assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
-      if (block_attr && block_attr.getBoundaryCheck())
-        return block_attr.getBoundaryCheck().getValue();
-      // return default value
-      return true;
+      auto attr = getEncodingOfType<BlockTensorDescAttr>();
+      assert(attr && "invalid on non BlockTensorDescAttr.");
+      return attr.getBoundaryCheck().getValue();
     }
 
     bool isScattered() {
-      return bool(getEncodingAsScatterTensorDescAttr());
+      return bool(getEncodingOfType<ScatterTensorDescAttr>());
     }
 
     // get the ChunkSize for scattered TensorDesc
     int getChunkSizeAsInt() {
-      auto attr = getEncoding();
-      auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
-      assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
-      return scatter_attr.getChunkSizeAsInt();
+      auto attr = getEncodingOfType<ScatterTensorDescAttr>();
+      assert(attr && "invalid on non ScatterTensorDescAttr.");
+      return attr.getChunkSizeAsInt();
     }
 
     /// Helper to drop all layout information from the TensorDesc type.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 642c393cbc2c8..8ab404d52eab4 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -112,6 +112,11 @@ BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
   return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
 }
 
+bool BlockTensorDescAttr::hasDefaultsOnly() {
+  return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
+         getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_ScatterTensorDescAttr
 //===----------------------------------------------------------------------===//
@@ -253,10 +258,11 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   if (parser.parseGreater())
     return {};
 
+  MLIRContext *ctxt = parser.getContext();
   return TensorDescType::getChecked(
-      [&]() { return parser.emitError(parser.getNameLoc()); },
-      parser.getContext(), shape, elementType,
-      encoding.value_or(mlir::Attribute()), layout.value_or(mlir::Attribute()));
+      [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
+      elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
+      layout.value_or(mlir::Attribute()));
 }
 
 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -273,7 +279,9 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
 
   printer << getElementType();
 
-  if (auto encoding = getEncoding())
+  auto encoding = getEncoding();
+  auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
+  if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
     printer << ", " << encoding;
 
   if (auto layout = getLayout())
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6b85a66a8bd36..b6df1f00c2462 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -54,7 +54,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
                                 std::multiplies<int64_t>());
 
   // Case 1: regular loads/stores
-  auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
+  auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>();
   if (scatterAttr) {
     auto chunkSize = scatterAttr.getChunkSize().getInt();
     // Verify if the first dimension of the tensor descriptor shape is
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 4af7061a4f8a3..fe7bc9f2395de 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -30,7 +30,7 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
 // CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
@@ -55,7 +55,7 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
 // CHECK-SAME:    [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
@@ -73,7 +73,7 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
 // CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index d68a02b54e967..53b5699e376b3 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -32,7 +32,7 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
 // CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
 
 // -----
@@ -57,7 +57,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
 // CHECK-SAME:    [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
 
 // -----
@@ -75,7 +75,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
 // CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index c2f760b29afc4..8fad4af6608fd 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -51,7 +51,7 @@ func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<32x64xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 8de6c2283b37c..58db6d6bb418b 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -80,7 +80,7 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
 // CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
 
 // -----

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

@chencha3 chencha3 requested a review from charithaintc July 17, 2025 14:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants