diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 74a2e2327d1b..12a2bf4a86e2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2793,6 +2793,53 @@ def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [ }]; } +def Torch_AtenLogitOp : Torch_Op<"aten.logit", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::logit : (Tensor, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalFloatType:$eps + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogitOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLogitOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenLogit_Op : Torch_Op<"aten.logit_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::logit_ : (Tensor, float?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchOptionalFloatType:$eps + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogit_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLogit_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index f742ded3f1bd..749945dee6e2 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1969,7 +1969,6 @@ class ConvertPrimsCollapseOp : public OpConversionPattern { associations.push_back(ReassociationIndices{i}); } - rewriter.replaceOpWithNewOp( op, resultRankedTensorType, adaptor.getA(), associations); @@ -1996,6 +1995,91 @@ class ConvertTensorStaticInfoCastOp }; } // namespace +namespace { +class ConvertLogitOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenLogitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + Value eps = adaptor.getEps(); + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + bool handleEps = false; + if (succeeded(checkNotNone(rewriter, op, eps))) + handleEps = true; + + if (handleEps && !eps.getType().isa()) { + op.emitError("Logit does not support non-floating point type"); + return failure(); + } + + auto inputType = input.getType().cast(); + auto inputElementType = inputType.getElementType(); + + if (!inputElementType.isa()) { + op.emitError("Logit does not support non-floating point type"); + return failure(); + } + + auto inputRank = inputType.getRank(); + + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + rewriter.getMultiDimIdentityMap(inputRank), // output + }; + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value logit = + rewriter + .create( + loc, input.getType(), + /*ins=*/input, + /*outs=*/input, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + + TypedAttr oneAttr = b.getFloatAttr(inputElementType, 1.0); + Value oneValue = b.create(loc, oneAttr); + + Value zI; + if (!handleEps) { + zI = input; + } else { + Value truncEps = + b.create(loc, inputElementType, eps); + Value oneMinusEps = + b.create(loc, oneValue, truncEps); + + Value min = + b.create(loc, input, oneMinusEps); + Value clampedInput = + b.create(loc, min, truncEps); + + zI = clampedInput; + } + + Value probability = + b.create(loc, oneValue, zI); + Value odds = b.create(loc, zI, probability); + Value result = b.create(loc, odds); + + b.create(loc, result); + }) + .getResult(0); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, logit); + return success(); + } +}; +} // namespace void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2028,6 +2112,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 55b9638dd0cc..3901cd34a4aa 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6352,6 +6352,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.logit\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rsqrt\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8739,6 +8743,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ddb4865ec535..98cde05a8f73 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1446,6 +1446,7 @@ "ViewCollapseDynamicWithAtenSizeIntModule_basic", "AtenEmbeddingBagSumExample_basic", "Aten_EmbeddingBagExample_basic", + "ElementwiseLogitModule_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Bool_basic", "AtenIntTensorByteDtypeModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a16d778c79a7..211023a9deec 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -141,6 +141,9 @@ def aten〇log10〡shape(self: List[int]) -> List[int]: def aten〇log1p〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇logit〡shape(self: List[int], eps: Optional[float] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇rsqrt〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -1659,6 +1662,11 @@ def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇logit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 9c0a0759b443..a9f9ed96dce2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -315,6 +315,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::log10 : (Tensor) -> (Tensor)", "aten::sqrt : (Tensor) -> (Tensor)", "aten::log1p : (Tensor) -> (Tensor)", + "aten::logit : (Tensor, float?) -> (Tensor)", "aten::rsqrt : (Tensor) -> (Tensor)", "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index c18c9103d888..5d6217b59072 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1508,6 +1508,28 @@ def ElementwiseLog1pModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseLogitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.logit(a, eps=1e-7) + + +@register_test_case(module_factory=lambda: ElementwiseLogitModule()) +def ElementwiseLogitModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseErfModule(torch.nn.Module): def __init__(self):