Skip to content

Commit

Permalink
lowered addcmul and addcdiv to linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
nodlabs authored and cathyzhyi committed Nov 24, 2021
1 parent 8d8d2c2 commit 67ce816
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 0 deletions.
37 changes: 37 additions & 0 deletions e2e_testing/torchscript/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,3 +684,40 @@ def forward(self, a, b, c):
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))

class AddCMulModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])

def forward(self, input, tensor1, tensor2):
return torch.addcmul(input, tensor1, tensor2, value=1.0)

@register_test_case(module_factory=lambda: AddCMulModule())
def AddCMulModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))

class AddCDivModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])

def forward(self, input, tensor1, tensor2):
return torch.addcdiv(input, tensor1, tensor2, value=1.0)

@register_test_case(module_factory=lambda: AddCDivModule())
def AddCDivModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
2 changes: 2 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@
"ElementwiseLogModule_basic",
"TanhBackward_basic",
"ReturnThreeTensorFloat32_basic",
"AddCMulModule_basic",
"AddCDivModule_basic",
}
34 changes: 34 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2895,3 +2895,37 @@ def Torch_Aten_LogSoftmaxBackwardDataOp : Torch_Op<"aten._log_softmax_backward_d
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
}

def Torch_AtenAddCMulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$tensor1,
AnyTorchTensorType:$tensor2,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $tensor1 `,` $tensor2 `,` $value attr-dict `:` type($self) `,` type($tensor1) `,` type($tensor2) `,` type($value) `->` type($result)";
}

def Torch_AtenAddCDivOp : Torch_Op<"aten.addcdiv", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$tensor1,
AnyTorchTensorType:$tensor2,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $tensor1 `,` $tensor2 `,` $value attr-dict `:` type($self) `,` type($tensor1) `,` type($tensor2) `,` type($value) `->` type($result)";
}

24 changes: 24 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,26 @@ class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
};
} // namespace

namespace {
template<typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value tensor1 = op.tensor1();
Value tensor2 = op.tensor2();
Value value = op.value();

Value product = rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input, product,
value);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -408,6 +428,10 @@ class DecomposeComplexOpsPass
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
});
patterns.add<DecomposeAtenAddCLikeOp<AtenAddCMulOp, AtenMulTensorOp>>(context);
target.addIllegalOp<AtenAddCMulOp>();
patterns.add<DecomposeAtenAddCLikeOp<AtenAddCDivOp, AtenDivTensorOp>>(context);
target.addIllegalOp<AtenAddCDivOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
Expand Down
25 changes: 25 additions & 0 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
return visitNumToTensorOp(numToTensorOp);
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
return visitAtenAddCLikeOp(op, operands);
}

// Otherwise, this is an unknown operation. Just mark all results as
Expand Down Expand Up @@ -535,6 +537,10 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
ChangeResult
visitAtenSoftmaxLikeOp(OpTy op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);

ChangeResult
visitAtenAddCLikeOp(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
};
} // namespace

Expand Down Expand Up @@ -1376,6 +1382,25 @@ ChangeResult TypeAnalyzer::visitAtenMatmulOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAtenAddCLikeOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto self = operands[0]->getValue();
auto tensor1 = operands[1]->getValue();
auto tensor2 = operands[2]->getValue();
if (tensor1.hasSizes && tensor2.hasSizes && self.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(
std::max(self.sizes.size(),
std::max(tensor1.sizes.size(), tensor2.sizes.size())),
kUnknownSize);
}
knowledge.dtype =
getPromotedResultType(getContext(), {&self, &tensor1, &tensor2});
return getLatticeElement(op->getResult(0)).join(knowledge);
}

// -----------------------------------------------------------------------------
// Transforms.
// -----------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
Expand Down

0 comments on commit 67ce816

Please sign in to comment.