Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stablehlo] support aten_adaptive_max_pool1d lowering #3728

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7078,6 +7078,35 @@ def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
}];
}

def Torch_AtenMaxPool1dWithIndicesOp : Torch_Op<"aten.max_pool1d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result0,
AnyTorchOptionalTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool1dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void AtenMaxPool1dWithIndicesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}

def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
158 changes: 157 additions & 1 deletion lib/Conversion/TorchToStablehlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,

// Max pooling
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
AtenMaxPool2dWithIndicesOp>(op)) {
AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
Expand All @@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
return nullptr;
}

// AtenMaxPool1dWithIndicesOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();

auto outValTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
auto outIdxTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));

if (inputRank <= 1) {
return op.emitError(
"max_pooling1d only supports inputs with rank higher than 1");
}

SmallVector<int64_t, 1> padding, kernelSize, stride, dilation;
bool ceilMode = false;

if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}

SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);

std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 1);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - 1);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 1);
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[0];

Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);

auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
DenseI64ArrayAttr baseDilations;

auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);

// no need to reshape here for max_pool_1d. Need to make sure the iota
// dimension. dim=inputRank-2 or dim=inputRank-1?
auto indexTensor =
rewriter
.create<stablehlo::DynamicIotaOp>(
op->getLoc(),
RankedTensorType::get(inputShape, rewriter.getI64Type()),
inputShapeTensor, static_cast<uint64_t>(inputRank - 1))
.getResult();
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();

auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
windowDimensions, windowStrides, baseDilations, windowDilations, pad);

// add block.
Block &block = reduceWindowOp.getBody().emplaceBlock();
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
auto *firstValArg = block.args_begin();
auto *firstIdxArg = std::next(firstValArg);
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);

stablehlo::ComparisonTypeAttr compareTypeAttr;
if (isa<mlir::FloatType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}

stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);

Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);

// Get smaller index if compared values are equal.
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);

rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
}

rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}

// AtenMaxPool2dWithIndicesOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
Expand Down Expand Up @@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool1dWithIndicesOp);
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
#undef INSERT_ATEN_POOLING_PATTERN
Expand Down
80 changes: 80 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7298,6 +7298,85 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
};
} // namespace

namespace {
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
// op.
class DecomposeAtenAdaptiveMaxPool1dOp
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op.getContext();

Value input = op.getSelf();
std::optional<unsigned> maybeRank = getTensorRank(input);
if (!maybeRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned rank = *maybeRank;
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);

Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
Value outputSize = outputShapeSizesTorchInt[0];

Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);

int64_t outputSizeInt;
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
return rewriter.notifyMatchFailure(
op, "the output size of adaptive_max_pool1d must be a constant int");
}

SmallVector<Value, 1> kernelSize;
if (outputSizeInt == 1) {
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(
inputShape[rank - 1] == kUnknownSize
? inputSize
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
} else {
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");
}
kernelSize.push_back(constantOne);
}

Value kernelSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
Value strideList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero});
Value dialationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});

rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
return success();
}
};
} // namespace

namespace {
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.

Expand Down Expand Up @@ -9801,6 +9880,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@
"AdaptiveAvgPool3dDynamic_basic",
"AdaptiveMaxPool1dDynamicNoBatch_basic",
"AdaptiveMaxPool1dDynamic_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"AdaptiveMaxPool1dStatic_basic",
"AdaptiveMaxPool2dDynamicNoBatch_basic",
"AdaptiveMaxPool2dDynamicWithIndices_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,9 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit(
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
)
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
emit(
Expand Down
16 changes: 16 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,22 @@ def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 10))


class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module):
def __init__(self):
super().__init__()
self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(1), return_indices=False)

@export
@annotate_args([None, ([1, 512, 7], torch.float32, True)])
def forward(self, x):
return self.amp1d(x)


@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDimOneStatic())
def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7))


# AdaptiveMaxPool2d


Expand Down
Loading