Skip to content

Commit cd38ecf

Browse files
authored
Add Scalarization Patterns for AtenToDtypeOp, AtenNegOp, AtenRemainderTensorOp (llvm#3861)
1. adds a lowering for `aten.neg.int` and `aten.remainder.int` to arith. 2. adds a scalarization pattern for `aten.neg` and `aten.remainder.Tensor` ops. 3. improves folding of `aten.mul.int` 4. adds a scalarization pattern for `aten.to.dtype` which relies on scalar cast ops and basic C++ casting between `double` and `int64_t`. 5. improves rank-0 case handling for `FoldAtenSplatPattern` 6. removes a bug with `aten.unflatten.int` decomposition incorrectly generating a constant size int from a dynamic shape. 7. simplifies the dim list for `aten.unflatten.int` ops generated from the `aten.view` canonicalization in scalarize shapes. All of these changes were necessary to unblock <iree-org/iree#18899>.
1 parent 889a836 commit cd38ecf

File tree

6 files changed

+302
-30
lines changed

6 files changed

+302
-30
lines changed

lib/Conversion/TorchToArith/TorchToArith.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,25 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
8282
};
8383
} // namespace
8484

85+
namespace {
86+
class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
87+
public:
88+
using OpConversionPattern<AtenNegIntOp>::OpConversionPattern;
89+
LogicalResult
90+
matchAndRewrite(AtenNegIntOp op,
91+
typename OpConversionPattern<AtenNegIntOp>::OpAdaptor adaptor,
92+
ConversionPatternRewriter &rewriter) const override {
93+
Value a = adaptor.getA();
94+
rewriter.replaceOpWithNewOp<arith::SubIOp>(
95+
op,
96+
rewriter.create<arith::ConstantIntOp>(op.getLoc(), /*value=*/0,
97+
/*bitwidth=*/64),
98+
a);
99+
return success();
100+
}
101+
};
102+
} // namespace
103+
85104
namespace {
86105
template <typename AtenOp, typename UnaryOp>
87106
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
@@ -465,11 +484,14 @@ class ConvertTorchToArith
465484

466485
target.addIllegalOp<AtenAddOp>();
467486
patterns.add<ConvertAtenAddOp>(typeConverter, context);
468-
487+
target.addIllegalOp<AtenNegIntOp>();
488+
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
469489
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
470-
AtenMulIntOp>();
490+
AtenMulIntOp, AtenRemainderIntOp>();
471491
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
472492
typeConverter, context);
493+
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
494+
typeConverter, context);
473495
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
474496
typeConverter, context);
475497
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(

lib/Dialect/Torch/IR/TorchOps.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -4068,6 +4068,10 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
40684068
int64_t lhs, rhs;
40694069
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
40704070
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
4071+
if (lConstant && lhs == 1)
4072+
return getOperand(1);
4073+
if (rConstant && rhs == 1)
4074+
return getOperand(0);
40714075
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
40724076
return getI64IntegerAttr(getContext(), 0);
40734077
if (lConstant && rConstant)

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -4587,6 +4587,11 @@ class DecomposeAtenUnflattenIntOp
45874587
if (!isValidDim(dimInt, inputRank))
45884588
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
45894589

4590+
if (inputShape[dimInt] == Torch::kUnknownSize &&
4591+
llvm::count(sizesInts, -1) > 0)
4592+
return rewriter.notifyMatchFailure(
4593+
op, "Unimplemented: dynamic unflatten dim with an inferred size.");
4594+
45904595
SmallVector<Value> sizesTorchInt;
45914596
if (!getListConstructElements(op.getSizes(), sizesTorchInt))
45924597
return rewriter.notifyMatchFailure(

lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp

+216-12
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
714714
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
715715

716716
// Rank 0 item op prop
717-
if (selfTy.getSizes().size() == 0) {
717+
if (selfTy.getSizes().empty()) {
718718
auto numToTensor = self.getDefiningOp<Torch::PrimNumToTensorScalarOp>();
719719
auto squeezeDim = self.getDefiningOp<AtenSqueezeDimOp>();
720720
if (!squeezeDim && !numToTensor)
@@ -746,6 +746,109 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
746746
};
747747
} // namespace
748748

749+
namespace {
750+
751+
LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b,
752+
SmallVector<OpFoldResult> &converted,
753+
SmallVector<OpFoldResult> &elements,
754+
Type inputDtype, Type resultDtype) {
755+
auto inputIsInt = dyn_cast<mlir::IntegerType>(inputDtype);
756+
auto resultIsInt = dyn_cast<mlir::IntegerType>(resultDtype);
757+
if (!inputIsInt && !isa<mlir::FloatType>(inputDtype))
758+
return failure();
759+
if (!resultIsInt && !isa<mlir::FloatType>(resultDtype))
760+
return failure();
761+
762+
// if dtypes are both int or both float, no conversion needed
763+
if (static_cast<bool>(inputIsInt) == static_cast<bool>(resultIsInt)) {
764+
converted = elements;
765+
return success();
766+
}
767+
768+
if (resultIsInt) {
769+
for (auto &e : elements) {
770+
auto eValue = dyn_cast<Value>(e);
771+
if (eValue) {
772+
converted.push_back(b.createOrFold<AtenIntScalarOp>(eValue));
773+
continue;
774+
}
775+
auto eAttr = dyn_cast<Attribute>(e);
776+
auto eFloatAttr = dyn_cast_or_null<FloatAttr>(eAttr);
777+
if (!eFloatAttr)
778+
return failure();
779+
780+
converted.push_back(IntegerAttr::get(
781+
resultDtype, static_cast<int64_t>(eFloatAttr.getValueAsDouble())));
782+
}
783+
return success();
784+
}
785+
786+
// result is float
787+
for (auto &e : elements) {
788+
auto eValue = dyn_cast<Value>(e);
789+
if (eValue) {
790+
converted.push_back(b.createOrFold<AtenFloatScalarOp>(eValue));
791+
continue;
792+
}
793+
auto eAttr = dyn_cast<Attribute>(e);
794+
auto eIntAttr = dyn_cast<IntegerAttr>(eAttr);
795+
if (!eIntAttr)
796+
return failure();
797+
798+
auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue()
799+
: eIntAttr.getValue().getZExtValue();
800+
converted.push_back(FloatAttr::get(resultDtype, static_cast<double>(eInt)));
801+
}
802+
return success();
803+
}
804+
805+
class PropagateAtenToDtypePattern : public OpRewritePattern<AtenToDtypeOp> {
806+
public:
807+
using OpRewritePattern<AtenToDtypeOp>::OpRewritePattern;
808+
LogicalResult matchAndRewrite(AtenToDtypeOp op,
809+
PatternRewriter &rewriter) const override {
810+
bool nonBlocking, copyArg;
811+
// The non_blocking arg must be `False`.
812+
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
813+
nonBlocking)
814+
return failure();
815+
// The copy arg must be `False`.
816+
if (!matchPattern(op.getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
817+
return failure();
818+
// The memory_format arg must be `none`.
819+
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()))
820+
return failure();
821+
822+
auto inputType = dyn_cast<ValueTensorType>(op.getSelf().getType());
823+
auto resultType = dyn_cast<ValueTensorType>(op.getType());
824+
if (!inputType || !resultType || !inputType.hasDtype() ||
825+
!resultType.hasDtype())
826+
return failure();
827+
auto inputDtype = inputType.getDtype();
828+
auto resultDtype = resultType.getDtype();
829+
830+
SmallVector<OpFoldResult> elements;
831+
if (failed(getListFromTensor(op.getSelf(), elements)))
832+
return failure();
833+
834+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
835+
SmallVector<OpFoldResult> converted;
836+
if (failed(convertOpFoldResults(b, converted, elements, inputDtype,
837+
resultDtype)))
838+
return rewriter.notifyMatchFailure(
839+
op, "Unhandled attribute type encountered.");
840+
841+
SmallVector<Value> vals;
842+
if (failed(materializeFolds(b, converted, vals)))
843+
return failure();
844+
845+
Value result = constructAtenTensorOpFromList(b, op.getType(), vals);
846+
rewriter.replaceOp(op, result);
847+
return success();
848+
}
849+
};
850+
} // namespace
851+
749852
namespace {
750853
template <typename AtenViewLikeOp>
751854
class PropagateAtenViewLikePattern : public OpRewritePattern<AtenViewLikeOp> {
@@ -828,7 +931,7 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
828931
if (failed(materializeFolds(b, resultFolds, resultVals)))
829932
return failure();
830933

831-
if (resultTy.getSizes().size() == 0) {
934+
if (resultTy.getSizes().empty()) {
832935
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
833936
op, resultTy, resultVals.front());
834937
return success();
@@ -841,6 +944,48 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
841944
};
842945
} // namespace
843946

947+
namespace {
948+
template <typename OpTy, typename ScalarOpTy>
949+
class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
950+
public:
951+
using OpRewritePattern<OpTy>::OpRewritePattern;
952+
LogicalResult matchAndRewrite(OpTy op,
953+
PatternRewriter &rewriter) const override {
954+
// Check type
955+
auto resultTy = cast<ValueTensorType>(op.getType());
956+
if (resultTy.getSizes().size() > 1)
957+
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
958+
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
959+
return rewriter.notifyMatchFailure(op, "not an int type");
960+
961+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
962+
SmallVector<OpFoldResult> selfFold;
963+
if (failed(getListFromTensor(op.getSelf(), selfFold)))
964+
return failure();
965+
SmallVector<Value> selfVals;
966+
if (failed(materializeFolds(b, selfFold, selfVals)))
967+
return failure();
968+
SmallVector<OpFoldResult> resultFolds;
969+
for (uint64_t i = 0; i < selfVals.size(); i++) {
970+
resultFolds.push_back(
971+
b.createOrFold<ScalarOpTy>(selfVals[i].getType(), selfVals[i]));
972+
}
973+
SmallVector<Value> resultVals;
974+
if (failed(materializeFolds(b, resultFolds, resultVals)))
975+
return failure();
976+
977+
if (resultTy.getSizes().size() == 0) {
978+
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
979+
op, resultTy, resultVals.front());
980+
return success();
981+
}
982+
983+
Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
984+
rewriter.replaceOp(op, result);
985+
return success();
986+
}
987+
};
988+
} // namespace
844989
/// ------ Fold Patterns ------ ///
845990
// These are shape-specific folding patterns
846991

