Skip to content

[mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op #148684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

shahidact
Copy link
Contributor

In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions.

Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support.

This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.

In case of mixed precision inputs, the inputs are generally casted to
match output type thereby introduces arith.extFOp/extIOp instructions.

Folding such pattern into vector.contract is desirable for HW having
mixed precision ISA support.

This patch adds folding of mixed precision pattern into vector.contract
optionaly which can be enabled using attribute 'vectorize_mixed_precision'.
@llvmbot
Copy link
Member

llvmbot commented Jul 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Md Asghar Ahmad Shahid (shahidact)

Changes

In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions.

Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support.

This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.


Full diff: https://github.com/llvm/llvm-project/pull/148684.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+5)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+12-1)
  • (modified) mlir/test/Dialect/Linalg/transform-op-vectorize.mlir (+89)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b4dde776822a1..dc4e6718907f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp :
     operation that is contained inside the vectorization target.
 
     This transformation supports the following attributes:
+    - `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
+      of ops that have mixed precision types. This enables the folding of
+      arith.extFOp/arith.extIOp into vector.contract with mixed precision.
     - `vectorize_padding`: a `UnitAttr` to activate the vectorization of
       `tensor.pad` ops. Different pipelines may prefer to lower such ops to
       loops.
@@ -2367,6 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
+                   UnitAttr:$vectorize_mixed_precision,
                    UnitAttr:$vectorize_padding,
                    UnitAttr:$vectorize_nd_extract,
                    UnitAttr:$flatten_1d_depthwise_conv,
@@ -2380,6 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :
 
   let builders = [
     OpBuilder<(ins "Value":$target,
+               CArg<"bool", "false">:$vectorizeMixedPrecision,
                CArg<"bool", "false">:$vectorizePadding,
                CArg<"bool", "false">:$vectorizeNDExtract,
                CArg<"bool", "false">:$flatten1DDepthwise)>
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d5f9de465561..c8f256cf38c9d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3784,8 +3784,15 @@ LogicalResult TileUsingForallOp::verify() {
 
 void transform::VectorizeChildrenAndApplyPatternsOp::build(
     OpBuilder &builder, OperationState &result, Value target,
-    bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
+    bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
+    bool flatten1DDepthwiseConv) {
   result.addOperands(target);
+  if (vectorizeMixedPrecision) {
+    result.addAttribute(
+        VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
+            result.name),
+        builder.getUnitAttr());
+  }
   if (vectorizePadding) {
     result.addAttribute(
         VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3876,6 +3883,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
 
   patterns.add<CopyVectorizationPattern>(ctx);
 
+  if (getVectorizeMixedPrecision()) {
+    vector::populateFoldArithExtensionPatterns(patterns);
+  }
+
   if (getVectorizePadding()) {
     linalg::populatePadOpVectorizationPatterns(patterns);
     // This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d..96f89653d20ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Mixed Precision vetorization tests.
+
+// CHECK-LABEL: func @mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+// CHECK:         vector.transfer_write
+func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+                         %C: memref<8x32xf32>) {
+  linalg.generic {
+    indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction"]
+  }
+   ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+   outs(%C : memref<8x32xf32>) {
+    ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+      %a = arith.extf %in : bf16 to f32
+      %b = arith.extf %in_0 : bf16 to f32
+      %d = arith.mulf %a, %b: f32
+      %e = arith.addf %c, %d: f32
+      linalg.yield %e : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1  { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+// CHECK:         vector.transfer_write
+func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+                              %B: tensor<12x25xbf16>,
+                              %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                       affine_map<(m, n, k) -> (k, n)>,
+                       affine_map<(m, n, k) -> (m, n)>]
+      ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+      outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @contraction_matmul
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
+  linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
+            outs(%C: memref<1584x1584xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1  { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

@shahidact shahidact requested review from rolfmorel, ftynse, rengolin and adam-smnk and removed request for ftynse and nicolasvasilache July 14, 2025 17:35
@banach-space banach-space changed the title [mlir][linalg] Add mixed precision folding pattern in transform op. [mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op Jul 15, 2025
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit concerned about expanding transform.structured.vectorize_children_and_apply_patterns like this - when/where do we stop?

Have you considered creating a TD op for populateFoldArithExtensionPatterns instead? That would make more sense to me TBH.

@@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.

This transformation supports the following attributes:
- `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
- `fold_mixed_precision_into_contract`: a `UnitAttr` to activate the vectorization

IIUC, vectorization will happen regardless of this this attribute. Could you update the comment accordingly? Thanks!

@adam-smnk
Copy link
Contributor

I am a bit concerned about expanding transform.structured.vectorize_children_and_apply_patterns like this - when/where do we stop?

Do you know some history behind this? Why it exists in the first place?

This transform already looks like a convenience wrapper for vectorization and optional additions that are integral to it.
Gathering relevant patterns/options in one place boosts discoverability so, I see some value in having such "comprehensive" (perhaps overloaded) transform.


// -----

// Mixed Precision vetorization tests.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo


// CHECK-LABEL: func @mixed_precision_generic_as_contract
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a case for extsi?

Comment on lines +3886 to +3888
if (getVectorizeMixedPrecision()) {
vector::populateFoldArithExtensionPatterns(patterns);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: doesn't need braces

@banach-space
Copy link
Contributor

Do you know some history behind this? Why it exists in the first place?

This is one of the earliest TD ops, introduced at a time when the intended use of TD was still evolving.

Since then, the op has grown organically without a clearly defined direction. It may be worth auditing - or even considering deprecating - it. To briefly summarise (apologies if you're already familiar):

  • transform.structured.vectorize_children_and_apply_patterns calls vectorize and applies some "clean-up" patterns. The generated output tends to be very compact, i.e. is easy for humans to parse, but doesn't really show what vectorize does (because there is so much more going on).
  • transform.structured.vectorize will only call vectorize - it is great to test specifically the transformations behind vectorize.

Conceptually, transform.structured.vectorize_children_and_apply_patterns treats "vectorization" as an abstract notion - vectorize plus additional Vector patterns- whereas transform.structured.vectorize directly encapsulates only vectorize.

II have some concerns about transform.structured.vectorize_children_and_apply_patterns in its current form - it represents just one specific collection of patterns, but it's not necessarily the canonical set IMHO.

This transform already looks like a convenience wrapper for vectorization and optional additions that are integral to it.
Gathering relevant patterns/options in one place boosts discoverability so, I see some value in having such "comprehensive" (perhaps overloaded) transform.

One concern is that it implicitly applies populateFoldArithExtensionPatterns, which can make it harder to trace or discover when analysing transformations.

These are just some points for discussion, should anyone be interested.

Btw, if you want to use transform.structured.vectorize_children_and_apply_patterns for this, could you move tests to one of the files in https://github.com/llvm/llvm-project/tree/main/mlir/test/Dialect/Linalg/vectorization?

@adam-smnk
Copy link
Contributor

Thanks for the insights.
I'm fine with either approach (extending this op or a separate apply_pattern op). Primarily, I'm trying to figure out if there's any bigger picture or rule of thumb for creating/extending transforms.

I have some concerns about transform.structured.vectorize_children_and_apply_patterns in its current form - it represents just one specific collection of patterns, but it's not necessarily the canonical set IMHO.

Fair enough. Sound like this transform doesn't have any particular design vision.
It became sth of a convenient bag of common patterns.

One concern is that it implicitly applies populateFoldArithExtensionPatterns, which can make it harder to trace or discover when analysing transformations.

In defense of the proposed change, the transform already applies a lot of rewrites. This addition doesn't particularly go against its "design". Furthermore, this extension is optional so, one knows what to expect (with better option naming 😉).

The naming vector::populateFoldArithExtensionPatterns itself is a bit misleading. Today, these patterns only fold arith extension ops into vector.contract. They specifically aim to undo the decomposition of numerical casting from mixed precision linalg contraction ops. As such, it somewhat fits thematically as this transform already tries to reconstruct contraction by default.

@adam-smnk
Copy link
Contributor

@tkarna to pull you into the discussion as our target user. Any thoughts or preferences?

@rolfmorel
Copy link
Contributor

rolfmorel commented Jul 16, 2025

Edit: for a better suggestion, see "Final edit" below

Here's my drive-by suggestion (which could be done in a follow-up PR as there isn't really anything wrong with the current PR, seeing as it keeps to an existing way of adding more patterns):

Merge transform.structured.vectorize_children_and_apply_patterns into transform.structured.vectorize modifying the latter to have (something like) the following syntax:

`transform.structured.vectorize`
    (`children_of`)?
    $target
    (`vectorize_nd_extract`)?
    (`flatten_1d_depthwise_conv`)?
    oilist(`vector_sizes` custom<DynamicIndexList>($vector_sizes, $static_vector_sizes, $scalable_sizes)) 
    (`with_patterns` $with_patterns)? attr-dict `:` functional-type(operands, results)

where

  • we could choose that the children_of UnitAttr cannot be specified alongside vector_sizes.
  • $with_patterns would be an optional region where users could hook in their (TD-exposed) patterns, so as to make sure they run in the same rewriting fixpoint. Maybe this only makes sense in case of children_of being provided - hmm.
  • In case we want to have that certain patterns are on by default (e.g. transform.apply_patterns.vector.multi_reduction_to_contract and transform.apply_patterns.vector.transfer_permutation_map_lowering) we could even go so far as to add
   (`without_patterns` $without_patterns)?

where you specify that certain "without" patterns shouldn't be added.

Just my two cents.


Edit: upon reflection, maybe the conceptually cleaner approach would be to just expose vectorize as a pattern through a transform op

`transform.apply_patterns.vector.vectorize` (`vectorize_nd_extract`)?  (`flatten_1d_depthwise_conv`)?

and replace all instances of transform.structured.vectorize_children_and_apply_patterns with transform.apply_patterns with the relevant patterns provided, potentially also additionally as a convenience bundle so that replacement would look something

transform.apply_patterns to %target {
   transform.apply_patterns.vector.vectorize,
   transform.apply_patterns.vector.vectorize.default_patterns
}

That would keep the normal transform.structured.vectorize op and make it easy to hook in new patterns and it would make it clearer that vectorize_children_and_apply_patterns's patterns are just a default set, not the canonical one. (Also, having transform.apply_patterns.vector.vectorize.default_patterns would mean the replacement for transform.structured.vectorize_children_and_apply_patterns would not bloat the schedule IR all that much.)


Final edit: how about:

transform.apply_patterns to %target {
   transform.apply_patterns.vector.vectorize vectorize_nd_extract flatten_1d_depthwise_conv with_vectorize_padding_patterns without_multi_reduction_to_contract_patterns without_transfer_permutation_map_lowering_patterns
}

as a direct replacement for transform.structured.vectorize_children_and_apply_patterns where each of the following are optional

  • vectorize_nd_extract
  • flatten_1d_depthwise_conv
  • with_vectorize_padding_patterns
  • without_multi_reduction_to_contract_patterns
  • without_transfer_permutation_map_lowering_patterns
    and do the same thing as on transform.structured.vectorize_children_and_apply_patterns. To be more idiomatic, with_vectorize_padding_patterns should probably be exposed just as a separate apply_patterns op.

With that in place, @shahidact 's PR would reduce to just exposing a pattern to the Transform dialect.

@banach-space
Copy link
Contributor

’m OK with extending transform.structured.vectorize_children_and_apply_patterns for now. However, if we plan to extend it further, we should step back and discuss the design goals of this Op more explicitly. Thanks!

My only outstanding request for this PR is to move the tests to:

I feel the current TD-specific test files can be removed - we already test TD Ops extensively elsewhere.


@rolfmorel Thanks for sharing all the ideas! Looks like we have quite a few directions to consider. We should probably move this discussion to a more suitable place - maybe a dedicated GitHub issue?

Two high-level concerns I have about combining the existing vectorize Ops into e.g. transform.apply_patterns.vector.vectorize:

  • Op granularity in Transform dialect is intentionally low - even single-pattern Ops exist. Introducing large meta-Ops may go against that design.

  • The tests in mlir/test/Dialect/Linalg/vectorization currently follow the split between vectorize_children_and_apply_patterns and vectorize. If we unify the Ops, we’ll need to rethink that split and redesign the test organisation - a fair bit of churn. And to be honest, test organisation and discoverability are already weak points.


Thanks again to everyone for the great discussion so far! 🙏🏻

@ftynse
Copy link
Member

ftynse commented Jul 18, 2025

Another fly-by: I am generally in line with @rolfmorel later edits as a direction. Transform dialect was intended to make transforms more composable and initially replaced a monolithic "LinalgCodegenStrategy" object. Let's not replicate that object with transform dialect.

The current granularity of pattern exposure being low does not mean we can't have pattern groups exposed as well. Maybe we can consider having lowerings within transform dialect (the op for a pattern group lowers to a set of ops for individual patterns/small groups) or introduce a library/macro mechanism similar to how we have transform.include for named sequences.

@ftynse
Copy link
Member

ftynse commented Jul 18, 2025

With my Area Team hat on, I think the right venue for a discussion is the forum. Github issues are not visible enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants