Skip to content

Commit

Permalink
[Stablehlo]fix CumsumInputDtypeInt32Module_basic on stablehlo backend. (
Browse files Browse the repository at this point in the history
#2797)

Code used for testing.For the location of CumsumInputDtypeInt32Module in
the repo you can see
[here](https://github.com/llvm/torch-mlir/blob/311b6b0286bfa016346bc7fd8b441bbd50216060/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py#L4148).
```python
import torch
import torch_mlir

class CumsumInputDtypeInt32Module(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, val):
        return torch.ops.aten.cumsum(val, 1)
module = torch_mlir.compile(CumsumInputDtypeInt32Module(), [torch.randn(2, 7, 4).to(torch.int32)], output_type="stablehlo")
print(module.operation.get_asm())
```
After fixing the bugs.
```
module attributes {torch.debug_module_name = "CumsumInputDtypeInt32Module"} {
  func.func @forward(%arg0: tensor<2x7x4xi32>) -> tensor<2x7x4xi64> {
    %0 = stablehlo.constant dense<0> : tensor<i64>
    %1 = stablehlo.convert %arg0 : (tensor<2x7x4xi32>) -> tensor<2x7x4xi64>
    %2 = "stablehlo.reduce_window"(%1, %0) ({
    ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
      %3 = stablehlo.add %arg1, %arg2 : tensor<i64>
      stablehlo.return %3 : tensor<i64>
    }) {padding = dense<[[0, 0], [6, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 7, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<2x7x4xi64>, tensor<i64>) -> tensor<2x7x4xi64>
    return %2 : tensor<2x7x4xi64>
  }
}
```
  • Loading branch information
linuxlonelyeagle authored Jan 25, 2024
1 parent f6f8905 commit e581b33
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lib/Conversion/TorchToStablehlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,13 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>();
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
inputTy = input.getType().cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto inputShape = inputTy.getShape();
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
Expand Down

0 comments on commit e581b33

Please sign in to comment.