Skip to content

Commit

Permalink
[TORCH][MLIR] Add E2E support for aten.ceil op
Browse files Browse the repository at this point in the history
This commit adds lowering of `aten.ceil` op as a part of element-wise
ops lowering.

Signed-Off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav authored and Gaurav Shukla committed Dec 11, 2021
1 parent 03b6edc commit a778f99
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 4 deletions.
16 changes: 16 additions & 0 deletions e2e_testing/torchscript/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,22 @@ def forward(self, a):
def ElementwiseFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

class ElementwiseCeilModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])

def forward(self, a):
return torch.ceil(a)

@register_test_case(module_factory=lambda: ElementwiseCeilModule())
def ElementwiseCeilModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

class ElementwisePowModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,34 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}

def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}

def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}

def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down
9 changes: 6 additions & 3 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::ExpOp>(loc, payloadArgs[0]);
if (isa<AtenFloorOp>(op))
return b.create<math::FloorOp>(loc, payloadArgs[0]);
if (isa<AtenCeilOp>(op))
return b.create<math::CeilOp>(loc, payloadArgs[0]);
if (isa<AtenLogOp>(op))
return b.create<math::LogOp>(loc, payloadArgs[0]);
if (isa<AtenSqrtOp>(op))
Expand Down Expand Up @@ -2067,7 +2069,8 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp>(op))
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp,
AtenCeilOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -3635,8 +3638,8 @@ class ConvertTorchToLinalg
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenWhereSelfOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp,
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp,
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp>(op)) {
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::cos : (Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)",
"aten::ceil : (Tensor) -> (Tensor)",
"aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
Expand Down

0 comments on commit a778f99

Please sign in to comment.