Skip to content

Commit ab62f35

Browse files
authored
Add more patterns to scalarize-shapes pass (llvm#3781)
-Adds patterns for propagating shapes through AtenWhereSelf and AtenEqTensor -Adds fold pattern for a rank0 squeezeDim of a full op -Adds support for getting a list from a splat ValueTensorLiteralOp for materializing scalar comparisons in where.self and eq.tensor With a bit of hammering, these changes should unblock several IREE inference failures.
1 parent 7b11dfc commit ab62f35

File tree

2 files changed

+287
-0
lines changed

2 files changed

+287
-0
lines changed

lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp

+211
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,29 @@ LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
6363
return success();
6464
}
6565

66+
LogicalResult constructListFromLiteral(PatternRewriter &rewriter,
67+
ValueTensorLiteralOp literalOp,
68+
SmallVector<Value> &vals) {
69+
// only supports splat ValueTensorLiterals for now. TODO: add support for
70+
// small non-splat valuetensorliterals.
71+
auto ty = dyn_cast<ValueTensorType>(literalOp.getType());
72+
if (!ty || !ty.hasSizes())
73+
return failure();
74+
auto attr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue());
75+
if (!attr)
76+
return failure();
77+
auto attrInt = dyn_cast<IntegerAttr>(attr.getSplatValue<Attribute>());
78+
if (!attrInt)
79+
return failure();
80+
IntegerType intty = cast<IntegerType>(attrInt.getType());
81+
if (!intty.isSignedInteger())
82+
return failure();
83+
Value materializedVal = rewriter.create<Torch::ConstantIntOp>(
84+
literalOp.getLoc(), attrInt.getSInt());
85+
vals.resize(vals.size() + ty.getSizes()[0], materializedVal);
86+
return success();
87+
}
88+
6689
LogicalResult getListFromTensor(Value value, SmallVector<Value> &vals) {
6790
constexpr int64_t kMaxFold = 16;
6891
if (auto tensor = value.getDefiningOp<Torch::AtenTensorOp>())
@@ -351,6 +374,172 @@ class PropagateAtenSliceTensorPattern
351374
};
352375
} // namespace
353376

