@@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
714
714
ImplicitLocOpBuilder b (op.getLoc (), rewriter);
715
715
716
716
// Rank 0 item op prop
717
- if (selfTy.getSizes ().size () == 0 ) {
717
+ if (selfTy.getSizes ().empty () ) {
718
718
auto numToTensor = self.getDefiningOp <Torch::PrimNumToTensorScalarOp>();
719
719
auto squeezeDim = self.getDefiningOp <AtenSqueezeDimOp>();
720
720
if (!squeezeDim && !numToTensor)
@@ -746,6 +746,109 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
746
746
};
747
747
} // namespace
748
748
749
+ namespace {
750
+
751
+ LogicalResult convertOpFoldResults (ImplicitLocOpBuilder &b,
752
+ SmallVector<OpFoldResult> &converted,
753
+ SmallVector<OpFoldResult> &elements,
754
+ Type inputDtype, Type resultDtype) {
755
+ auto inputIsInt = dyn_cast<mlir::IntegerType>(inputDtype);
756
+ auto resultIsInt = dyn_cast<mlir::IntegerType>(resultDtype);
757
+ if (!inputIsInt && !isa<mlir::FloatType>(inputDtype))
758
+ return failure ();
759
+ if (!resultIsInt && !isa<mlir::FloatType>(resultDtype))
760
+ return failure ();
761
+
762
+ // if dtypes are both int or both float, no conversion needed
763
+ if (static_cast <bool >(inputIsInt) == static_cast <bool >(resultIsInt)) {
764
+ converted = elements;
765
+ return success ();
766
+ }
767
+
768
+ if (resultIsInt) {
769
+ for (auto &e : elements) {
770
+ auto eValue = dyn_cast<Value>(e);
771
+ if (eValue) {
772
+ converted.push_back (b.createOrFold <AtenIntScalarOp>(eValue));
773
+ continue ;
774
+ }
775
+ auto eAttr = dyn_cast<Attribute>(e);
776
+ auto eFloatAttr = dyn_cast_or_null<FloatAttr>(eAttr);
777
+ if (!eFloatAttr)
778
+ return failure ();
779
+
780
+ converted.push_back (IntegerAttr::get (
781
+ resultDtype, static_cast <int64_t >(eFloatAttr.getValueAsDouble ())));
782
+ }
783
+ return success ();
784
+ }
785
+
786
+ // result is float
787
+ for (auto &e : elements) {
788
+ auto eValue = dyn_cast<Value>(e);
789
+ if (eValue) {
790
+ converted.push_back (b.createOrFold <AtenFloatScalarOp>(eValue));
791
+ continue ;
792
+ }
793
+ auto eAttr = dyn_cast<Attribute>(e);
794
+ auto eIntAttr = dyn_cast<IntegerAttr>(eAttr);
795
+ if (!eIntAttr)
796
+ return failure ();
797
+
798
+ auto eInt = (inputIsInt.isSigned ()) ? eIntAttr.getValue ().getSExtValue ()
799
+ : eIntAttr.getValue ().getZExtValue ();
800
+ converted.push_back (FloatAttr::get (resultDtype, static_cast <double >(eInt)));
801
+ }
802
+ return success ();
803
+ }
804
+
805
+ class PropagateAtenToDtypePattern : public OpRewritePattern <AtenToDtypeOp> {
806
+ public:
807
+ using OpRewritePattern<AtenToDtypeOp>::OpRewritePattern;
808
+ LogicalResult matchAndRewrite (AtenToDtypeOp op,
809
+ PatternRewriter &rewriter) const override {
810
+ bool nonBlocking, copyArg;
811
+ // The non_blocking arg must be `False`.
812
+ if (!matchPattern (op.getNonBlocking (), m_TorchConstantBool (&nonBlocking)) ||
813
+ nonBlocking)
814
+ return failure ();
815
+ // The copy arg must be `False`.
816
+ if (!matchPattern (op.getCopy (), m_TorchConstantBool (©Arg)) || copyArg)
817
+ return failure ();
818
+ // The memory_format arg must be `none`.
819
+ if (!isa<Torch::NoneType>(op.getMemoryFormat ().getType ()))
820
+ return failure ();
821
+
822
+ auto inputType = dyn_cast<ValueTensorType>(op.getSelf ().getType ());
823
+ auto resultType = dyn_cast<ValueTensorType>(op.getType ());
824
+ if (!inputType || !resultType || !inputType.hasDtype () ||
825
+ !resultType.hasDtype ())
826
+ return failure ();
827
+ auto inputDtype = inputType.getDtype ();
828
+ auto resultDtype = resultType.getDtype ();
829
+
830
+ SmallVector<OpFoldResult> elements;
831
+ if (failed (getListFromTensor (op.getSelf (), elements)))
832
+ return failure ();
833
+
834
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
835
+ SmallVector<OpFoldResult> converted;
836
+ if (failed (convertOpFoldResults (b, converted, elements, inputDtype,
837
+ resultDtype)))
838
+ return rewriter.notifyMatchFailure (
839
+ op, " Unhandled attribute type encountered." );
840
+
841
+ SmallVector<Value> vals;
842
+ if (failed (materializeFolds (b, converted, vals)))
843
+ return failure ();
844
+
845
+ Value result = constructAtenTensorOpFromList (b, op.getType (), vals);
846
+ rewriter.replaceOp (op, result);
847
+ return success ();
848
+ }
849
+ };
850
+ } // namespace
851
+
749
852
namespace {
750
853
template <typename AtenViewLikeOp>
751
854
class PropagateAtenViewLikePattern : public OpRewritePattern <AtenViewLikeOp> {
@@ -828,7 +931,7 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
828
931
if (failed (materializeFolds (b, resultFolds, resultVals)))
829
932
return failure ();
830
933
831
- if (resultTy.getSizes ().size () == 0 ) {
934
+ if (resultTy.getSizes ().empty () ) {
832
935
rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
833
936
op, resultTy, resultVals.front ());
834
937
return success ();
@@ -841,6 +944,48 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
841
944
};
842
945
} // namespace
843
946
947
+ namespace {
948
+ template <typename OpTy, typename ScalarOpTy>
949
+ class PropagateAtenUnaryPattern : public OpRewritePattern <OpTy> {
950
+ public:
951
+ using OpRewritePattern<OpTy>::OpRewritePattern;
952
+ LogicalResult matchAndRewrite (OpTy op,
953
+ PatternRewriter &rewriter) const override {
954
+ // Check type
955
+ auto resultTy = cast<ValueTensorType>(op.getType ());
956
+ if (resultTy.getSizes ().size () > 1 )
957
+ return rewriter.notifyMatchFailure (op, " unsupported: rank > 1" );
958
+ if (!resultTy.hasDtype () || !isa<mlir::IntegerType>(resultTy.getDtype ()))
959
+ return rewriter.notifyMatchFailure (op, " not an int type" );
960
+
961
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
962
+ SmallVector<OpFoldResult> selfFold;
963
+ if (failed (getListFromTensor (op.getSelf (), selfFold)))
964
+ return failure ();
965
+ SmallVector<Value> selfVals;
966
+ if (failed (materializeFolds (b, selfFold, selfVals)))
967
+ return failure ();
968
+ SmallVector<OpFoldResult> resultFolds;
969
+ for (uint64_t i = 0 ; i < selfVals.size (); i++) {
970
+ resultFolds.push_back (
971
+ b.createOrFold <ScalarOpTy>(selfVals[i].getType (), selfVals[i]));
972
+ }
973
+ SmallVector<Value> resultVals;
974
+ if (failed (materializeFolds (b, resultFolds, resultVals)))
975
+ return failure ();
976
+
977
+ if (resultTy.getSizes ().size () == 0 ) {
978
+ rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
979
+ op, resultTy, resultVals.front ());
980
+ return success ();
981
+ }
982
+
983
+ Value result = constructAtenTensorOpFromList (b, resultTy, resultVals);
984
+ rewriter.replaceOp (op, result);
985
+ return success ();
986
+ }
987
+ };
988
+ } // namespace
844
989
// / ------ Fold Patterns ------ ///
845
990
// These are shape-specific folding patterns
846
991
@@ -915,19 +1060,22 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
915
1060
auto resultTy = cast<BaseTensorType>(op.getType ());
916
1061
if (!resultTy.hasSizes () || !resultTy.areAllSizesKnown ())
917
1062
return rewriter.notifyMatchFailure (op, " dynamic output shape" );
1063
+ if (resultTy.getSizes ().size () == 0 ) {
1064
+ rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
1065
+ op, op.getType (), elements.front ());
1066
+ return success ();
1067
+ }
918
1068
919
1069
auto loc = op.getLoc ();
920
1070
SmallVector<Value> sizes;
921
1071
for (auto size : resultTy.getSizes ())
922
1072
sizes.push_back (rewriter.create <Torch::ConstantIntOp>(
923
1073
loc, rewriter.getI64IntegerAttr (size)));
924
1074
925
- Value one = rewriter.create <Torch::ConstantIntOp>(
926
- loc, rewriter.getType <Torch::IntType>(), 1 );
927
1075
Value sizeList = rewriter.create <Torch::PrimListConstructOp>(
928
1076
loc,
929
1077
rewriter.getType <Torch::ListType>(rewriter.getType <Torch::IntType>()),
930
- one );
1078
+ sizes );
931
1079
932
1080
Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
933
1081
Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
@@ -1031,6 +1179,24 @@ class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
1031
1179
};
1032
1180
} // namespace
1033
1181
1182
+ namespace {
1183
+ // fold ridiculous patterns like size.int -> float.scalar -> int.scalar
1184
+ class FoldAtenIntScalarPattern : public OpRewritePattern <AtenIntScalarOp> {
1185
+ public:
1186
+ using OpRewritePattern<AtenIntScalarOp>::OpRewritePattern;
1187
+ LogicalResult matchAndRewrite (AtenIntScalarOp op,
1188
+ PatternRewriter &rewriter) const override {
1189
+ auto floatScalarOp = op.getA ().getDefiningOp <AtenFloatScalarOp>();
1190
+ if (!floatScalarOp)
1191
+ return failure ();
1192
+ auto sizeOp = floatScalarOp.getA ().getDefiningOp <AtenSizeIntOp>();
1193
+ if (!sizeOp)
1194
+ return failure ();
1195
+ rewriter.replaceOp (op, floatScalarOp.getA ());
1196
+ return success ();
1197
+ }
1198
+ };
1199
+ } // namespace
1034
1200
namespace {
1035
1201
class FoldAtenUnsqueezePattern : public OpRewritePattern <AtenUnsqueezeOp> {
1036
1202
public:
@@ -1182,8 +1348,29 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
1182
1348
if (inputUnmatched == 1 && outputUnmatched > 1 ) {
1183
1349
Value dimVal =
1184
1350
rewriter.create <Torch::ConstantIntOp>(op.getLoc (), leftMatchEnd);
1185
- ArrayRef<Value> unflattenSizes (viewSizes.begin () + leftMatchEnd,
1186
- viewSizes.end () - rightMatchEnd);
1351
+ SmallVector<Value> unflattenSizes (viewSizes.begin () + leftMatchEnd,
1352
+ viewSizes.end () - rightMatchEnd);
1353
+ // try to convert a single dynamic size input to -1
1354
+ int64_t dynCount = 0 ;
1355
+ int64_t dynIdx = 0 ;
1356
+ for (auto [i, v] : llvm::enumerate (unflattenSizes)) {
1357
+ int64_t szeInt;
1358
+ if (!matchPattern (v, m_TorchConstantInt (&szeInt))) {
1359
+ dynCount++;
1360
+ dynIdx = i;
1361
+ continue ;
1362
+ }
1363
+ // if we have a -1 already, make dynCount invalid and break
1364
+ if (szeInt == -1 ) {
1365
+ dynCount = -1 ;
1366
+ break ;
1367
+ }
1368
+ }
1369
+ // if only one size is dynamic, make it -1
1370
+ if (dynCount == 1 )
1371
+ unflattenSizes[dynIdx] =
1372
+ rewriter.create <Torch::ConstantIntOp>(op.getLoc (), -1 );
1373
+
1187
1374
Value unflattenList = rewriter.create <Torch::PrimListConstructOp>(
1188
1375
op.getLoc (), op.getSize ().getType (), unflattenSizes);
1189
1376
rewriter.replaceOpWithNewOp <AtenUnflattenIntOp>(
@@ -1227,6 +1414,18 @@ template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
1227
1414
1228
1415
namespace {
1229
1416
1417
+ bool isItemForSliceOp (Operation *op) {
1418
+ auto itemOp = dyn_cast_or_null<AtenItemOp>(op);
1419
+ if (!itemOp)
1420
+ return false ;
1421
+ for (OpOperand &use : op->getUses ()) {
1422
+ Operation *userOp = use.getOwner ();
1423
+ if (isa<AtenSliceTensorOp>(userOp))
1424
+ return true ;
1425
+ }
1426
+ return false ;
1427
+ }
1428
+
1230
1429
bool isSourceOpForShapeScalarization (Operation *op) {
1231
1430
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
1232
1431
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
@@ -1244,7 +1443,7 @@ bool isPrimListOfInts(Operation *op) {
1244
1443
1245
1444
bool isAnchorOp (Operation *op) {
1246
1445
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
1247
- isPrimListOfInts (op);
1446
+ isPrimListOfInts (op) || isItemForSliceOp (op) ;
1248
1447
}
1249
1448
1250
1449
// The argument to this function, op, is the use of some source op, srcOp. If
@@ -1278,9 +1477,9 @@ bool isInvalidValidViewConsumer(Operation *op,
1278
1477
void populateScalarizationFoldPatterns (RewritePatternSet &patterns) {
1279
1478
patterns.insert <FoldAtenSqueezePattern<AtenSqueezeOp>,
1280
1479
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
1281
- FoldAtenUnsqueezePattern, FoldAtenWhereSelf ,
1282
- FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
1283
- patterns.getContext ());
1480
+ FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern ,
1481
+ FoldAtenWhereSelf, FoldAtenTensorSplatPattern,
1482
+ FoldAtenEqIntPattern>( patterns.getContext ());
1284
1483
}
1285
1484
1286
1485
void populateScalarizationCanonicalizePatterns (RewritePatternSet &patterns) {
@@ -1303,24 +1502,29 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
1303
1502
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
1304
1503
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
1305
1504
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
1306
- PropagateAtenTransposeIntPattern,
1505
+ PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
1506
+ PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
1307
1507
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
1308
1508
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
1309
1509
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
1510
+ PropagateAtenArithmeticPattern<AtenRemainderTensorOp, AtenRemainderIntOp>,
1310
1511
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
1311
1512
patterns.getContext ());
1312
1513
}
1313
1514
1314
1515
void populateScalarizationRemovePatterns (RewritePatternSet &patterns) {
1315
1516
patterns.insert <RemoveUnusedPattern<Torch::AtenIntBoolOp>,
1316
1517
RemoveUnusedPattern<Torch::AtenEqIntOp>,
1518
+ RemoveUnusedPattern<Torch::AtenToDtypeOp>,
1317
1519
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
1318
1520
RemoveUnusedPattern<Torch::AtenFullOp>,
1319
1521
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
1320
1522
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
1321
1523
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
1322
1524
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
1323
1525
RemoveUnusedPattern<Torch::AtenTensorOp>,
1526
+ RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
1527
+ RemoveUnusedPattern<Torch::AtenIntScalarOp>,
1324
1528
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
1325
1529
patterns.getContext ());
1326
1530
}
0 commit comments