@@ -34,10 +34,10 @@ using namespace mlir::torch::Torch;
34
34
35
35
namespace {
36
36
37
- // These legalizations are for unary ops with only for floating point datatypes.
38
- // There is no supported quantized integer mode for these.
37
+ // These legalizations are for unary ops with promoting input to floating- point
38
+ // datatypes only. There is no supported quantized integer mode for these.
39
39
template <typename AtenOpT, typename TosaOpT>
40
- class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern <AtenOpT> {
40
+ class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern <AtenOpT> {
41
41
public:
42
42
using OpConversionPattern<AtenOpT>::OpConversionPattern;
43
43
using OpAdaptor = typename AtenOpT::Adaptor;
@@ -51,17 +51,22 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
51
51
return rewriter.notifyMatchFailure (op,
52
52
" Only Tensor types supported in TOSA" );
53
53
54
- if (isa<mlir::FloatType>(selfTy.getElementType ())) {
55
- rewriter.replaceOpWithNewOp <TosaOpT>(
56
- op,
57
- OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
58
- op.getType ()),
59
- self);
60
- return success ();
61
- } else {
54
+ auto resultTy = dyn_cast<TensorType>(
55
+ OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
56
+ op.getType ()));
57
+
58
+ if (!isa<mlir::FloatType>(resultTy.getElementType ()))
62
59
return rewriter.notifyMatchFailure (
63
- op, " Only floating-point datatype legalization supported" );
64
- }
60
+ op, " Only floating-point datatype result types are supported" );
61
+
62
+ // Non floating point inputs are not supported in TOSA so we cast the input
63
+ // to result type
64
+ if (!isa<mlir::FloatType>(selfTy.getElementType ()))
65
+ self = tosa::promoteType (rewriter, self, resultTy);
66
+
67
+ rewriter.replaceOpWithNewOp <TosaOpT>(op, resultTy, self);
68
+
69
+ return success ();
65
70
}
66
71
};
67
72
@@ -2922,24 +2927,32 @@ template <>
2922
2927
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
2923
2928
AtenLog2Op op, OpAdaptor adaptor,
2924
2929
ConversionPatternRewriter &rewriter) const {
2930
+ auto self = adaptor.getSelf ();
2925
2931
2926
2932
// Not a tensor type.
2927
- auto selfType = dyn_cast<TensorType>(adaptor. getSelf () .getType ());
2933
+ auto selfType = dyn_cast<TensorType>(self .getType ());
2928
2934
if (!selfType)
2929
2935
return rewriter.notifyMatchFailure (
2930
2936
op, " Only tensor types are currently supported" );
2931
2937
2938
+ auto outType =
2939
+ dyn_cast<TensorType>(getTypeConverter ()->convertType (op.getType ()));
2940
+
2941
+ // If input is not a float type then cast it to output type
2942
+ auto selfElemTy = selfType.getElementType ();
2943
+ if (!isa<mlir::FloatType>(selfElemTy))
2944
+ self = tosa::promoteType (rewriter, self, outType);
2945
+
2932
2946
// Constant value of ln2.
2933
2947
SmallVector<int64_t > ln2Shape (selfType.getRank (), 1 );
2934
2948
auto ln2Op = tosa::getConstTensor<float >(rewriter, op, {0 .69314718056f },
2935
- ln2Shape, selfType .getElementType ())
2949
+ ln2Shape, outType .getElementType ())
2936
2950
.value ();
2951
+
2937
2952
auto rcpOp =
2938
2953
rewriter.create <tosa::ReciprocalOp>(op.getLoc (), ln2Op.getType (), ln2Op);
2939
2954
2940
- auto outType = getTypeConverter ()->convertType (op.getType ());
2941
- auto logOp =
2942
- rewriter.create <tosa::LogOp>(op.getLoc (), outType, adaptor.getSelf ());
2955
+ auto logOp = rewriter.create <tosa::LogOp>(op.getLoc (), outType, self);
2943
2956
rewriter.replaceOpWithNewOp <tosa::MulOp>(op, outType, logOp, rcpOp,
2944
2957
/* shift=*/ 0 );
2945
2958
@@ -8025,6 +8038,166 @@ class ConvertUpsampleNearest2dForward : public OpConversionPattern<AtenOpT> {
8025
8038
}
8026
8039
};
8027
8040
8041
+ // Legalization for aten.logit
8042
+ template <>
8043
+ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
8044
+ AtenLogitOp op, OpAdaptor adaptor,
8045
+ ConversionPatternRewriter &rewriter) const {
8046
+ // Logit formula:
8047
+ // result = log(zi / (1 - zi))
8048
+ // Where: if eps is not None:
8049
+ // zi = input clampled to [eps, 1 - eps]
8050
+ // else:
8051
+ // zi = input
8052
+ auto self = adaptor.getSelf ();
8053
+
8054
+ auto selfType = dyn_cast<TensorType>(self.getType ());
8055
+ if (!selfType)
8056
+ return rewriter.notifyMatchFailure (op, " Only tensor types are supported" );
8057
+
8058
+ auto resultType =
8059
+ dyn_cast<TensorType>(typeConverter->convertType (op.getType ()));
8060
+ auto resultElemTy = resultType.getElementType ();
8061
+
8062
+ if (!isa<mlir::FloatType>(resultElemTy))
8063
+ return rewriter.notifyMatchFailure (
8064
+ op, " Only floating-point datatype result types are supported" );
8065
+
8066
+ // If input is not a float type then cast it to result element type
8067
+ auto selfElemTy = selfType.getElementType ();
8068
+ if (!isa<mlir::FloatType>(selfElemTy))
8069
+ self = tosa::promoteType (rewriter, self, resultType);
8070
+
8071
+ bool isEpsNone = isa<Torch::NoneType>(op.getEps ().getType ());
8072
+
8073
+ double eps;
8074
+ if (!isEpsNone && !matchPattern (op.getEps (), m_TorchConstantFloat (&eps)))
8075
+ return rewriter.notifyMatchFailure (op,
8076
+ " Non-const eps value is not supported" );
8077
+
8078
+ auto zi = self;
8079
+
8080
+ // Clamp input to [eps, 1 - eps] when eps is not None
8081
+ if (!isEpsNone) {
8082
+ zi = rewriter
8083
+ .create <tosa::ClampOp>(
8084
+ op->getLoc (), resultType, self,
8085
+ rewriter.getI64IntegerAttr (static_cast <int64_t >(eps)),
8086
+ rewriter.getI64IntegerAttr (static_cast <int64_t >(1 - eps)),
8087
+ rewriter.getF32FloatAttr (static_cast <float >(eps)),
8088
+ rewriter.getF32FloatAttr (static_cast <float >(1 - eps)))
8089
+ .getResult ();
8090
+ }
8091
+
8092
+ auto one =
8093
+ tosa::getConstTensor<float >(rewriter, op, 1 .0f , {}, resultElemTy).value ();
8094
+
8095
+ auto oneMinusZi =
8096
+ rewriter.create <tosa::SubOp>(op->getLoc (), resultType, one, zi);
8097
+
8098
+ auto oneMinusZiReciprocal = rewriter.create <tosa::ReciprocalOp>(
8099
+ op->getLoc (), resultType, oneMinusZi.getResult ());
8100
+
8101
+ auto mulOp = rewriter.create <tosa::MulOp>(op->getLoc (), resultType, zi,
8102
+ oneMinusZiReciprocal.getResult (),
8103
+ /* shift=*/ 0 );
8104
+
8105
+ auto result =
8106
+ rewriter.create <tosa::LogOp>(op->getLoc (), resultType, mulOp.getResult ());
8107
+
8108
+ rewriter.replaceOp (op, {result.getResult ()});
8109
+
8110
+ return success ();
8111
+ }
8112
+
8113
+ // Legalization for aten.log1p
8114
+ template <>
8115
+ LogicalResult ConvertAtenOp<AtenLog1pOp>::matchAndRewrite(
8116
+ AtenLog1pOp op, OpAdaptor adaptor,
8117
+ ConversionPatternRewriter &rewriter) const {
8118
+ // log1p formula:
8119
+ // yi = log(xi + 1)
8120
+ auto self = adaptor.getSelf ();
8121
+
8122
+ auto selfType = dyn_cast<TensorType>(self.getType ());
8123
+ if (!selfType)
8124
+ return rewriter.notifyMatchFailure (op, " Only tensor types are supported" );
8125
+
8126
+ auto resultType =
8127
+ dyn_cast<TensorType>(typeConverter->convertType (op.getType ()));
8128
+ auto resultElemTy = resultType.getElementType ();
8129
+
8130
+ if (!isa<mlir::FloatType>(resultElemTy))
8131
+ return rewriter.notifyMatchFailure (
8132
+ op, " Only floating-point datatype result types are supported" );
8133
+
8134
+ // If input is not a float type then cast it to result element type
8135
+ auto selfElemTy = selfType.getElementType ();
8136
+ if (!isa<mlir::FloatType>(selfElemTy))
8137
+ self = tosa::promoteType (rewriter, self, resultType);
8138
+
8139
+ auto one =
8140
+ tosa::getConstTensor<float >(rewriter, op, 1 .0f , {}, resultElemTy).value ();
8141
+
8142
+ auto addOp =
8143
+ rewriter.create <tosa::AddOp>(op->getLoc (), resultType, self, one);
8144
+
8145
+ auto result =
8146
+ rewriter.create <tosa::LogOp>(op->getLoc (), resultType, addOp.getResult ());
8147
+
8148
+ rewriter.replaceOp (op, {result.getResult ()});
8149
+
8150
+ return success ();
8151
+ }
8152
+
8153
+ // Legalization for aten.log10
8154
+ template <>
8155
+ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
8156
+ AtenLog10Op op, OpAdaptor adaptor,
8157
+ ConversionPatternRewriter &rewriter) const {
8158
+ // log10 formula (using log base changing formula since TOSA doesn't have a
8159
+ // builtin log10 op):
8160
+ // yi = log(xi) / log(10)
8161
+ auto self = adaptor.getSelf ();
8162
+
8163
+ auto selfType = dyn_cast<TensorType>(self.getType ());
8164
+ if (!selfType)
8165
+ return rewriter.notifyMatchFailure (op, " Only tensor types are supported" );
8166
+
8167
+ auto resultType =
8168
+ dyn_cast<TensorType>(typeConverter->convertType (op.getType ()));
8169
+ auto resultElemTy = resultType.getElementType ();
8170
+
8171
+ if (!isa<mlir::FloatType>(resultElemTy))
8172
+ return rewriter.notifyMatchFailure (
8173
+ op, " Only floating-point datatype result types are supported" );
8174
+
8175
+ // If input is not a float type then cast it to result element type
8176
+ auto selfElemTy = selfType.getElementType ();
8177
+ if (!isa<mlir::FloatType>(selfElemTy))
8178
+ self = tosa::promoteType (rewriter, self, resultType);
8179
+
8180
+ auto ten = tosa::getConstTensor<float >(rewriter, op, 10 .0f , {}, resultElemTy)
8181
+ .value ();
8182
+
8183
+ auto logOfSelf = rewriter.create <tosa::LogOp>(op->getLoc (), resultType, self);
8184
+
8185
+ auto constType = RankedTensorType::get ({}, resultElemTy);
8186
+
8187
+ auto logOfTen = rewriter.create <tosa::LogOp>(op->getLoc (), constType, ten);
8188
+
8189
+ auto reciprocalOp = rewriter.create <tosa::ReciprocalOp>(
8190
+ op->getLoc (), constType, logOfTen.getResult ());
8191
+
8192
+ auto result = rewriter.create <tosa::MulOp>(
8193
+ op->getLoc (), resultType, logOfSelf.getResult (), reciprocalOp.getResult (),
8194
+ /* shift=*/ 0 );
8195
+
8196
+ rewriter.replaceOp (op, {result.getResult ()});
8197
+
8198
+ return success ();
8199
+ }
8200
+
8028
8201
} // namespace
8029
8202
8030
8203
// -----------------------------------------------------------------------------
@@ -8069,13 +8242,13 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
8069
8242
8070
8243
RewritePatternSet patterns (context);
8071
8244
8072
- #define INSERT_UNARY_FPONLY_PATTERN (AtenOp, TosaOp ) \
8245
+ #define INSERT_UNARY_PROMOTE_TO_FP_PATTERN (AtenOp, TosaOp ) \
8073
8246
target.addIllegalOp <AtenOp>(); \
8074
- patterns.add <ConvertAtenUnaryFPOnlyOp <AtenOp, TosaOp>>(typeConverter, \
8075
- context);
8076
- INSERT_UNARY_FPONLY_PATTERN (AtenLogOp, tosa::LogOp)
8077
- INSERT_UNARY_FPONLY_PATTERN (AtenExpOp, tosa::ExpOp)
8078
- #undef INSERT_UNARY_FPONLY_PATTERN
8247
+ patterns.add <ConvertAtenUnaryPromoteToFPOp <AtenOp, TosaOp>>(typeConverter, \
8248
+ context);
8249
+ INSERT_UNARY_PROMOTE_TO_FP_PATTERN (AtenLogOp, tosa::LogOp)
8250
+ INSERT_UNARY_PROMOTE_TO_FP_PATTERN (AtenExpOp, tosa::ExpOp)
8251
+ #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN
8079
8252
8080
8253
#define INSERT_UNARY_PATTERN (AtenOp, TosaOp ) \
8081
8254
target.addIllegalOp <AtenOp>(); \
@@ -8364,6 +8537,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
8364
8537
INSERT_ATENOP_PATTERN (AtenReplicationPad2dOp);
8365
8538
INSERT_ATENOP_PATTERN (PrimsSplitDimOp);
8366
8539
INSERT_ATENOP_PATTERN (AtenOuterOp);
8540
+ INSERT_ATENOP_PATTERN (AtenLogitOp);
8541
+ INSERT_ATENOP_PATTERN (AtenLog1pOp);
8542
+ INSERT_ATENOP_PATTERN (AtenLog10Op);
8367
8543
#undef INSERT_ATENOP_PATTERN
8368
8544
8369
8545
#define INSERT_CLONE_ATENOP_PATTERN (AtenOp ) \
0 commit comments