Skip to content

Commit 7058f45

Browse files
authored
[Stablehlo] support aten.isfinite (llvm#3850)
1 parent dda65b1 commit 7058f45

File tree

7 files changed

+86
-0
lines changed

7 files changed

+86
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+23
Original file line numberDiff line numberDiff line change
@@ -4976,6 +4976,29 @@ def Torch_AtenFakeQuantizePerChannelAffineCachemaskOp : Torch_Op<"aten.fake_quan
49764976
}];
49774977
}
49784978

4979+
def Torch_AtenIsfiniteOp : Torch_Op<"aten.isfinite", [
4980+
AllowsTypeRefinement,
4981+
HasValueSemantics,
4982+
ReadOnly
4983+
]> {
4984+
let summary = "Generated op for `aten::isfinite : (Tensor) -> (Tensor)`";
4985+
let arguments = (ins
4986+
AnyTorchTensorType:$self
4987+
);
4988+
let results = (outs
4989+
AnyTorchOptionalTensorType:$result
4990+
);
4991+
let hasCustomAssemblyFormat = 1;
4992+
let extraClassDefinition = [{
4993+
ParseResult AtenIsfiniteOp::parse(OpAsmParser &parser, OperationState &result) {
4994+
return parseDefaultTorchOp(parser, result, 1, 1);
4995+
}
4996+
void AtenIsfiniteOp::print(OpAsmPrinter &printer) {
4997+
printDefaultTorchOp(printer, *this, 1, 1);
4998+
}
4999+
}];
5000+
}
5001+
49795002
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
49805003
AllowsTypeRefinement,
49815004
HasValueSemantics,

lib/Conversion/TorchToStablehlo/Basic.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -2075,6 +2075,30 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
20752075
return success();
20762076
}
20772077

2078+
template <>
2079+
LogicalResult ConvertAtenOp<AtenIsfiniteOp>::matchAndRewrite(
2080+
AtenIsfiniteOp op, OpAdaptor adaptor,
2081+
ConversionPatternRewriter &rewriter) const {
2082+
Value self = adaptor.getSelf();
2083+
auto selfTy = cast<RankedTensorType>(self.getType());
2084+
if (!selfTy)
2085+
return rewriter.notifyMatchFailure(
2086+
op, "Only Tensor types are currently supported");
2087+
2088+
auto outType =
2089+
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
2090+
Type outElemTy = outType.getElementType();
2091+
if (!outElemTy.isInteger(1)) {
2092+
return rewriter.notifyMatchFailure(
2093+
op, "Only i1 output element type is supported");
2094+
}
2095+
2096+
rewriter.replaceOpWithNewOp<stablehlo::IsFiniteOp>(op.getOperation(), outType,
2097+
self);
2098+
2099+
return success();
2100+
}
2101+
20782102
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
20792103
TypeConverter &typeConverter, RewritePatternSet &patterns,
20802104
ConversionTarget &target, const TorchToStablehloOptions &options) {
@@ -2248,6 +2272,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
22482272
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
22492273

22502274
INSERT_ATENOP_PATTERN(AtenTrilOp);
2275+
INSERT_ATENOP_PATTERN(AtenIsfiniteOp);
22512276
#undef INSERT_ATENOP_PATTERN
22522277

22532278
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -6495,6 +6495,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
64956495
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
64966496
" return %0 : !torch.list<int>\n"
64976497
" }\n"
6498+
" func.func @\"__torch_mlir_shape_fn.aten.isfinite\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6499+
" return %arg0 : !torch.list<int>\n"
6500+
" }\n"
64986501
" func.func @\"__torch_mlir_shape_fn.aten.cosine_similarity\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
64996502
" %none = torch.constant.none\n"
65006503
" %int1 = torch.constant.int 1\n"
@@ -11448,6 +11451,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1144811451
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
1144911452
" return %1 : !torch.int\n"
1145011453
" }\n"
11454+
" func.func @\"__torch_mlir_dtype_fn.aten.isfinite\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
11455+
" %int11 = torch.constant.int 11\n"
11456+
" return %int11 : !torch.int\n"
11457+
" }\n"
1145111458
" func.func @\"__torch_mlir_dtype_fn.aten.rad2deg\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1145211459
" %none = torch.constant.none\n"
1145311460
" %str = torch.constant.str \"AssertionError: \"\n"

projects/pt1/e2e_testing/xfail_sets.py

+1
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@
519519
"IndexPutImpl2DNoneIndexStaticModule_basic",
520520
"IndexPutImpl3DFloatNonAccumulateModule_basic",
521521
"IndexPutImplIndexWithNoneModule_basic",
522+
"IsInfiniteModule_basic",
522523
"InterpolateDynamicModule_sizes_nearest",
523524
"IouOfModule_basic",
524525
"MeshgridIndexingIJ_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

+6
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ def aten〇exp2〡shape(self: List[int]) -> List[int]:
222222
def aten〇expm1〡shape(self: List[int]) -> List[int]:
223223
return upstream_shape_functions.unary(self)
224224

225+
def aten〇isfinite〡shape(self: List[int]) -> List[int]:
226+
return self
227+
225228
def aten〇cosine_similarity〡shape(x1: List[int], x2: List[int], dim: int = 1, eps: float = 1e-08) -> List[int]:
226229
broadcast = upstream_shape_functions.broadcast(x1, x2)
227230
return broadcast[:dim] + broadcast[dim + 1:]
@@ -2656,6 +2659,9 @@ def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
26562659
self_rank, self_dtype = self_rank_dtype
26572660
return _get_dtype_of_floating_point_op(self_dtype)
26582661

2662+
def aten〇isfinite〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
2663+
return torch.bool
2664+
26592665
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}))
26602666
def aten〇rad2deg〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
26612667
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def emit_with_mutating_variants(key, **kwargs):
484484
emit(
485485
"aten::fake_quantize_per_channel_affine_cachemask : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor, Tensor)"
486486
)
487+
emit("aten::isfinite : (Tensor) -> (Tensor)")
487488
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
488489
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
489490
emit("aten::fmax : (Tensor, Tensor) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

+23
Original file line numberDiff line numberDiff line change
@@ -4373,6 +4373,29 @@ def PowIntFloatModule_basic(module, tu: TestUtils):
43734373
# ==============================================================================
43744374

43754375

4376+
class IsInfiniteModule(torch.nn.Module):
4377+
def __init__(self):
4378+
super().__init__()
4379+
4380+
@export
4381+
@annotate_args(
4382+
[
4383+
None,
4384+
([-1], torch.float32, True),
4385+
]
4386+
)
4387+
def forward(self, x):
4388+
return torch.ops.aten.isfinite(x)
4389+
4390+
4391+
@register_test_case(module_factory=lambda: IsInfiniteModule())
4392+
def IsInfiniteModule_basic(module, tu: TestUtils):
4393+
module.forward(torch.tensor([-torch.inf, torch.inf, torch.nan, -2.3, 0.0, 1.5]))
4394+
4395+
4396+
# ==============================================================================
4397+
4398+
43764399
class BaddbmmDynamicModule(torch.nn.Module):
43774400
def __init__(self):
43784401
super().__init__()

0 commit comments

Comments
 (0)