Skip to content

[MLIR][Linalg] Remove elemwise_unary and elemwise_binary #147082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions mlir/docs/Tutorials/transform/Ch1.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>

// Elementwise addition.
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>

// Elementwise max with 0 (ReLU).
%c0f = arith.constant 0.0 : f32
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
ins(%biased, %c0f : tensor<512x512xf32>, f32)
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
func.return %relued : tensor<512x512xf32>
Expand All @@ -41,7 +41,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">):
%arg2: !transform.op<"linalg.elementwise">):
transform.yield
}
}
Expand Down Expand Up @@ -72,11 +72,11 @@ To check or debug a transform sequence, it is possible to print various entities
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">):
%arg2: !transform.op<"linalg.elementwise">):
transform.debug.emit_remark_at %arg1, "matmul"
: !transform.op<"linalg.matmul">
transform.debug.emit_remark_at %arg2, "elemwise_binaries"
: !transform.op<"linalg.elemwise_binary">
: !transform.op<"linalg.elementwise">
transform.yield
}
```
Expand All @@ -89,24 +89,24 @@ Since we don’t want to recompile the compiler every time we change a transform
```sh
$ mlir-opt sequence.mlir --pass-pipeline="
builtin.module(transform-interpreter{
debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary})"
debug-bind-trailing-args=linalg.matmul,linalg.elementwise})"
```

The `sequence.mlir` file contains _both_ the payload IR function _and_ the transform IR sequence nested in the same module. The transform interpreter pass will apply the `@__transform_main` named sequence to the anchor operation of the pass. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all `linalg.matmul` and `linalg.elemwise_binary` payload operations through the respective pass options. Running this pass results in the expected remarks:
The `sequence.mlir` file contains _both_ the payload IR function _and_ the transform IR sequence nested in the same module. The transform interpreter pass will apply the `@__transform_main` named sequence to the anchor operation of the pass. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all `linalg.matmul` and `linalg.elementwise` payload operations through the respective pass options. Running this pass results in the expected remarks:

```sh
sequence.mlir:7:13: remark: matmul
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
^
sequence.mlir:7:13: note: see current operation: %0 = linalg.matmul ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
sequence.mlir:10:13: remark: elemwise_binaries
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
^
sequence.mlir:10:13: note: see current operation: %1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
sequence.mlir:10:13: note: see current operation: %1 = linalg.elementwise kind=#linalg.elementwise_kind<add>> ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
sequence.mlir:14:13: remark: elemwise_binaries
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
^
sequence.mlir:14:13: note: see current operation: %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>} ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
sequence.mlir:14:13: note: see current operation: %2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>> ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
```

Note that `%arg2` is associated with both elementwise payload operations. Any handle is associated with a list of entities. Individual transformations may or may not care about the order of elements in that list.
Expand All @@ -121,7 +121,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
%arg2: !transform.op<"linalg.elementwise">) {
// The actual tiling transformation takes tile sizes as attributes.
%loop, %tiled = transform.structured.tile_using_forall %arg1
tile_sizes [4, 32]
Expand Down Expand Up @@ -163,10 +163,10 @@ func.func @fc_relu(%arg0: tensor<512x512xf32>,
: tensor<4x32xf32> into tensor<512x512xf32>
}
}
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
%1 = linalg.elementwise kind=#linalg.elementwise_kind<add>>
ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>)
outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>}
%2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>>
ins(%1, %cst : tensor<512x512xf32>, f32)
outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
return %2 : tensor<512x512xf32>
Expand All @@ -185,7 +185,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
%arg2: !transform.op<"linalg.elementwise">) {
// The actual tiling transformation takes tile sizes as attributes.
%loop, %tiled = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32]
: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
Expand Down Expand Up @@ -219,7 +219,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
%arg2: !transform.op<"linalg.elementwise">) {
// We can cast one type to another as long as operations are compatible
// with both types. This creates "aliasing" handles.
%casted = transform.cast %arg1 : !transform.op<"linalg.matmul">
Expand Down Expand Up @@ -248,7 +248,7 @@ sequence.mlir:28:3: error: op uses a handle invalidated by a previously executed
transform.debug.emit_remark_at %matmul, "elemwise_binaries" : !transform.op<"linalg.matmul">
^
sequence.mlir:21:29: note: handle to invalidated ops
^bb0(%root: !transform.any_op, %matmul: !transform.op<"linalg.matmul">, %elemwise: !transform.op<"linalg.elemwise_binary">):
^bb0(%root: !transform.any_op, %matmul: !transform.op<"linalg.matmul">, %elemwise: !transform.op<"linalg.elementwise">):
^
sequence.mlir:27:19: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them
%loop, %tiled = transform.structured.tile_using_forall %mm tile_sizes [4, 32]
Expand All @@ -263,12 +263,12 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
%arg2: !transform.op<"linalg.elementwise">) {
// Since the %arg2 handle is associated with both elementwise operations,
// we need to split it into two handles so we can target only the second
// elementwise operation.
%add, %max = transform.split_handle %arg2
: (!transform.op<"linalg.elemwise_binary">)
: (!transform.op<"linalg.elementwise">)
-> (!transform.any_op, !transform.any_op)

// The actual tiling transformation takes tile sizes as attributes. It
Expand Down Expand Up @@ -308,12 +308,12 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
%arg2: !transform.op<"linalg.elementwise">) {
// Since the %arg2 handle is associated with both elementwise operations,
// we need to split it into two handles so we can target only the second
// elementwise operation.
%add, %max = transform.split_handle %arg2
: (!transform.op<"linalg.elemwise_binary">)
: (!transform.op<"linalg.elementwise">)
-> (!transform.any_op, !transform.any_op)

// The actual tiling transformation takes tile sizes as attributes. It
Expand Down Expand Up @@ -384,7 +384,7 @@ test/Examples/transform/Ch1/invalidation-2.mlir:106:18: note: invalidated by thi
%func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
^
test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: ancestor payload op
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
^
test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: nested payload op
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
Expand Down
4 changes: 2 additions & 2 deletions mlir/docs/Tutorials/transform/Ch2.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,12 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
%arg2: !transform.op<"linalg.elementwise">) {
// Since the %arg2 handle is associated with both elementwise operations,
// we need to split it into two handles so we can target only the second
// elementwise operation.
%add, %max = transform.split_handle %arg2
: (!transform.op<"linalg.elemwise_binary">)
: (!transform.op<"linalg.elementwise">)
-> (!transform.any_op, !transform.any_op)

// The actual tiling transformation takes tile sizes as attributes. It
Expand Down
18 changes: 9 additions & 9 deletions mlir/docs/Tutorials/transform/Ch4.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>

// Elementwise addition.
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>

// Elementwise max with 0 (ReLU).
%c0f = arith.constant 0.0 : f32
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
ins(%biased, %c0f : tensor<512x512xf32>, f32)
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
func.return %relued : tensor<512x512xf32>
Expand All @@ -59,7 +59,7 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,

In Chapter 1, we were calling the test transform interpreter pass with
additional arguments, `bind-first-extra-to-ops=linalg.matmul
bind-second-extra-to-ops=linalg.elemwise_binary`, to provide initial
bind-second-extra-to-ops=linalg.elementwise`, to provide initial
associations for operation handles. Instead, we can use match operations to
discover relevant operations in the payload IR. Match operations can be combined
with “regular” transform operations using, e.g., the
Expand Down Expand Up @@ -97,7 +97,7 @@ module @transforms attributes { transform.with_named_sequence } {
// rewriter sequence on success.
transform.named_sequence @match_elemwise(
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.match.operation_name %entry ["linalg.elemwise_binary"]
transform.match.operation_name %entry ["linalg.elementwise"]
: !transform.any_op
transform.yield %entry : !transform.any_op
}
Expand Down Expand Up @@ -127,7 +127,7 @@ module @transforms attributes { transform.with_named_sequence } {
This script can be executed using the non-test interpreter pass running on the
root operation of the translation unit without additional flags: `mlir-opt
--transform-interpreter`. It will emit corresponding remarks at
`linalg.elemwise_binary` and `linalg.matmul` operations. In debug builds, the
`linalg.elementwise` and `linalg.matmul` operations. In debug builds, the
infrastructure provides a convenient method to understand the matching process
by passing `-debug-only=transform-matcher` to `mlir-opt` or a derived tool. It
will print the silenceable failure messages produced by the match operations
Expand Down Expand Up @@ -169,7 +169,7 @@ transform.named_sequence @match_matmul_elemwise(
%last: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_op, !transform.any_op) {
// The last operation must be an elementwise binary.
transform.match.operation_name %last ["linalg.elemwise_binary"]
transform.match.operation_name %last ["linalg.elementwise"]
: !transform.any_op
// Its first operand must be defined by another operation, to which we
// will get a handle here. We are guaranteed that the first operand exists
Expand All @@ -179,7 +179,7 @@ transform.named_sequence @match_matmul_elemwise(
%middle = transform.get_producer_of_operand %last[0]
: (!transform.any_op) -> !transform.any_op
// The defining operation must itself be an elementwise binary.
transform.match.operation_name %middle ["linalg.elemwise_binary"]
transform.match.operation_name %middle ["linalg.elementwise"]
: !transform.any_op
// And the first operand of that operation must be defined by yet another
// operation.
Expand Down Expand Up @@ -399,7 +399,7 @@ transform.named_sequence @match_matmul_elemwise(
-> (!transform.any_op, !transform.any_op, !transform.any_op,
!transform.param<i32>) {
// The last operation must be an elementwise binary.
transform.match.operation_name %last ["linalg.elemwise_binary"]
transform.match.operation_name %last ["linalg.elementwise"]
: !transform.any_op

// One of its operands must be defined by another operation, to which we
Expand All @@ -413,7 +413,7 @@ transform.named_sequence @match_matmul_elemwise(
%def = transform.get_defining_op %operand
: (!transform.any_value) -> !transform.any_op
// The defining operation must itself be an elementwise binary.
transform.match.operation_name %def ["linalg.elemwise_binary"]
transform.match.operation_name %def ["linalg.elementwise"]
: !transform.any_op
transform.yield %def : !transform.any_op
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/docs/Tutorials/transform/ChH.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ scf.forall (%co) in (2) {
scf.forall (%n, %y, %xo) in (5, 80, 20) {
tensor.extract_slice
// Implicit dimensions [ni=0:1, y=0:1, xi=0:5, ci=0:64]
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } // ...
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed> // ...
scf.forall.in_parallel {
tensor.parallel_insert_slice // ...
}
Expand Down
114 changes: 0 additions & 114 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,56 +44,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: elemwise_unary
cpp_class_name: ElemwiseUnaryOp
doc: |-
Applies the unary function fun elementwise.

Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: fun
kind: unary_fn_attr
default_fn: exp
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast_signed
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
attr_name: fun
operands:
- !ScalarExpression
scalar_fn:
kind: type
attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: exp
cpp_class_name: ExpOp
Expand Down Expand Up @@ -549,70 +499,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: elemwise_binary
cpp_class_name: ElemwiseBinaryOp
doc: |-
Applies the binary function fun elementwise.

Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: lhs
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: rhs
kind: input_tensor
type_var: T2
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: fun
kind: binary_fn_attr
default_fn: add
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast_signed
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: binary
attr_name: fun
operands:
- !ScalarExpression
scalar_fn:
kind: type
attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: lhs
- !ScalarExpression
scalar_fn:
kind: type
attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: add
cpp_class_name: AddOp
Expand Down
Loading