Skip to content

Commit

Permalink
[TORCH][MLIR] Add E2E support for aten._softmax operation. (#431)
Browse files Browse the repository at this point in the history
Signed-Off-By: Prateek Gupta <[email protected]>
  • Loading branch information
gprateek93 committed Nov 25, 2021
1 parent 67ce816 commit f461a7e
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 13 deletions.
17 changes: 17 additions & 0 deletions e2e_testing/torchscript/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,23 @@ def forward(self, tensor):
def SoftmaxIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))

class _SoftmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten._softmax(tensor, 0, False)


@register_test_case(module_factory=lambda: _SoftmaxModule())
def _SoftmaxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))


class SoftmaxIntNegDimModule(torch.nn.Module):
def __init__(self):
Expand Down
16 changes: 16 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,22 @@ def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}

def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
Torch_BoolType:$half_to_float
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
}

def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
AllowsTypeRefinement
]> {
Expand Down
64 changes: 55 additions & 9 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,34 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
};
} // namespace

// Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)).
template <typename OpTy>
static Value getSoftmaxResult(OpTy op, Type resultType,
PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.dim();
Value self = op.self();

// exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, resultType, self);
// sum(exp(x))
Value sum =
createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
if (!sum)
return nullptr;
// exp(x) / sum(exp(x))
return rewriter.create<AtenDivTensorOp>(loc, resultType, exp, sum);
}

// Decompose softmax into: exp(x) / sum(exp(x))
namespace {
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value dim = op.dim();
if (!op.dtype().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for softmax");
Expand All @@ -144,14 +162,40 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");

// exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
// sum(exp(x))
Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
if (!sum)
Value result = getSoftmaxResult(op, tensorType, rewriter);
if (!result)
return failure();
// exp(x) / sum(exp(x))
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
result);
return success();
}
};
} // namespace

namespace {
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
bool halfToFloat;
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
return rewriter.notifyMatchFailure(
op, "Expected a boolean value for half_to_float");

// Currently, setting `halfToFloat` is not supported as the E2E testing for
// the same is not present on CPU.
if (halfToFloat)
return rewriter.notifyMatchFailure(
op, "halfToFloat is currently not supported.");

Value result = getSoftmaxResult(op, tensorType, rewriter);
if (!result)
return op.emitError("failed to get softmax result");
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
result);
return success();
Expand Down Expand Up @@ -406,6 +450,8 @@ class DecomposeComplexOpsPass

patterns.add<DecomposeAtenSoftmaxIntOp>(context);
target.addIllegalOp<AtenSoftmaxIntOp>();
patterns.add<DecomposeAten_SoftmaxOp>(context);
target.addIllegalOp<Aten_SoftmaxOp>();
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenExpandOp>(context);
Expand Down
33 changes: 29 additions & 4 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
return visitAtenMatmulOp(matmul, operands);
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
return visitAten_SoftmaxOp(_softmaxOp, operands);
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
Expand Down Expand Up @@ -541,6 +543,10 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
ChangeResult
visitAtenAddCLikeOp(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);

ChangeResult
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
};
} // namespace

Expand Down Expand Up @@ -1332,21 +1338,40 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
return getLatticeElement(op.getResult()).join(knowledge);
}

static ValueKnowledge
getSameSizeAsInput(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = input.hasSizes;
knowledge.sizes = input.sizes;
return knowledge;
}

// Common template for softmax like ops, eg., log_softmax.
template <typename OpTy>
ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto dtype = op.dtype();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = input.hasSizes;
knowledge.sizes = input.sizes;
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
return getLatticeElement(op.getResult()).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAten_SoftmaxOp(
Aten_SoftmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
bool halfToFloat;
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
knowledge.dtype =
halfToFloat ? Float32Type::get(op->getContext()) : input.dtype;
}
return getLatticeElement(op.getResult()).join(knowledge);
}

ChangeResult TypeAnalyzer::visitAtenBmmOp(
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::sqrt : (Tensor) -> (Tensor)")
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")

# Misc tensor ops.
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
Expand Down

0 comments on commit f461a7e

Please sign in to comment.