From be5f327764194e15a48bf62b4ff1cfbe26b6f444 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Mon, 22 Apr 2024 05:09:06 +0000 Subject: [PATCH 1/4] support dynamic dims in DecomposeAtenEyeMOp --- .../Torch/Transforms/DecomposeComplexOps.cpp | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 87f93ba9c555..115a7fb5c42e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1059,16 +1059,6 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenEyeMOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - int64_t n; - - if (!matchPattern(op.getN(), m_TorchConstantInt(&n))) - return rewriter.notifyMatchFailure(op, - "unimplemented: n must be constant"); - int64_t m; - if (!matchPattern(op.getM(), m_TorchConstantInt(&m))) - return rewriter.notifyMatchFailure(op, - "unimplemented: m must be constant"); - Value none = rewriter.create(loc); auto outType = op.getType().dyn_cast(); if (!outType) return rewriter.notifyMatchFailure( @@ -1076,27 +1066,39 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { if (!outType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } - if (n < 0) { - return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0"); - } - if (m < 0) { - return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0"); - } - + Value none = rewriter.create(loc); auto context = op.getContext(); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); - auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type); + + int64_t n; + Type rangeNType; + if (matchPattern(op.getN(), m_TorchConstantInt(&n))) { + rangeNType = outType.getWithSizesAndDtype(std::nullopt, si64Type); + } else { + if (n < 0) + return rewriter.notifyMatchFailure(op, + "n must be greater or equal to 0"); + rangeNType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type); + } Value rangeN = rewriter.create( - loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, + loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/op.getDevice(), /*pin_memory=*/none); - auto arangeType1 = - outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); + int64_t m; + Type rangeMType; + if (matchPattern(op.getM(), m_TorchConstantInt(&m))) { + rangeMType = outType.getWithSizesAndDtype(std::nullopt, si64Type); + } else { + if (m < 0) + return rewriter.notifyMatchFailure(op, + "m must be greater or equal to 0"); + rangeMType = outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); + } Value rangeM = rewriter.create( - loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, + loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); Value constMinusOne = rewriter.create( From 9e0c455e2af4db412c60354b3f613266d4a44929 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Mon, 22 Apr 2024 19:50:45 +0000 Subject: [PATCH 2/4] prioritize getting shapes from output shapes and use kUnknown when no shape is available --- .../Torch/Transforms/DecomposeComplexOps.cpp | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 115a7fb5c42e..a253cbf02b0e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1073,30 +1073,28 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); - int64_t n; - Type rangeNType; - if (matchPattern(op.getN(), m_TorchConstantInt(&n))) { - rangeNType = outType.getWithSizesAndDtype(std::nullopt, si64Type); - } else { - if (n < 0) - return rewriter.notifyMatchFailure(op, - "n must be greater or equal to 0"); - rangeNType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type); - } + int64_t n = kUnknownSize; + int64_t m = kUnknownSize; + // prioritize getting shape from output shape + if (outType.hasSizes()) { + n = outType.getSizes().front(); + m = outType.getSizes().back(); + } + // if output shape is not available, try to get shape from input + if (n == kUnknownSize) { + matchPattern(op.getN(), m_TorchConstantInt(&n)); + } + if (m == kUnknownSize) { + matchPattern(op.getM(), m_TorchConstantInt(&m)); + } + Type rangeNType = outType.getWithSizesAndDtype( + n == kUnknownSize ? std::nullopt : llvm::ArrayRef(n), si64Type); Value rangeN = rewriter.create( loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/op.getDevice(), /*pin_memory=*/none); - int64_t m; - Type rangeMType; - if (matchPattern(op.getM(), m_TorchConstantInt(&m))) { - rangeMType = outType.getWithSizesAndDtype(std::nullopt, si64Type); - } else { - if (m < 0) - return rewriter.notifyMatchFailure(op, - "m must be greater or equal to 0"); - rangeMType = outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); - } + Type rangeMType = outType.getWithSizesAndDtype( + m == kUnknownSize ? std::nullopt : llvm::ArrayRef(m), si64Type); Value rangeM = rewriter.create( loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); From 9738ec463b2094614ef3ae6ed201174220cd08ce Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 26 Apr 2024 01:19:19 +0000 Subject: [PATCH 3/4] address review comments --- .../Torch/Transforms/DecomposeComplexOps.cpp | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a253cbf02b0e..dc8ea08affab 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1075,26 +1075,31 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { int64_t n = kUnknownSize; int64_t m = kUnknownSize; + // prioritize getting shape from output shape - if (outType.hasSizes()) { + if (outType.hasSizes() && outType.getSizes().size() == 2) { n = outType.getSizes().front(); m = outType.getSizes().back(); } // if output shape is not available, try to get shape from input - if (n == kUnknownSize) { + if (n == kUnknownSize) matchPattern(op.getN(), m_TorchConstantInt(&n)); - } - if (m == kUnknownSize) { + if (m == kUnknownSize) matchPattern(op.getM(), m_TorchConstantInt(&m)); - } - Type rangeNType = outType.getWithSizesAndDtype( - n == kUnknownSize ? std::nullopt : llvm::ArrayRef(n), si64Type); + + // prepare two unsqueezed ranges that are equal on and only on the diagonal + std::optional> rangeNSize; + if (n != kUnknownSize) + rangeNSize = {n}; + Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type); Value rangeN = rewriter.create( loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/op.getDevice(), /*pin_memory=*/none); - Type rangeMType = outType.getWithSizesAndDtype( - m == kUnknownSize ? std::nullopt : llvm::ArrayRef(m), si64Type); + std::optional> rangeMSize; + if (m != kUnknownSize) + rangeMSize = {m}; + Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type); Value rangeM = rewriter.create( loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); From 5ac730edc31fd29110fbe0efb807b56eb6c5cb13 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 26 Apr 2024 01:53:16 +0000 Subject: [PATCH 4/4] fix smallvector implicit list initializer --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index dc8ea08affab..9fff06c456ed 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1075,7 +1075,6 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { int64_t n = kUnknownSize; int64_t m = kUnknownSize; - // prioritize getting shape from output shape if (outType.hasSizes() && outType.getSizes().size() == 2) { n = outType.getSizes().front(); @@ -1088,17 +1087,13 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { matchPattern(op.getM(), m_TorchConstantInt(&m)); // prepare two unsqueezed ranges that are equal on and only on the diagonal - std::optional> rangeNSize; - if (n != kUnknownSize) - rangeNSize = {n}; + auto rangeNSize = llvm::SmallVector({n}); Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type); Value rangeN = rewriter.create( loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/op.getDevice(), /*pin_memory=*/none); - std::optional> rangeMSize; - if (m != kUnknownSize) - rangeMSize = {m}; + auto rangeMSize = llvm::SmallVector({m}); Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type); Value rangeM = rewriter.create( loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, @@ -1114,7 +1109,6 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { } Value unsqzRangeN = *unsqzTensorInfo; - // compare unsqueezed input with boundaries auto eqType = ValueTensorType::get( context, op.getType().cast().getSizes(), IntegerType::get(context, 1));