@@ -63,6 +63,29 @@ LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
63
63
return success ();
64
64
}
65
65
66
+ LogicalResult constructListFromLiteral (PatternRewriter &rewriter,
67
+ ValueTensorLiteralOp literalOp,
68
+ SmallVector<Value> &vals) {
69
+ // only supports splat ValueTensorLiterals for now. TODO: add support for
70
+ // small non-splat valuetensorliterals.
71
+ auto ty = dyn_cast<ValueTensorType>(literalOp.getType ());
72
+ if (!ty || !ty.hasSizes ())
73
+ return failure ();
74
+ auto attr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue ());
75
+ if (!attr)
76
+ return failure ();
77
+ auto attrInt = dyn_cast<IntegerAttr>(attr.getSplatValue <Attribute>());
78
+ if (!attrInt)
79
+ return failure ();
80
+ IntegerType intty = cast<IntegerType>(attrInt.getType ());
81
+ if (!intty.isSignedInteger ())
82
+ return failure ();
83
+ Value materializedVal = rewriter.create <Torch::ConstantIntOp>(
84
+ literalOp.getLoc (), attrInt.getSInt ());
85
+ vals.resize (vals.size () + ty.getSizes ()[0 ], materializedVal);
86
+ return success ();
87
+ }
88
+
66
89
LogicalResult getListFromTensor (Value value, SmallVector<Value> &vals) {
67
90
constexpr int64_t kMaxFold = 16 ;
68
91
if (auto tensor = value.getDefiningOp <Torch::AtenTensorOp>())
@@ -351,6 +374,172 @@ class PropagateAtenSliceTensorPattern
351
374
};
352
375
} // namespace
353
376
377
+ namespace {
378
+ class PropagateAtenWhereSelfPattern : public OpRewritePattern <AtenWhereSelfOp> {
379
+ public:
380
+ using OpRewritePattern<AtenWhereSelfOp>::OpRewritePattern;
381
+ LogicalResult matchAndRewrite (AtenWhereSelfOp op,
382
+ PatternRewriter &rewriter) const override {
383
+ Value condition = op.getCondition ();
384
+ Value self = op.getSelf ();
385
+ Value other = op.getOther ();
386
+ auto conditionTy = dyn_cast<Torch::ValueTensorType>(condition.getType ());
387
+ if (!conditionTy || !conditionTy.hasSizes () ||
388
+ conditionTy.getSizes ().size () != 1 )
389
+ return rewriter.notifyMatchFailure (op, " bad condition type" );
390
+ auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType ());
391
+ if (!selfTy || !selfTy.hasSizes () || selfTy.getSizes ().size () != 1 )
392
+ return rewriter.notifyMatchFailure (op, " bad self type" );
393
+ auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType ());
394
+ if (!otherTy || !otherTy.hasSizes () || otherTy.getSizes ().size () != 1 )
395
+ return rewriter.notifyMatchFailure (op, " bad other type" );
396
+ int64_t conditionSize = selfTy.getSizes ()[0 ];
397
+ int64_t selfSize = selfTy.getSizes ()[0 ];
398
+ int64_t otherSize = otherTy.getSizes ()[0 ];
399
+
400
+ if (selfSize != otherSize || selfSize != conditionSize)
401
+ return rewriter.notifyMatchFailure (
402
+ op,
403
+ " unimplemented: support for propogating with implicit broadcasting." );
404
+
405
+ constexpr int64_t kMaxFold = 16 ;
406
+ if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold )
407
+ return rewriter.notifyMatchFailure (op,
408
+ " arguments are dynamic or too big" );
409
+
410
+ SmallVector<Value> conditionList, selfList, otherList;
411
+ if (failed (getListFromTensor (condition, conditionList)) ||
412
+ (int64_t )conditionList.size () != conditionSize)
413
+ return failure ();
414
+
415
+ // If one of these tensors is a value tensor literal op, we will need to
416
+ // create constant ints in the IR to form a list. Before calling
417
+ // constructListFromLiteral, we must be certain that the conversion can no
418
+ // longer fail, otherwise we will cause an infinite loop of creating a
419
+ // constant and removing it.
420
+ LogicalResult selfFromList = getListFromTensor (self, selfList);
421
+ LogicalResult otherFromList = getListFromTensor (other, otherList);
422
+
423
+ if (failed (selfFromList) && failed (otherFromList))
424
+ return rewriter.notifyMatchFailure (
425
+ op, " At least one operand must succeed at constructing a list" );
426
+
427
+ auto selfLiteral = self.getDefiningOp <Torch::ValueTensorLiteralOp>();
428
+ auto otherLiteral = other.getDefiningOp <Torch::ValueTensorLiteralOp>();
429
+ if (succeeded (selfFromList) && otherLiteral &&
430
+ failed (constructListFromLiteral (rewriter, otherLiteral, otherList)))
431
+ return failure ();
432
+ if (succeeded (otherFromList) && selfLiteral &&
433
+ failed (constructListFromLiteral (rewriter, selfLiteral, selfList)))
434
+ return failure ();
435
+ if ((int64_t )selfList.size () != selfSize ||
436
+ (int64_t )otherList.size () != otherSize)
437
+ // this should only occur if we did not generate IR with
438
+ // constructListFromLiteral
439
+ return failure ();
440
+
441
+ Location loc = op.getLoc ();
442
+ SmallVector<Value> whereVals;
443
+ auto rank0IntTy = rewriter.getType <Torch::ValueTensorType>(
444
+ ArrayRef<int64_t >({}), selfTy.getDtype ());
445
+ auto rank0BoolTy = rewriter.getType <Torch::ValueTensorType>(
446
+ ArrayRef<int64_t >({}), conditionTy.getDtype ());
447
+ for (uint64_t i = 0 ; i < selfList.size (); i++) {
448
+ Value rank0Cond = rewriter.create <Torch::PrimNumToTensorScalarOp>(
449
+ loc, rank0BoolTy, conditionList[i]);
450
+ Value rank0Self = rewriter.create <Torch::PrimNumToTensorScalarOp>(
451
+ loc, rank0IntTy, selfList[i]);
452
+ Value rank0Other = rewriter.create <Torch::PrimNumToTensorScalarOp>(
453
+ loc, rank0IntTy, otherList[i]);
454
+ Value rank0Where = rewriter.create <AtenWhereSelfOp>(
455
+ loc, rank0IntTy, rank0Cond, rank0Self, rank0Other);
456
+ whereVals.push_back (rewriter.create <AtenItemOp>(
457
+ loc, rewriter.getType <Torch::IntType>(), rank0Where));
458
+ }
459
+ Value list = rewriter.create <Torch::PrimListConstructOp>(
460
+ op.getLoc (), Torch::ListType::get (whereVals[0 ].getType ()), whereVals);
461
+ Value cstNone = rewriter.create <Torch::ConstantNoneOp>(op.getLoc ());
462
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(
463
+ op.getLoc (), rewriter.getBoolAttr (false ));
464
+ rewriter.replaceOpWithNewOp <Torch::AtenTensorOp>(
465
+ op, op.getType (), list, cstNone, cstNone, cstFalse);
466
+ return success ();
467
+ }
468
+ };
469
+ } // namespace
470
+
471
+ namespace {
472
+ class PropagateAtenEqTensorPattern : public OpRewritePattern <AtenEqTensorOp> {
473
+ public:
474
+ using OpRewritePattern<AtenEqTensorOp>::OpRewritePattern;
475
+ LogicalResult matchAndRewrite (AtenEqTensorOp op,
476
+ PatternRewriter &rewriter) const override {
477
+ Value self = op.getSelf ();
478
+ Value other = op.getOther ();
479
+ auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType ());
480
+ if (!selfTy || !selfTy.hasSizes () || selfTy.getSizes ().size () != 1 )
481
+ return rewriter.notifyMatchFailure (op, " bad self type" );
482
+ auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType ());
483
+ if (!otherTy || !otherTy.hasSizes () || otherTy.getSizes ().size () != 1 )
484
+ return rewriter.notifyMatchFailure (op, " bad other type" );
485
+ int64_t selfSize = selfTy.getSizes ()[0 ];
486
+ int64_t otherSize = otherTy.getSizes ()[0 ];
487
+
488
+ if (selfSize != otherSize)
489
+ return rewriter.notifyMatchFailure (
490
+ op,
491
+ " unimplemented: support for propogating with implicit broadcasting." );
492
+
493
+ constexpr int64_t kMaxFold = 16 ;
494
+ if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold ||
495
+ otherSize == Torch::kUnknownSize || otherSize > kMaxFold )
496
+ return rewriter.notifyMatchFailure (op,
497
+ " self or other is dynamic or too big" );
498
+
499
+ SmallVector<Value> selfList, otherList;
500
+ // If one of these tensors is a value tensor literal op, we will need to
501
+ // create constant ints in the IR to form a list. Before calling
502
+ // constructListFromLiteral, we must be certain that the conversion can no
503
+ // longer fail, otherwise we will cause an infinite loop of creating a
504
+ // constant and removing it.
505
+ LogicalResult selfFromList = getListFromTensor (self, selfList);
506
+ LogicalResult otherFromList = getListFromTensor (other, otherList);
507
+
508
+ if (failed (selfFromList) && failed (otherFromList))
509
+ return rewriter.notifyMatchFailure (
510
+ op, " At least one operand must succeed at constructing a list" );
511
+
512
+ auto selfLiteral = self.getDefiningOp <Torch::ValueTensorLiteralOp>();
513
+ auto otherLiteral = other.getDefiningOp <Torch::ValueTensorLiteralOp>();
514
+ if (succeeded (selfFromList) && otherLiteral &&
515
+ failed (constructListFromLiteral (rewriter, otherLiteral, otherList)))
516
+ return failure ();
517
+ if (succeeded (otherFromList) && selfLiteral &&
518
+ failed (constructListFromLiteral (rewriter, selfLiteral, selfList)))
519
+ return failure ();
520
+ if ((int64_t )selfList.size () != selfSize ||
521
+ (int64_t )otherList.size () != otherSize)
522
+ // this should only occur if we did not generate IR with
523
+ // constructListFromLiteral
524
+ return failure ();
525
+
526
+ SmallVector<Value> eqVals;
527
+ for (uint64_t i = 0 ; i < selfList.size (); i++) {
528
+ eqVals.push_back (
529
+ rewriter.create <AtenEqIntOp>(op.getLoc (), selfList[i], otherList[i]));
530
+ }
531
+ Value list = rewriter.create <Torch::PrimListConstructOp>(
532
+ op.getLoc (), Torch::ListType::get (eqVals[0 ].getType ()), eqVals);
533
+ Value cstNone = rewriter.create <Torch::ConstantNoneOp>(op.getLoc ());
534
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(
535
+ op.getLoc (), rewriter.getBoolAttr (false ));
536
+ rewriter.replaceOpWithNewOp <Torch::AtenTensorOp>(
537
+ op, op.getType (), list, cstNone, cstNone, cstFalse);
538
+ return success ();
539
+ }
540
+ };
541
+ } // namespace
542
+
354
543
namespace {
355
544
class PropagateAtenItemPattern : public OpRewritePattern <AtenItemOp> {
356
545
public:
@@ -454,6 +643,26 @@ class FoldAtenSqueezePattern : public OpRewritePattern<AtenSqueezeOp> {
454
643
};
455
644
} // namespace
456
645
646
+ namespace {
647
+ class FoldAtenSqueezeDimPattern : public OpRewritePattern <AtenSqueezeDimOp> {
648
+ public:
649
+ using OpRewritePattern<AtenSqueezeDimOp>::OpRewritePattern;
650
+ LogicalResult matchAndRewrite (AtenSqueezeDimOp op,
651
+ PatternRewriter &rewriter) const override {
652
+ auto resultTy = cast<ValueTensorType>(op.getType ());
653
+ if (!resultTy.hasSizes () || resultTy.getSizes ().size () != 0 )
654
+ return rewriter.notifyMatchFailure (op, " Unknown result shape" );
655
+
656
+ if (auto atenFull = op.getSelf ().getDefiningOp <AtenFullOp>()) {
657
+ rewriter.replaceOpWithNewOp <PrimNumToTensorScalarOp>(
658
+ op, resultTy, atenFull.getFillValue ());
659
+ return success ();
660
+ }
661
+ return failure ();
662
+ }
663
+ };
664
+ } // namespace
665
+
457
666
namespace {
458
667
class FoldAtenWhereSelf : public OpRewritePattern <AtenWhereSelfOp> {
459
668
public:
@@ -694,6 +903,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
694
903
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
695
904
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
696
905
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
906
+ PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern,
907
+ FoldAtenSqueezeDimPattern,
697
908
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
698
909
RemoveUnusedPattern<Torch::AtenEqIntOp>,
699
910
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
0 commit comments