Skip to content

Commit 73636fc

Browse files
authored
[AutoBump] Merge with fixes of a58e774 (Jan 28, tosa.mul gets third operand) (1) (#700)
2 parents 7e04812 + 60a1aba commit 73636fc

File tree

19 files changed

+457
-148
lines changed

19 files changed

+457
-148
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_TOSA_IR_TOSAOPS_H
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
1718
#include "mlir/Dialect/Traits.h"
1819
#include "mlir/IR/OpDefinition.h"
1920
#include "mlir/IR/OpImplementation.h"
@@ -53,34 +54,43 @@ class MulOperandsAndResultElementType
5354
: public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
5455
public:
5556
static LogicalResult verifyTrait(Operation *op) {
56-
auto resElemType = getElementTypeOrSelf(op->getResult(0));
57-
58-
// In cases of floating point type, op requires the same element
59-
// type for all operands and result.
60-
if (llvm::isa<FloatType>(resElemType))
61-
return impl::verifySameOperandsAndResultElementType(op);
62-
57+
// Check we have a single result.
58+
if (failed(impl::verifyOneResult(op)))
59+
return failure();
60+
Type resElemType = getElementTypeOrSelf(op->getResult(0));
61+
62+
// Check we have lhs and rhs.
63+
if (failed(impl::verifyAtLeastNOperands(op, 2)))
64+
return failure();
65+
66+
Type lhsElemType = getElementTypeOrSelf(op->getOperand(0));
67+
Type rhsElemType = getElementTypeOrSelf(op->getOperand(1));
68+
69+
// Check that for i32 a shift has been explicitly provided.
70+
if (lhsElemType.isInteger(32) && failed(impl::verifyNOperands(op, 3)))
71+
return failure();
72+
73+
// Verify operands type match (ignoring the shift parameter which will
74+
// always be i8).
75+
if (lhsElemType != rhsElemType)
76+
return op->emitOpError("requires the same element type for all operands");
77+
78+
// Though the spec requires the element type of result to be i32, a more
79+
// relaxed way is provided at dialect level for easier cooperating with
80+
// other dialects.
6381
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
64-
IntegerType lhsIntType =
65-
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
66-
IntegerType rhsIntType =
67-
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
68-
if (lhsIntType != rhsIntType)
69-
return op->emitOpError(
70-
"requires the same element type for all operands");
71-
72-
// Though the spec requires the element type of result to be i32, a more
73-
// relaxed way is provided at dialect level for easier cooperating with
74-
// other dialects.
82+
auto lhsIntType = cast<IntegerType>(lhsElemType);
7583
if (lhsIntType.getWidth() > resIntType.getWidth())
7684
return op->emitOpError("invalid data type size for operands or result");
77-
78-
return success();
85+
} else {
86+
// In cases of floating point type or quant types, op requires the same
87+
// element type for all operands and result (excluding shift).
88+
if (resElemType != lhsElemType)
89+
return op->emitOpError(
90+
"requires the same element type for all operands and results");
7991
}
8092

81-
// In cases of all other types, op requires the same element
82-
// type for all operands and result.
83-
return impl::verifySameOperandsAndResultElementType(op);
93+
return llvm::success();
8494
}
8595
};
8696

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -800,9 +800,11 @@ def MulOperandsAndResultElementType :
800800
//===----------------------------------------------------------------------===//
801801
// Operator: mul
802802
//===----------------------------------------------------------------------===//
803-
def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
803+
def Tosa_MulOp : Tosa_Op<"mul", [
804+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
805+
["inferReturnTypeComponents"]>,
804806
Commutative,
805-
MulOperandsAndResultElementType]> {
807+
Pure]> {
806808
let summary = "Multiplication operator";
807809

808810
let description = [{
@@ -814,7 +816,8 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
814816
let arguments = (ins
815817
Tosa_Tensor:$input1,
816818
Tosa_Tensor:$input2,
817-
I8Attr:$shift
819+
// Apply right shift on i32_t input data only
820+
Tosa_ScalarInt8Tensor:$shift
818821
);
819822

820823
let results = (outs
@@ -823,6 +826,9 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
823826

824827
let hasFolder = 1;
825828
let hasVerifier = 1;
829+
830+
let assemblyFormat =
831+
"operands attr-dict `:` functional-type(operands, results)";
826832
}
827833

828834
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[
9393
IsRankedTensorTypePred,
9494
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
9595

96+
def AllDimensionsAreSizeOne : And<[
97+
IsRankedTensorTypePred,
98+
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
99+
96100
// AMD: removed HasNo0Dimensions constraint below to allow lowerings
97101
// in onnx-mlir like onnx.Split.
98102
class TosaTensorOf<
@@ -111,6 +115,11 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
111115
[HasAnyRankOfPred<ranks>],
112116
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
113117

118+
class TosaScalarTensorOf<list<Type> allowedTypes, list<int> ranks>
119+
: TosaRankedTensorOf<allowedTypes,
120+
[HasAnyRankOfPred<ranks>, AllDimensionsAreSizeOne],
121+
"tosa-conformant scalar tensor">;
122+
114123
//===----------------------------------------------------------------------===//
115124
// Tensor types
116125
//===----------------------------------------------------------------------===//
@@ -139,8 +148,8 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
139148
// Tensor types with constrained ranks.
140149
//===----------------------------------------------------------------------===//
141150

142-
// Rank-0 (scalar) tensor
143151
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
152+
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
144153

145154
// We include unranked tensors as a supported type for all possible tosa
146155
// Tensors as unranked does not guarantee invalid. If unranked tensors exist

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -100,43 +100,59 @@ static Value createLinalgBodyCalculationForElementwiseOp(
100100
}
101101

102102
// tosa::MulOp
103-
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
104-
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
105-
106-
if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
107-
Value a = args[0];
108-
Value b = args[1];
109-
auto shift =
110-
cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
111-
if (shift > 0) {
112-
auto shiftConst =
113-
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
114-
if (!a.getType().isInteger(32))
115-
a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
116-
117-
if (!b.getType().isInteger(32))
118-
b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
119-
120-
auto result = rewriter.create<tosa::ApplyScaleOp>(
121-
loc, rewriter.getI32Type(), a, b, shiftConst,
122-
rewriter.getBoolAttr(false));
123-
124-
if (elementTy.isInteger(32))
125-
return result;
126-
127-
return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
103+
if (isa<tosa::MulOp>(op)) {
104+
auto shift_val = cast<tosa::MulOp>(op).getShift();
105+
ElementsAttr shift_elem;
106+
if (!shift_val.getImpl() ||
107+
!matchPattern(shift_val, m_Constant(&shift_elem))) {
108+
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
128109
}
129110

130-
int aWidth = a.getType().getIntOrFloatBitWidth();
131-
int bWidth = b.getType().getIntOrFloatBitWidth();
132-
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
111+
int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
133112

134-
if (aWidth < cWidth)
135-
a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
136-
if (bWidth < cWidth)
137-
b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
113+
if (isa<FloatType>(elementTy)) {
114+
if (shift != 0) {
115+
(void)rewriter.notifyMatchFailure(op,
116+
"Cannot have shift value for float");
117+
return nullptr;
118+
}
119+
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
120+
}
121+
122+
if (isa<IntegerType>(elementTy)) {
123+
Value a = args[0];
124+
Value b = args[1];
125+
126+
if (shift > 0) {
127+
auto shiftConst =
128+
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
129+
if (!a.getType().isInteger(32))
130+
a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
131+
132+
if (!b.getType().isInteger(32))
133+
b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
134+
135+
auto result = rewriter.create<tosa::ApplyScaleOp>(
136+
loc, rewriter.getI32Type(), a, b, shiftConst,
137+
rewriter.getBoolAttr(false));
138+
139+
if (elementTy.isInteger(32))
140+
return result;
138141

139-
return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
142+
return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
143+
}
144+
145+
int aWidth = a.getType().getIntOrFloatBitWidth();
146+
int bWidth = b.getType().getIntOrFloatBitWidth();
147+
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
148+
149+
if (aWidth < cWidth)
150+
a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
151+
if (bWidth < cWidth)
152+
b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
153+
154+
return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
155+
}
140156
}
141157

142158
// tosa::NegateOp
@@ -990,7 +1006,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
9901006
auto loc = operation->getLoc();
9911007
auto rank =
9921008
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
993-
auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
1009+
// For the mul op we need to avoid expanding the rank of the optional shift
1010+
// input.
1011+
auto operandsToExpand =
1012+
isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
1013+
1014+
auto expandedOperands =
1015+
expandInputRanks(rewriter, loc, operandsToExpand, rank);
9941016
auto [targetShape, masterOperands] =
9951017
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
9961018
auto broadcastOperands = broadcastDynamicDimensions(

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
12301230
auto rhsAttr =
12311231
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
12321232

1233-
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
1233+
// Result right shift on i32_t data type only. For simplification, synthesize
1234+
// a zero shift for other data type.
1235+
int32_t shift = 0;
1236+
if (resultETy.isInteger(32)) {
1237+
ElementsAttr shift_elem;
1238+
if (getShift().getImpl()) {
1239+
if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1240+
// cannot be folded when the shift value is unknown.
1241+
return {};
1242+
shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1243+
}
1244+
}
12341245

12351246
if (rhsTy == resultTy) {
12361247
if (isSplatZero(resultETy, lhsAttr))
@@ -1245,7 +1256,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
12451256
return lhs;
12461257
}
12471258

1248-
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
1259+
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
12491260
}
12501261

12511262
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {

0 commit comments

Comments
 (0)