Skip to content

Commit

Permalink
[stablehlo] Support aten.any and aten.all lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Apr 23, 2024
1 parent e48d36c commit d0aedff
Showing 1 changed file with 158 additions and 0 deletions.
158 changes: 158 additions & 0 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}

if (isa<AtenAllOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}

if (isa<AtenAnyOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}

op->emitError("unimplemented lowering in "
"createInitialValueForReduceOp");
return nullptr;
Expand Down Expand Up @@ -448,6 +460,150 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
}
} // namespace

// AtenAllOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAllOp>::matchAndRewrite(
AtenAllOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();

// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();

if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}

SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}

Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value allResult = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), allResult);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace

// AtenAnyOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAnyOp>::matchAndRewrite(
AtenAnyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();

// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();

if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}

SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}

Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value anyResult = rewriter.create<stablehlo::OrOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), anyResult);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace

// AtenMaxOp
namespace {
template <>
Expand Down Expand Up @@ -963,6 +1119,8 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
Expand Down

0 comments on commit d0aedff

Please sign in to comment.