From f34eb6612415fb64d4e3812a1e99b0f36a2faa6c Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Thu, 9 Dec 2021 01:58:50 +0530 Subject: [PATCH] [TORCH][MLIR] Add E2E support for [`aten.gt.Scalar`|`aten.where.self`] This commit adds lowering of `aten.gt.Scalar` and `aten.where.self` as a part of element-wise ops lowering. Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/elementwise.py | 43 +++++++++++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 16 +++++++ .../TorchToLinalg/TorchToLinalg.cpp | 26 ++++++++++- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 43 +++++++++++++++++-- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 5 files changed, 124 insertions(+), 5 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index af973429994..9bb5ac7c11b 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -85,6 +85,29 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseWhereSelfModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, a, b, c): + return torch.where(a > 0.5, b, c) + + +@register_test_case(module_factory=lambda: ElementwiseWhereSelfModule()) +def ElementwiseWhereSelfModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5)) + + +# ============================================================================== + + # Addition is an interesting special case of a binary op, because under the hood # it carries a third scalar "alpha" parameter, which needs special handling. class ElementwiseAddModule(torch.nn.Module): @@ -303,6 +326,26 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseGtScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.gt(x, 0.6) + + +@register_test_case(module_factory=lambda: ElementwiseGtScalarModule()) +def ElementwiseGtScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + +# ============================================================================== + + class ElementwiseClampModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index d7858a60628..380d07e26fd 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1071,6 +1071,22 @@ def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [ let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; } +def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$condition, + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)"; +} + def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 322405ea036..7234178c50c 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1684,6 +1684,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, payloadArgs[0], expPromoted); } + + if (auto gtScalar = dyn_cast(op)) { + Type dtype = gtScalar.self().getType().cast().getDtype(); + if (!dtype.isa()) { + gtScalar.emitError("unimplemented: non-floating point operand dtype"); + return nullptr; + } + Value otherPromoted = convertScalarToDtype(b, loc, operands[1], dtype); + return b.create(loc, arith::CmpFPredicate::UGT, + payloadArgs[0], otherPromoted); + } + + if (auto whereSelf = dyn_cast(op)) { + Type dtype = converter->convertType(whereSelf.getType()) + .cast() + .getElementType(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype); + return b.create(loc, payloadArgs[0], lhs, rhs); + } + if (auto lerp = dyn_cast(op)) { if (!lerp.getType() .cast() @@ -2040,7 +2061,7 @@ struct ConvertElementwiseOp : ConversionPattern { AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp>(op)) + AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -3461,7 +3482,8 @@ class ConvertTorchToLinalg AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp>(); + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, + AtenWhereSelfOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index b61a3e145cd..2d2b0e94105 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -235,9 +235,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { ArrayRef *> operands) final { if (isa { return getLatticeElement(op->getResult(0)).join(*operands[0]); } + // These comparison ops return a tensor with 1-bit integer dtype. + if (isa( + op)) { + auto operand = operands[0]->getValue(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + if (operand.hasSizes) { + knowledge.hasSizes = true; + knowledge.sizes = operand.sizes; + } + knowledge.dtype = IntegerType::get(op->getContext(), 1); + return getLatticeElement(op->getResult(0)).join(knowledge); + } + // Resize to [1, 1] with integer dtype. if (isa(op)) { auto input = operands[0]->getValue(); @@ -307,6 +320,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp, AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) { return visitBinaryBroadcastingOp(op, operands); + } else if (auto whereSelf = llvm::dyn_cast(op)) { + return visitAtenWhereSelfOp(whereSelf, operands); } else if (auto lerpTensor = llvm::dyn_cast(op)) { return visitAtenLerpTensorOp(lerpTensor, operands); } else if (auto flatten = dyn_cast(op)) { @@ -487,6 +502,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { ChangeResult visitBinaryBroadcastingOp( Operation *op, ArrayRef *> operands); ChangeResult + visitAtenWhereSelfOp(AtenWhereSelfOp op, + ArrayRef *> operands); + ChangeResult visitAtenLerpTensorOp(AtenLerpTensorOp op, ArrayRef *> operands); ChangeResult visitAtenFlattenUsingIntsOp( @@ -856,6 +874,25 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp( return getLatticeElement(op->getResult(0)).join(knowledge); } +ChangeResult TypeAnalyzer::visitAtenWhereSelfOp( + AtenWhereSelfOp op, ArrayRef *> operands) { + auto condition = operands[0]->getValue(); + auto lhs = operands[1]->getValue(); + auto rhs = operands[2]->getValue(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(getContext()); + if (condition.hasSizes && lhs.hasSizes && rhs.hasSizes) { + knowledge.hasSizes = true; + knowledge.sizes.resize( + std::max(condition.sizes.size(), + std::max(lhs.sizes.size(), rhs.sizes.size())), + kUnknownSize); + } + + knowledge.dtype = getPromotedResultType(getContext(), {&lhs, &rhs}); + return getLatticeElement(op->getResult(0)).join(knowledge); +} + ChangeResult TypeAnalyzer::visitAtenLerpTensorOp( AtenLerpTensorOp op, ArrayRef *> operands) { // This is a general broadcasting shape transfer function. diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 67f4dd8a1b8..1db64c2824c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -479,6 +479,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") + emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::gelu : (Tensor) -> (Tensor)")