Skip to content

Commit

Permalink
* [tosa] Support for Rsqrt legalization (#480)
Browse files Browse the repository at this point in the history
Signed-off-by: Anup Gangwar <[email protected]>

Co-authored-by: Anup Gangwar <[email protected]>
  • Loading branch information
anupgangwar and Anup Gangwar committed Dec 14, 2021
1 parent 6dabf18 commit cce490d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
"ElementwiseRsqrtModule_basic",
}
1 change: 1 addition & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
#undef INSERT_UNARY_PATTERN

Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,16 @@ func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtens
return %0 : !torch.vtensor<[1],i1>
}

// -----

// CHECK-LABEL: func @torch.aten.rsqrt$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

0 comments on commit cce490d

Please sign in to comment.