From 7f2a17e7571b03e05a5cf329c8f271976281e280 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 7 Aug 2024 20:34:00 -0700 Subject: [PATCH] [ONNX] fix padding for `onnx.MaxPool` (#3611) The saga of aligning onnx and torch padding conventions continues. ```python onnx_pads = [low_x, low_y, low_z, high_x, high_y, high_z] torch_pads = [low_z, high_z, low_y, high_y, low_x, high_x] ``` So not only is the lexicographical ordering hierarchy swapped (low/high x spatial-dim -> spatial-dim x low/high) but the ordering in the the spatial-dim specification is also reversed. This patch properly reverses the pad ordering (and actually uses the `shuffledPadding` to pad). --- lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | 7 +++---- test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 96459a3a06a9..acb6fb21bc06 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -788,15 +788,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto operandTy = cast(operand.getType()); llvm::SmallVector shuffledPadding(spatial * 2); llvm::SmallVector paddedShape(operandTy.getSizes()); - shuffledPadding.resize(2 * rank); for (int i = 0; i < spatial; ++i) { paddedShape[i + 2] += padding[i] + padding[i + spatial]; - shuffledPadding[2 * i] = padding[i]; - shuffledPadding[2 * i + 1] = padding[i + spatial]; + shuffledPadding[2 * i] = padding[spatial - i - 1]; + shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1]; } Value shuffledPaddingList = - createConstantIntList(binder, rewriter, padding); + createConstantIntList(binder, rewriter, shuffledPadding); Value zero; if (isa(resultTypeOut.getDtype())) { zero = rewriter.create( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c879cefc56c0..ce8a60109106 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -670,8 +670,8 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> // CHECK-LABEL: func.func @test_maxpool_pad func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 2 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 1 // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 // CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308