Skip to content

Commit

Permalink
[onnx] Support for optional axis attribute for onnx.Pad (#3635)
Browse files Browse the repository at this point in the history
The `axis` attribute is optionally available. Added support by computing
the pad based on the axis values.

---------

Signed-off-by: Rob Suderman <[email protected]>
  • Loading branch information
rsuderman authored Aug 24, 2024
1 parent b3b8e2e commit 6cf1396
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 8 deletions.
98 changes: 90 additions & 8 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2700,15 +2700,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value data, pads, axes;
std::string mode;

// TODO: The `axes` parameter is not supported yet.
if (!binder.tensorOperandAtIndex(axes, 3)) {
return rewriter.notifyMatchFailure(
binder.op, "The axes parameter is not supported yet");
}
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(mode, "mode", "constant"))
return failure();

(void)binder.tensorOperandAtIndex(axes, 3);

bool cstMode = (mode == "constant");

// get input rank
Expand Down Expand Up @@ -2822,16 +2820,100 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (!cstMode)
constantValue = rewriter.create<Torch::ConstantNoneOp>(loc);

llvm::SmallVector<Value> begins;
llvm::SmallVector<Value> ends;
for (uint32_t i = 0; i < padsSize / 2; ++i)
begins.push_back(padsTensorValue[i]);
for (uint32_t i = padsSize / 2; i < padsSize; ++i)
ends.push_back(padsTensorValue[i]);

// If we have the axes we need to compute the appropriate pads:
if (axes) {
auto axesTy = cast<Torch::ValueTensorType>(axes.getType());
assert(axesTy.getSizes().size() == 1);
assert(axesTy.getSizes()[0] != Torch::kUnknownSize);

auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
int64_t rank = dataTensorType.getSizes().size();
auto boolTy = rewriter.getType<Torch::BoolType>();
auto intTy = rewriter.getType<Torch::IntType>();
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));

// Extract the values:
int64_t numAxes = axesTy.getSizes()[0];
Type axesElemType = Torch::ValueTensorType::get(
axesTy.getContext(), ArrayRef<int64_t>{},
axesTy.getOptionalDtype());
llvm::SmallVector<Value> axesExtracted;
Value rankV = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank));
for (uint32_t i = 0; i < numAxes; ++i) {
Value index = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
auto select = rewriter.create<Torch::AtenSelectIntOp>(
loc, axesElemType, axes, constZero, index);
Value selectInt = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), select);

Value negAxis = rewriter.create<Torch::AtenLtIntOp>(
loc, boolTy, selectInt, constZero);
negAxis =
rewriter.create<Torch::AtenIntBoolOp>(loc, intTy, negAxis);
Value axis = rewriter.create<Torch::AtenMulIntOp>(loc, intTy,
negAxis, rankV);
axis = rewriter.create<Torch::AtenAddIntOp>(loc, intTy, axis,
selectInt);
axesExtracted.push_back(axis);
}

llvm::SmallVector<Value> newBegins;
llvm::SmallVector<Value> newEnds;

for (int j = 0; j < rank; ++j) {
Value newBegin = constZero;
Value newEnd = constZero;
Value iv = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(j));

for (size_t i = 0; i < axesExtracted.size(); ++i) {
Value begin = begins[i];
Value end = ends[i];

Value sameAxis = rewriter.create<Torch::AtenEqIntOp>(
loc, boolTy, axesExtracted[i], iv);
sameAxis =
rewriter.create<Torch::AtenIntBoolOp>(loc, intTy, sameAxis);

begin = rewriter.create<Torch::AtenMulIntOp>(loc, intTy, sameAxis,
begin);
end = rewriter.create<Torch::AtenMulIntOp>(loc, intTy, sameAxis,
end);

newBegin = rewriter.create<Torch::AtenAddIntOp>(loc, intTy,
newBegin, begin);
newEnd =
rewriter.create<Torch::AtenAddIntOp>(loc, intTy, newEnd, end);
}

newBegins.push_back(newBegin);
newEnds.push_back(newEnd);
}

begins = std::move(newBegins);
ends = std::move(newEnds);
}

