Skip to content

Commit 92d0f04

Browse files
[TOSA] Add logit, log1p, log10 and add promote type to unary fponly ops (llvm#3900)
* Add Torch to TOSA legalization for the following ops: - torch.aten.logit - torch.aten.log1p - torch.aten.log10 * Add promote to FP to FP-only TOSA ops like log and exp * Update xfail with new e2e results * Add new LIT tests to basic.mlir Change-Id: I1cd7ec6964373dbaf08d419a806b3d735b830655 Signed-off-by: Justin Ngo <[email protected]>
1 parent 8711d3e commit 92d0f04

File tree

3 files changed

+387
-35
lines changed

3 files changed

+387
-35
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+200-24
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ using namespace mlir::torch::Torch;
3434

3535
namespace {
3636

37-
// These legalizations are for unary ops with only for floating point datatypes.
38-
// There is no supported quantized integer mode for these.
37+
// These legalizations are for unary ops with promoting input to floating-point
38+
// datatypes only. There is no supported quantized integer mode for these.
3939
template <typename AtenOpT, typename TosaOpT>
40-
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
40+
class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern<AtenOpT> {
4141
public:
4242
using OpConversionPattern<AtenOpT>::OpConversionPattern;
4343
using OpAdaptor = typename AtenOpT::Adaptor;
@@ -51,17 +51,22 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
5151
return rewriter.notifyMatchFailure(op,
5252
"Only Tensor types supported in TOSA");
5353

54-
if (isa<mlir::FloatType>(selfTy.getElementType())) {
55-
rewriter.replaceOpWithNewOp<TosaOpT>(
56-
op,
57-
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
58-
op.getType()),
59-
self);
60-
return success();
61-
} else {
54+
auto resultTy = dyn_cast<TensorType>(
55+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
56+
op.getType()));
57+
58+
if (!isa<mlir::FloatType>(resultTy.getElementType()))
6259
return rewriter.notifyMatchFailure(
63-
op, "Only floating-point datatype legalization supported");
64-
}
60+
op, "Only floating-point datatype result types are supported");
61+
62+
// Non floating point inputs are not supported in TOSA so we cast the input
63+
// to result type
64+
if (!isa<mlir::FloatType>(selfTy.getElementType()))
65+
self = tosa::promoteType(rewriter, self, resultTy);
66+
67+
rewriter.replaceOpWithNewOp<TosaOpT>(op, resultTy, self);
68+
69+
return success();
6570
}
6671
};
6772

@@ -2922,24 +2927,32 @@ template <>
29222927
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
29232928
AtenLog2Op op, OpAdaptor adaptor,
29242929
ConversionPatternRewriter &rewriter) const {
2930+
auto self = adaptor.getSelf();
29252931

29262932
// Not a tensor type.
2927-
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
2933+
auto selfType = dyn_cast<TensorType>(self.getType());
29282934
if (!selfType)
29292935
return rewriter.notifyMatchFailure(
29302936
op, "Only tensor types are currently supported");
29312937

2938+
auto outType =
2939+
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
2940+
2941+
// If input is not a float type then cast it to output type
2942+
auto selfElemTy = selfType.getElementType();
2943+
if (!isa<mlir::FloatType>(selfElemTy))
2944+
self = tosa::promoteType(rewriter, self, outType);
2945+
29322946
// Constant value of ln2.
29332947
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
29342948
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056f},
2935-
ln2Shape, selfType.getElementType())
2949+
ln2Shape, outType.getElementType())
29362950
.value();
2951+
29372952
auto rcpOp =
29382953
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
29392954

2940-
auto outType = getTypeConverter()->convertType(op.getType());
2941-
auto logOp =
2942-
rewriter.create<tosa::LogOp>(op.getLoc(), outType, adaptor.getSelf());
2955+
auto logOp = rewriter.create<tosa::LogOp>(op.getLoc(), outType, self);
29432956
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, outType, logOp, rcpOp,
29442957
/*shift=*/0);
29452958

@@ -8025,6 +8038,166 @@ class ConvertUpsampleNearest2dForward : public OpConversionPattern<AtenOpT> {
80258038
}
80268039
};
80278040

