Skip to content

Commit

Permalink
[ONNX] fix padding for onnx.MaxPool (#3611)
Browse files Browse the repository at this point in the history
The saga of aligning onnx and torch padding conventions continues. 

```python
onnx_pads = [low_x, low_y, low_z, high_x, high_y, high_z]
torch_pads = [low_z, high_z, low_y, high_y, low_x, high_x]
```

So not only is the lexicographical ordering hierarchy swapped (low/high
x spatial-dim -> spatial-dim x low/high) but the ordering in the the
spatial-dim specification is also reversed.

This patch properly reverses the pad ordering (and actually uses the
`shuffledPadding` to pad).
  • Loading branch information
zjgarvey authored Aug 8, 2024
1 parent 6c33ab0 commit 7f2a17e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
7 changes: 3 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,15 +788,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
shuffledPadding.resize(2 * rank);
for (int i = 0; i < spatial; ++i) {
paddedShape[i + 2] += padding[i] + padding[i + spatial];
shuffledPadding[2 * i] = padding[i];
shuffledPadding[2 * i + 1] = padding[i + spatial];
shuffledPadding[2 * i] = padding[spatial - i - 1];
shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1];
}

Value shuffledPaddingList =
createConstantIntList(binder, rewriter, padding);
createConstantIntList(binder, rewriter, shuffledPadding);
Value zero;
if (isa<FloatType>(resultTypeOut.getDtype())) {
zero = rewriter.create<Torch::ConstantFloatOp>(
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,8 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
// CHECK-LABEL: func.func @test_maxpool_pad
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[INT2_0:.+]] = torch.constant.int 2
// CHECK: %[[INT1_1:.+]] = torch.constant.int 2
// CHECK: %[[INT2_0:.+]] = torch.constant.int 1
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
// CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308
Expand Down

0 comments on commit 7f2a17e

Please sign in to comment.