@@ -2075,6 +2075,30 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
2075
2075
return success ();
2076
2076
}
2077
2077
2078
+ template <>
2079
+ LogicalResult ConvertAtenOp<AtenIsfiniteOp>::matchAndRewrite(
2080
+ AtenIsfiniteOp op, OpAdaptor adaptor,
2081
+ ConversionPatternRewriter &rewriter) const {
2082
+ Value self = adaptor.getSelf ();
2083
+ auto selfTy = cast<RankedTensorType>(self.getType ());
2084
+ if (!selfTy)
2085
+ return rewriter.notifyMatchFailure (
2086
+ op, " Only Tensor types are currently supported" );
2087
+
2088
+ auto outType =
2089
+ dyn_cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
2090
+ Type outElemTy = outType.getElementType ();
2091
+ if (!outElemTy.isInteger (1 )) {
2092
+ return rewriter.notifyMatchFailure (
2093
+ op, " Only i1 output element type is supported" );
2094
+ }
2095
+
2096
+ rewriter.replaceOpWithNewOp <stablehlo::IsFiniteOp>(op.getOperation (), outType,
2097
+ self);
2098
+
2099
+ return success ();
2100
+ }
2101
+
2078
2102
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality (
2079
2103
TypeConverter &typeConverter, RewritePatternSet &patterns,
2080
2104
ConversionTarget &target, const TorchToStablehloOptions &options) {
@@ -2248,6 +2272,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
2248
2272
INSERT_ATENOP_PATTERN (AtenBitwiseRightShiftTensorOp);
2249
2273
2250
2274
INSERT_ATENOP_PATTERN (AtenTrilOp);
2275
+ INSERT_ATENOP_PATTERN (AtenIsfiniteOp);
2251
2276
#undef INSERT_ATENOP_PATTERN
2252
2277
2253
2278
#define INSERT_BINARY_BROADCAST_PATTERN (AtenOp, StablehloOp ) \
0 commit comments