Skip to content

Commit

Permalink
[torch] Add folders for torch.fill, torch.ones, torch.zeros and…
Browse files Browse the repository at this point in the history
… `aten.getItem` (#2849)

So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor

---------

Co-authored-by: Rob Suderman <[email protected]>
Co-authored-by: Xida Ren <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2024
1 parent 962d514 commit 24b8c86
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 11 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8416,6 +8416,7 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [
Expand Down Expand Up @@ -8471,6 +8472,7 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
Expand Down Expand Up @@ -9858,6 +9860,7 @@ def Torch_AtenItemOp : Torch_Op<"aten.item", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [
Expand Down Expand Up @@ -11202,6 +11205,7 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [
Expand Down
8 changes: 3 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,11 +1089,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, "expected result type to have a dtype");
}
// resultTensorType.print(llvm::outs());
Value resultDType = Torch::getDtypeIntValueForType(
rewriter, loc, resultTensorType.getDtype());

rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(
binder.op, resultType, operand, dim, resultDType);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(binder.op, resultType,
operand, dim, none);
return success();
});
patterns.onOp(
Expand Down
148 changes: 147 additions & 1 deletion lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "torch-mlir-torch-dialect"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/Debug.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -2813,6 +2814,151 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenItemOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
// see if we have a constant tensor
DenseElementsAttr attr;
if (matchPattern(getOperand(), m_Constant(&attr))) {
auto splat = attr.getSplatValue<Attribute>();
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
return getI64IntegerAttr(getContext(), intAttr.getSInt());
}
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble());
}
return nullptr;
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenOnesOp, AtenZerosOp, AtenFullOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}

Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: result type is not "
"a tensor type or does not have a dtype.\n");
return nullptr;
}

ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenOnesOp: ShapedType cast failed.\n");
return nullptr;
}
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
Attribute attribute = IntegerAttr::get(elementType, 1);
return DenseElementsAttr::get(shapedty, attribute);
}
if (elementType.isa<FloatType>()) {
Attribute attribute = FloatAttr::get(elementType, 1.0);
return DenseElementsAttr::get(shapedty, attribute);
}
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: element type is "
"not integer or float.\n");
return nullptr;
}

OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}

Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: result type is "
"not a tensor type or does not have a dtype.\n");
return nullptr;
}

ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenZerosOp: ShapedType cast failed.\n");
return nullptr;
}

auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
Attribute attribute = IntegerAttr::get(elementType, 0);
return DenseElementsAttr::get(shapedty, attribute);
}
if (elementType.isa<FloatType>()) {
Attribute attribute = FloatAttr::get(elementType, 0.0);
return DenseElementsAttr::get(shapedty, attribute);
}

LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: element type is "
"not integer or float.\n");
return nullptr;
}

OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}

Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: result type is not "
"a tensor type or does not have a dtype.\n");
return nullptr;
}

ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenFullOp: ShapedType cast failed.\n");
return nullptr;
}
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
int64_t value = 0;
if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
Attribute attribute = IntegerAttr::get(elementType, value);
return DenseElementsAttr::get(shapedty, attribute);
}
}
if (elementType.isa<FloatType>()) {
double value = 0.0;
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
Attribute attribute = FloatAttr::get(elementType, value);
return DenseElementsAttr::get(shapedty, attribute);
}
}
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: element type is "
"not integer or float.\n");
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenCeilFloatOp
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors

# torch._dynamo.exc.Unsupported: Tensor.item
"CumsumModule_basic",

# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
# RuntimeError: Failed running call_function aten.convolution_backward(...
# https://github.com/pytorch/pytorch/issues/89629
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,9 +564,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True)
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
Expand Down Expand Up @@ -618,7 +618,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)")
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
Expand Down Expand Up @@ -669,7 +669,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
emit("aten::t : (Tensor) -> (Tensor)")
emit("aten::numpy_T : (Tensor) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
Expand Down
8 changes: 7 additions & 1 deletion projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4092,7 +4092,13 @@ def __init__(self):
([-1, -1, -1], torch.float32, True),
])
def forward(self, val):
return torch.ops.aten.cumsum(val, 1)
# the onnx cumsum op uses a constant 1d tensor
# to specify the dimension along which to do cumsum
# we replicate that here to ensure that cumsum correctly
# trigger the relevant folders and provides TMTensor
# with a constant dimension
ones = torch.ones([1], dtype=torch.int32)
return torch.ops.aten.cumsum(val, ones.item())

@register_test_case(module_factory=lambda: CumsumModule())
def CumsumModule_basic(module, tu: TestUtils):
Expand Down
40 changes: 40 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,46 @@ func.func @torch.runtime.assert() {
return
}

// CHECK-LABEL: func.func @torch.aten.ones_item
// CHECK: %[[CONST:.*]] = torch.constant.int 1
// CHECK: return %[[CONST]] : !torch.int
func.func @torch.aten.ones_item() -> !torch.int {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.ones %0, %int3, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
return %2 : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.zeros_item
// CHECK: %[[CONST:.*]] = torch.constant.int 0
// CHECK: return %[[CONST]] : !torch.int
func.func @torch.aten.zeros_item() -> !torch.int {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %int3, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
return %2 : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.full_item
// CHECK: %[[CONST:.*]] = torch.constant.int 1337
// CHECK: return %[[CONST]] : !torch.int
func.func @torch.aten.full_item() -> !torch.int {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 1337
%int5 = torch.constant.int 5
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.full %0, %int3, %int5, %none, %none, %none : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
return %2 : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_true
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
Expand Down

0 comments on commit 24b8c86

Please sign in to comment.