Skip to content

Commit 7c1ce86

Browse files
committed
Add custom support for onnx.Resize options in aten.interpolate lowering
1 parent f8a5793 commit 7c1ce86

File tree

5 files changed

+255
-173
lines changed

5 files changed

+255
-173
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

+46-26
Original file line numberDiff line numberDiff line change
@@ -2786,16 +2786,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27862786
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
27872787
binder.customOpNameStringAttr(nearest_mode, "nearest_mode", ""))
27882788
return failure();
2789-
2789+
if (coordTfMode == "tf_crop_and_resize")
2790+
return rewriter.notifyMatchFailure(
2791+
binder.op, "unimplemented: coordinate transformation mode: "
2792+
"tf_crop_and_resize");
27902793
if (mode == "nearest" && nearest_mode != "floor") {
27912794
return rewriter.notifyMatchFailure(
27922795
binder.op, "unimplemented: support not present for nearest_mode "
27932796
"except floor");
27942797
}
2795-
if (coordTfMode == "half_pixel_symmetric" ||
2796-
coordTfMode == "asymmetric" || coordTfMode == "tf_crop_and_resize")
2797-
return rewriter.notifyMatchFailure(
2798-
binder.op, "unimplemented coordinate transformation mode.");
2798+
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
2799+
.getSizes()
2800+
.size();
27992801

28002802
Value zero = rewriter.create<Torch::ConstantIntOp>(
28012803
binder.getLoc(), rewriter.getType<Torch::IntType>(),
@@ -2857,36 +2859,54 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
28572859
Value sizesValueList = noneVal;
28582860
Value alignCorners =
28592861
coordTfMode == "align_corners" ? cstTrue : cstFalse;
2860-
28612862
if (mode == "cubic") {
28622863
return rewriter.notifyMatchFailure(binder.op,
28632864
"unimplemented: bicubic mode");
28642865
}
2866+
// supported modes:
2867+
// bilinear (half_pixel), bilinear with align_corners,
2868+
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
2869+
// (asymmetric), nearest with align_corners, nearest_half_pixel,
2870+
// nearest_pytorch_half_pixel
28652871
if (mode == "linear") {
2866-
modeStrValue = rewriter.create<Torch::ConstantStrOp>(binder.getLoc(),
2867-
"bilinear");
2868-
if (operands.size() < 4) {
2869-
Value scaleOperand = operands[2];
2870-
scalesValueList = getValueList(scaleOperand);
2871-
sizesValueList = noneVal;
2872-
} else {
2873-
Value sizeOperand = operands[3];
2874-
scalesValueList = noneVal;
2875-
sizesValueList = getValueList(sizeOperand);
2872+
std::string modeStr;
2873+
switch (rank) {
2874+
case 3:
2875+
modeStr = "linear";
2876+
break;
2877+
case 4:
2878+
modeStr = "bilinear";
2879+
break;
2880+
case 5:
2881+
modeStr = "trilinear";
2882+
break;
2883+
default:
2884+
return failure();
28762885
}
2886+
// Confusingly enough, the default coordTfMode for pytorch bilinear
2887+
// mode is apparently half_pixel, NOT pytorch_half_pixel
2888+
if (coordTfMode != "half_pixel" && coordTfMode != "align_corners")
2889+
modeStr = (modeStr + "_") + coordTfMode;
2890+
modeStrValue = rewriter.create<Torch::ConstantStrOp>(binder.getLoc(),
2891+
modeStr);
28772892
}
28782893
if (mode == "nearest") {
2894+
std::string modeStr = "nearest";
2895+
// The default coordTfMode for pytorch with mode = nearest is
2896+
// apparently asymmetric
2897+
if (coordTfMode != "asymmetric" && coordTfMode != "align_corners")
2898+
modeStr = (modeStr + "_") + coordTfMode;
28792899
modeStrValue =
2880-
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), "nearest");
2881-
if (operands.size() < 4) {
2882-
Value scaleOperand = operands[2];
2883-
scalesValueList = getValueList(scaleOperand);
2884-
sizesValueList = noneVal;
2885-
} else {
2886-
Value sizesOperand = operands[3];
2887-
scalesValueList = noneVal;
2888-
sizesValueList = getValueList(sizesOperand);
2889-
}
2900+
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
2901+
}
2902+
if (operands.size() < 4) {
2903+
Value scaleOperand = operands[2];
2904+
scalesValueList = getValueList(scaleOperand);
2905+
sizesValueList = noneVal;
2906+
} else {
2907+
Value sizeOperand = operands[3];
2908+
scalesValueList = noneVal;
2909+
sizesValueList = getValueList(sizeOperand);
28902910
}
28912911
if (scalesValueList.getType().isa<Torch::NoneType>() &&
28922912
sizesValueList.getType().isa<Torch::NoneType>()) {

0 commit comments

Comments
 (0)