diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 96459a3a06a9..acb6fb21bc06 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -788,15 +788,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto operandTy = cast(operand.getType()); llvm::SmallVector shuffledPadding(spatial * 2); llvm::SmallVector paddedShape(operandTy.getSizes()); - shuffledPadding.resize(2 * rank); for (int i = 0; i < spatial; ++i) { paddedShape[i + 2] += padding[i] + padding[i + spatial]; - shuffledPadding[2 * i] = padding[i]; - shuffledPadding[2 * i + 1] = padding[i + spatial]; + shuffledPadding[2 * i] = padding[spatial - i - 1]; + shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1]; } Value shuffledPaddingList = - createConstantIntList(binder, rewriter, padding); + createConstantIntList(binder, rewriter, shuffledPadding); Value zero; if (isa(resultTypeOut.getDtype())) { zero = rewriter.create( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c879cefc56c0..ce8a60109106 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -670,8 +670,8 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> // CHECK-LABEL: func.func @test_maxpool_pad func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 2 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 1 // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 // CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308