Skip to content

Commit 140cad5

Browse files
authored
Add More Scalarize Shapes Patterns (llvm#3810)
### new patterns: 1. Propagates `aten.broadcast_to` ops of a single value to an `aten.full` op 2. Propagates arithmetic operations through a templated class which associates some tensor arithmetic ops to their integer-scalar counterparts. These are a major blocker right now, since some models have a bunch of rank 0 arithmetic being done with tensor ops. See the lit test for an interesting example that pads an input to the smallest shape which will become divisible by twelve in `dim0`. If you think this is convoluted, you haven't been staring at ONNX generated IR long enough. 3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to `false`. See the comment in that conversion pattern for more justification as to why it is acceptable to make this assumption here. This is another major blocker for models, since this lack of folding propagates to lack of folding for subsequent `where.self` operations. 4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern` ### other changes: 1. Add two new anchor ops: `AtenArangeStartStepOp` and `Torch::RuntimeAssertOp`. I've checked all possible sources of the runtime assert ops and it is always shape related. The Arange op only takes int inputs, and these are all shape related. Adds a size check to getting a list from literal ops. 2. Improved folders for int arithmetic ops to fold some common patterns. 3. adds the ability to get some values from scalar-tensor ops to getListFromTensor. 4. further cleans up getListFromTensor for readability. ### points to scrutinize: 1. I made the choice to scalarize `div.Tensor` (int dtype result) to `floordiv.int`. This is because our shape computations involving this kind of arithmetic are never negative in practice, and we don't have a "round towards zero" scalar int divide counterpart. 2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if someone happens to add a runtime assert in the future that doesn't boil down to shapes, then it would add to the worklist considerably. We might be able to get around this by adding "NoMemoryEffect" to ops which are "ReadOnly" so that the inputs for the runtime asserts get cse'd with existing elements of the worklist before we even get to this pass.
1 parent a83e106 commit 140cad5

File tree

4 files changed

+330
-24
lines changed

4 files changed

+330
-24
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -3700,6 +3700,12 @@ OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
37003700
//===----------------------------------------------------------------------===//
37013701

37023702
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
3703+
auto intLhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
3704+
auto intRhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
3705+
if (intRhs && intRhs.getValue().getSExtValue() == 0)
3706+
return getA();
3707+
if (intLhs && intLhs.getValue().getSExtValue() == 0)
3708+
return getB();
37033709
return atenBinaryIntOperatorFoldHelper(
37043710
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
37053711
}
@@ -3709,6 +3715,9 @@ OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
37093715
//===----------------------------------------------------------------------===//
37103716

37113717
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
3718+
if (getA() == getB())
3719+
return IntegerAttr::get(
3720+
IntegerType::get(getContext(), 64, IntegerType::Signless), 0);
37123721
return atenBinaryIntOperatorFoldHelper(
37133722
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
37143723
}

lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp

+226-22
Original file line numberDiff line numberDiff line change
@@ -86,42 +86,62 @@ LogicalResult getListFromTensor(Value value, SmallVector<OpFoldResult> &vals) {
8686
getAsOpFoldResult(full.getFillValue()));
8787
return success();
8888
}
89-
// TODO: Add a case for unsqueeze of a primnumtotensorscalarop?
89+
90+
if (auto unsqueeze = value.getDefiningOp<Torch::AtenUnsqueezeOp>()) {
91+
Value usqSelf = unsqueeze.getSelf();
92+
if (auto numToTensor =
93+
usqSelf.getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
94+
vals.push_back(getAsOpFoldResult(numToTensor.getA()));
95+
return success();
96+
}
97+
}
98+
99+
// A common rank 0 tensor producer
100+
if (auto numToTensor =
101+
value.getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
102+
vals.push_back(getAsOpFoldResult(numToTensor.getA()));
103+
return success();
104+
}
90105

91106
// Last supported case: ValueTensorLiteralOp
92107
auto literalOp = value.getDefiningOp<Torch::ValueTensorLiteralOp>();
93108
if (!literalOp)
94109
return failure();
95110

