@@ -2786,16 +2786,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2786
2786
coordTfMode, " coordinate_transformation_mode" , " half_pixel" ) ||
2787
2787
binder.customOpNameStringAttr (nearest_mode, " nearest_mode" , " " ))
2788
2788
return failure ();
2789
-
2789
+ if (coordTfMode == " tf_crop_and_resize" )
2790
+ return rewriter.notifyMatchFailure (
2791
+ binder.op , " unimplemented: coordinate transformation mode: "
2792
+ " tf_crop_and_resize" );
2790
2793
if (mode == " nearest" && nearest_mode != " floor" ) {
2791
2794
return rewriter.notifyMatchFailure (
2792
2795
binder.op , " unimplemented: support not present for nearest_mode "
2793
2796
" except floor" );
2794
2797
}
2795
- if (coordTfMode == " half_pixel_symmetric" ||
2796
- coordTfMode == " asymmetric" || coordTfMode == " tf_crop_and_resize" )
2797
- return rewriter.notifyMatchFailure (
2798
- binder.op , " unimplemented coordinate transformation mode." );
2798
+ unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0 ].getType ())
2799
+ .getSizes ()
2800
+ .size ();
2799
2801
2800
2802
Value zero = rewriter.create <Torch::ConstantIntOp>(
2801
2803
binder.getLoc (), rewriter.getType <Torch::IntType>(),
@@ -2857,36 +2859,54 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2857
2859
Value sizesValueList = noneVal;
2858
2860
Value alignCorners =
2859
2861
coordTfMode == " align_corners" ? cstTrue : cstFalse;
2860
-
2861
2862
if (mode == " cubic" ) {
2862
2863
return rewriter.notifyMatchFailure (binder.op ,
2863
2864
" unimplemented: bicubic mode" );
2864
2865
}
2866
+ // supported modes:
2867
+ // bilinear (half_pixel), bilinear with align_corners,
2868
+ // bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
2869
+ // (asymmetric), nearest with align_corners, nearest_half_pixel,
2870
+ // nearest_pytorch_half_pixel
2865
2871
if (mode == " linear" ) {
2866
- modeStrValue = rewriter.create <Torch::ConstantStrOp>(binder.getLoc (),
2867
- " bilinear" );
2868
- if (operands.size () < 4 ) {
2869
- Value scaleOperand = operands[2 ];
2870
- scalesValueList = getValueList (scaleOperand);
2871
- sizesValueList = noneVal;
2872
- } else {
2873
- Value sizeOperand = operands[3 ];
2874
- scalesValueList = noneVal;
2875
- sizesValueList = getValueList (sizeOperand);
2872
+ std::string modeStr;
2873
+ switch (rank) {
2874
+ case 3 :
2875
+ modeStr = " linear" ;
2876
+ break ;
2877
+ case 4 :
2878
+ modeStr = " bilinear" ;
2879
+ break ;
2880
+ case 5 :
2881
+ modeStr = " trilinear" ;
2882
+ break ;
2883
+ default :
2884
+ return failure ();
2876
2885
}
2886
+ // Confusingly enough, the default coordTfMode for pytorch bilinear
2887
+ // mode is apparently half_pixel, NOT pytorch_half_pixel
2888
+ if (coordTfMode != " half_pixel" && coordTfMode != " align_corners" )
2889
+ modeStr = (modeStr + " _" ) + coordTfMode;
2890
+ modeStrValue = rewriter.create <Torch::ConstantStrOp>(binder.getLoc (),
2891
+ modeStr);
2877
2892
}
2878
2893
if (mode == " nearest" ) {
2894
+ std::string modeStr = " nearest" ;
2895
+ // The default coordTfMode for pytorch with mode = nearest is
2896
+ // apparently asymmetric
2897
+ if (coordTfMode != " asymmetric" && coordTfMode != " align_corners" )
2898
+ modeStr = (modeStr + " _" ) + coordTfMode;
2879
2899
modeStrValue =
2880
- rewriter.create <Torch::ConstantStrOp>(binder.getLoc (), " nearest " );
2881
- if (operands. size () < 4 ) {
2882
- Value scaleOperand = operands[ 2 ];
2883
- scalesValueList = getValueList (scaleOperand) ;
2884
- sizesValueList = noneVal ;
2885
- } else {
2886
- Value sizesOperand = operands[ 3 ];
2887
- scalesValueList = noneVal ;
2888
- sizesValueList = getValueList (sizesOperand) ;
2889
- }
2900
+ rewriter.create <Torch::ConstantStrOp>(binder.getLoc (), modeStr );
2901
+ }
2902
+ if (operands. size () < 4 ) {
2903
+ Value scaleOperand = operands[ 2 ] ;
2904
+ scalesValueList = getValueList (scaleOperand) ;
2905
+ sizesValueList = noneVal;
2906
+ } else {
2907
+ Value sizeOperand = operands[ 3 ] ;
2908
+ scalesValueList = noneVal ;
2909
+ sizesValueList = getValueList (sizeOperand);
2890
2910
}
2891
2911
if (scalesValueList.getType ().isa <Torch::NoneType>() &&
2892
2912
sizesValueList.getType ().isa <Torch::NoneType>()) {
0 commit comments