@@ -86,42 +86,62 @@ LogicalResult getListFromTensor(Value value, SmallVector<OpFoldResult> &vals) {
86
86
getAsOpFoldResult (full.getFillValue ()));
87
87
return success ();
88
88
}
89
- // TODO: Add a case for unsqueeze of a primnumtotensorscalarop?
89
+
90
+ if (auto unsqueeze = value.getDefiningOp <Torch::AtenUnsqueezeOp>()) {
91
+ Value usqSelf = unsqueeze.getSelf ();
92
+ if (auto numToTensor =
93
+ usqSelf.getDefiningOp <Torch::PrimNumToTensorScalarOp>()) {
94
+ vals.push_back (getAsOpFoldResult (numToTensor.getA ()));
95
+ return success ();
96
+ }
97
+ }
98
+
99
+ // A common rank 0 tensor producer
100
+ if (auto numToTensor =
101
+ value.getDefiningOp <Torch::PrimNumToTensorScalarOp>()) {
102
+ vals.push_back (getAsOpFoldResult (numToTensor.getA ()));
103
+ return success ();
104
+ }
90
105
91
106
// Last supported case: ValueTensorLiteralOp
92
107
auto literalOp = value.getDefiningOp <Torch::ValueTensorLiteralOp>();
93
108
if (!literalOp)
94
109
return failure ();
95
110
96
- // Check the type. We make sure the type is not unsigned here before trying to
97
- // materialize
111
+ // Check the type.
98
112
auto ty = cast<ValueTensorType>(literalOp.getType ());
99
113
if (!ty.hasSizes () || ty.getSizes ().size () > 1 )
100
114
return failure ();
101
- int64_t listSize = ty. getSizes (). size () == 1 ? ty. getSizes (). front () : 1 ;
115
+ // make sure the type is not unsigned here before trying to materialize
102
116
auto intTy = dyn_cast_or_null<IntegerType>(ty.getDtype ());
103
117
if (!intTy || intTy.isUnsigned ())
104
118
return failure ();
105
119
120
+ // if we have a rank 0 literal, we will be adding one element to the list
121
+ int64_t listSize = ty.getSizes ().size () == 1 ? ty.getSizes ().front () : 1 ;
122
+
123
+ if (listSize > kMaxFold )
124
+ return failure ();
125
+
126
+ // check for a splat or dense attr
106
127
auto splattr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue ());
107
128
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(literalOp.getValue ());
108
129
109
130
if (!splattr && !denseAttr)
110
131
return failure ();
111
132
133
+ // These are not mutually exclusive, so try splat first.
112
134
if (splattr) {
113
135
auto attr = splattr.getSplatValue <Attribute>();
114
136
vals.resize ((int64_t )vals.size () + listSize, attr);
137
+ return success ();
115
138
}
116
139
117
- if (denseAttr && !splattr) {
118
- for (auto e : denseAttr.getValues <Attribute>())
119
- vals.push_back (e);
120
- }
121
-
122
- if ((int64_t )vals.size () != listSize)
140
+ // remaining case: denseAttr
141
+ if ((int64_t )denseAttr.getValues <Attribute>().size () != listSize)
123
142
return failure ();
124
-
143
+ for (auto e : denseAttr.getValues <Attribute>())
144
+ vals.push_back (e);
125
145
return success ();
126
146
}
127
147
@@ -143,6 +163,45 @@ Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy,
143
163
// [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to
144
164
// getListFromTensor(A), and further propagate scalarization.
145
165
166
+ namespace {
167
+ class PropagateAtenBroadcastToPattern
168
+ : public OpRewritePattern<AtenBroadcastToOp> {
169
+ public:
170
+ using OpRewritePattern<AtenBroadcastToOp>::OpRewritePattern;
171
+ LogicalResult matchAndRewrite (AtenBroadcastToOp op,
172
+ PatternRewriter &rewriter) const override {
173
+ constexpr int64_t kMaxFold = 16 ;
174
+ // for tensor<si64>, or tensor<1xsi64>, broadcasted to tensor<nxsi64>, grab
175
+ // the element and convert to a full op.
176
+ auto ty = cast<ValueTensorType>(op.getType ());
177
+ if (!ty.areAllSizesKnown () || ty.getSizes ().size () != 1 )
178
+ return failure ();
179
+
180
+ if (ty.getSizes ()[0 ] > kMaxFold )
181
+ return failure ();
182
+
183
+ SmallVector<OpFoldResult> fillFold;
184
+ if (failed (getListFromTensor (op.getSelf (), fillFold)) ||
185
+ fillFold.size () != 1 )
186
+ return failure ();
187
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
188
+ SmallVector<Value, 1 > fillVals;
189
+ if (failed (materializeFolds (b, fillFold, fillVals)))
190
+ return failure ();
191
+
192
+ Value size = b.create <Torch::ConstantIntOp>(ty.getSizes ().front ());
193
+ Value sizeList = b.create <Torch::PrimListConstructOp>(
194
+ rewriter.getType <Torch::ListType>(rewriter.getType <Torch::IntType>()),
195
+ size);
196
+ Value none = b.create <Torch::ConstantNoneOp>();
197
+ Value cstFalse = b.create <Torch::ConstantBoolOp>(false );
198
+ rewriter.replaceOpWithNewOp <AtenFullOp>(op, ty, sizeList, fillVals.front (),
199
+ none, none, none, cstFalse);
200
+ return success ();
201
+ }
202
+ };
203
+ } // namespace
204
+
146
205
namespace {
147
206
class PropagateAtenShapeToTensorPattern
148
207
: public OpRewritePattern<Aten_ShapeAsTensorOp> {
@@ -541,9 +600,128 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
541
600
};
542
601
} // namespace
543
602
603
+ namespace {
604
+
605
+ template <typename OpTy> struct ArithmeticHelper {
606
+ static LogicalResult getAlphaAndVerify (OpTy &op, int64_t &alpha) {
607
+ alpha = 1 ;
608
+ return success ();
609
+ }
610
+ };
611
+
612
+ template <> struct ArithmeticHelper <AtenAddTensorOp> {
613
+ static LogicalResult getAlphaAndVerify (AtenAddTensorOp &op, int64_t &alpha) {
614
+ if (!matchPattern (op.getAlpha (), m_TorchConstantInt (&alpha)) || alpha != 1 )
615
+ return failure ();
616
+ return success ();
617
+ }
618
+ };
619
+
620
+ template <> struct ArithmeticHelper <AtenSubTensorOp> {
621
+ static LogicalResult getAlphaAndVerify (AtenSubTensorOp &op, int64_t &alpha) {
622
+ if (!matchPattern (op.getAlpha (), m_TorchConstantInt (&alpha)) || alpha != 1 )
623
+ return failure ();
624
+ return success ();
625
+ }
626
+ };
627
+
628
+ template <typename OpTy, typename ScalarOpTy>
629
+ class PropagateAtenArithmeticPattern : public OpRewritePattern <OpTy> {
630
+ public:
631
+ using OpRewritePattern<OpTy>::OpRewritePattern;
632
+ LogicalResult matchAndRewrite (OpTy op,
633
+ PatternRewriter &rewriter) const override {
634
+ // Check type
635
+ auto resultTy = cast<ValueTensorType>(op.getType ());
636
+ if (resultTy.getSizes ().size () > 1 )
637
+ return rewriter.notifyMatchFailure (op, " unsupported: rank > 1" );
638
+ if (!resultTy.hasDtype () || !isa<mlir::IntegerType>(resultTy.getDtype ()))
639
+ return rewriter.notifyMatchFailure (op, " not an int type" );
640
+
641
+ int64_t alpha;
642
+ if (failed (ArithmeticHelper<OpTy>::getAlphaAndVerify (op, alpha)))
643
+ return rewriter.notifyMatchFailure (op, " alpha must be 1" );
644
+
645
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
646
+ SmallVector<OpFoldResult> selfFold, otherFold;
647
+ if (failed (getListFromTensor (op.getSelf (), selfFold)) ||
648
+ failed (getListFromTensor (op.getOther (), otherFold)) ||
649
+ selfFold.size () != otherFold.size ())
650
+ return failure ();
651
+ SmallVector<Value> selfVals, otherVals;
652
+ if (failed (materializeFolds (b, selfFold, selfVals)) ||
653
+ failed (materializeFolds (b, otherFold, otherVals)))
654
+ return failure ();
655
+ SmallVector<OpFoldResult> resultFolds;
656
+ for (uint64_t i = 0 ; i < selfVals.size (); i++) {
657
+ resultFolds.push_back (b.createOrFold <ScalarOpTy>(
658
+ selfVals[i].getType (), selfVals[i], otherVals[i]));
659
+ }
660
+ SmallVector<Value> resultVals;
661
+ if (failed (materializeFolds (b, resultFolds, resultVals)))
662
+ return failure ();
663
+
664
+ if (resultTy.getSizes ().size () == 0 ) {
665
+ rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
666
+ op, resultTy, resultVals.front ());
667
+ return success ();
668
+ }
669
+
670
+ Value result = constructAtenTensorOpFromList (b, resultTy, resultVals);
671
+ rewriter.replaceOp (op, result);
672
+ return success ();
673
+ }
674
+ };
675
+ } // namespace
676
+
544
677
// / ------ Fold Patterns ------ ///
545
678
// These are shape-specific folding patterns
546
679
680
+ namespace {
681
+ class FoldAtenEqIntPattern : public OpRewritePattern <AtenEqIntOp> {
682
+ public:
683
+ using OpRewritePattern<AtenEqIntOp>::OpRewritePattern;
684
+ LogicalResult matchAndRewrite (AtenEqIntOp op,
685
+ PatternRewriter &rewriter) const override {
686
+ // replaces (size.int == 0) with false and adds an assert
687
+ // these comparisons are getting generated because onnx.Reshape considers 0
688
+ // to mean "don't change this dim". However, if the size we are passing to
689
+ // onnx.Reshape is a tensor dim, this is definitely never supposed to be
690
+ // interpreted as "don't change this dim".
691
+ int64_t otherInt;
692
+ if (!matchPattern (op.getB (), m_TorchConstantInt (&otherInt)) ||
693
+ otherInt != 0 )
694
+ return failure ();
695
+
696
+ // in case the shape is a product of two ints, check each
697
+ if (auto mulOp = op.getA ().getDefiningOp <AtenMulIntOp>()) {
698
+ Value self = mulOp.getA ();
699
+ Value other = mulOp.getB ();
700
+ Value selfEq = rewriter.create <AtenEqIntOp>(op.getLoc (), self, op.getB ());
701
+ Value otherEq =
702
+ rewriter.create <AtenEqIntOp>(op.getLoc (), other, op.getB ());
703
+ rewriter.replaceOpWithNewOp <Aten__Or__BoolOp>(op, selfEq, otherEq);
704
+ return success ();
705
+ }
706
+
707
+ // if lhs is size.int op, assert size > 0 and replace with false.
708
+ if (auto sizeOp = op.getA ().getDefiningOp <AtenSizeIntOp>()) {
709
+ Value selfGtOther = rewriter.create <AtenGtIntOp>(
710
+ op.getLoc (), op.getType (), op.getA (), op.getB ());
711
+ rewriter.create <Torch::RuntimeAssertOp>(
712
+ op.getLoc (), selfGtOther,
713
+ rewriter.getStringAttr (" Expected dim size > 0." ));
714
+ Value cstFalse =
715
+ rewriter.create <Torch::ConstantBoolOp>(op.getLoc (), false );
716
+ rewriter.replaceOp (op, cstFalse);
717
+ return success ();
718
+ }
719
+
720
+ return failure ();
721
+ }
722
+ };
723
+ } // namespace
724
+
547
725
namespace {
548
726
class FoldAtenTensorSplatPattern : public OpRewritePattern <AtenTensorOp> {
549
727
public:
@@ -594,16 +772,24 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
594
772
} // namespace
595
773
596
774
namespace {
597
- class FoldAtenSqueezePattern : public OpRewritePattern <AtenSqueezeOp> {
775
+ template <typename SqueezeOp>
776
+ class FoldAtenSqueezePattern : public OpRewritePattern <SqueezeOp> {
598
777
public:
599
- using OpRewritePattern<AtenSqueezeOp >::OpRewritePattern;
600
- LogicalResult matchAndRewrite (AtenSqueezeOp op,
778
+ using OpRewritePattern<SqueezeOp >::OpRewritePattern;
779
+ LogicalResult matchAndRewrite (SqueezeOp op,
601
780
PatternRewriter &rewriter) const override {
602
781
auto resultTy = cast<ValueTensorType>(op.getType ());
603
782
if (!resultTy.hasSizes () || !resultTy.areAllSizesKnown ())
604
783
return rewriter.notifyMatchFailure (op, " Unknown result shape" );
605
784
606
- if (auto atenFull = op.getSelf ().getDefiningOp <AtenFullOp>()) {
785
+ Value self = op.getSelf ();
786
+ if (auto atenFull = self.getDefiningOp <AtenFullOp>()) {
787
+ // in the rank 0 case, just return the rank 0 scalar
788
+ if (resultTy.getSizes ().size () == 0 ) {
789
+ rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
790
+ op, resultTy, atenFull.getFillValue ());
791
+ return success ();
792
+ }
607
793
SmallVector<Value> sizes;
608
794
for (int i = 0 , s = resultTy.getSizes ().size (); i < s; ++i)
609
795
sizes.push_back (rewriter.create <Torch::ConstantIntOp>(
@@ -874,9 +1060,16 @@ bool isPrimListOfInts(Operation *op) {
874
1060
return llvm::isa<Torch::IntType>(listType.getContainedType ());
875
1061
}
876
1062
1063
+ bool isAnchorOp (Operation *op) {
1064
+ return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
1065
+ isPrimListOfInts (op);
1066
+ }
1067
+
877
1068
void populateScalarizationFoldPatterns (RewritePatternSet &patterns) {
878
- patterns.insert <FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
879
- FoldAtenWhereSelf, FoldAtenTensorSplatPattern>(
1069
+ patterns.insert <FoldAtenSqueezePattern<AtenSqueezeOp>,
1070
+ FoldAtenSqueezePattern<AtenSqueezeDimOp>,
1071
+ FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
1072
+ FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
880
1073
patterns.getContext ());
881
1074
}
882
1075
@@ -885,10 +1078,21 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
885
1078
}
886
1079
887
1080
void populateScalarizationPropagationPatterns (RewritePatternSet &patterns) {
888
- patterns.insert <PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
889
- PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
890
- PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
891
- PropagateAtenWhereSelfPattern>(patterns.getContext ());
1081
+ // A note on division: onnx.Div from int, int -> int types rounds towards
1082
+ // zero. The torch DivTensorOp actually doesn't allow returning an int dtype,
1083
+ // but this was artificially plummbed through. Unfortunately, there is no
1084
+ // scalar trunc div op in torch; however, we can safely assume all operands
1085
+ // are positive so floor divide should be a sufficient scalar replacement.
1086
+ patterns.insert <
1087
+ PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
1088
+ PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
1089
+ PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
1090
+ PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
1091
+ PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
1092
+ PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
1093
+ PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
1094
+ PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
1095
+ patterns.getContext ());
892
1096
}
893
1097
894
1098
void populateScalarizationRemovePatterns (RewritePatternSet &patterns) {
@@ -940,7 +1144,7 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
940
1144
[&](Operation *op) {
941
1145
// Walking bottom-up, start adding ops when we reach an anchor point
942
1146
// (a prim list of ints)
943
- if (isPrimListOfInts (op)) {
1147
+ if (isAnchorOp (op)) {
944
1148
shapeCalculationOps.insert (op);
945
1149
return ;
946
1150
}
0 commit comments