96-
// Check the type. We make sure the type is not unsigned here before trying to
97-
// materialize
111+
// Check the type.
98112
auto ty = cast<ValueTensorType>(literalOp.getType());
99113
if (!ty.hasSizes() || ty.getSizes().size() > 1)
100114
return failure();
101-
int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1;
115+
// make sure the type is not unsigned here before trying to materialize
102116
auto intTy = dyn_cast_or_null<IntegerType>(ty.getDtype());
103117
if (!intTy || intTy.isUnsigned())
104118
return failure();
105119

120+
// if we have a rank 0 literal, we will be adding one element to the list
121+
int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1;
122+
123+
if (listSize > kMaxFold)
124+
return failure();
125+
126+
// check for a splat or dense attr
106127
auto splattr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue());
107128
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(literalOp.getValue());
108129

109130
if (!splattr && !denseAttr)
110131
return failure();
111132

133+
// These are not mutually exclusive, so try splat first.
112134
if (splattr) {
113135
auto attr = splattr.getSplatValue<Attribute>();
114136
vals.resize((int64_t)vals.size() + listSize, attr);
137+
return success();
115138
}
116139

117-
if (denseAttr && !splattr) {
118-
for (auto e : denseAttr.getValues<Attribute>())
119-
vals.push_back(e);
120-
}
121-
122-
if ((int64_t)vals.size() != listSize)
140+
// remaining case: denseAttr
141+
if ((int64_t)denseAttr.getValues<Attribute>().size() != listSize)
123142
return failure();
124-
143+
for (auto e : denseAttr.getValues<Attribute>())
144+
vals.push_back(e);
125145
return success();
126146
}
127147

@@ -143,6 +163,45 @@ Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy,
143163
// [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to
144164
// getListFromTensor(A), and further propagate scalarization.
145165

166+
namespace {
167+
class PropagateAtenBroadcastToPattern
168+
: public OpRewritePattern<AtenBroadcastToOp> {
169+
public:
170+
using OpRewritePattern<AtenBroadcastToOp>::OpRewritePattern;
171+
LogicalResult matchAndRewrite(AtenBroadcastToOp op,
172+
PatternRewriter &rewriter) const override {
173+
constexpr int64_t kMaxFold = 16;
174+
// for tensor<si64>, or tensor<1xsi64>, broadcasted to tensor<nxsi64>, grab
175+
// the element and convert to a full op.
176+
auto ty = cast<ValueTensorType>(op.getType());
177+
if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1)
178+
return failure();
179+
180+
if (ty.getSizes()[0] > kMaxFold)
181+
return failure();
182+
183+
SmallVector<OpFoldResult> fillFold;
184+
if (failed(getListFromTensor(op.getSelf(), fillFold)) ||
185+
fillFold.size() != 1)
186+
return failure();
187+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
188+
SmallVector<Value, 1> fillVals;
189+
if (failed(materializeFolds(b, fillFold, fillVals)))
190+
return failure();
191+
192+
Value size = b.create<Torch::ConstantIntOp>(ty.getSizes().front());
193+
Value sizeList = b.create<Torch::PrimListConstructOp>(
194+
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
195+
size);
196+
Value none = b.create<Torch::ConstantNoneOp>();
197+
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
198+
rewriter.replaceOpWithNewOp<AtenFullOp>(op, ty, sizeList, fillVals.front(),
199+
none, none, none, cstFalse);
200+
return success();
201+
}
202+
};
203+
} // namespace
204+
146205
namespace {
147206
class PropagateAtenShapeToTensorPattern
148207
: public OpRewritePattern<Aten_ShapeAsTensorOp> {
@@ -541,9 +600,128 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
541600
};
542601
} // namespace
543602

