From 18139994e807d262f52a13b2c8e1b3edfa45ffa0 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 7 Aug 2024 10:32:28 -0700 Subject: [PATCH] [onnx] Fix edge condition for `onnx.ReduceMax` (#3598) For length-0 on `onnx.ReduceMax` the length 0 case was incorrect due to a copy paste error. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 13 +++++++------ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 957700d1ae19..399f2731b958 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1328,7 +1328,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto dataTy = cast(data.getType()); Torch::IntType torchIntTy = rewriter.getType(); - // If any of the input dims are 0 we set to the upper limit: + // If any of the input dims are 0 we set to the lower limit: if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) && (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == Torch::kUnknownSize; }) || @@ -1336,7 +1336,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto dty = dataTy.getDtype(); Value scalar; if (FloatType fpTy = dyn_cast(dty)) { - auto inf = APFloat::getInf(fpTy.getFloatSemantics()); + auto inf = + APFloat::getInf(fpTy.getFloatSemantics(), /*Negative=*/true); scalar = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), @@ -1344,14 +1345,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } if (IntegerType intTy = dyn_cast(dty)) { - auto mx = + auto minInt = intTy.isSigned() - ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) - : APInt::getMaxValue(intTy.getIntOrFloatBitWidth()); + ? APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) + : APInt::getMinValue(intTy.getIntOrFloatBitWidth()); scalar = rewriter.create( binder.getLoc(), torchIntTy, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - mx.getSExtValue())); + minInt.getSExtValue())); } llvm::SmallVector fillDims; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index e57cd605b007..403b320833fb 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -644,7 +644,7 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: func.func @test_reduce_max_empty_set_fp func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000 + // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0xFFF0000000000000 // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 @@ -660,7 +660,7 @@ func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg // CHECK-LABEL: func.func @test_reduce_max_empty_set_int func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647 + // CHECK-DAG: %[[INF:.+]] = torch.constant.int -2147483648 // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4