diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aef156c5f1d05..fc99a8e30ef1f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -798,7 +798,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
This operation decomposes all the scalar elements from a vector. The
decomposed scalar elements are returned in row-major order. The number of
scalar results must match the number of elements in the input vector type.
- All the result elements have the same result type, which must match the
+ All the result elements have the same type, which must match the
element type of the input vector. Scalable vectors are not supported.
Examples:
@@ -813,7 +813,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
// %0#0 = %v1[0]
// %0#1 = %v1[1]
- // Decompose a 2-D.
+ // Decompose a 2-D vector.
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
// %0#0 = %v2[0, 0]
// %0#1 = %v2[0, 1]
@@ -835,6 +835,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
let arguments = (ins AnyVectorOfAnyRank:$source);
let results = (outs Variadic<AnyType>:$elements);
+
+
+ let builders = [
+ // Build method that infers the result types from `elements`.
+ OpBuilder<(ins "Value":$elements)>,
+ ];
+
let assemblyFormat = "$source attr-dict `:` type($source)";
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6f0ac6bb58282..cd0516c80377b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2417,6 +2417,15 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
return foldToElementsFromElements(*this, results);
}
+void vector::ToElementsOp::build(OpBuilder &builder, OperationState &result,
+ Value elements) {
+ auto vectorType = cast<VectorType>(elements.getType());
+ Type elementType = vectorType.getElementType();
+ int64_t nbElements = vectorType.getNumElements();
+ SmallVector<Type> scalarTypes(nbElements, elementType);
+ build(builder, result, scalarTypes, elements);
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//