603+
namespace {
604+
605+
template <typename OpTy> struct ArithmeticHelper {
606+
static LogicalResult getAlphaAndVerify(OpTy &op, int64_t &alpha) {
607+
alpha = 1;
608+
return success();
609+
}
610+
};
611+
612+
template <> struct ArithmeticHelper<AtenAddTensorOp> {
613+
static LogicalResult getAlphaAndVerify(AtenAddTensorOp &op, int64_t &alpha) {
614+
if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1)
615+
return failure();
616+
return success();
617+
}
618+
};
619+
620+
template <> struct ArithmeticHelper<AtenSubTensorOp> {
621+
static LogicalResult getAlphaAndVerify(AtenSubTensorOp &op, int64_t &alpha) {
622+
if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1)
623+
return failure();
624+
return success();
625+
}
626+
};
627+
628+
template <typename OpTy, typename ScalarOpTy>
629+
class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
630+
public:
631+
using OpRewritePattern<OpTy>::OpRewritePattern;
632+
LogicalResult matchAndRewrite(OpTy op,
633+
PatternRewriter &rewriter) const override {
634+
// Check type
635+
auto resultTy = cast<ValueTensorType>(op.getType());
636+
if (resultTy.getSizes().size() > 1)
637+
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
638+
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
639+
return rewriter.notifyMatchFailure(op, "not an int type");
640+
641+
int64_t alpha;
642+
if (failed(ArithmeticHelper<OpTy>::getAlphaAndVerify(op, alpha)))
643+
return rewriter.notifyMatchFailure(op, "alpha must be 1");
644+
645+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
646+
SmallVector<OpFoldResult> selfFold, otherFold;
647+
if (failed(getListFromTensor(op.getSelf(), selfFold)) ||
648+
failed(getListFromTensor(op.getOther(), otherFold)) ||
649+
selfFold.size() != otherFold.size())
650+
return failure();
651+
SmallVector<Value> selfVals, otherVals;
652+
if (failed(materializeFolds(b, selfFold, selfVals)) ||
653+
failed(materializeFolds(b, otherFold, otherVals)))
654+
return failure();
655+
SmallVector<OpFoldResult> resultFolds;
656+
for (uint64_t i = 0; i < selfVals.size(); i++) {
657+
resultFolds.push_back(b.createOrFold<ScalarOpTy>(
658+
selfVals[i].getType(), selfVals[i], otherVals[i]));
659+
}
660+
SmallVector<Value> resultVals;
661+
if (failed(materializeFolds(b, resultFolds, resultVals)))
662+
return failure();
663+
664+
if (resultTy.getSizes().size() == 0) {
665+
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
666+
op, resultTy, resultVals.front());
667+
return success();
668+
}
669+
670+
Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
671+
rewriter.replaceOp(op, result);
672+
return success();
673+
}
674+
};
675+
} // namespace
676+
544677
/// ------ Fold Patterns ------ ///
545678
// These are shape-specific folding patterns
546679

680+
namespace {
681+
class FoldAtenEqIntPattern : public OpRewritePattern<AtenEqIntOp> {
682+
public:
683+
using OpRewritePattern<AtenEqIntOp>::OpRewritePattern;
684+
LogicalResult matchAndRewrite(AtenEqIntOp op,
685+
PatternRewriter &rewriter) const override {
686+
// replaces (size.int == 0) with false and adds an assert
687+
// these comparisons are getting generated because onnx.Reshape considers 0
688+
// to mean "don't change this dim". However, if the size we are passing to
689+
// onnx.Reshape is a tensor dim, this is definitely never supposed to be
690+
// interpreted as "don't change this dim".
691+
int64_t otherInt;
692+
if (!matchPattern(op.getB(), m_TorchConstantInt(&otherInt)) ||
693+
otherInt != 0)
694+
return failure();
695+
696+
// in case the shape is a product of two ints, check each
697+
if (auto mulOp = op.getA().getDefiningOp<AtenMulIntOp>()) {
698+
Value self = mulOp.getA();
699+
Value other = mulOp.getB();
700+
Value selfEq = rewriter.create<AtenEqIntOp>(op.getLoc(), self, op.getB());
701+
Value otherEq =
702+
rewriter.create<AtenEqIntOp>(op.getLoc(), other, op.getB());
703+
rewriter.replaceOpWithNewOp<Aten__Or__BoolOp>(op, selfEq, otherEq);
704+
return success();
705+
}
706+
707+
// if lhs is size.int op, assert size > 0 and replace with false.
708+
if (auto sizeOp = op.getA().getDefiningOp<AtenSizeIntOp>()) {
709+
Value selfGtOther = rewriter.create<AtenGtIntOp>(
710+
op.getLoc(), op.getType(), op.getA(), op.getB());
711+
rewriter.create<Torch::RuntimeAssertOp>(
712+
op.getLoc(), selfGtOther,
713+
rewriter.getStringAttr("Expected dim size > 0."));
714+
Value cstFalse =
715+
rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
716+
rewriter.replaceOp(op, cstFalse);
717+
return success();
718+
}
719+
720+
return failure();
721+
}
722+
};
723+
} // namespace
724+
547725
namespace {
548726
class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
549727
public:
@@ -594,16 +772,24 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
594772
} // namespace
595773

