From cce490d71d79d41c7652d4d1e8d2d36d7c302a37 Mon Sep 17 00:00:00 2001 From: Anup Gangwar Date: Tue, 14 Dec 2021 12:03:58 -0600 Subject: [PATCH] * [tosa] Support for Rsqrt legalization (#480) Signed-off-by: Anup Gangwar Co-authored-by: Anup Gangwar --- e2e_testing/torchscript/xfail_sets.py | 1 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 1 + test/Conversion/TorchToTosa/basic.mlir | 13 +++++++++++++ 3 files changed, 15 insertions(+) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 1e0ed43bbf2..6ce28c876aa 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -43,4 +43,5 @@ "BoolTensorReturnFalseModule_basic", "BoolTensorReturnTrueModule_basic", "BoolTensorReturnMixedModule_basic", + "ElementwiseRsqrtModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 3320e4b632b..a7097fe7092 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -442,6 +442,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { patterns.add>(typeConverter, context); INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) + INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) #undef INSERT_UNARY_PATTERN diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 0a5ab6ca7ff..893c6c98db7 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -285,3 +285,16 @@ func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtens return %0 : !torch.vtensor<[1],i1> } +// ----- + +// CHECK-LABEL: func @torch.aten.rsqrt$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +}