Skip to content

Commit

Permalink
[TorchToLinalg] Add lowering for torch.aten.diagonal (#2632)
Browse files Browse the repository at this point in the history
  • Loading branch information
frafranz authored Jan 22, 2024
1 parent 50ac3b1 commit b9806cf
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 1 deletion.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11306,6 +11306,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
}];
}

def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::diagonal : (Tensor, int, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$offset,
Torch_IntType:$dim1,
Torch_IntType:$dim2
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDiagonalOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenDiagonalOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenDiagonalCopyOp : Torch_Op<"aten.diagonal_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
138 changes: 138 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,142 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern<AtenViewAsRealOp> {
};
} // namespace

namespace {
class ConvertAtenDiagonalOp : public OpConversionPattern<AtenDiagonalOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenDiagonalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

int64_t offset;
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
return rewriter.notifyMatchFailure(op, "offset must be constant");
int64_t dim1;
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
int64_t dim2;
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2)))
return rewriter.notifyMatchFailure(op, "dim2 must be constant");

Value inputMatrix = adaptor.getSelf();
RankedTensorType inputType = inputMatrix.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();

if (inputRank < 2)
return rewriter.notifyMatchFailure(
op, "input must have at least two dimensions");
int64_t outputRank = inputRank - 1;

dim1 = toPositiveDim(dim1, inputRank);
if (!isValidDim(dim1, inputRank))
return rewriter.notifyMatchFailure(op, "dim1 out of range");
dim2 = toPositiveDim(dim2, inputRank);
if (!isValidDim(dim2, inputRank))
return rewriter.notifyMatchFailure(op, "dim2 out of range");
if (dim1 == dim2)
return rewriter.notifyMatchFailure(
op, "diagonal dimensions cannot be identical");

Type elementType = inputType.getElementType();
RankedTensorType outputType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Location loc = op.getLoc();

Value dim1Size, dim2Size;
dim1Size = getDimOp(rewriter, loc, inputMatrix, dim1);
dim2Size = getDimOp(rewriter, loc, inputMatrix, dim2);

// compute the length of the diagonal with possible offset
// if the offset is very large or very small, diagSize=0 and an empty tensor
// is returned
Value indexZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value indexMinusOne = rewriter.create<arith::ConstantIndexOp>(loc, -1);
Value indexOffset = rewriter.create<arith::ConstantIndexOp>(loc, offset);
Value offsetIsNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, indexOffset, indexZero);
Value sizeForNegativeOffset = rewriter.create<arith::MaxSIOp>(
loc,
rewriter.create<arith::MinSIOp>(
loc, rewriter.create<arith::AddIOp>(loc, dim1Size, indexOffset),
dim2Size),
indexZero);
Value sizeForPositiveOffset = rewriter.create<arith::MaxSIOp>(
loc,
rewriter.create<arith::MinSIOp>(
loc, rewriter.create<arith::SubIOp>(loc, dim2Size, indexOffset),
dim1Size),
indexZero);
Value diagSize = rewriter.create<arith::SelectOp>(
loc, offsetIsNegative, sizeForNegativeOffset, sizeForPositiveOffset);

// depending on its sign, the offset affects only the row or column indices
// of the diagonal
Value diagStart1 = rewriter.create<arith::SelectOp>(
loc, offsetIsNegative,
rewriter.create<arith::MulIOp>(loc, indexOffset, indexMinusOne),
indexZero);
Value diagStart2 = rewriter.create<arith::SelectOp>(loc, offsetIsNegative,
indexZero, indexOffset);

SmallVector<Value> outputDims;
for (auto i = 0; i < inputRank; i++) {
if (!(i == dim1 || i == dim2))
outputDims.push_back(getDimOp(rewriter, loc, inputMatrix, i));
}
outputDims.push_back(diagSize);

Value outputMatrix = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outputDims), elementType);

SmallVector<AffineMap> indexingMaps = {
AffineMap::getMultiDimIdentityMap(outputRank, rewriter.getContext())};
SmallVector<utils::IteratorType> iteratorTypes(
outputRank, utils::IteratorType::parallel);