596774
namespace {
597-
class FoldAtenSqueezePattern : public OpRewritePattern<AtenSqueezeOp> {
775+
template <typename SqueezeOp>
776+
class FoldAtenSqueezePattern : public OpRewritePattern<SqueezeOp> {
598777
public:
599-
using OpRewritePattern<AtenSqueezeOp>::OpRewritePattern;
600-
LogicalResult matchAndRewrite(AtenSqueezeOp op,
778+
using OpRewritePattern<SqueezeOp>::OpRewritePattern;
779+
LogicalResult matchAndRewrite(SqueezeOp op,
601780
PatternRewriter &rewriter) const override {
602781
auto resultTy = cast<ValueTensorType>(op.getType());
603782
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
604783
return rewriter.notifyMatchFailure(op, "Unknown result shape");
605784

606-
if (auto atenFull = op.getSelf().getDefiningOp<AtenFullOp>()) {
785+
Value self = op.getSelf();
786+
if (auto atenFull = self.getDefiningOp<AtenFullOp>()) {
787+
// in the rank 0 case, just return the rank 0 scalar
788+
if (resultTy.getSizes().size() == 0) {
789+
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
790+
op, resultTy, atenFull.getFillValue());
791+
return success();
792+
}
607793
SmallVector<Value> sizes;
608794
for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i)
609795
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
@@ -874,9 +1060,16 @@ bool isPrimListOfInts(Operation *op) {
8741060
return llvm::isa<Torch::IntType>(listType.getContainedType());
8751061
}
8761062

1063+
bool isAnchorOp(Operation *op) {
1064+
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
1065+
isPrimListOfInts(op);
1066+
}
1067+
8771068
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
878-
patterns.insert<FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
879-
FoldAtenWhereSelf, FoldAtenTensorSplatPattern>(
1069+
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
1070+
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
1071+
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
1072+
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
8801073
patterns.getContext());
8811074
}
8821075

@@ -885,10 +1078,21 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
8851078
}
8861079

8871080
void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
888-
patterns.insert<PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
889-
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
890-
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
891-
PropagateAtenWhereSelfPattern>(patterns.getContext());
1081+
// A note on division: onnx.Div from int, int -> int types rounds towards
1082+
// zero. The torch DivTensorOp actually doesn't allow returning an int dtype,
1083+
// but this was artificially plummbed through. Unfortunately, there is no
1084+
// scalar trunc div op in torch; however, we can safely assume all operands
1085+
// are positive so floor divide should be a sufficient scalar replacement.
1086+
patterns.insert<
1087+
PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
1088+
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
1089+
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
1090+
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
1091+
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
1092+
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
1093+
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
1094+
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
1095+
patterns.getContext());
8921096
}
8931097

8941098
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
@@ -940,7 +1144,7 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
9401144
[&](Operation *op) {
9411145
// Walking bottom-up, start adding ops when we reach an anchor point
9421146
// (a prim list of ints)
943-
if (isPrimListOfInts(op)) {
1147+
if (isAnchorOp(op)) {
9441148
shapeCalculationOps.insert(op);
9451149
return;
9461150
}

0 commit comments

Comments
 (0)