-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op #148684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions. Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support. This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Md Asghar Ahmad Shahid (shahidact) ChangesIn case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions. Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support. This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'. Full diff: https://github.com/llvm/llvm-project/pull/148684.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b4dde776822a1..dc4e6718907f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.
This transformation supports the following attributes:
+ - `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
+ of ops that have mixed precision types. This enables the folding of
+ arith.extFOp/arith.extIOp into vector.contract with mixed precision.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
loops.
@@ -2367,6 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
+ UnitAttr:$vectorize_mixed_precision,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2380,6 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
+ CArg<"bool", "false">:$vectorizeMixedPrecision,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d5f9de465561..c8f256cf38c9d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3784,8 +3784,15 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
+ bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
+ bool flatten1DDepthwiseConv) {
result.addOperands(target);
+ if (vectorizeMixedPrecision) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
+ result.name),
+ builder.getUnitAttr());
+ }
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3876,6 +3883,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
patterns.add<CopyVectorizationPattern>(ctx);
+ if (getVectorizeMixedPrecision()) {
+ vector::populateFoldArithExtensionPatterns(patterns);
+ }
+
if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d..96f89653d20ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Mixed Precision vetorization tests.
+
+// CHECK-LABEL: func @mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+ %C: memref<8x32xf32>) {
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+ outs(%C : memref<8x32xf32>) {
+ ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+ %a = arith.extf %in : bf16 to f32
+ %b = arith.extf %in_0 : bf16 to f32
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+ %B: tensor<12x25xbf16>,
+ %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+ outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @contraction_matmul
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
+ linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
+ outs(%C: memref<1584x1584xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit concerned about expanding transform.structured.vectorize_children_and_apply_patterns
like this - when/where do we stop?
Have you considered creating a TD op for populateFoldArithExtensionPatterns
instead? That would make more sense to me TBH.
@@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp : | |||
operation that is contained inside the vectorization target. | |||
|
|||
This transformation supports the following attributes: | |||
- `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization | |
- `fold_mixed_precision_into_contract`: a `UnitAttr` to activate the vectorization |
IIUC, vectorization will happen regardless of this this attribute. Could you update the comment accordingly? Thanks!
Do you know some history behind this? Why it exists in the first place? This transform already looks like a convenience wrapper for vectorization and optional additions that are integral to it. |
|
||
// ----- | ||
|
||
// Mixed Precision vetorization tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: typo
|
||
// CHECK-LABEL: func @mixed_precision_generic_as_contract | ||
// CHECK-COUNT-3: vector.transfer_read | ||
// CHECK-NOT: arith.extf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add a case for extsi
?
if (getVectorizeMixedPrecision()) { | ||
vector::populateFoldArithExtensionPatterns(patterns); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: doesn't need braces
This is one of the earliest TD ops, introduced at a time when the intended use of TD was still evolving. Since then, the op has grown organically without a clearly defined direction. It may be worth auditing - or even considering deprecating - it. To briefly summarise (apologies if you're already familiar):
Conceptually, II have some concerns about transform.structured.vectorize_children_and_apply_patterns in its current form - it represents just one specific collection of patterns, but it's not necessarily the canonical set IMHO.
One concern is that it implicitly applies These are just some points for discussion, should anyone be interested. Btw, if you want to use |
Thanks for the insights.
Fair enough. Sound like this transform doesn't have any particular design vision.
In defense of the proposed change, the transform already applies a lot of rewrites. This addition doesn't particularly go against its "design". Furthermore, this extension is optional so, one knows what to expect (with better option naming 😉). The naming |
@tkarna to pull you into the discussion as our target user. Any thoughts or preferences? |
Edit: for a better suggestion, see "Final edit" belowHere's my drive-by suggestion (which could be done in a follow-up PR as there isn't really anything wrong with the current PR, seeing as it keeps to an existing way of adding more patterns): Merge
where
where you specify that certain "without" patterns shouldn't be added. Just my two cents. Edit: upon reflection, maybe the conceptually cleaner approach would be to just expose
and replace all instances of
That would keep the normal Final edit: how about:
as a direct replacement for
With that in place, @shahidact 's PR would reduce to just exposing a pattern to the Transform dialect. |
’m OK with extending My only outstanding request for this PR is to move the tests to:
I feel the current TD-specific test files can be removed - we already test TD Ops extensively elsewhere. @rolfmorel Thanks for sharing all the ideas! Looks like we have quite a few directions to consider. We should probably move this discussion to a more suitable place - maybe a dedicated GitHub issue? Two high-level concerns I have about combining the existing vectorize Ops into e.g.
Thanks again to everyone for the great discussion so far! 🙏🏻 |
Another fly-by: I am generally in line with @rolfmorel later edits as a direction. Transform dialect was intended to make transforms more composable and initially replaced a monolithic "LinalgCodegenStrategy" object. Let's not replicate that object with transform dialect. The current granularity of pattern exposure being low does not mean we can't have pattern groups exposed as well. Maybe we can consider having lowerings within transform dialect (the op for a pattern group lowers to a set of ops for individual patterns/small groups) or introduce a library/macro mechanism similar to how we have |
With my Area Team hat on, I think the right venue for a discussion is the forum. Github issues are not visible enough. |
In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions.
Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support.
This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.