auto diagonal =
rewriter
.create<linalg::GenericOp>(
loc, outputMatrix.getType(), ValueRange{}, outputMatrix,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> diagIndices;
Value indexOnDiag =
b.create<linalg::IndexOp>(loc, outputRank - 1);
Value dim1Index =
b.create<arith::AddIOp>(loc, indexOnDiag, diagStart1);
Value dim2Index =
b.create<arith::AddIOp>(loc, indexOnDiag, diagStart2);

// specify at which input indices the diagonal values are
// extracted
for (int indIn = 0, indOut = 0; indIn < inputRank; indIn++) {
if (indIn == dim1)
diagIndices.push_back(dim1Index);
else if (indIn == dim2)
diagIndices.push_back(dim2Index);
else {
diagIndices.push_back(
b.create<linalg::IndexOp>(loc, indOut));
indOut++;
}
}
Value diagElt = b.create<tensor::ExtractOp>(
loc, elementType, inputMatrix, diagIndices);
b.create<linalg::YieldOp>(loc, diagElt);
})
.getResult(0);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, diagonal);
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -1872,4 +2008,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
target.addIllegalOp<AtenViewAsRealOp>();
patterns.add<ConvertAtenViewAsRealOp>(typeConverter, context);
target.addIllegalOp<AtenDiagonalOp>();
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
}
72 changes: 72 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6238,6 +6238,74 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str_0 = torch.constant.str \"AssertionError: input must have at least two dimensions\"\n"
" %int2 = torch.constant.int 2\n"
" %int9223372036854775807 = torch.constant.int 9223372036854775807\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %3 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg2, %2, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %5 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg3, %4, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n"
" %6 = torch.aten.ne.int %3, %5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %8 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %9 = torch.prim.ListConstruct %int9223372036854775807, %8 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %10 = torch.prim.min.self_int %9 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %10, %true, init() {\n"
" ^bb0(%arg4: !torch.int):\n"
" %19 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = torch.aten.eq.int %arg4, %3 : !torch.int, !torch.int -> !torch.bool\n"
" %21 = torch.prim.If %20 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %22 = torch.aten.eq.int %arg4, %5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %22 : !torch.bool\n"
" }\n"
" torch.prim.If %21 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %22 = torch.aten.append.t %7, %19 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %11 = torch.aten.__getitem__.t %arg0, %3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.aten.sub.int %12, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %14 = torch.prim.min.int %11, %13 : !torch.int, !torch.int -> !torch.int\n"
" %15 = torch.prim.max.int %14, %int0 : !torch.int, !torch.int -> !torch.int\n"
" %16 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
" %19 = torch.aten.__getitem__.t %arg0, %3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = torch.aten.add.int %19, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %21 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int\n"
" %22 = torch.prim.min.int %20, %21 : !torch.int, !torch.int -> !torch.int\n"
" %23 = torch.prim.max.int %22, %int0 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %23 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %15 : !torch.int\n"
" }\n"
" %18 = torch.aten.append.t %7, %17 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" return %7 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9980,6 +10048,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.diagonal\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
AtenPixelShuffleOp>(op);
AtenPixelShuffleOp, AtenDiagonalOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,36 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
return upstream_shape_functions.unary(self)

@check_shape_function([
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.
Invocation(TensorOfShape(2, 3, 4), dim1=-1, dim2=-2, offset=1), # Positive `offset`.
Invocation(TensorOfShape(2, 3, 4), offset=-1), # Negative `offset``.
Invocation(TensorOfShape(2, 3, 4), offset=3), # Empty result due to large `offset`.
ErrorInvocation(TensorOfShape(2)), # Input one-dimensional.
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal.
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=3, dim2=1), # `dim1` out of bounds.
])
def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> List[int]:
assert len(self) >= 2, "input must have at least two dimensions"
dim1 = upstream_shape_functions.maybe_wrap_dim(dim1, len(self))
dim2 = upstream_shape_functions.maybe_wrap_dim(dim2, len(self))
assert dim1 != dim2, "diagonal dimensions cannot be identical"

diagonal: List[int] = []
for i, self_dim in enumerate(self):
if (i==dim1) or (i==dim2):
pass
else:
diagonal.append(self_dim)

diag_size = max(min(self[dim1], self[dim2] - offset), 0)
if offset<0:
diag_size = max(min(self[dim1] + offset, self[dim2]), 0)
diagonal.append(diag_size)

return diagonal

def aten〇tan〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -2493,6 +2523,11 @@ def aten〇tril〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) ->
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim1=0, dim2=1))
def aten〇diagonal〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::alias_copy : (Tensor) -> (Tensor)")
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
emit("aten::permute_copy : (Tensor, int[]) -> (Tensor)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ def register_all_tests():
from . import control_flow
from . import stats
from . import padding
from . import diagonal
Loading

0 comments on commit b9806cf

Please sign in to comment.