8041+
// Legalization for aten.logit
8042+
template <>
8043+
LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
8044+
AtenLogitOp op, OpAdaptor adaptor,
8045+
ConversionPatternRewriter &rewriter) const {
8046+
// Logit formula:
8047+
// result = log(zi / (1 - zi))
8048+
// Where: if eps is not None:
8049+
// zi = input clampled to [eps, 1 - eps]
8050+
// else:
8051+
// zi = input
8052+
auto self = adaptor.getSelf();
8053+
8054+
auto selfType = dyn_cast<TensorType>(self.getType());
8055+
if (!selfType)
8056+
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
8057+
8058+
auto resultType =
8059+
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
8060+
auto resultElemTy = resultType.getElementType();
8061+
8062+
if (!isa<mlir::FloatType>(resultElemTy))
8063+
return rewriter.notifyMatchFailure(
8064+
op, "Only floating-point datatype result types are supported");
8065+
8066+
// If input is not a float type then cast it to result element type
8067+
auto selfElemTy = selfType.getElementType();
8068+
if (!isa<mlir::FloatType>(selfElemTy))
8069+
self = tosa::promoteType(rewriter, self, resultType);
8070+
8071+
bool isEpsNone = isa<Torch::NoneType>(op.getEps().getType());
8072+
8073+
double eps;
8074+
if (!isEpsNone && !matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
8075+
return rewriter.notifyMatchFailure(op,
8076+
"Non-const eps value is not supported");
8077+
8078+
auto zi = self;
8079+
8080+
// Clamp input to [eps, 1 - eps] when eps is not None
8081+
if (!isEpsNone) {
8082+
zi = rewriter
8083+
.create<tosa::ClampOp>(
8084+
op->getLoc(), resultType, self,
8085+
rewriter.getI64IntegerAttr(static_cast<int64_t>(eps)),
8086+
rewriter.getI64IntegerAttr(static_cast<int64_t>(1 - eps)),
8087+
rewriter.getF32FloatAttr(static_cast<float>(eps)),
8088+
rewriter.getF32FloatAttr(static_cast<float>(1 - eps)))
8089+
.getResult();
8090+
}
8091+
8092+
auto one =
8093+
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();
8094+
8095+
auto oneMinusZi =
8096+
rewriter.create<tosa::SubOp>(op->getLoc(), resultType, one, zi);
8097+
8098+
auto oneMinusZiReciprocal = rewriter.create<tosa::ReciprocalOp>(
8099+
op->getLoc(), resultType, oneMinusZi.getResult());
8100+
8101+
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), resultType, zi,
8102+
oneMinusZiReciprocal.getResult(),
8103+
/*shift=*/0);
8104+
8105+
auto result =
8106+
rewriter.create<tosa::LogOp>(op->getLoc(), resultType, mulOp.getResult());
8107+
8108+
rewriter.replaceOp(op, {result.getResult()});
8109+
8110+
return success();
8111+
}
8112+
8113+
// Legalization for aten.log1p
8114+
template <>
8115+
LogicalResult ConvertAtenOp<AtenLog1pOp>::matchAndRewrite(
8116+
AtenLog1pOp op, OpAdaptor adaptor,
8117+
ConversionPatternRewriter &rewriter) const {
8118+
// log1p formula:
8119+
// yi = log(xi + 1)
8120+
auto self = adaptor.getSelf();
8121+
8122+
auto selfType = dyn_cast<TensorType>(self.getType());
8123+
if (!selfType)
8124+
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
8125+
8126+
auto resultType =
8127+
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
8128+
auto resultElemTy = resultType.getElementType();
8129+
8130+
if (!isa<mlir::FloatType>(resultElemTy))
8131+
return rewriter.notifyMatchFailure(
8132+
op, "Only floating-point datatype result types are supported");
8133+
8134+
// If input is not a float type then cast it to result element type
8135+
auto selfElemTy = selfType.getElementType();
8136+
if (!isa<mlir::FloatType>(selfElemTy))
8137+
self = tosa::promoteType(rewriter, self, resultType);
8138+
8139+
auto one =
8140+
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();
8141+
8142+
auto addOp =
8143+
rewriter.create<tosa::AddOp>(op->getLoc(), resultType, self, one);
8144+
8145+
auto result =
8146+
rewriter.create<tosa::LogOp>(op->getLoc(), resultType, addOp.getResult());
8147+
8148+
rewriter.replaceOp(op, {result.getResult()});
8149+
8150+
return success();
8151+
}
8152+
8153+
// Legalization for aten.log10
8154+
template <>
8155+
LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
8156+
AtenLog10Op op, OpAdaptor adaptor,
8157+
ConversionPatternRewriter &rewriter) const {
8158+
// log10 formula (using log base changing formula since TOSA doesn't have a
8159+
// builtin log10 op):
8160+
// yi = log(xi) / log(10)
8161+
auto self = adaptor.getSelf();
8162+
8163+
auto selfType = dyn_cast<TensorType>(self.getType());
8164+
if (!selfType)
8165+
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
8166+
8167+
auto resultType =
8168+
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
8169+
auto resultElemTy = resultType.getElementType();
8170+
8171+
if (!isa<mlir::FloatType>(resultElemTy))
8172+
return rewriter.notifyMatchFailure(
8173+
op, "Only floating-point datatype result types are supported");
8174+
8175+
// If input is not a float type then cast it to result element type
8176+
auto selfElemTy = selfType.getElementType();
8177+
if (!isa<mlir::FloatType>(selfElemTy))
8178+
self = tosa::promoteType(rewriter, self, resultType);
8179+
8180+
auto ten = tosa::getConstTensor<float>(rewriter, op, 10.0f, {}, resultElemTy)
8181+
.value();
8182+
8183+
auto logOfSelf = rewriter.create<tosa::LogOp>(op->getLoc(), resultType, self);
8184+
8185+
auto constType = RankedTensorType::get({}, resultElemTy);
8186+
8187+
auto logOfTen = rewriter.create<tosa::LogOp>(op->getLoc(), constType, ten);
8188+
8189+
auto reciprocalOp = rewriter.create<tosa::ReciprocalOp>(
8190+
op->getLoc(), constType, logOfTen.getResult());
8191+
8192+
auto result = rewriter.create<tosa::MulOp>(
8193+
op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(),
8194+
/*shift=*/0);
8195+
8196+
rewriter.replaceOp(op, {result.getResult()});
8197+
8198+
return success();
8199+
}
8200+
80288201
} // namespace
80298202

