From 01b49439072c63e038dc19dca8666a5146b265ab Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Fri, 26 Apr 2024 01:58:38 +0800 Subject: [PATCH 1/4] support serveral math ops --- lib/Conversion/TorchToStablehlo/Basic.cpp | 43 +++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1858b1a6d7ca..cebd1d9023a5 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -960,6 +960,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenPowScalarOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPowScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.getSelf(); + auto lhsType = lhs.getType().dyn_cast(); + Value rhs = adaptor.getExponent(); + auto rhsType = rhs.getType().dyn_cast(); + + if (!rhsType) + return op.emitError("only Tensor types supported in StableHLO"); + + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + + if (!lhsType) { + lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy); + } + DenseI64ArrayAttr bcastDimensions; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); + auto loc = op.getLoc(); + Value result = rewriter.create(loc, outType, lhs, rhs, + bcastDimensions); + + rewriter.replaceOp(op, result); + return success(); +} + // PrimNumToTensorScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2020,11 +2057,15 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanOp, chlo::TanOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinhOp, chlo::SinhOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCoshOp, chlo::CoshOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinhOp, chlo::AsinhOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcoshOp, chlo::AcoshOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanhOp, chlo::AtanhOp); #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ @@ -2107,6 +2148,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenTensorIntOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenPowScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); @@ -2149,5 +2191,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseAndTensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseOrTensorOp, chlo::BroadcastOrOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseXorTensorOp, chlo::BroadcastXorOp); + INSERT_BINARY_BROADCAST_PATTERN(AtenAtan2Op, chlo::BroadcastAtan2Op); #undef INSERT_BINARY_BROADCAST_PATTERN } From 4924a6d8624f0514a6e06243066ecaaaf30a31ed Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Thu, 25 Apr 2024 18:18:39 +0000 Subject: [PATCH 2/4] finish --- projects/pt1/e2e_testing/xfail_sets.py | 19 +++-- .../test_suite/elementwise.py | 70 +++++++++++++++++++ 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c3d9d0dfeb09..70000eda3174 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -622,16 +622,10 @@ "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAcoshIntModule_basic", - "ElementwiseAcoshModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAsinhIntModule_basic", - "ElementwiseAsinhModule_basic", "ElementwiseAtan2FloatIntModule_basic", "ElementwiseAtan2TensorFloatModule_basic", "ElementwiseAtan2TensorIntModule_basic", - "ElementwiseAtanhIntModule_basic", - "ElementwiseAtanhModule_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic", "ElementwiseBitwiseLeftShiftInt64Module_basic", "ElementwiseBitwiseLeftShiftInt8Module_basic", @@ -643,7 +637,6 @@ "ElementwiseErfIntModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", - "ElementwisePowScalarModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -992,6 +985,15 @@ "DropoutEvalIntModule_basic", "ElementwiseAbsFloatModule_basic", "ElementwiseAbsIntModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseAtan2TensorFloatStaticModule_basic", + "ElementwiseAtan2TensorIntStaticModule_basic", + "ElementwiseAtan2FloatIntStaticModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenIsinfOpModule_basic", @@ -1053,6 +1055,7 @@ "ElementwiseNegModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwisePowScalarModule_basic", "ElementwisePowTensorBroadcastStaticModule_basic", "ElementwisePowTensorStaticModule_basic", "ElementwisePreluStaticModule_basic", @@ -1065,6 +1068,8 @@ "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", "ElementwiseSqrtModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeIdentityModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 3aa8f10ff9dd..cbd2868b71d6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1792,6 +1792,28 @@ def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtan2TensorFloatStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ([4, 5, 6], torch.float32, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatStaticModule()) +def ElementwiseAtan2TensorFloatStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6)) + + +# ============================================================================== + class ElementwiseAtan2TensorIntModule(torch.nn.Module): def __init__(self): @@ -1816,6 +1838,30 @@ def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtan2TensorIntStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.int32, True), + ([4, 5, 6], torch.int64, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntStaticModule()) +def ElementwiseAtan2TensorIntStaticModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, 5, 6, low=1, high=10).type(torch.int32), tu.randint(4, 5, 6, low=1, high=10)) + + +# ============================================================================== + + class ElementwiseAtan2FloatIntModule(torch.nn.Module): def __init__(self): @@ -1840,6 +1886,30 @@ def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtan2FloatIntStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.int32, True), + ([4, 5, 6], torch.float64, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntStaticModule()) +def ElementwiseAtan2FloatIntStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(4, 5, 6, low=1, high=10).to(torch.int32), + tu.rand(4, 5, 6).double()) + + +# ============================================================================== + + class ElementwiseLogModule(torch.nn.Module): def __init__(self): From db50636b22b11d158e85c4794f7b762aa6540467 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Fri, 26 Apr 2024 03:07:02 +0000 Subject: [PATCH 3/4] fix --- projects/pt1/e2e_testing/xfail_sets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 70000eda3174..fc7a20fcae10 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2181,6 +2181,7 @@ "AvgPool2dDivisorOverrideModule_basic", "BroadcastDynamicDimModule_basic", "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAtan2TensorIntStaticModule_basic", "ElementwiseAtenFloorDivideScalarNegativeModule_basic", "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", From 405c7f948c587ef3e1ce5a2fc56a0d267db7094d Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Fri, 26 Apr 2024 07:00:56 +0000 Subject: [PATCH 4/4] fix dyn_cast style --- lib/Conversion/TorchToStablehlo/Basic.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index cebd1d9023a5..700ad9dacf59 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -966,16 +966,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); - auto lhsType = lhs.getType().dyn_cast(); + auto lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getExponent(); - auto rhsType = rhs.getType().dyn_cast(); + auto rhsType = dyn_cast(rhs.getType()); if (!rhsType) return op.emitError("only Tensor types supported in StableHLO"); - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) {