Skip to content

Commit

Permalink
[onnx] Fix edge condition for onnx.ReduceMax (#3598)
Browse files Browse the repository at this point in the history
For length-0 on `onnx.ReduceMax` the length 0 case was incorrect due to
a copy paste error.
  • Loading branch information
rsuderman authored Aug 7, 2024
1 parent 8d95fe9 commit 1813999
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
13 changes: 7 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1328,30 +1328,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();

// If any of the input dims are 0 we set to the upper limit:
// If any of the input dims are 0 we set to the lower limit:
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
(llvm::any_of(dataTy.getSizes(),
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
keepDims)) {
auto dty = dataTy.getDtype();
Value scalar;
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
auto inf =
APFloat::getInf(fpTy.getFloatSemantics(), /*Negative=*/true);
scalar = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(),
inf.convertToDouble()));
}

if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
auto mx =
auto minInt =
intTy.isSigned()
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
? APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
: APInt::getMinValue(intTy.getIntOrFloatBitWidth());
scalar = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
mx.getSExtValue()));
minInt.getSExtValue()));
}

llvm::SmallVector<Value> fillDims;
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,

// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp
func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0xFFF0000000000000
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
Expand All @@ -660,7 +660,7 @@ func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg

// CHECK-LABEL: func.func @test_reduce_max_empty_set_int
func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647
// CHECK-DAG: %[[INF:.+]] = torch.constant.int -2147483648
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
Expand Down

0 comments on commit 1813999

Please sign in to comment.