80308203
// -----------------------------------------------------------------------------
@@ -8069,13 +8242,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
80698242

80708243
RewritePatternSet patterns(context);
80718244

8072-
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \
8245+
#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \
80738246
target.addIllegalOp<AtenOp>(); \
8074-
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, TosaOp>>(typeConverter, \
8075-
context);
8076-
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp)
8077-
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp)
8078-
#undef INSERT_UNARY_FPONLY_PATTERN
8247+
patterns.add<ConvertAtenUnaryPromoteToFPOp<AtenOp, TosaOp>>(typeConverter, \
8248+
context);
8249+
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp)
8250+
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp)
8251+
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN
80798252

80808253
#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \
80818254
target.addIllegalOp<AtenOp>(); \
@@ -8364,6 +8537,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
83648537
INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp);
83658538
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
83668539
INSERT_ATENOP_PATTERN(AtenOuterOp);
8540+
INSERT_ATENOP_PATTERN(AtenLogitOp);
8541+
INSERT_ATENOP_PATTERN(AtenLog1pOp);
8542+
INSERT_ATENOP_PATTERN(AtenLog10Op);
83678543
#undef INSERT_ATENOP_PATTERN
83688544

83698545
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

projects/pt1/e2e_testing/xfail_sets.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -1729,6 +1729,22 @@
17291729
"RandIntPinMemoryModule_basic",
17301730
"RenormModuleFloat16_basic",
17311731
"SplitDimStaticModule_basic",
1732+
"Deg2radModule_basic",
1733+
"ElementwiseExpIntModule_basic",
1734+
"ElementwiseLog10IntModule_basic",
1735+
"ElementwiseLog10Module_basic",
1736+
"ElementwiseLog1pModule_basic",
1737+
"ElementwiseLog2IntModule_basic",
1738+
"ElementwiseLogIntModule_basic",
1739+
"ElementwiseLogitModule_basic",
1740+
"ElementwiseMishModule_basic",
1741+
"L1LossMeanReductionModule_basic",
1742+
"L1LossNoReductionModule_basic",
1743+
"L1LossSumReductionModule_basic",
1744+
"RandIntLowModule_basic",
1745+
"RandIntModule_basic",
1746+
"RandIntPinMemoryModule_basic",
1747+
"SoftplusModule_basic",
17321748
"ReflectionPad1dModule2dInput_Right",
17331749
"ReflectionPad1dModule2dInput_basic",
17341750
"ReflectionPad1dModule3dInput_Left",
@@ -3416,6 +3432,8 @@
34163432
}
34173433

