@@ -2793,25 +2793,6 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
2793
2793
return failure ();
2794
2794
}
2795
2795
2796
- LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite (
2797
- ExtractMapOp extract, PatternRewriter &rewriter) const {
2798
- Operation *definedOp = extract.vector ().getDefiningOp ();
2799
- if (!definedOp || definedOp->getNumResults () != 1 )
2800
- return failure ();
2801
- // TODO: Create an interfaceOp for elementwise operations.
2802
- if (!isa<AddFOp>(definedOp))
2803
- return failure ();
2804
- Location loc = extract.getLoc ();
2805
- SmallVector<Value, 4 > extractOperands;
2806
- for (OpOperand &operand : definedOp->getOpOperands ())
2807
- extractOperands.push_back (rewriter.create <vector::ExtractMapOp>(
2808
- loc, extract.getResultType (), operand.get (), extract.ids ()));
2809
- Operation *newOp = cloneOpWithOperandsAndTypes (
2810
- rewriter, loc, definedOp, extractOperands, extract.getResult ().getType ());
2811
- rewriter.replaceOp (extract, newOp->getResult (0 ));
2812
- return success ();
2813
- }
2814
-
2815
2796
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp (
2816
2797
OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
2817
2798
ArrayRef<int64_t > multiplicity, const AffineMap &map) {
@@ -2843,6 +2824,91 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
2843
2824
return ops;
2844
2825
}
2845
2826
2827
+ // / Canonicalize an extract_map using the result of a pointwise operation.
2828
+ // / Transforms:
2829
+ // / %v = addf %a, %b : vector32xf32>
2830
+ // / %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
2831
+ // / to:
2832
+ // / %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
2833
+ // / %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
2834
+ // / %dv = addf %da, %db : vector<1xf32>
2835
+ struct PointwiseExtractPattern : public OpRewritePattern <vector::ExtractMapOp> {
2836
+ using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
2837
+ LogicalResult matchAndRewrite (vector::ExtractMapOp extract,
2838
+ PatternRewriter &rewriter) const override {
2839
+ Operation *definedOp = extract.vector ().getDefiningOp ();
2840
+ if (!definedOp || !OpTrait::hasElementwiseMappableTraits (definedOp) ||
2841
+ definedOp->getNumResults () != 1 )
2842
+ return failure ();
2843
+ Location loc = extract.getLoc ();
2844
+ SmallVector<Value, 4 > extractOperands;
2845
+ for (OpOperand &operand : definedOp->getOpOperands ()) {
2846
+ auto vecType = operand.get ().getType ().template dyn_cast <VectorType>();
2847
+ if (!vecType) {
2848
+ extractOperands.push_back (operand.get ());
2849
+ continue ;
2850
+ }
2851
+ extractOperands.push_back (rewriter.create <vector::ExtractMapOp>(
2852
+ loc,
2853
+ VectorType::get (extract.getResultType ().getShape (),
2854
+ vecType.getElementType ()),
2855
+ operand.get (), extract.ids ()));
2856
+ }
2857
+ Operation *newOp = cloneOpWithOperandsAndTypes (
2858
+ rewriter, loc, definedOp, extractOperands, extract.getResultType ());
2859
+ rewriter.replaceOp (extract, newOp->getResult (0 ));
2860
+ return success ();
2861
+ }
2862
+ };
2863
+
2864
+ // / Canonicalize an extract_map using the result of a contract operation.
2865
+ // / This propagate the extract_map to operands.
2866
+ struct ContractExtractPattern : public OpRewritePattern <vector::ExtractMapOp> {
2867
+ using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
2868
+ LogicalResult matchAndRewrite (vector::ExtractMapOp extract,
2869
+ PatternRewriter &rewriter) const override {
2870
+ Operation *definedOp = extract.vector ().getDefiningOp ();
2871
+ auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
2872
+ if (!contract)
2873
+ return failure ();
2874
+ Location loc = contract.getLoc ();
2875
+ unsigned accIndex = vector::ContractionOp::getAccOperandIndex ();
2876
+ AffineMap affineMap = contract.getIndexingMaps ()[accIndex];
2877
+ // Create a map of the dimensions distributed based on the acc affine map.
2878
+ // Only parallel dimensions are being distributed, reduction dimensions are
2879
+ // untouched.
2880
+ DenseMap<int64_t , int64_t > map;
2881
+ for (unsigned i : llvm::seq (unsigned (0 ), affineMap.getNumResults ()))
2882
+ map[affineMap.getDimPosition (i)] = extract.getResultType ().getDimSize (i);
2883
+ SmallVector<Value, 4 > extractOperands;
2884
+ for (auto it : llvm::enumerate (contract.getIndexingMaps ())) {
2885
+ // For each operands calculate the new vector type after distribution.
2886
+ Value operand = contract->getOperand (it.index ());
2887
+ auto vecType = operand.getType ().cast <VectorType>();
2888
+ SmallVector<int64_t > operandShape (vecType.getShape ().begin (),
2889
+ vecType.getShape ().end ());
2890
+ for (unsigned i : llvm::seq (unsigned (0 ), it.value ().getNumResults ())) {
2891
+ unsigned dim = it.value ().getDimPosition (i);
2892
+ auto distributedDim = map.find (dim);
2893
+ // If the dimension is not in the map it means it is a reduction and
2894
+ // doesn't get distributed.
2895
+ if (distributedDim == map.end ())
2896
+ continue ;
2897
+ operandShape[i] = distributedDim->second ;
2898
+ }
2899
+ VectorType newVecType =
2900
+ VectorType::get (operandShape, vecType.getElementType ());
2901
+ extractOperands.push_back (rewriter.create <vector::ExtractMapOp>(
2902
+ loc, newVecType, operand, extract.ids ()));
2903
+ }
2904
+ Operation *newOp =
2905
+ cloneOpWithOperandsAndTypes (rewriter, loc, definedOp, extractOperands,
2906
+ extract.getResult ().getType ());
2907
+ rewriter.replaceOp (extract, newOp->getResult (0 ));
2908
+ return success ();
2909
+ }
2910
+ };
2911
+
2846
2912
// / Converts TransferRead op used by ExtractMap op into a smaller dimension
2847
2913
// / TransferRead.
2848
2914
// / Example:
@@ -4100,8 +4166,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
4100
4166
// TODO: Add this as DRR pattern.
4101
4167
void mlir::vector::populateVectorToVectorTransformationPatterns (
4102
4168
RewritePatternSet &patterns) {
4103
- patterns.add <ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
4104
- TransferReadExtractPattern, TransferWriteInsertPattern>(
4169
+ patterns.add <ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp>(
4105
4170
patterns.getContext ());
4106
4171
}
4107
4172
@@ -4112,6 +4177,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
4112
4177
ignoreFilter);
4113
4178
}
4114
4179
4180
+ void mlir::vector::populatePropagateVectorDistributionPatterns (
4181
+ RewritePatternSet &patterns) {
4182
+ patterns.add <PointwiseExtractPattern, ContractExtractPattern,
4183
+ TransferReadExtractPattern, TransferWriteInsertPattern>(
4184
+ patterns.getContext ());
4185
+ }
4186
+
4115
4187
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns (
4116
4188
RewritePatternSet &patterns) {
4117
4189
patterns.add <CastAwayExtractStridedSliceLeadingOneDim,
0 commit comments