Skip to content

[mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op #148684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.

This transformation supports the following attributes:
- `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
- `fold_mixed_precision_into_contract`: a `UnitAttr` to activate the vectorization

IIUC, vectorization will happen regardless of this this attribute. Could you update the comment accordingly? Thanks!

of ops that have mixed precision types. This enables the folding of
arith.extFOp/arith.extIOp into vector.contract with mixed precision.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
loops.
Expand All @@ -2367,6 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];

let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$vectorize_mixed_precision,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
Expand All @@ -2380,6 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :

let builders = [
OpBuilder<(ins "Value":$target,
CArg<"bool", "false">:$vectorizeMixedPrecision,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
Expand Down
13 changes: 12 additions & 1 deletion mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3784,8 +3784,15 @@ LogicalResult TileUsingForallOp::verify() {

void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
result.addOperands(target);
if (vectorizeMixedPrecision) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
result.name),
builder.getUnitAttr());
}
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
Expand Down Expand Up @@ -3876,6 +3883,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(

patterns.add<CopyVectorizationPattern>(ctx);

if (getVectorizeMixedPrecision()) {
vector::populateFoldArithExtensionPatterns(patterns);
}
Comment on lines +3886 to +3888
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: doesn't need braces


if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
Expand Down
89 changes: 89 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// Mixed Precision vetorization tests.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo


// CHECK-LABEL: func @mixed_precision_generic_as_contract
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a case for extsi?

// CHECK: vector.contract
// CHECK: vector.transfer_write
func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
%C: memref<8x32xf32>) {
linalg.generic {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>
],
iterator_types = ["parallel", "parallel", "reduction"]
}
ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
outs(%C : memref<8x32xf32>) {
^bb(%in: bf16, %in_0: bf16, %c: f32) :
%a = arith.extf %in : bf16 to f32
%b = arith.extf %in_0 : bf16 to f32
%d = arith.mulf %a, %b: f32
%e = arith.addf %c, %d: f32
linalg.yield %e : f32
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @mixed_precision_matmul_as_contract
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extf
// CHECK: vector.contract
// CHECK: vector.transfer_write
func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
%B: tensor<12x25xbf16>,
%C: tensor<24x25xf32>) -> tensor<24x25xf32> {
%0 = linalg.contract
indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>]
ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @contraction_matmul
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extf
// CHECK: vector.contract
func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
outs(%C: memref<1584x1584xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
Loading