377+
namespace {
378+
class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> {
379+
public:
380+
using OpRewritePattern<AtenWhereSelfOp>::OpRewritePattern;
381+
LogicalResult matchAndRewrite(AtenWhereSelfOp op,
382+
PatternRewriter &rewriter) const override {
383+
Value condition = op.getCondition();
384+
Value self = op.getSelf();
385+
Value other = op.getOther();
386+
auto conditionTy = dyn_cast<Torch::ValueTensorType>(condition.getType());
387+
if (!conditionTy || !conditionTy.hasSizes() ||
388+
conditionTy.getSizes().size() != 1)
389+
return rewriter.notifyMatchFailure(op, "bad condition type");
390+
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
391+
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
392+
return rewriter.notifyMatchFailure(op, "bad self type");
393+
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
394+
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
395+
return rewriter.notifyMatchFailure(op, "bad other type");
396+
int64_t conditionSize = selfTy.getSizes()[0];
397+
int64_t selfSize = selfTy.getSizes()[0];
398+
int64_t otherSize = otherTy.getSizes()[0];
399+
400+
if (selfSize != otherSize || selfSize != conditionSize)
401+
return rewriter.notifyMatchFailure(
402+
op,
403+
"unimplemented: support for propogating with implicit broadcasting.");
404+
405+
constexpr int64_t kMaxFold = 16;
406+
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold)
407+
return rewriter.notifyMatchFailure(op,
408+
"arguments are dynamic or too big");
409+
410+
SmallVector<Value> conditionList, selfList, otherList;
411+
if (failed(getListFromTensor(condition, conditionList)) ||
412+
(int64_t)conditionList.size() != conditionSize)
413+
return failure();
414+
415+
// If one of these tensors is a value tensor literal op, we will need to
416+
// create constant ints in the IR to form a list. Before calling
417+
// constructListFromLiteral, we must be certain that the conversion can no
418+
// longer fail, otherwise we will cause an infinite loop of creating a
419+
// constant and removing it.
420+
LogicalResult selfFromList = getListFromTensor(self, selfList);
421+
LogicalResult otherFromList = getListFromTensor(other, otherList);
422+
423+
if (failed(selfFromList) && failed(otherFromList))
424+
return rewriter.notifyMatchFailure(
425+
op, "At least one operand must succeed at constructing a list");
426+
427+
auto selfLiteral = self.getDefiningOp<Torch::ValueTensorLiteralOp>();
428+
auto otherLiteral = other.getDefiningOp<Torch::ValueTensorLiteralOp>();
429+
if (succeeded(selfFromList) && otherLiteral &&
430+
failed(constructListFromLiteral(rewriter, otherLiteral, otherList)))
431+
return failure();
432+
if (succeeded(otherFromList) && selfLiteral &&
433+
failed(constructListFromLiteral(rewriter, selfLiteral, selfList)))
434+
return failure();
435+
if ((int64_t)selfList.size() != selfSize ||
436+
(int64_t)otherList.size() != otherSize)
437+
// this should only occur if we did not generate IR with
438+
// constructListFromLiteral
439+
return failure();
440+
441+
Location loc = op.getLoc();
442+
SmallVector<Value> whereVals;
443+
auto rank0IntTy = rewriter.getType<Torch::ValueTensorType>(
444+
ArrayRef<int64_t>({}), selfTy.getDtype());
445+
auto rank0BoolTy = rewriter.getType<Torch::ValueTensorType>(
446+
ArrayRef<int64_t>({}), conditionTy.getDtype());
447+
for (uint64_t i = 0; i < selfList.size(); i++) {
448+
Value rank0Cond = rewriter.create<Torch::PrimNumToTensorScalarOp>(
449+
loc, rank0BoolTy, conditionList[i]);
450+
Value rank0Self = rewriter.create<Torch::PrimNumToTensorScalarOp>(
451+
loc, rank0IntTy, selfList[i]);
452+
Value rank0Other = rewriter.create<Torch::PrimNumToTensorScalarOp>(
453+
loc, rank0IntTy, otherList[i]);
454+
Value rank0Where = rewriter.create<AtenWhereSelfOp>(
455+
loc, rank0IntTy, rank0Cond, rank0Self, rank0Other);
456+
whereVals.push_back(rewriter.create<AtenItemOp>(
457+
loc, rewriter.getType<Torch::IntType>(), rank0Where));
458+
}
459+
Value list = rewriter.create<Torch::PrimListConstructOp>(
460+
op.getLoc(), Torch::ListType::get(whereVals[0].getType()), whereVals);
461+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
462+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
463+
op.getLoc(), rewriter.getBoolAttr(false));
464+
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
465+
op, op.getType(), list, cstNone, cstNone, cstFalse);
466+
return success();
467+
}
468+
};
469+
} // namespace
470+
471+
namespace {
472+
class PropagateAtenEqTensorPattern : public OpRewritePattern<AtenEqTensorOp> {
473+
public:
474+
using OpRewritePattern<AtenEqTensorOp>::OpRewritePattern;
475+
LogicalResult matchAndRewrite(AtenEqTensorOp op,
476+
PatternRewriter &rewriter) const override {
477+
Value self = op.getSelf();
478+
Value other = op.getOther();
479+
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
480+
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
481+
return rewriter.notifyMatchFailure(op, "bad self type");
482+
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
483+
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
484+
return rewriter.notifyMatchFailure(op, "bad other type");
485+
int64_t selfSize = selfTy.getSizes()[0];
486+
int64_t otherSize = otherTy.getSizes()[0];
487+
488+
if (selfSize != otherSize)
489+
return rewriter.notifyMatchFailure(
490+
op,
491+
"unimplemented: support for propogating with implicit broadcasting.");
492+
493+
constexpr int64_t kMaxFold = 16;
494+
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold ||
495+
otherSize == Torch::kUnknownSize || otherSize > kMaxFold)
496+
return rewriter.notifyMatchFailure(op,
497+
"self or other is dynamic or too big");
498+
499+
SmallVector<Value> selfList, otherList;
500+
// If one of these tensors is a value tensor literal op, we will need to
501+
// create constant ints in the IR to form a list. Before calling
502+
// constructListFromLiteral, we must be certain that the conversion can no
503+
// longer fail, otherwise we will cause an infinite loop of creating a
504+
// constant and removing it.
505+
LogicalResult selfFromList = getListFromTensor(self, selfList);
506+
LogicalResult otherFromList = getListFromTensor(other, otherList);
507+
508+
if (failed(selfFromList) && failed(otherFromList))
509+
return rewriter.notifyMatchFailure(
510+
op, "At least one operand must succeed at constructing a list");
511+
512+
auto selfLiteral = self.getDefiningOp<Torch::ValueTensorLiteralOp>();
513+
auto otherLiteral = other.getDefiningOp<Torch::ValueTensorLiteralOp>();
514+
if (succeeded(selfFromList) && otherLiteral &&
515+
failed(constructListFromLiteral(rewriter, otherLiteral, otherList)))
516+
return failure();
517+
if (succeeded(otherFromList) && selfLiteral &&
518+
failed(constructListFromLiteral(rewriter, selfLiteral, selfList)))
519+
return failure();
520+
if ((int64_t)selfList.size() != selfSize ||
521+
(int64_t)otherList.size() != otherSize)
522+
// this should only occur if we did not generate IR with
523+
// constructListFromLiteral
524+
return failure();
525+
526+
SmallVector<Value> eqVals;
527+
for (uint64_t i = 0; i < selfList.size(); i++) {
528+
eqVals.push_back(
529+
rewriter.create<AtenEqIntOp>(op.getLoc(), selfList[i], otherList[i]));
530+
}
531+
Value list = rewriter.create<Torch::PrimListConstructOp>(
532+
op.getLoc(), Torch::ListType::get(eqVals[0].getType()), eqVals);
533+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
534+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
535+
op.getLoc(), rewriter.getBoolAttr(false));
536+
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
537+
op, op.getType(), list, cstNone, cstNone, cstFalse);
538+
return success();
539+
}
540+
};
541+
} // namespace
542+
354543
namespace {
355544
class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
356545
public:
@@ -454,6 +643,26 @@ class FoldAtenSqueezePattern : public OpRewritePattern<AtenSqueezeOp> {
454643
};
455644
} // namespace
456645

