Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stablehlo] Support aten.any and aten.all lowering #3217

Merged
merged 9 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
isa<AtenNormScalarOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

if (isa<AtenAllDimOp>(op)) {
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
}

if (isa<AtenAnyOp>(op)) {
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
}

op->emitError("unimplemented lowering in createInitElementForReduceOp");
return nullptr;
}
Expand Down Expand Up @@ -439,11 +443,16 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenAllDimOp>(op)) {
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
return b.create<arith::AndIOp>(loc, self, result);
} else if (isa<AtenAnyOp>(op)) {
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
return b.create<arith::MulIOp>(loc, self, result);
return b.create<arith::OrIOp>(loc, self, result);
}
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
return nullptr;
Expand Down Expand Up @@ -510,13 +519,13 @@ class ConvertReductionOp : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};

if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp, AtenNormScalarOp>(
op)) {
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
AtenNormScalarOp>(op)) {
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();

// `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and `AtenMinOp` each reduce
// along all the dimensions of the input tensor.
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
// `AtenMinOp` each reduce along all the dimensions of the input tensor.
for (int64_t i = 0; i < inputType.getRank(); i++)
opInfo.dimSet.insert(i);

Expand Down Expand Up @@ -715,6 +724,8 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenMinDimOp>();
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
target.addIllegalOp<AtenSumOp>();
target.addIllegalOp<AtenAnyOp>();
target.addIllegalOp<AtenAllOp>();
target.addIllegalOp<AtenSumDimIntListOp>();
target.addIllegalOp<AtenProdOp>();
target.addIllegalOp<AtenProdDimIntOp>();
Expand Down
158 changes: 158 additions & 0 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}

if (isa<AtenAllOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}

if (isa<AtenAnyOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}

op->emitError("unimplemented lowering in "
"createInitialValueForReduceOp");
return nullptr;
Expand Down Expand Up @@ -463,6 +475,150 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
}
} // namespace

// AtenAllOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAllOp>::matchAndRewrite(
AtenAllOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();

// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();

if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}

SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}

Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value allResult = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), allResult);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace

// AtenAnyOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAnyOp>::matchAndRewrite(
AtenAnyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();

// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();

if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}

SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}

Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value anyResult = rewriter.create<stablehlo::OrOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), anyResult);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace

// AtenProdOp
namespace {
template <>
Expand Down Expand Up @@ -1052,6 +1208,8 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,12 @@
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"ReduceAllFloatModule_basic",
"ReduceAllIntModule_basic",
"ReduceAllBoolModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceAnyIntModule_basic",
"ReduceAnyBoolModule_basic",
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
Expand Down Expand Up @@ -1809,6 +1815,8 @@
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"ReduceAllBoolModule_basic",
"ReduceAnyBoolModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimFloatModule_basic",
Expand Down Expand Up @@ -2715,6 +2723,7 @@
"MaskedFillTensorFloatValueModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
}
Expand Down
114 changes: 114 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,120 @@ def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils):

# ==============================================================================

class ReduceAllFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.all(a)


@register_test_case(module_factory=lambda: ReduceAllFloatModule())
def ReduceAllFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ReduceAllIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.all(a)


@register_test_case(module_factory=lambda: ReduceAllIntModule())
def ReduceAllIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32))

# ==============================================================================

class ReduceAllBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, a):
return torch.ops.aten.all(a)


@register_test_case(module_factory=lambda: ReduceAllBoolModule())
def ReduceAllBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=2).to(torch.bool))

# ==============================================================================

class ReduceAnyFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.any(a)


@register_test_case(module_factory=lambda: ReduceAnyFloatModule())
def ReduceAnyFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ReduceAnyIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.any(a)


@register_test_case(module_factory=lambda: ReduceAnyIntModule())
def ReduceAnyIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32))

# ==============================================================================

class ReduceAnyBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, a):
return torch.ops.aten.any(a)


@register_test_case(module_factory=lambda: ReduceAnyBoolModule())
def ReduceAnyBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=2).to(torch.bool))

# ==============================================================================

class ReduceSumDimIntListFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading