Skip to content

Commit

Permalink
Add aten.pool_max3d support to torch-to-linalg (#2735)
Browse files Browse the repository at this point in the history
Added verification logic to the abstract_interpreter_lib_gen.py

Also made some unit tests

Initially, I thought we can use `linalg::pooling_ndhwc_max` to help
implement this problem. However, on a 5-dimensional matrix it does the
pooling on dimensions (2, 3, 4) which is not what we want. We want
pooling on dimensions (3, 4, 5).

To achieve this, we would need to lower our code using the `linalg`
dialect.


Turns out the pooling code in `linalg` looks like this.

```
func @max_pooling_ncdhw(%I: memref<?x?x?x?x?xf32>, %K: memref<3xindex>, %O: memref<?x?x?x?x?xf32>,
                        %strides: memref<3xindex>, %dilations: memref<3xindex>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %N = memref.dim %I, %c0 : memref<?x?x?x?x?xf32>
    %C = memref.dim %I, %c1 : memref<?x?x?x?x?xf32>
    %D = memref.dim %I, 2 : memref<?x?x?x?x?xf32>
    %H = memref.dim %I, 3 : memref<?x?x?x?x?xf32>
    %W = memref.dim %I, 4 : memref<?x?x?x?x?xf32>

    %kernel_d = memref.load %K[%c0] : memref<3xindex>
    %kernel_h = memref.load %K[%c1] : memref<3xindex>
    %kernel_w = memref.load %K[2] : memref<3xindex>
    %stride_d = memref.load %strides[%c0] : memref<3xindex>
    %stride_h = memref.load %strides[%c1] : memref<3xindex>
    %stride_w = memref.load %strides[2] : memref<3xindex>
    %dilation_d = memref.load %dilations[%c0] : memref<3xindex>
    %dilation_h = memref.load %dilations[%c1] : memref<3xindex>
    %dilation_w = memref.load %dilations[2] : memref<3xindex>

    linalg.generic {
        indexing_maps = [
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>,  // Map for input tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>,                                              // Map for kernel tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)>                                            // Map for output tensor
        ],
        iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"],
        doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size"
    } ins(%I, %K : memref<?x?x?x?x?xf32>, memref<3xindex>) outs(%O : memref<?x?x?x?x?xf32>) {
        ^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32):
            %max_val = arith.maxf %input_elem, %output_elem : f32
            linalg.yield %max_val : f32
    }
    return
}

```

This was implemented based on it's source code with the adjustments
mentioned above:

https://github.com/llvm/llvm-project/blob/4ca1b5e094280ef1af40412e3cfcb62dc3cf15bc/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L5647

Issues related to this can be found here

nod-ai/SHARK-ModelDev#324
  • Loading branch information
wu-s-john authored Jan 19, 2024
1 parent faa4517 commit 704cfda
Show file tree
Hide file tree
Showing 6 changed files with 937 additions and 63 deletions.
267 changes: 206 additions & 61 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,15 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
return success();
}

// Creates a pooling operation based on the type specified by `OpTy` and
// arguments passed.
template <typename OpTy>
static LogicalResult createPoolingOp(
Operation *op, ConversionPatternRewriter &rewriter, Value self,
bool supportNonFPInput, bool ceilMode, int64_t dimensionality,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
Location loc = op->getLoc();
static Value computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
Value self, int64_t dimensionality, bool ceilMode,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
SmallVectorImpl<int64_t> &dilationInts,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<Value> &outTensorShape, Value initValue) {
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
return op->emitError("unimplemented: non-floating point type");

SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
lowPaddingIncludingNC.append(paddingInts);
SmallVector<int64_t> highPaddingIncludingNC = lowPaddingIncludingNC;

if (ceilMode) {
for (int64_t i = 0; i < dimensionality; ++i) {
highPaddingIncludingNC[i + 2] += strideInts[i];
}
}

Value initValue =
rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));
paddedInput = torch_to_linalg::getPaddedTensor(
op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC,
initValue);
Location loc = op->getLoc();

Value N = getDimOp(rewriter, loc, self, 0);
Value C = getDimOp(rewriter, loc, self, 1);
Expand All @@ -124,8 +103,54 @@ static LogicalResult createPoolingOp(

// Create output tensor initialized with smallest floating point value.
outTensorShape.insert(outTensorShape.begin(), {N, C});
Value outTensorInitialized =
createInitTensor(rewriter, loc, outTensorShape, elementType, initValue);
return createInitTensor(rewriter, loc, outTensorShape, elementType,
initValue);
}

static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter,
Value self, bool ceilMode, int64_t dimensionality,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
Value initValue) {
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
lowPaddingIncludingNC.append(paddingInts);
SmallVector<int64_t> highPaddingIncludingNC = lowPaddingIncludingNC;

if (ceilMode) {
for (int64_t i = 0; i < dimensionality; ++i) {
highPaddingIncludingNC[i + 2] += strideInts[i];
}
}

return torch_to_linalg::getPaddedTensor(op, rewriter, self,
lowPaddingIncludingNC,
highPaddingIncludingNC, initValue);
}

// Creates a pooling operation based on the type specified by `OpTy` and
// arguments passed.
template <typename OpTy>
static LogicalResult createPoolingOp(
Operation *op, ConversionPatternRewriter &rewriter, Value self,
bool supportNonFPInput, bool ceilMode, int64_t dimensionality,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
Location loc = op->getLoc();
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
return op->emitError("unimplemented: non-floating point type");

Value initValue =
rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));

