Skip to content

Commit

Permalink
[tosa] Add Torch reduction operators
Browse files Browse the repository at this point in the history
- Supports variants with multiple dims, one dim, all dime
- Leverages legalize_common and legalize_utils code from
TensorFlow-TOSA work

Signed-off-by: Suraj Sudhir <[email protected]>
  • Loading branch information
sjarus authored and silvasean committed Dec 3, 2021
1 parent ab62111 commit c9c9b68
Show file tree
Hide file tree
Showing 8 changed files with 854 additions and 8 deletions.
64 changes: 64 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H

#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace tosa {

// Lowers ReduceAll to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

// Lowers ReduceAny to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

// Lowers ReduceMin to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

// Lowers ReduceMax to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

// Lowers ReduceProd to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

// Lowers ReduceSum to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

// Lowers ReduceMean to a sequence of TOSA ops.
llvm::Optional<Value>
convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

} // namespace tosa
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
96 changes: 96 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H

#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace tosa {

// Create a TOSA rescale op from input framework scaling, zero points and
// rounding mode
Value buildRescale(PatternRewriter &rewriter, Operation *op,
ShapedType output_type, Value input_val, double scale,
int64_t input_zp, int64_t output_zp, bool double_round,
bool scale32);

// Creates TOSA rescale op with int32 output
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Value input_val, double input_scale,
int64_t input_zp);

// Create a 32-bit float constant operator from a float
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
Args &&... args) {
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);

InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
if (!shapeInterface)
return op;

SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getRegions(), returnedShapes)
.failed())
return op;

// We need to use the element type of the existing result type to generate
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
auto result = op->getResult(0);
auto predictedShape = returnedShapes[0];
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);

// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}

// Compute the new type based on the joined version.
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
auto new_ty = newKnowledge.getType();
result.setType(new_ty);
return op;
}

template <typename TosaOp, typename... Args>
void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
Type result_ty, Args &&... args) {
auto result =
CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), result_ty, args...);
rewriter.replaceOp(op, result->getResults());
}

} // namespace tosa
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
2 changes: 2 additions & 0 deletions lib/Conversion/TorchToTosa/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa
TorchToTosa.cpp
TosaLegalizeUtils.cpp
TosaLegalizeCommon.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa
Expand Down
182 changes: 182 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
Expand Down Expand Up @@ -252,6 +254,156 @@ LogicalResult ConvertAtenOp<AtenDivTensorOp>::matchAndRewrite(
return success();
}

using ReductionConvFunc = llvm::Optional<Value> (*)(PatternRewriter &,
Operation *,
RankedTensorType, Value,
ElementsAttr, bool);

// They all constitute a common form invoking the appropriate
// converion function in TosaLegalizeCommon.cpp
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

// Each variant must implement corresponding parameter parsing options
virtual LogicalResult readReduceDimsAndKeepDims(
AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr, bool &keepDims) const {
return rewriter.notifyMatchFailure(
op, "Unimplemented reduce_dims and keep_dims parsing function");
}

// Common rewriter for all reduction ops, calls the specific implementation of
// readReduceDimsAndKeepDims() needed for the op variant.
LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.self();
auto selfTy = self.getType().cast<TensorType>();

if (!selfTy)
return op.emitError("Only Tensor types supported in TOSA");

auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
if (!outputTy)
return op.emitError(
"Only ranked tensor type outputs permitted for reduce_mean");

ElementsAttr reduceDimsAttr;
bool keepDims;

if (failed(readReduceDimsAndKeepDims(op, adaptor, rewriter, reduceDimsAttr,
keepDims)))
return failure();

llvm::Optional<Value> result =
ConversionFuncT(rewriter, op, outputTy, self, reduceDimsAttr, keepDims);

if (!result)
return failure();

// TBD - support dtype casting.

rewriter.replaceOp(op, {result.getValue()});

return success();
}
};

// This reduction op legalization template handles op variants that have
// explicit reduce_dims dimensions (provided as a list) and keep_dims
// parameters.
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
class ConvertAtenMultipleDimsReductionOp
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
using ConvertAtenReductionOp<AtenOpT,
ConversionFuncT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr,
bool &keepDims) const {
SmallVector<int64_t, 4> reduceDims;
if (!matchPattern(op.dim(), m_TorchConstantIntList(reduceDims)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
int64_t N = reduceDims.size();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef(reduceDims));

keepDims = false;
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims)))
return rewriter.notifyMatchFailure(
op, "non-const keepdim parameter unsupported");

return success();
}
};

// This reduction op legalization template handles op variants that reduce in
// only one explicit dim which is provided as a number (rather than a list), and
// a keep_dims parameter.
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
class ConvertAtenOneDimReductionOp
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
using ConvertAtenReductionOp<AtenOpT,
ConversionFuncT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr,
bool &keepDims) const {
int64_t reduceDim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef({reduceDim}));

keepDims = false;
if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims)))
return rewriter.notifyMatchFailure(
op, "non-const keepdim parameter unsupported");

return success();
}
};

// This reduction op legalization template handles op variants that reduce all
// dims does not keep dims.
template <typename AtenOpT, ReductionConvFunc ConversionFuncT>
class ConvertAtenAllDimsReductionOp
: public ConvertAtenReductionOp<AtenOpT, ConversionFuncT> {
public:
using ConvertAtenReductionOp<AtenOpT,
ConversionFuncT>::ConvertAtenReductionOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
ElementsAttr &reduceDimsAttr,
bool &keepDims) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();

// Select all dims to reduce
SmallVector<int64_t, 4> reduceDims;
for (int64_t i = 0; i < selfTy.getRank(); i++)
reduceDims.push_back(i);
int64_t N = selfTy.getRank();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef(reduceDims));
keepDims = false;

return success();
}
};

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -300,6 +452,36 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
#undef INSERT_BINARY_ADDSUB_PATTERN

#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp,
mlir::tosa::convertReduceMeanOp)
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp,
mlir::tosa::convertReduceSumOp)
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN

#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOneDimReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
mlir::tosa::convertReduceAnyOp)
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN

#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAllDimsReductionOp<AtenOp, ConversionFunc>>( \
typeConverter, context);
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp,
mlir::tosa::convertReduceAllOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
mlir::tosa::convertReduceSumOp)
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN

#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
Expand Down
Loading

0 comments on commit c9c9b68

Please sign in to comment.