diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 1e0ed43bbf2e..6ce28c876aaf 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -43,4 +43,5 @@ "BoolTensorReturnFalseModule_basic", "BoolTensorReturnTrueModule_basic", "BoolTensorReturnMixedModule_basic", + "ElementwiseRsqrtModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 3320e4b632b2..a7097fe70926 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -442,6 +442,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { patterns.add>(typeConverter, context); INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) + INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) #undef INSERT_UNARY_PATTERN diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 0a5ab6ca7ff3..893c6c98db73 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -285,3 +285,16 @@ func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtens return %0 : !torch.vtensor<[1],i1> } +// ----- + +// CHECK-LABEL: func @torch.aten.rsqrt$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +}