Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stablehlo] Support AtenPowScalarOp, AtenTanOp, AtenAsinhOp, AtenAcoshOp, AtenAtanhOp, Atan2Op #3233

Merged
merged 4 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,43 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
return success();
}

// AtenPowScalarOp
template <>
LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
AtenPowScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
auto lhsType = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getExponent();
auto rhsType = dyn_cast<TensorType>(rhs.getType());

if (!rhsType)
return op.emitError("only Tensor types supported in StableHLO");

auto outType = cast<TensorType>(
OpConversionPattern<AtenPowScalarOp>::getTypeConverter()->convertType(
op.getType()));

Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}

if (!lhsType) {
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
}
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions);

rewriter.replaceOp(op, result);
return success();
}

// PrimNumToTensorScalarOp
template <>
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
Expand Down Expand Up @@ -2020,11 +2057,15 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanOp, chlo::TanOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinhOp, chlo::SinhOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCoshOp, chlo::CoshOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinhOp, chlo::AsinhOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcoshOp, chlo::AcoshOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanhOp, chlo::AtanhOp);
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN

#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
Expand Down Expand Up @@ -2107,6 +2148,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenTensorIntOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenPowScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenScalarImplicitOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
Expand Down Expand Up @@ -2149,5 +2191,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseAndTensorOp, chlo::BroadcastAndOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseOrTensorOp, chlo::BroadcastOrOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseXorTensorOp, chlo::BroadcastXorOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenAtan2Op, chlo::BroadcastAtan2Op);
#undef INSERT_BINARY_BROADCAST_PATTERN
}
20 changes: 13 additions & 7 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,16 +622,10 @@
"DiagonalModule_with_offset",
"DivFloatModule_basic",
"DivIntModule_basic",
"ElementwiseAcoshIntModule_basic",
"ElementwiseAcoshModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAsinhIntModule_basic",
"ElementwiseAsinhModule_basic",
"ElementwiseAtan2FloatIntModule_basic",
"ElementwiseAtan2TensorFloatModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseBitwiseLeftShiftInt32Module_basic",
"ElementwiseBitwiseLeftShiftInt64Module_basic",
"ElementwiseBitwiseLeftShiftInt8Module_basic",
Expand All @@ -643,7 +637,6 @@
"ElementwiseErfIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwisePowScalarModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
Expand Down Expand Up @@ -992,6 +985,15 @@
"DropoutEvalIntModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"ElementwiseAcoshIntModule_basic",
"ElementwiseAcoshModule_basic",
"ElementwiseAsinhIntModule_basic",
"ElementwiseAsinhModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseAtan2TensorFloatStaticModule_basic",
"ElementwiseAtan2TensorIntStaticModule_basic",
"ElementwiseAtan2FloatIntStaticModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAtenIsinfOpModule_basic",
Expand Down Expand Up @@ -1053,6 +1055,7 @@
"ElementwiseNegModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseAndScalarStaticShapeModule_basic",
"ElementwisePowScalarModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwisePreluStaticModule_basic",
Expand All @@ -1065,6 +1068,8 @@
"ElementwiseSigmoidModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeIdentityModule_basic",
Expand Down Expand Up @@ -2176,6 +2181,7 @@
"AvgPool2dDivisorOverrideModule_basic",
"BroadcastDynamicDimModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseAtan2TensorIntStaticModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseLog10IntModule_basic",
Expand Down
70 changes: 70 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,28 @@ def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAtan2TensorFloatStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([4, 5, 6], torch.float32, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatStaticModule())
def ElementwiseAtan2TensorFloatStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6))


# ==============================================================================

class ElementwiseAtan2TensorIntModule(torch.nn.Module):

def __init__(self):
Expand All @@ -1816,6 +1838,30 @@ def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAtan2TensorIntStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.int32, True),
([4, 5, 6], torch.int64, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntStaticModule())
def ElementwiseAtan2TensorIntStaticModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, 5, 6, low=1, high=10).type(torch.int32), tu.randint(4, 5, 6, low=1, high=10))


# ==============================================================================


class ElementwiseAtan2FloatIntModule(torch.nn.Module):

def __init__(self):
Expand All @@ -1840,6 +1886,30 @@ def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAtan2FloatIntStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4, 5, 6], torch.int32, True),
([4, 5, 6], torch.float64, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntStaticModule())
def ElementwiseAtan2FloatIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, 6, low=1, high=10).to(torch.int32),
tu.rand(4, 5, 6).double())


# ==============================================================================


class ElementwiseLogModule(torch.nn.Module):

def __init__(self):
Expand Down
Loading