@@ -915,19 +1060,22 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
9151060
auto resultTy = cast<BaseTensorType>(op.getType());
9161061
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
9171062
return rewriter.notifyMatchFailure(op, "dynamic output shape");
1063+
if (resultTy.getSizes().size() == 0) {
1064+
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
1065+
op, op.getType(), elements.front());
1066+
return success();
1067+
}
9181068

9191069
auto loc = op.getLoc();
9201070
SmallVector<Value> sizes;
9211071
for (auto size : resultTy.getSizes())
9221072
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
9231073
loc, rewriter.getI64IntegerAttr(size)));
9241074

925-
Value one = rewriter.create<Torch::ConstantIntOp>(
926-
loc, rewriter.getType<Torch::IntType>(), 1);
9271075
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
9281076
loc,
9291077
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
930-
one);
1078+
sizes);
9311079

9321080
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
9331081
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
@@ -1031,6 +1179,24 @@ class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
10311179
};
10321180
} // namespace
10331181

1182+
namespace {
1183+
// fold ridiculous patterns like size.int -> float.scalar -> int.scalar
1184+
class FoldAtenIntScalarPattern : public OpRewritePattern<AtenIntScalarOp> {
1185+
public:
1186+
using OpRewritePattern<AtenIntScalarOp>::OpRewritePattern;
1187+
LogicalResult matchAndRewrite(AtenIntScalarOp op,
1188+
PatternRewriter &rewriter) const override {
1189+
auto floatScalarOp = op.getA().getDefiningOp<AtenFloatScalarOp>();
1190+
if (!floatScalarOp)
1191+
return failure();
1192+
auto sizeOp = floatScalarOp.getA().getDefiningOp<AtenSizeIntOp>();
1193+
if (!sizeOp)
1194+
return failure();
1195+
rewriter.replaceOp(op, floatScalarOp.getA());
1196+
return success();
1197+
}
1198+
};
1199+
} // namespace
10341200
namespace {
10351201
class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
10361202
public:
@@ -1182,8 +1348,29 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
11821348
if (inputUnmatched == 1 && outputUnmatched > 1) {
11831349
Value dimVal =
11841350
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
1185-
ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
1186-
viewSizes.end() - rightMatchEnd);
1351+
SmallVector<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
1352+
viewSizes.end() - rightMatchEnd);
1353+
// try to convert a single dynamic size input to -1
1354+
int64_t dynCount = 0;
1355+
int64_t dynIdx = 0;
1356+
for (auto [i, v] : llvm::enumerate(unflattenSizes)) {
1357+
int64_t szeInt;
1358+
if (!matchPattern(v, m_TorchConstantInt(&szeInt))) {
1359+
dynCount++;
1360+
dynIdx = i;
1361+
continue;
1362+
}
1363+
// if we have a -1 already, make dynCount invalid and break
1364+
if (szeInt == -1) {
1365+
dynCount = -1;
1366+
break;
1367+
}
1368+
}
1369+
// if only one size is dynamic, make it -1
1370+
if (dynCount == 1)
1371+
unflattenSizes[dynIdx] =
1372+
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
1373+
11871374
Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
11881375
op.getLoc(), op.getSize().getType(), unflattenSizes);
11891376
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
@@ -1227,6 +1414,18 @@ template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
12271414

