Skip to content

Commit

Permalink
add aten.sub.int/aten.mul.int lowering in TorchToStd
Browse files Browse the repository at this point in the history
  • Loading branch information
xndcn authored and silvasean committed Dec 17, 2021
1 parent d8ba681 commit 5eed562
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 15 deletions.
34 changes: 34 additions & 0 deletions e2e_testing/torchscript/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,41 @@ def __init__(self):
def forward(self, lhs, rhs):
return int(lhs)+int(rhs)

class SubIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.int64, True),
([], torch.int64, True),
])
def forward(self, lhs, rhs):
return int(lhs)-int(rhs)

class MulIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.int64, True),
([], torch.int64, True),
])
def forward(self, lhs, rhs):
return int(lhs)*int(rhs)


@register_test_case(module_factory=lambda: AddIntModule())
def AddIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))

@register_test_case(module_factory=lambda: SubIntModule())
def SubIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))

@register_test_case(module_factory=lambda: MulIntModule())
def MulIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
20 changes: 14 additions & 6 deletions lib/Conversion/TorchToStd/TorchToStd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ class ConvertAtenDimOp : public OpConversionPattern<AtenDimOp> {
} // namespace

namespace {
class ConvertAtenAddIntOp : public OpConversionPattern<AtenAddIntOp> {
template <typename AtenOp, typename BinOp>
class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
public:
using OpConversionPattern::OpConversionPattern;
using OpConversionPattern<AtenOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAddIntOp op, OpAdaptor adaptor,
matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, adaptor.a(), adaptor.b());
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.a(), adaptor.b());
return success();
}
};
Expand Down Expand Up @@ -142,8 +144,14 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
target.addIllegalOp<AtenAddIntOp>();
patterns.add<ConvertAtenAddIntOp>(typeConverter, context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
Expand Down
19 changes: 11 additions & 8 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
return visitNumToTensorOp(numToTensorOp);
} else if (isa<AtenAddcmulOp, AtenAddcdivOp>(op)) {
return visitAtenAddCLikeOp(op, operands);
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
return visitBinaryScalarOp(scalarOp);
} else if (isa<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>(op)) {
return visitBinaryScalarOp(op, operands);
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
return visitAtenNllLossForwardOp(nllForwardOp, operands);
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
Expand Down Expand Up @@ -590,7 +590,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
ChangeResult
visitAtenEmbeddingOp(AtenEmbeddingOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
template <typename OpTy> ChangeResult visitBinaryScalarOp(OpTy op);
ChangeResult
visitBinaryScalarOp(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);

ChangeResult
visitAtenBmmOp(AtenBmmOp op,
Expand Down Expand Up @@ -1276,12 +1278,13 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
return getLatticeElement(op.getResult()).join(knowledge);
}

template <typename OpTy>
ChangeResult TypeAnalyzer::visitBinaryScalarOp(OpTy op) {
ChangeResult TypeAnalyzer::visitBinaryScalarOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = getPromotedResultType({op.a().getType(), op.b().getType()});
return getLatticeElement(op.getResult()).join(knowledge);
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType(
{op->getOperand(0).getType(), op->getOperand(1).getType()});
return getLatticeElement(op->getResult(0)).join(knowledge);
}

// `torch.aten.tensor` get a tensor from a list. Each layer of the list
Expand Down
24 changes: 23 additions & 1 deletion test/Conversion/TorchToStd/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,28 @@ func @torch.constant.int() -> !torch.int {
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
// CHECK: return %[[INT:.*]] : !torch.int
func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}

// CHECK-LABEL: func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[INT:.*]] = arith.subi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
// CHECK: return %[[INT:.*]] : !torch.int
func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.sub.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}

// CHECK-LABEL: func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[INT:.*]] = arith.muli %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
// CHECK: return %[[INT:.*]] : !torch.int
func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}

0 comments on commit 5eed562

Please sign in to comment.