diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index a2f25b54e63..908dea5280a 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -196,6 +196,24 @@ def ElementwiseReluModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 2) - 0.5) # ============================================================================== +class ElementwiseLeakyReluModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.leaky_relu(x, negative_slope=0.1) + + +@register_test_case(module_factory=lambda: ElementwiseLeakyReluModule()) +def ElementwiseLeakyReluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 2) - 0.5) + +# ============================================================================== class ElementwiseGeluModule(torch.nn.Module): diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index f3d4d0c6800..2f776353caa 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -72,6 +72,36 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [ let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; } +def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::leaky_relu : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$negative_slope + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)"; +} + +def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::leaky_relu_ : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$negative_slope + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)"; +} + def Torch_AtenLogOp : Torch_Op<"aten.log", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 47dcade816f..93f0dfa5513 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -1378,11 +1379,30 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } Type elementType = payloadArgs[0].getType(); Value constZero = - b.create(loc, FloatAttr::get(elementType, 0.0)); + b.create(loc, b.getZeroAttr(elementType)); Value pred = b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], constZero); return b.create(loc, pred, payloadArgs[0], constZero); } + if (auto lrelu = dyn_cast(op)) { + if (!lrelu.getType() + .cast() + .getDtype() + .isa()) { + lrelu.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } + Type elementType = payloadArgs[0].getType(); + Value constZero = + b.create(loc, b.getZeroAttr(elementType)); + Value pred = b.create(loc, arith::CmpFPredicate::UGT, + payloadArgs[0], constZero); + Value positivePart = b.create(loc, pred, payloadArgs[0], constZero); + Value negativePart = b.create(loc, pred, constZero, payloadArgs[0]); + Value scale = convertScalarToDtype(b, loc, operands[1], elementType); + Value scaledNegativePart = b.create(loc, negativePart, scale); + return b.create(loc, positivePart, scaledNegativePart); + } if (auto gelu = dyn_cast(op)) { if (!gelu.getType() .cast() @@ -1812,7 +1832,7 @@ struct ConvertElementwiseOp : ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isa(); patterns.add(typeConverter, context); target.addIllegalOp< - AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, + AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index e2079609e47..8036ff01d2a 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -289,7 +289,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands); } else if (isa(op)) { + AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(op)) { return visitBinaryTensorScalarOp(op, operands); } else if (isa (Tensor)", "aten::relu : (Tensor) -> (Tensor)", + "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)",