Skip to content

Commit 627733b

Browse files
committed
[mlir][vector] Extend vector distribution to all elementwise and contract
Uses elementwise interface to generalize canonicalization pattern and add a new pattern for vector.contract case. Differential Revision: https://reviews.llvm.org/D104343
1 parent 0c400e8 commit 627733b

File tree

5 files changed

+147
-45
lines changed

5 files changed

+147
-45
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
105105
// a sequence of vector.reduction ops.
106106
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
107107

108+
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
109+
/// chain.
110+
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
111+
108112
/// An attribute that specifies the combining function for `vector.contract`,
109113
/// and `vector.reduction`.
110114
class CombiningKindAttr

mlir/include/mlir/Dialect/Vector/VectorTransforms.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -251,26 +251,6 @@ Optional<DistributeOps>
251251
distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
252252
ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
253253
const AffineMap &map);
254-
/// Canonicalize an extra element using the result of a pointwise operation.
255-
/// Transforms:
256-
/// %v = addf %a, %b : vector32xf32>
257-
/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
258-
/// to:
259-
/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
260-
/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
261-
/// %dv = addf %da, %db : vector<1xf32>
262-
struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
263-
using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
264-
PointwiseExtractPattern(
265-
MLIRContext *context, FilterConstraintType constraint =
266-
[](ExtractMapOp op) { return success(); })
267-
: OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
268-
LogicalResult matchAndRewrite(ExtractMapOp extract,
269-
PatternRewriter &rewriter) const override;
270-
271-
private:
272-
FilterConstraintType filter;
273-
};
274254

275255
/// Implements transfer op write to read forwarding and dead transfer write
276256
/// optimizations.

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2793,25 +2793,6 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
27932793
return failure();
27942794
}
27952795

2796-
LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
2797-
ExtractMapOp extract, PatternRewriter &rewriter) const {
2798-
Operation *definedOp = extract.vector().getDefiningOp();
2799-
if (!definedOp || definedOp->getNumResults() != 1)
2800-
return failure();
2801-
// TODO: Create an interfaceOp for elementwise operations.
2802-
if (!isa<AddFOp>(definedOp))
2803-
return failure();
2804-
Location loc = extract.getLoc();
2805-
SmallVector<Value, 4> extractOperands;
2806-
for (OpOperand &operand : definedOp->getOpOperands())
2807-
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
2808-
loc, extract.getResultType(), operand.get(), extract.ids()));
2809-
Operation *newOp = cloneOpWithOperandsAndTypes(
2810-
rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
2811-
rewriter.replaceOp(extract, newOp->getResult(0));
2812-
return success();
2813-
}
2814-
28152796
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
28162797
OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
28172798
ArrayRef<int64_t> multiplicity, const AffineMap &map) {
@@ -2843,6 +2824,91 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
28432824
return ops;
28442825
}
28452826

2827+
/// Canonicalize an extract_map using the result of a pointwise operation.
2828+
/// Transforms:
2829+
/// %v = addf %a, %b : vector32xf32>
2830+
/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
2831+
/// to:
2832+
/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
2833+
/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
2834+
/// %dv = addf %da, %db : vector<1xf32>
2835+
struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
2836+
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
2837+
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
2838+
PatternRewriter &rewriter) const override {
2839+
Operation *definedOp = extract.vector().getDefiningOp();
2840+
if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
2841+
definedOp->getNumResults() != 1)
2842+
return failure();
2843+
Location loc = extract.getLoc();
2844+
SmallVector<Value, 4> extractOperands;
2845+
for (OpOperand &operand : definedOp->getOpOperands()) {
2846+
auto vecType = operand.get().getType().template dyn_cast<VectorType>();
2847+
if (!vecType) {
2848+
extractOperands.push_back(operand.get());
2849+
continue;
2850+
}
2851+
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
2852+
loc,
2853+
VectorType::get(extract.getResultType().getShape(),
2854+
vecType.getElementType()),
2855+
operand.get(), extract.ids()));
2856+
}
2857+
Operation *newOp = cloneOpWithOperandsAndTypes(
2858+
rewriter, loc, definedOp, extractOperands, extract.getResultType());
2859+
rewriter.replaceOp(extract, newOp->getResult(0));
2860+
return success();
2861+
}
2862+
};
2863+
2864+
/// Canonicalize an extract_map using the result of a contract operation.
2865+
/// This propagate the extract_map to operands.
2866+
struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
2867+
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
2868+
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
2869+
PatternRewriter &rewriter) const override {
2870+
Operation *definedOp = extract.vector().getDefiningOp();
2871+
auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
2872+
if (!contract)
2873+
return failure();
2874+
Location loc = contract.getLoc();
2875+
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
2876+
AffineMap affineMap = contract.getIndexingMaps()[accIndex];
2877+
// Create a map of the dimensions distributed based on the acc affine map.
2878+
// Only parallel dimensions are being distributed, reduction dimensions are
2879+
// untouched.
2880+
DenseMap<int64_t, int64_t> map;
2881+
for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
2882+
map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
2883+
SmallVector<Value, 4> extractOperands;
2884+
for (auto it : llvm::enumerate(contract.getIndexingMaps())) {
2885+
// For each operands calculate the new vector type after distribution.
2886+
Value operand = contract->getOperand(it.index());
2887+
auto vecType = operand.getType().cast<VectorType>();
2888+
SmallVector<int64_t> operandShape(vecType.getShape().begin(),
2889+
vecType.getShape().end());
2890+
for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
2891+
unsigned dim = it.value().getDimPosition(i);
2892+
auto distributedDim = map.find(dim);
2893+
// If the dimension is not in the map it means it is a reduction and
2894+
// doesn't get distributed.
2895+
if (distributedDim == map.end())
2896+
continue;
2897+
operandShape[i] = distributedDim->second;
2898+
}
2899+
VectorType newVecType =
2900+
VectorType::get(operandShape, vecType.getElementType());
2901+
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
2902+
loc, newVecType, operand, extract.ids()));
2903+
}
2904+
Operation *newOp =
2905+
cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
2906+
extract.getResult().getType());
2907+
rewriter.replaceOp(extract, newOp->getResult(0));
2908+
return success();
2909+
}
2910+
};
2911+
28462912
/// Converts TransferRead op used by ExtractMap op into a smaller dimension
28472913
/// TransferRead.
28482914
/// Example:
@@ -4100,8 +4166,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
41004166
// TODO: Add this as DRR pattern.
41014167
void mlir::vector::populateVectorToVectorTransformationPatterns(
41024168
RewritePatternSet &patterns) {
4103-
patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
4104-
TransferReadExtractPattern, TransferWriteInsertPattern>(
4169+
patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp>(
41054170
patterns.getContext());
41064171
}
41074172

@@ -4112,6 +4177,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
41124177
ignoreFilter);
41134178
}
41144179

