-
Notifications
You must be signed in to change notification settings - Fork 505
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tosa] Add Torch reduction operators
- Supports variants with multiple dims, one dim, all dime - Leverages legalize_common and legalize_utils code from TensorFlow-TOSA work Signed-off-by: Suraj Sudhir <[email protected]>
- Loading branch information
Showing
8 changed files
with
854 additions
and
8 deletions.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H | ||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H | ||
|
||
#include "mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/Support/LLVM.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace tosa { | ||
|
||
// Lowers ReduceAll to a sequence of TOSA ops. | ||
llvm::Optional<Value> | ||
convertReduceAllOp(PatternRewriter &rewriter, Operation *op, | ||
RankedTensorType output_type, Value input_value, | ||
ElementsAttr axes_elems, bool keep_dims); | ||
|
||
// Lowers ReduceAny to a sequence of TOSA ops. | ||
llvm::Optional<Value> | ||
convertReduceAnyOp(PatternRewriter &rewriter, Operation *op, | ||
RankedTensorType output_type, Value input_value, | ||
ElementsAttr axes_elems, bool keep_dims); | ||
|
||
// Lowers ReduceMin to a sequence of TOSA ops. | ||
llvm::Optional<Value> | ||
convertReduceMinOp(PatternRewriter &rewriter, Operation *op, | ||
RankedTensorType output_type, Value input_value, | ||
ElementsAttr axes_elems, bool keep_dims); | ||
|
||
// Lowers ReduceMax to a sequence of TOSA ops. | ||
llvm::Optional<Value> | ||
convertReduceMaxOp(PatternRewriter &rewriter, Operation *op, | ||
RankedTensorType output_type, Value input_value, | ||
ElementsAttr axes_elems, bool keep_dims); | ||
|
||
// Lowers ReduceProd to a sequence of TOSA ops. | ||
llvm::Optional<Value> | ||
convertReduceProdOp(PatternRewriter &rewriter, Operation *op, | ||
RankedTensorType output_type, Value input_value, | ||
ElementsAttr axes_elems, bool keep_dims); | ||
|
||
// Lowers ReduceSum to a sequence of TOSA ops. | ||
llvm::Optional<Value> | ||
convertReduceSumOp(PatternRewriter &rewriter, Operation *op, | ||
RankedTensorType output_type, Value input_value, | ||
ElementsAttr axes_elems, bool keep_dims); | ||
|
||
// Lowers ReduceMean to a sequence of TOSA ops. | ||
llvm::Optional<Value> | ||
convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, | ||
RankedTensorType output_type, Value input_value, | ||
ElementsAttr axes_elems, bool keep_dims); | ||
|
||
} // namespace tosa | ||
} // namespace mlir | ||
|
||
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H |
96 changes: 96 additions & 0 deletions
96
include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H | ||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H | ||
|
||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project | ||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project | ||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project | ||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project | ||
#include "mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project | ||
#include "mlir/Support/LLVM.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace tosa { | ||
|
||
// Create a TOSA rescale op from input framework scaling, zero points and | ||
// rounding mode | ||
Value buildRescale(PatternRewriter &rewriter, Operation *op, | ||
ShapedType output_type, Value input_val, double scale, | ||
int64_t input_zp, int64_t output_zp, bool double_round, | ||
bool scale32); | ||
|
||
// Creates TOSA rescale op with int32 output | ||
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, | ||
Value input_val, double input_scale, | ||
int64_t input_zp); | ||
|
||
// Create a 32-bit float constant operator from a float | ||
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, | ||
float val); | ||
|
||
// Creates a TOSA operation and performs shape inference on the individual | ||
// op. This allows shape inference during the framework to TOSA lowering. | ||
template <typename TosaOp, typename... Args> | ||
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, | ||
Args &&... args) { | ||
auto op = rewriter.create<TosaOp>(loc, result_ty, args...); | ||
|
||
InferShapedTypeOpInterface shapeInterface = | ||
dyn_cast<InferShapedTypeOpInterface>(op.getOperation()); | ||
if (!shapeInterface) | ||
return op; | ||
|
||
SmallVector<ShapedTypeComponents> returnedShapes; | ||
if (shapeInterface | ||
.inferReturnTypeComponents(op.getContext(), op.getLoc(), | ||
op->getOperands(), op->getAttrDictionary(), | ||
op->getRegions(), returnedShapes) | ||
.failed()) | ||
return op; | ||
|
||
// We need to use the element type of the existing result type to generate | ||
// the new result shaped type. This is because rescale can include a cast to | ||
// different bit-width types and does not have a TypeAttr to define the | ||
// target type. | ||
auto result = op->getResult(0); | ||
auto predictedShape = returnedShapes[0]; | ||
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty); | ||
|
||
// Compute the knowledge based on the inferred type. | ||
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); | ||
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType(); | ||
inferredKnowledge.hasRank = predictedShape.hasRank(); | ||
if (predictedShape.hasRank()) { | ||
for (auto dim : predictedShape.getDims()) { | ||
inferredKnowledge.sizes.push_back(dim); | ||
} | ||
} | ||
|
||
// Compute the new type based on the joined version. | ||
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge); | ||
auto new_ty = newKnowledge.getType(); | ||
result.setType(new_ty); | ||
return op; | ||
} | ||
|
||
template <typename TosaOp, typename... Args> | ||
void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, | ||
Type result_ty, Args &&... args) { | ||
auto result = | ||
CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), result_ty, args...); | ||
rewriter.replaceOp(op, result->getResults()); | ||
} | ||
|
||
} // namespace tosa | ||
} // namespace mlir | ||
|
||
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.