Skip to content

Commit

Permalink
[TORCH][MLIR] Add E2E support for [aten.gt.Scalar|aten.where.self]
Browse files Browse the repository at this point in the history
This commit adds lowering of `aten.gt.Scalar` and `aten.where.self` as a
part of element-wise ops lowering.

Signed-Off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav authored and Gaurav Shukla committed Dec 9, 2021
1 parent 2414bdb commit f34eb66
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 5 deletions.
43 changes: 43 additions & 0 deletions e2e_testing/torchscript/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseWhereSelfModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, a, b, c):
return torch.where(a > 0.5, b, c)


@register_test_case(module_factory=lambda: ElementwiseWhereSelfModule())
def ElementwiseWhereSelfModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5))


# ==============================================================================


# Addition is an interesting special case of a binary op, because under the hood
# it carries a third scalar "alpha" parameter, which needs special handling.
class ElementwiseAddModule(torch.nn.Module):
Expand Down Expand Up @@ -303,6 +326,26 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseGtScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.gt(x, 0.6)


@register_test_case(module_factory=lambda: ElementwiseGtScalarModule())
def ElementwiseGtScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))

# ==============================================================================


class ElementwiseClampModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
16 changes: 16 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,22 @@ def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$condition,
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$condition `,` $self `,` $other attr-dict `:` type($condition) `,` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [
AllowsTypeRefinement,
HasValueSemantics
Expand Down
26 changes: 24 additions & 2 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
}

if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
Type dtype = gtScalar.self().getType().cast<ValueTensorType>().getDtype();
if (!dtype.isa<mlir::FloatType>()) {
gtScalar.emitError("unimplemented: non-floating point operand dtype");
return nullptr;
}
Value otherPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], otherPromoted);
}

if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Type dtype = converter->convertType(whereSelf.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
return b.create<SelectOp>(loc, payloadArgs[0], lhs, rhs);
}

if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
if (!lerp.getType()
.cast<ValueTensorType>()
Expand Down Expand Up @@ -2040,7 +2061,7 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp>(op))
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -3461,7 +3482,8 @@ class ConvertTorchToLinalg
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp>();
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenWhereSelfOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
Expand Down
43 changes: 40 additions & 3 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp,
AtenNeScalarOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp,
AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp,
AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp,
Expand All @@ -247,6 +246,20 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

// These comparison ops return a tensor with 1-bit integer dtype.
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp>(
op)) {
auto operand = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
if (operand.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes = operand.sizes;
}
knowledge.dtype = IntegerType::get(op->getContext(), 1);
return getLatticeElement(op->getResult(0)).join(knowledge);
}

// Resize to [1, 1] with integer dtype.
if (isa<AtenAnyOp, AtenAllOp>(op)) {
auto input = operands[0]->getValue();
Expand Down Expand Up @@ -307,6 +320,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
return visitBinaryBroadcastingOp(op, operands);
} else if (auto whereSelf = llvm::dyn_cast<AtenWhereSelfOp>(op)) {
return visitAtenWhereSelfOp(whereSelf, operands);
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
return visitAtenLerpTensorOp(lerpTensor, operands);
} else if (auto flatten = dyn_cast<AtenFlattenUsingIntsOp>(op)) {
Expand Down Expand Up @@ -487,6 +502,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
ChangeResult visitBinaryBroadcastingOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenWhereSelfOp(AtenWhereSelfOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenLerpTensorOp(AtenLerpTensorOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenFlattenUsingIntsOp(
Expand Down Expand Up @@ -856,6 +874,25 @@ ChangeResult TypeAnalyzer::visitBinaryBroadcastingOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAtenWhereSelfOp(
AtenWhereSelfOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto condition = operands[0]->getValue();
auto lhs = operands[1]->getValue();
auto rhs = operands[2]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(getContext());
if (condition.hasSizes && lhs.hasSizes && rhs.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(
std::max(condition.sizes.size(),
std::max(lhs.sizes.size(), rhs.sizes.size())),
kUnknownSize);
}

knowledge.dtype = getPromotedResultType(getContext(), {&lhs, &rhs});
return getLatticeElement(op->getResult(0)).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAtenLerpTensorOp(
AtenLerpTensorOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
// This is a general broadcasting shape transfer function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::gelu : (Tensor) -> (Tensor)")
Expand Down

0 comments on commit f34eb66

Please sign in to comment.