646+
namespace {
647+
class FoldAtenSqueezeDimPattern : public OpRewritePattern<AtenSqueezeDimOp> {
648+
public:
649+
using OpRewritePattern<AtenSqueezeDimOp>::OpRewritePattern;
650+
LogicalResult matchAndRewrite(AtenSqueezeDimOp op,
651+
PatternRewriter &rewriter) const override {
652+
auto resultTy = cast<ValueTensorType>(op.getType());
653+
if (!resultTy.hasSizes() || resultTy.getSizes().size() != 0)
654+
return rewriter.notifyMatchFailure(op, "Unknown result shape");
655+
656+
if (auto atenFull = op.getSelf().getDefiningOp<AtenFullOp>()) {
657+
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
658+
op, resultTy, atenFull.getFillValue());
659+
return success();
660+
}
661+
return failure();
662+
}
663+
};
664+
} // namespace
665+
457666
namespace {
458667
class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
459668
public:
@@ -694,6 +903,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
694903
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
695904
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
696905
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
906+
PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern,
907+
FoldAtenSqueezeDimPattern,
697908
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
698909
RemoveUnusedPattern<Torch::AtenEqIntOp>,
699910
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,

test/Dialect/Torch/scalarize-shapes.mlir

+76
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,79 @@ func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !t
160160
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
161161
return %14 : !torch.int
162162
}
163+
164+
165+
// -----
166+
167+
// CHECK-LABEL: @eq_tensor_and_where_self
168+
func.func @eq_tensor_and_where_self(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],si64> {
169+
// CHECK-DAG: %[[false:.*]] = torch.constant.bool false
170+
// CHECK-DAG: %[[none:.*]] = torch.constant.none
171+
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
172+
// CHECK-DAG: %[[I0:.*]] = torch.constant.int 0
173+
// CHECK-DAG: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
174+
// CHECK-DAG: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
175+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[I1]], %[[DIM1]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
176+
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
177+
// CHECK: return %[[TENSOR]] : !torch.vtensor<[4],si64>
178+
%none = torch.constant.none
179+
%0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
180+
%1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
181+
%false = torch.constant.bool false
182+
%int1 = torch.constant.int 1
183+
%int0 = torch.constant.int 0
184+
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
185+
%3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
186+
%4 = torch.prim.ListConstruct %3, %int1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
187+
%5 = torch.aten.tensor %4, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
188+
%6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1>
189+
%7 = torch.aten.where.self %6, %1, %5 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64>
190+
return %7 : !torch.vtensor<[4],si64>
191+
}
192+
193+
194+
// -----
195+
196+
// CHECK-LABEL: @eq_tensor_from_tensor_and_literal
197+
func.func @eq_tensor_from_tensor_and_literal(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],i1> {
198+
// CHECK-DAG: %[[none:.*]] = torch.constant.none
199+
// CHECK-DAG: %[[false:.*]] = torch.constant.bool false
200+
// CHECK-DAG: %[[true:.*]] = torch.constant.bool true
201+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[false]], %[[true]], %[[false]], %[[false]] : (!torch.bool, !torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
202+
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list<bool>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],i1>
203+
// CHECK: return %[[TENSOR]] : !torch.vtensor<[4],i1>
204+
%none = torch.constant.none
205+
%0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
206+
%1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
207+
%false = torch.constant.bool false
208+
%int1 = torch.constant.int 1
209+
%int-1 = torch.constant.int -1
210+
%int0 = torch.constant.int 0
211+
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
212+
%3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
213+
%4 = torch.prim.ListConstruct %3, %int-1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
214+
%5 = torch.aten.tensor %4, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64>
215+
%6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1>
216+
return %6 : !torch.vtensor<[4],i1>
217+
}
218+
219+
220+
221+
// -----
222+
223+
// CHECK-LABEL: @squeeze_dim_full_fold
224+
func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
225+
// CHECK: %[[I0:.*]] = torch.constant.int 0
226+
// CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
227+
// CHECK: return %[[SZE]] : !torch.int
228+
%int0 = torch.constant.int 0
229+
%int1 = torch.constant.int 1
230+
%none = torch.constant.none
231+
%false = torch.constant.bool false
232+
%51 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int
233+
%55 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
234+
%56 = torch.aten.full %55, %51, %none, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
235+
%57 = torch.aten.squeeze.dim %56, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
236+
%58 = torch.aten.item %57 : !torch.vtensor<[],si64> -> !torch.int
237+
return %58 : !torch.int
238+
}

0 commit comments

Comments
 (0)