Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops #3645

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
}];
}

def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -16479,6 +16534,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
}];
}

def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
Torch_BoolType:$self_is_result
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
57 changes: 57 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6683,6 +6683,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -7277,6 +7281,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -11877,6 +11885,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down Expand Up @@ -12063,6 +12079,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
132 changes: 132 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3399,6 +3399,59 @@ class DecomposeAtenLeakyReluBackwardOp
};
} // namespace

namespace {
class DecomposeAtenRreluWithNoiseBackwardOp
: public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value gradOutput = op.getGradOutput();
Value self = op.getSelf();
Value noise = op.getNoise();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}

bool training;
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
return rewriter.notifyMatchFailure(op,
"training should be a bool constant");
}

bool selfIsResult = false;
if (!matchPattern(op.getSelfIsResult(),
m_TorchConstantBool(&selfIsResult)) ||
selfIsResult)
return rewriter.notifyMatchFailure(
op, "unimplemented: self_is_result should be false");

double lower, upper;
if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) ||
!matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) {
return rewriter.notifyMatchFailure(
op, "lower and upper should be float constants");
}

if (training && (upper - lower > 0.000001)) {
Value rreluWithNoiseBackwardOutput =
rewriter.create<AtenMulTensorOp>(loc, resType, gradOutput, noise);
rewriter.replaceOp(op, rreluWithNoiseBackwardOutput);
} else {
double negative_slope = (upper + lower) / 2;
Value cstNegativeSlope = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(negative_slope));
rewriter.replaceOpWithNewOp<AtenLeakyReluBackwardOp>(
op, resType, gradOutput, self, cstNegativeSlope,
op.getSelfIsResult());
}
return success();
}
};
} // namespace

namespace {
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
public:
Expand Down Expand Up @@ -3498,6 +3551,82 @@ class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
};
} // namespace

namespace {
class DecomposeAtenRreluWithNoiseOp
: public OpRewritePattern<AtenRreluWithNoiseOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value noise = op.getNoise();
Value lower = op.getLower();
Value upper = op.getUpper();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}

bool training;
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
return rewriter.notifyMatchFailure(op, "training should be a constant");
}

Value constantZeroFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value constantOneFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value constantTwoFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));

Value alpha;
if (training) {
Value none = rewriter.create<ConstantNoneOp>(loc);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resType, self, constantZeroFloat, /*dtype=*/none,
/*layout=*/none,
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
/*from=*/lower, /*to=*/upper,
/*generator=*/none);
} else {
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
lower, upper);
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
constantTwoFloat);
}

Value zeroTensor =
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);

Value scaledSelf;
if (training) {
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(),
rewriter.getI1Type());
Value oneTensor =
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
Value not_positive = rewriter.create<AtenLtScalarOp>(
loc, boolResType, self, constantZeroFloat);
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
alpha, oneTensor);
} else {
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
}

Value negativeOutput =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
rewriter.replaceOp(op, rreluOutput);
return success();
}
};
} // namespace

// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
Expand Down Expand Up @@ -9474,6 +9603,9 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenRreluOp>();
target.addIllegalOp<AtenRreluWithNoiseOp>();
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
target.addIllegalOp<AtenCeluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();
Expand Down
Loading
Loading