// The torch.pad op expects a different arrangement of padding pairs for
// each dimension as compared to the onnx.pad op. Rearrange the pad
// tensor as shown below:
//
// [x1_begin, x2_begin, ..., x1_end, x2_end,...] ->
// [xn_begin, xn_end, ...., x2_begin, x2_end, x1_begin, x1_end]
SmallVector<Value> padsRearrange;
for (uint32_t i = padsSize - 1; i >= padsSize / 2; i--) {
padsRearrange.emplace_back(padsTensorValue[i - padsSize / 2]);
padsRearrange.emplace_back(padsTensorValue[i]);
for (int32_t i = begins.size() - 1; i >= 0; i--) {
padsRearrange.emplace_back(begins[i]);
padsRearrange.emplace_back(ends[i]);
}

Value padsSizeList =
Expand Down
81 changes: 81 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,87 @@ func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor

// -----

func.func @test_center_crop_pad_crop_axes_chw_expanded(%arg0: !torch.vtensor<[4,5],f32>, %arg1: !torch.vtensor<[4],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {

// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[ZERO:.+]] = torch.constant.int 0

// CHECK: %[[IDX:.+]] = torch.constant.int 0
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]]
// CHECK: %[[PAD0:.+]] = torch.aten.item %[[SEL]]

// CHECK: %[[IDX:.+]] = torch.constant.int 1
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]]
// CHECK: %[[PAD1:.+]] = torch.aten.item %[[SEL]]

// CHECK: %[[IDX:.+]] = torch.constant.int 2
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]]
// CHECK: %[[PAD2:.+]] = torch.aten.item %[[SEL]]

// CHECK: %[[IDX:.+]] = torch.constant.int 3
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]]
// CHECK: %[[PAD3:.+]] = torch.aten.item %[[SEL]]

// CHECK: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[RANK:.+]] = torch.constant.int 2

// CHECK: %[[IDX:.+]] = torch.constant.int 0
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg2, %[[ZERO]], %[[IDX]]
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[ZERO]]
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[LT]]
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[INT]], %[[RANK]]
// CHECK: %[[AXIS0:.+]] = torch.aten.add.int %[[MUL]], %[[ITEM]]

// CHECK: %[[IDX:.+]] = torch.constant.int 1
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg2, %[[ZERO]], %[[IDX]]
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[ZERO]]
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[LT]]
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[INT]], %[[RANK]]
// CHECK: %[[AXIS1:.+]] = torch.aten.add.int %[[MUL]], %[[ITEM]]


// CHECK: %[[AX:.+]] = torch.constant.int 0
// CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS0]], %[[AX]]
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]]
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD0]]
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD2]]
// CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL0]]
// CHECK: %[[ADD1:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL1]]

// CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS1]], %[[AX]]
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]]
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD1]]
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD3]]
// CHECK: %[[BEGIN0:.+]] = torch.aten.add.int %[[ADD0]], %[[MUL0]]
// CHECK: %[[END0:.+]] = torch.aten.add.int %[[ADD1]], %[[MUL1]]

// CHECK: %[[AX:.+]] = torch.constant.int 1
// CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS0]], %[[AX]]
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]]
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD0]]
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD2]]
// CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL0]]
// CHECK: %[[ADD1:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL1]]

// CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS1]], %[[AX]]
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]]
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD1]]
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD3]]
// CHECK: %[[BEGIN1:.+]] = torch.aten.add.int %[[ADD0]], %[[MUL0]]
// CHECK: %[[END1:.+]] = torch.aten.add.int %[[ADD1]], %[[MUL1]]

// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[BEGIN1]], %[[END1]], %[[BEGIN0]], %[[END0]]
// CHECK: %[[MODE:.+]] = torch.constant.str "constant"
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[MODE]], %[[NONE]]
%none = torch.constant.none
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[4,5],f32>, !torch.vtensor<[4],si64>, !torch.none, !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

// -----

// CHECK-LABEL: func.func @test_pow
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
Expand Down

0 comments on commit 6cf1396

Please sign in to comment.