From e70be44cfc535855a670c46585b8863c7f565d04 Mon Sep 17 00:00:00 2001 From: Andrija Bosnjakovic Date: Fri, 20 Sep 2024 17:14:37 +0200 Subject: [PATCH] add forward+backward test --- .../Torch/Transforms/DecomposeComplexOps.cpp | 94 ++++++++++++++++++- .../test_suite/backprop.py | 38 ++++++++ .../test_suite/elementwise.py | 14 +-- 3 files changed, 135 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a54399982711..b353e4dc91ee 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3563,7 +3563,7 @@ class DecomposeAtenRreluWithNoiseOp Value noise = op.getNoise(); Value lower = op.getLower(); Value upper = op.getUpper(); - auto resType = cast(self.getType()); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -3609,13 +3609,12 @@ class DecomposeAtenRreluWithNoiseOp rewriter.getI1Type()); Value oneTensor = createRank0Tensor(rewriter, loc, resType, constantOneFloat); - Value not_positive = rewriter.create( + Value not_positive = rewriter.create( loc, boolResType, self, constantZeroFloat); noise = rewriter.create(loc, resType, not_positive, - noise, oneTensor); + alpha, oneTensor); } else { scaledSelf = rewriter.create(loc, resType, self, alpha); - noise = alpha; } Value negativeOutput = @@ -3628,6 +3627,93 @@ class DecomposeAtenRreluWithNoiseOp }; } // namespace +// namespace { +// class DecomposeAtenRreluWithNoiseOp +// : public OpRewritePattern { +// public: +// using OpRewritePattern::OpRewritePattern; +// LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op, +// PatternRewriter &rewriter) const override { +// Location loc = op.getLoc(); +// Value self = op.getSelf(); +// Value noise = op.getNoise(); +// Value lower = op.getLower(); +// Value upper = op.getUpper(); +// auto resType = cast(op.getType()); +// if (!resType.hasDtype()) { +// return rewriter.notifyMatchFailure(op, "result should have dtype"); +// } + +// bool training; +// if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { +// return rewriter.notifyMatchFailure(op, "training should be a +// constant"); +// } + +// Value constantZeroFloat = +// rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); +// Value constantOneFloat = +// rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); +// Value constantTwoFloat = +// rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + +// // Value alpha; +// // if (training) { +// // Value none = rewriter.create(loc); +// // Value emptyTensor = rewriter.create( +// // loc, resType, self, constantZeroFloat, /*dtype=*/none, +// // /*layout=*/none, +// // /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); +// // alpha = rewriter.create(loc, resType, emptyTensor, +// // /*from=*/lower, /*to=*/upper, +// // /*generator=*/none); +// // } else { +// // Value half = rewriter.create(loc, +// constantTwoFloat.getType(), +// // lower, upper); +// // alpha = rewriter.create(loc, constantTwoFloat.getType(), +// half, +// // constantTwoFloat); +// // } + +// Value zeroTensor = +// createRank0Tensor(rewriter, loc, resType, constantZeroFloat); +// Value positiveOutput = +// rewriter.create(loc, resType, zeroTensor, self); + +// Value scaledSelf; +// if (training) { +// scaledSelf = rewriter.create(loc, resType, self, +// noise); auto boolResType = +// resType.getWithSizesAndDtype(resType.getSizes(), +// rewriter.getI1Type()); +// Value oneTensor = +// createRank0Tensor(rewriter, loc, resType, constantOneFloat); +// Value not_positive = rewriter.create( +// loc, boolResType, self, constantZeroFloat); +// noise = rewriter.create(loc, resType, not_positive, +// noise, oneTensor); +// } else { +// Value half = rewriter.create(loc, +// constantTwoFloat.getType(), +// lower, upper); +// Value alpha = rewriter.create(loc, +// constantTwoFloat.getType(), half, +// constantTwoFloat); +// scaledSelf = rewriter.create(loc, resType, self, +// alpha); +// } + +// Value negativeOutput = +// rewriter.create(loc, resType, zeroTensor, scaledSelf); +// Value rreluOutput = rewriter.create( +// loc, resType, positiveOutput, negativeOutput, constantOneFloat); +// rewriter.replaceOp(op, rreluOutput); +// return success(); +// } +// }; +// } // namespace + // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index 70125ff074a0..73cbbfc3bb10 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -448,3 +448,41 @@ def forward(self, grad, input, noise): @register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule()) def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseForwardBackwardModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + torch.ops.aten.rrelu_with_noise( + input, noise, lower=0.4, upper=0.6, training=True + ) + res = torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.4, + upper=0.6, + training=True, + self_is_result=False, + ) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule()) +def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(256, 244), + tu.rand(256, 244, low=-1.0, high=1.0), + tu.rand(256, 244, low=0.4, high=0.6), + ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 15f8ca97b6bc..9ec36ad3d279 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1188,13 +1188,13 @@ def __init__(self): [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] ) def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True) + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True) return torch.mean(res), torch.std(res) @register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule()) def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1024, 1536), torch.zeros((1024, 1536))) + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) # ============================================================================== @@ -1206,16 +1206,16 @@ def __init__(self): @export @annotate_args( - [None, ([1024, 1536], torch.float32, True), ([1024, 1536], torch.float32, True)] + [None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)] ) def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.1, 0.9, True) + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True) return torch.mean(res), torch.std(res) @register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule()) def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1024, 1536), torch.zeros((1024, 1536))) + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) # ============================================================================== @@ -1236,7 +1236,7 @@ def forward(self, x, noise): @register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule()) def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 3, low=-1, high=1), torch.zeros((5, 3))) + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) # ============================================================================== @@ -1255,7 +1255,7 @@ def forward(self, x, noise): @register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule()) def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 3, low=-1, high=1), torch.zeros((5, 3))) + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) # ==============================================================================