Skip to content

Commit

Permalink
Add leakyrelu support
Browse files Browse the repository at this point in the history
  • Loading branch information
ljfitz authored and Prashant Kumar committed Nov 27, 2021
1 parent f461a7e commit 7616d28
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 4 deletions.
18 changes: 18 additions & 0 deletions e2e_testing/torchscript/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,24 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2) - 0.5)

# ==============================================================================
class ElementwiseLeakyReluModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.leaky_relu(x, negative_slope=0.1)


@register_test_case(module_factory=lambda: ElementwiseLeakyReluModule())
def ElementwiseLeakyReluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2) - 0.5)

# ==============================================================================


class ElementwiseGeluModule(torch.nn.Module):
Expand Down
30 changes: 30 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,36 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}

def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::leaky_relu : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$negative_slope
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)";
}

def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::leaky_relu_ : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$negative_slope
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)";
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down
26 changes: 23 additions & 3 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -1378,11 +1379,30 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
}
if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) {
if (!lrelu.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
lrelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]);
Value scale = convertScalarToDtype(b, loc, operands[1], elementType);
Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale);
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
}
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
if (!gelu.getType()
.cast<ValueTensorType>()
Expand Down Expand Up @@ -1812,7 +1832,7 @@ struct ConvertElementwiseOp : ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
Expand Down Expand Up @@ -2969,7 +2989,7 @@ class ConvertTorchToLinalg
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<
AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
AtenPowTensorScalarOp, AtenRsubScalarOp>(op)) {
AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(op)) {
return visitBinaryTensorScalarOp(op, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def emit_with_mutating_variants(key, **kwargs):
for key in [
"aten::tanh : (Tensor) -> (Tensor)",
"aten::relu : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sin : (Tensor) -> (Tensor)",
Expand Down

0 comments on commit 7616d28

Please sign in to comment.