12281415
namespace {
12291416

1417+
bool isItemForSliceOp(Operation *op) {
1418+
auto itemOp = dyn_cast_or_null<AtenItemOp>(op);
1419+
if (!itemOp)
1420+
return false;
1421+
for (OpOperand &use : op->getUses()) {
1422+
Operation *userOp = use.getOwner();
1423+
if (isa<AtenSliceTensorOp>(userOp))
1424+
return true;
1425+
}
1426+
return false;
1427+
}
1428+
12301429
bool isSourceOpForShapeScalarization(Operation *op) {
12311430
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
12321431
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
@@ -1244,7 +1443,7 @@ bool isPrimListOfInts(Operation *op) {
12441443

12451444
bool isAnchorOp(Operation *op) {
12461445
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
1247-
isPrimListOfInts(op);
1446+
isPrimListOfInts(op) || isItemForSliceOp(op);
12481447
}
12491448

12501449
// The argument to this function, op, is the use of some source op, srcOp. If
@@ -1278,9 +1477,9 @@ bool isInvalidValidViewConsumer(Operation *op,
12781477
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
12791478
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
12801479
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
1281-
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
1282-
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
1283-
patterns.getContext());
1480+
FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern,
1481+
FoldAtenWhereSelf, FoldAtenTensorSplatPattern,
1482+
FoldAtenEqIntPattern>(patterns.getContext());
12841483
}
12851484

12861485
void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
@@ -1303,24 +1502,29 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
13031502
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
13041503
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
13051504
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
1306-
PropagateAtenTransposeIntPattern,
1505+
PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
1506+
PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
13071507
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
13081508
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
13091509
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
1510+
PropagateAtenArithmeticPattern<AtenRemainderTensorOp, AtenRemainderIntOp>,
13101511
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
13111512
patterns.getContext());
13121513
}
13131514

13141515
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
13151516
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
13161517
RemoveUnusedPattern<Torch::AtenEqIntOp>,
1518+
RemoveUnusedPattern<Torch::AtenToDtypeOp>,
13171519
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
13181520
RemoveUnusedPattern<Torch::AtenFullOp>,
13191521
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
13201522
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
13211523
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
13221524
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
13231525
RemoveUnusedPattern<Torch::AtenTensorOp>,
1526+
RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
1527+
RemoveUnusedPattern<Torch::AtenIntScalarOp>,
13241528
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
13251529
patterns.getContext());
13261530
}

0 commit comments

Comments
 (0)