4180+
void mlir::vector::populatePropagateVectorDistributionPatterns(
4181+
RewritePatternSet &patterns) {
4182+
patterns.add<PointwiseExtractPattern, ContractExtractPattern,
4183+
TransferReadExtractPattern, TransferWriteInsertPattern>(
4184+
patterns.getContext());
4185+
}
4186+
41154187
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
41164188
RewritePatternSet &patterns) {
41174189
patterns.add<CastAwayExtractStridedSliceLeadingOneDim,

mlir/test/Dialect/Vector/vector-distribution.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,4 -split-input-file | FileCheck %s --check-prefix=CHECK2D
23

34
// CHECK-LABEL: func @distribute_vector_add
45
// CHECK-SAME: (%[[ID:.*]]: index
@@ -15,6 +16,24 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
1516

1617
// -----
1718

19+
// CHECK-LABEL: func @distribute_vector_add_exp
20+
// CHECK-SAME: (%[[ID:.*]]: index
21+
// CHECK-NEXT: %[[EXPV:.*]] = math.exp %{{.*}} : vector<32xf32>
22+
// CHECK-NEXT: %[[ADDV:.*]] = addf %[[EXPV]], %{{.*}} : vector<32xf32>
23+
// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
24+
// CHECK-NEXT: %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32>
25+
// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
26+
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXC]], %[[EXB]] : vector<1xf32>
27+
// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32>
28+
// CHECK-NEXT: return %[[INS]] : vector<32xf32>
29+
func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
30+
%C = math.exp %A : vector<32xf32>
31+
%0 = addf %C, %B : vector<32xf32>
32+
return %0: vector<32xf32>
33+
}
34+
35+
// -----
36+
1837
// CHECK-LABEL: func @vector_add_read_write
1938
// CHECK-SAME: (%[[ID:.*]]: index
2039
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
@@ -154,3 +173,32 @@ func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?
154173
vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
155174
return
156175
}
176+
177+
// -----
178+
179+
// CHECK2D-LABEL: vector_add_contract
180+
// CHECK2D: %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref<?x?xf32>, vector<2x4xf32>
181+
// CHECK2D: %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref<?x?xf32>, vector<16x4xf32>
182+
// CHECK2D: %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref<?x?xf32>, vector<2x16xf32>
183+
// CHECK2D: %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref<?x?xf32>, vector<2x16xf32>
184+
// CHECK2D: %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32>
185+
// CHECK2D: %[[R:.+]] = addf %[[D]], %[[E]] : vector<2x16xf32>
186+
// CHECK2D: vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref<?x?xf32>
187+
func @vector_add_contract(%id0 : index, %id1 : index, %A: memref<?x?xf32>,
188+
%B: memref<?x?xf32>, %C: memref<?x?xf32>, %D: memref<?x?xf32>) {
189+
%c0 = constant 0 : index
190+
%cf0 = constant 0.0 : f32
191+
%a = vector.transfer_read %A[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
192+
%b = vector.transfer_read %B[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
193+
%c = vector.transfer_read %C[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
194+
%d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
195+
affine_map<(d0, d1, d2) -> (d1, d2)>,
196+
affine_map<(d0, d1, d2) -> (d0, d1)>],
197+
iterator_types = ["parallel", "parallel", "reduction"],
198+
kind = #vector.kind<add>}
199+
%a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32>
200+
%e = vector.transfer_read %D[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
201+
%r = addf %d, %e : vector<64x64xf32>
202+
vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref<?x?xf32>
203+
return
204+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,7 @@ struct TestVectorDistributePatterns
275275
}
276276
}
277277
});
278-
patterns.add<PointwiseExtractPattern>(ctx);
279-
populateVectorToVectorTransformationPatterns(patterns);
278+
populatePropagateVectorDistributionPatterns(patterns);
280279
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
281280
}
282281
};
@@ -339,8 +338,7 @@ struct TestVectorToLoopPatterns
339338
}
340339
return mlir::WalkResult::interrupt();
341340
});
342-
patterns.add<PointwiseExtractPattern>(ctx);
343-
populateVectorToVectorTransformationPatterns(patterns);
341+
populatePropagateVectorDistributionPatterns(patterns);
344342
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
345343
}
346344
};

0 commit comments

Comments
 (0)