Skip to content

Commit

Permalink
Implement lowering of torch.aten.logit (#2697)
Browse files Browse the repository at this point in the history
  • Loading branch information
ikalinic authored Jan 11, 2024
1 parent 5862854 commit e1a86e4
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 1 deletion.
47 changes: 47 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2793,6 +2793,53 @@ def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [
}];
}

def Torch_AtenLogitOp : Torch_Op<"aten.logit", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::logit : (Tensor, float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalFloatType:$eps
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogitOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogitOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenLogit_Op : Torch_Op<"aten.logit_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::logit_ : (Tensor, float?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchOptionalFloatType:$eps
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogit_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogit_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
88 changes: 87 additions & 1 deletion lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1969,7 +1969,6 @@ class ConvertPrimsCollapseOp : public OpConversionPattern<PrimsCollapseOp> {
associations.push_back(ReassociationIndices{i});
}


rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultRankedTensorType, adaptor.getA(), associations);

Expand All @@ -1996,6 +1995,91 @@ class ConvertTensorStaticInfoCastOp
};
} // namespace

namespace {
class ConvertLogitOp : public OpConversionPattern<AtenLogitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenLogitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();
Value input = adaptor.getSelf();
Value eps = adaptor.getEps();

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

bool handleEps = false;
if (succeeded(checkNotNone(rewriter, op, eps)))
handleEps = true;

if (handleEps && !eps.getType().isa<mlir::FloatType>()) {
op.emitError("Logit does not support non-floating point type");
return failure();
}

auto inputType = input.getType().cast<RankedTensorType>();
auto inputElementType = inputType.getElementType();

if (!inputElementType.isa<mlir::FloatType>()) {
op.emitError("Logit does not support non-floating point type");
return failure();
}

auto inputRank = inputType.getRank();

SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
rewriter.getMultiDimIdentityMap(inputRank), // output
};
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
Value logit =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(),
/*ins=*/input,
/*outs=*/input,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0];

TypedAttr oneAttr = b.getFloatAttr(inputElementType, 1.0);
Value oneValue = b.create<arith::ConstantOp>(loc, oneAttr);

Value zI;
if (!handleEps) {
zI = input;
} else {
Value truncEps =
b.create<arith::TruncFOp>(loc, inputElementType, eps);
Value oneMinusEps =
b.create<arith::SubFOp>(loc, oneValue, truncEps);

Value min =
b.create<arith::MinimumFOp>(loc, input, oneMinusEps);
Value clampedInput =
b.create<arith::MaximumFOp>(loc, min, truncEps);

zI = clampedInput;
}

Value probability =
b.create<arith::SubFOp>(loc, oneValue, zI);
Value odds = b.create<arith::DivFOp>(loc, zI, probability);
Value result = b.create<math::LogOp>(loc, odds);

b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, logit);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -2028,6 +2112,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<AtenLogitOp>();
patterns.add<ConvertLogitOp>(typeConverter, context);
target.addIllegalOp<PrimsCollapseOp>();
patterns.add<ConvertPrimsCollapseOp>(typeConverter, context);
target.addIllegalOp<PrimsSplitDimOp>();
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6352,6 +6352,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.logit\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rsqrt\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8739,6 +8743,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,7 @@
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"AtenEmbeddingBagSumExample_basic",
"Aten_EmbeddingBagExample_basic",
"ElementwiseLogitModule_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Bool_basic",
"AtenIntTensorByteDtypeModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def aten〇log10〡shape(self: List[int]) -> List[int]:
def aten〇log1p〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇logit〡shape(self: List[int], eps: Optional[float] = None) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇rsqrt〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -1659,6 +1662,11 @@ def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇logit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::log10 : (Tensor) -> (Tensor)",
"aten::sqrt : (Tensor) -> (Tensor)",
"aten::log1p : (Tensor) -> (Tensor)",
"aten::logit : (Tensor, float?) -> (Tensor)",
"aten::rsqrt : (Tensor) -> (Tensor)",
"aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)",
Expand Down
22 changes: 22 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,28 @@ def ElementwiseLog1pModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseLogitModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.logit(a, eps=1e-7)


@register_test_case(module_factory=lambda: ElementwiseLogitModule())
def ElementwiseLogitModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))


# ==============================================================================


class ElementwiseErfModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit e1a86e4

Please sign in to comment.