34183434
FX_IMPORTER_TOSA_XFAIL_SET = {
3435+
"AtenFftRfft2DLastDim_basic",
3436+
"AtenFftRfft2DMiddleDim_basic",
34193437
"IsInfiniteModule_basic",
34203438
"LayerNormFwAndBwModule_basic",
34213439
"LayerNormManualFwAndBwModule_basic",
@@ -3627,17 +3645,9 @@
36273645
"ElementwiseDequantizePerChannelModule_basic",
36283646
"ElementwiseDequantizePerTensorModule_basic",
36293647
"ElementwiseErfIntModule_basic",
3630-
"ElementwiseExpIntModule_basic",
36313648
"ElementwiseExpm1IntModule_basic",
36323649
"ElementwiseExpm1Module_basic",
36333650
"ElementwiseIntTensorLtFloatScalarModule_basic",
3634-
"ElementwiseLog10IntModule_basic",
3635-
"ElementwiseLog10Module_basic",
3636-
"ElementwiseLog1pModule_basic",
3637-
"ElementwiseLog2IntModule_basic",
3638-
"ElementwiseLogIntModule_basic",
3639-
"ElementwiseLogitModule_basic",
3640-
"ElementwiseMishModule_basic",
36413651
"ElementwiseMulTensorComplexDiffModule_basic",
36423652
"ElementwiseMulTensorComplexModule_basic",
36433653
"ElementwiseQuantizePerTensorModule_basic",
@@ -3755,6 +3765,7 @@
37553765
"NumelModule_basic",
37563766
"NumelZeroRankModule_basic",
37573767
"OnesLikeModule_falsePinMemory",
3768+
"PowIntIntModule_basic",
37583769
"PowIntFloatModule_basic",
37593770
"PrimMaxIntModule_basic",
37603771
"PrimMinIntDynamicModule_basic",
@@ -3822,7 +3833,6 @@
38223833
"SliceOutOfLowerBoundEndIndexModule_basic",
38233834
"SliceOutOfLowerBoundStartIndexModule_basic",
38243835
"SliceSizeTwoStepModule_basic",
3825-
"SoftplusModule_basic",
38263836
"SortIntListReverse_basic",
38273837
"SortIntList_basic",
38283838
"SortTensorDescending_basic",
@@ -3902,6 +3912,11 @@
39023912
}
39033913

39043914
ONNX_TOSA_XFAIL_SET = {
3915+
"AtenFftRfft2DLastDim_basic",
3916+
"AtenFftRfft2DMiddleDim_basic",
3917+
"PowFloatIntModule_basic",
3918+
"PowIntFloatModule_basic",
3919+
"PowIntIntModule_basic",
39053920
"ColumnStack0dModule_basic",
39063921
"ColumnStack1dModule_basic",
39073922
"ColumnStackBasicIntModule_basic",
@@ -4311,7 +4326,6 @@
43114326
"ElementwiseLog2IntModule_basic",
43124327
"ElementwiseLogIntModule_basic",
43134328
"ElementwiseLtDiffWidthScalarModule_basic",
4314-
"ElementwiseMishModule_basic",
43154329
"ElementwiseMulScalarModule_basic",
43164330
"ElementwiseMulTensorComplexDiffModule_basic",
43174331
"ElementwiseMulTensorComplexModule_basic",
@@ -4755,7 +4769,6 @@
47554769
"SoftmaxIntModule_basic",
47564770
"SoftmaxIntNegDimModule_basic",
47574771
"SoftmaxIntNonNoneDtypeModule_basic",
4758-
"SoftplusModule_basic",
47594772
"SortIntListReverse_basic",
47604773
"SortIntList_basic",
47614774
"SortTensorDescending_basic",

0 commit comments

Comments
 (0)