paddedInput = padInputTensor(op, rewriter, self, ceilMode, dimensionality,
strideInts, paddingInts, initValue);

auto outTensorInitialized = computeOutputTensor(
op, rewriter, self, dimensionality, ceilMode, strideInts, paddingInts,
dilationInts, kernelSizeIntValues, outTensorShape, initValue);

auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Expand All @@ -138,57 +163,174 @@ static LogicalResult createPoolingOp(
ValueRange{paddedInput, windowTensor},
outTensorInitialized, stridesAttr, dilationAttr)
.getResult(0);

return success();
}

namespace {
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
template <typename OpTy>
class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;

private:
template <typename T> struct DimensionTraits;

template <> struct DimensionTraits<AtenMaxPool2dOp> {
static const int64_t Dim = 2;
};

template <> struct DimensionTraits<AtenMaxPool3dOp> {
static const int64_t Dim = 3;
};

static const int64_t Dim = DimensionTraits<OpTy>::Dim;

LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op,
typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
SmallVectorImpl<int64_t> &dilationInts,
bool ceilMode) const {
SmallVector<Value, 5> outTensorShape;
Value self = adaptor.getSelf();
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
Value initValue =
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);

Value paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3,
strideInts, paddingInts, initValue);

auto outTensorInitialized = computeOutputTensor(
op, rewriter, self, 3, ceilMode, strideInts, paddingInts, dilationInts,
kernelSizeIntValues, outTensorShape, initValue);

auto shape =
castIntVectorToIndexVector(rewriter, op->getLoc(), kernelSizeIntValues);
Value windowTensor = rewriter.create<tensor::EmptyOp>(
op->getLoc(), getAsOpFoldResult(shape), elementType);

MLIRContext *context = rewriter.getContext();

auto mapInput = mlir::AffineMap::get(
8, 0,
{
rewriter.getAffineDimExpr(0), // n
rewriter.getAffineDimExpr(1), // c
// dim_d * stride_d + kernal_d * dilation_d
rewriter.getAffineDimExpr(2) *
getAffineConstantExpr(strideInts[0], context) +
rewriter.getAffineDimExpr(5) *
getAffineConstantExpr(dilationInts[0], context),
// dim_h * stride_h + kernal_h * dilation_h
rewriter.getAffineDimExpr(3) *
getAffineConstantExpr(strideInts[1], context) +
rewriter.getAffineDimExpr(6) *
getAffineConstantExpr(dilationInts[1], context),
// dim_w * stride_w + kernal_w * dilation_w
rewriter.getAffineDimExpr(4) *
getAffineConstantExpr(strideInts[2], context) +
rewriter.getAffineDimExpr(7) *
getAffineConstantExpr(dilationInts[2], context),
},
context);
auto mapKernel =
mlir::AffineMap::get(8, 0,
{
rewriter.getAffineDimExpr(5), // kd
rewriter.getAffineDimExpr(6), // kh
rewriter.getAffineDimExpr(7) // kw
},
context);
auto mapOutput = mlir::AffineMap::get(
8, 0,
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1),
rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(3),
rewriter.getAffineDimExpr(4)},
context);
auto iteratorTypes =
SmallVector<utils::IteratorType>(5, utils::IteratorType::parallel);
iteratorTypes.append(3, utils::IteratorType::reduction);
SmallVector<AffineMap> indexingMaps = {mapInput, mapKernel, mapOutput};
Value poolingOp =
rewriter
.create<linalg::GenericOp>(
op->getLoc(),
/* result types */ outTensorInitialized.getType(),
/* operands */ ValueRange({paddedInput, windowTensor}),
/* outputs */ outTensorInitialized,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value currentVal = args[0], accMaxValue = args[2];
Value max_result =
b.create<arith::MaximumFOp>(loc, currentVal, accMaxValue);
;
b.create<linalg::YieldOp>(loc, max_result);
})
.getResult(0);
Type newResultType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, poolingOp);
return success();
}

public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMaxPool2dOp op, OpAdaptor adaptor,
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

const TypeConverter *typeConverter = getTypeConverter();
const TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf();
int64_t selfRank = self.getType().cast<RankedTensorType>().getRank();
// TODO: Add support for 3D inputs.
if (selfRank == 3)

if (selfRank != Dim + 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only support 4D input");
op, "unimplemented: Does not support inputs with rank");

bool ceilMode;
SmallVector<Value, 2> kernelSizeIntValues;
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
SmallVector<Value, Dim> kernelSizeIntValues;
SmallVector<int64_t, Dim> strideInts, paddingInts, dilationInts;
if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dOp>(
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
strideInts, paddingInts)))

if (failed(checkAndGetPoolingParameters<OpTy>(op, rewriter, typeConverter,
ceilMode, kernelSizeIntValues,
strideInts, paddingInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");

Type elementType = self.getType().cast<RankedTensorType>().getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
/*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts,
dilationInts, smallestFPValueAttr, outTensorShape, paddedInput,
maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
return success();

if constexpr (Dim == 2) {
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
/*dimensionality=*/2, kernelSizeIntValues, strideInts,
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
paddedInput, maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
Type newResultType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
return success();
} else {
return createPoolingMax3D(op, adaptor, rewriter,
kernelSizeIntValues, strideInts, paddingInts,
dilationInts, ceilMode);
}
}
};
} // namespace
Expand Down Expand Up @@ -650,7 +792,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool3dOp>();
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dOp>>(typeConverter, context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);

target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp>();
Expand Down
Loading

0 comments on commit 704cfda

Please sign in to comment.