Skip to content

Commit 08126c9

Browse files
BowenBaofacebook-github-bot
authored andcommitted
[ONNX] Utilize ONNX shape inference for ONNX exporter (pytorch#40628)
Summary: It is often that the conversion from torch operator to onnx operator requires input rank/dtype/shape to be known. Previously, the conversion depends on tracer to provide these info, leaving a gap in conversion of scripted modules. We are extending the export with support from onnx shape inference. If enabled, onnx shape inference will be called whenever an onnx node is created. This is the first PR introducing the initial look of the feature. More and more cases will be supported following this PR. * Added pass to run onnx shape inference on a given node. The node has to have namespace `onnx`. * Moved helper functions from `export.cpp` to a common place for re-use. * This feature is currently experimental, and can be turned on through flag `onnx_shape_inference` in internal api `torch.onnx._export`. * Currently skipping ONNX Sequence ops, If/Loop and ConstantOfShape due to limitations. Support will be added in the future. Pull Request resolved: pytorch#40628 Reviewed By: mrshenli Differential Revision: D22709746 Pulled By: bzinodev fbshipit-source-id: b52aeeae00667e66e0b0c1144022f7af9a8b2948
1 parent 3aeb70d commit 08126c9

File tree

19 files changed

+559
-193
lines changed

19 files changed

+559
-193
lines changed

.github/workflows/lint.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ jobs:
152152
--verbose \
153153
--paths torch/csrc/ \
154154
--diff "$MERGE_BASE" \
155+
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
156+
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp"\
157+
-g"-torch/csrc/jit/serialization/onnx.cpp" \
155158
-g"-torch/csrc/jit/serialization/export.cpp" \
156159
-g"-torch/csrc/jit/serialization/import.cpp" \
157160
-g"-torch/csrc/jit/serialization/import_legacy.cpp" \

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ namespace c10 {
283283
_(onnx, SequenceConstruct) \
284284
_(onnx, SequenceEmpty) \
285285
_(onnx, SequenceInsert) \
286+
_(onnx, SequenceErase) \
286287
_(onnx, ConcatFromSequence) \
287288
_(onnx, Identity) \
288289
_(onnx, SoftmaxCrossEntropyLoss) \

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
459459
if(NOT INTERN_BUILD_MOBILE)
460460
list(APPEND TORCH_SRCS
461461
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
462+
${TORCH_SRC_DIR}/csrc/jit/serialization/onnx.cpp
462463
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
463464
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
464465
${TORCH_SRC_DIR}/csrc/jit/serialization/import_legacy.cpp

test/onnx/test_pytorch_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,14 @@ def wrapper(self):
8686
return wrapper
8787
return skip_dec
8888

89+
def skipIfONNXShapeInference(onnx_shape_inference):
90+
def skip_dec(func):
91+
def wrapper(self):
92+
if self.onnx_shape_inference is onnx_shape_inference:
93+
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
94+
return func(self)
95+
return wrapper
96+
return skip_dec
97+
8998
def flatten(x):
9099
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
1818
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, enableScriptTest,
1919
skipIfUnsupportedOpsetVersion, skipIfNoLapack,
20-
skipIfUnsupportedMaxOpsetVersion)
20+
skipIfUnsupportedMaxOpsetVersion, skipIfONNXShapeInference)
2121
from test_pytorch_common import BATCH_SIZE
2222
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
2323
import model_defs.word_language_model as word_language_model
@@ -79,7 +79,8 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
7979
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
8080
dynamic_axes=dynamic_axes,
8181
input_names=input_names, output_names=output_names,
82-
fixed_batch_size=fixed_batch_size)
82+
fixed_batch_size=fixed_batch_size,
83+
onnx_shape_inference=self.onnx_shape_inference)
8384

8485
# compute onnxruntime output prediction
8586
ort_sess = onnxruntime.InferenceSession(f.getvalue())
@@ -103,6 +104,7 @@ class TestONNXRuntime(unittest.TestCase):
103104
from torch.onnx.symbolic_helper import _export_onnx_opset_version
104105
opset_version = _export_onnx_opset_version
105106
keep_initializers_as_inputs = True # For IR version 3 type export.
107+
onnx_shape_inference = False
106108

107109
def setUp(self):
108110
torch.manual_seed(0)
@@ -496,15 +498,15 @@ def test_tensor(self):
496498
class ScalarInputModel(torch.jit.ScriptModule):
497499
@torch.jit.script_method
498500
def forward(self, input):
499-
return torch.tensor(input.shape[1])
501+
return torch.tensor(input.shape[1])
500502

501503
x = torch.randn(3, 4)
502504
self.run_test(ScalarInputModel(), x)
503505

504506
class TensorInputModel(torch.jit.ScriptModule):
505507
@torch.jit.script_method
506508
def forward(self, input):
507-
return torch.tensor([input.shape[0], input.shape[1]])
509+
return torch.tensor([input.shape[0], input.shape[1]])
508510

509511
x = torch.randn(3, 4)
510512
self.run_test(TensorInputModel(), x)
@@ -520,15 +522,15 @@ def forward(self, input):
520522
class InputWithDtypeModel(torch.jit.ScriptModule):
521523
@torch.jit.script_method
522524
def forward(self, input):
523-
return torch.tensor(input.shape[1], dtype=torch.long)
525+
return torch.tensor(input.shape[1], dtype=torch.long)
524526

525527
x = torch.randn(3, 4)
526528
self.run_test(InputWithDtypeModel(), x)
527529

528530
class MixedInputModel(torch.jit.ScriptModule):
529531
@torch.jit.script_method
530532
def forward(self, input):
531-
return torch.tensor([input.shape[0], int(input)])
533+
return torch.tensor([input.shape[0], int(input)])
532534

533535
x = torch.randn(1)
534536
self.run_test(MixedInputModel(), x)
@@ -686,6 +688,23 @@ def forward(self, input1, input2, input3):
686688
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
687689
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
688690

691+
# Conversion of Transpose depends on input shape to be known.
692+
# The following test only works when onnx shape inference is enabled.
693+
@skipIfONNXShapeInference(False)
694+
def test_transpose_infer_shape(self):
695+
class TransposeModule(torch.jit.ScriptModule):
696+
def __init__(self):
697+
super(TransposeModule, self).__init__()
698+
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
699+
700+
@torch.jit.script_method
701+
def forward(self, x):
702+
x = self.conv(x)
703+
return x.transpose(0, 1)
704+
705+
x = torch.randn(32, 3, 64, 64)
706+
self.run_test(TransposeModule(), x)
707+
689708
def squeeze_model_tests(self, d, x1, x2):
690709
class Squeeze(torch.nn.Module):
691710
def forward(self, x):
@@ -842,6 +861,23 @@ def forward(self, x):
842861
x = torch.randn(2, 3, 4)
843862
self.run_test(ArithmeticModule(), x)
844863

864+
# In scripting the first transpose node do not carry shape and dtype info.
865+
# The following test only works when onnx shape inference is enabled.
866+
@skipIfONNXShapeInference(False)
867+
def test_arithmetic_infer_dtype(self):
868+
class ArithmeticModule(torch.jit.ScriptModule):
869+
@torch.jit.script_method
870+
def forward(self, x):
871+
x = x.t()
872+
x = x + 2
873+
x = x - 4
874+
x = x * 6
875+
x = x / 8
876+
return x
877+
878+
x = torch.randn(2, 3)
879+
self.run_test(ArithmeticModule(), x)
880+
845881
def test_floor_div(self):
846882
class FloorDivModule(torch.nn.Module):
847883
def forward(self, x, y):
@@ -3015,6 +3051,21 @@ def forward(self, x):
30153051
x = torch.randn(4, 2, 3, requires_grad=True)
30163052
self.run_test(UnfoldModel(), x)
30173053

3054+
@skipIfONNXShapeInference(False)
3055+
def test_unfold_infer_shape(self):
3056+
class UnfoldModule(torch.jit.ScriptModule):
3057+
def __init__(self):
3058+
super(UnfoldModule, self).__init__()
3059+
self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
3060+
3061+
@torch.jit.script_method
3062+
def forward(self, x):
3063+
x = self.conv(x)
3064+
return x.unfold(dimension=2, size=2, step=2)
3065+
3066+
x = torch.randn(32, 3, 64)
3067+
self.run_test(UnfoldModule(), x)
3068+
30183069
def test_remainder(self):
30193070
class RemainderModel(torch.nn.Module):
30203071
def forward(self, input, other):
@@ -4187,5 +4238,11 @@ def setup_rnn_tests():
41874238
keep_initializers_as_inputs=False))
41884239

41894240

4241+
# opset 12 tests, with _onnx_shape_inference=True.
4242+
TestONNXRuntime_opset12_onnx_shape_inference = type(str("TestONNXRuntime_opset12_onnx_shape_inference"),
4243+
(unittest.TestCase,),
4244+
dict(TestONNXRuntime.__dict__, opset_version=12,
4245+
onnx_shape_inference=True))
4246+
41904247
if __name__ == '__main__':
41914248
unittest.main()

tools/build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
322322
"torch/csrc/jit/mobile/observer.cpp",
323323
"torch/csrc/jit/mobile/optim/sgd.cpp",
324324
"torch/csrc/jit/mobile/sequential.cpp",
325+
"torch/csrc/jit/serialization/onnx.cpp",
325326
"torch/csrc/jit/serialization/export.cpp",
326327
"torch/csrc/jit/serialization/export_module.cpp",
327328
"torch/csrc/jit/serialization/import_legacy.cpp",
@@ -501,6 +502,7 @@ libtorch_python_core_sources = [
501502
"torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp",
502503
"torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp",
503504
"torch/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.cpp",
505+
"torch/csrc/jit/passes/onnx/shape_type_inference.cpp",
504506
"torch/csrc/jit/python/python_arg_flatten.cpp",
505507
"torch/csrc/jit/python/python_custom_class.cpp",
506508
"torch/csrc/jit/python/python_interpreter.cpp",

tools/git-pre-commit

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ then
1010
python tools/clang_tidy.py \
1111
--paths torch/csrc \
1212
--diff HEAD \
13+
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
14+
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp" \
15+
-g"-torch/csrc/jit/serialization/onnx.cpp" \
1316
-g"-torch/csrc/jit/serialization/export.cpp" \
1417
-g"-torch/csrc/jit/serialization/import.cpp" \
1518
-j

torch/csrc/jit/passes/onnx/constant_fold.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ c10::optional<at::Tensor> runTorchSlice_opset9(
103103
c10::optional<at::Tensor> runTorchSlice_opset10(
104104
const Node* node,
105105
std::vector<at::Tensor>& inputTensorValues) {
106-
if (inputTensorValues.size() < 3 || inputTensorValues.size() > 5) {
106+
const int maxSliceInputCount = 5;
107+
const int minSliceInputCount = 3;
108+
if (inputTensorValues.size() < minSliceInputCount ||
109+
inputTensorValues.size() > maxSliceInputCount) {
107110
std::cerr
108111
<< "Warning: Constant folding - Invalid number of inputs found for opset 10 or 11 onnx::Slice op. "
109112
<< "Constant folding not applied." << std::endl;
@@ -249,11 +252,9 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
249252
return c10::optional<at::Tensor>(updated_val);
250253
} else if (node->kind() == onnx::Cast) {
251254
assert(inputTensorValues.size() == 1);
252-
if (node->hasAttributeS("to") &&
253-
onnxTypeToScalarTypeMap.find(node->i(attr::to)) !=
254-
onnxTypeToScalarTypeMap.end()) {
255-
updated_val =
256-
inputTensorValues[0].to(onnxTypeToScalarTypeMap[node->i(attr::to)]);
255+
if (node->hasAttributeS("to") && ONNXTypeToATenType(node->i(attr::to))) {
256+
updated_val = inputTensorValues[0].to(
257+
ONNXTypeToATenType(node->i(attr::to)).value());
257258
return c10::optional<at::Tensor>(updated_val);
258259
}
259260
return c10::nullopt;

torch/csrc/jit/passes/onnx/helper.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <torch/csrc/jit/passes/onnx/helper.h>
2-
#include <torch/csrc/jit/jit_log.h>
2+
#include <onnx/onnx_pb.h>
33

44
namespace torch {
55
namespace jit {
@@ -59,5 +59,40 @@ Node* addNodeToBlock(Block* block, Value* input, Symbol kind) {
5959
}
6060
return new_node;
6161
}
62+
63+
c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type) {
64+
switch (onnx_type) {
65+
case ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
66+
return at::ScalarType::Undefined;
67+
case ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
68+
return at::kFloat;
69+
case ::ONNX_NAMESPACE::TensorProto_DataType_UINT8:
70+
return at::kByte;
71+
case ::ONNX_NAMESPACE::TensorProto_DataType_INT8:
72+
return at::kChar;
73+
case ::ONNX_NAMESPACE::TensorProto_DataType_INT16:
74+
return at::kShort;
75+
case ::ONNX_NAMESPACE::TensorProto_DataType_INT32:
76+
return at::kInt;
77+
case ::ONNX_NAMESPACE::TensorProto_DataType_INT64:
78+
return at::kLong;
79+
case ::ONNX_NAMESPACE::TensorProto_DataType_BOOL:
80+
return at::kBool;
81+
case ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
82+
return at::kHalf;
83+
case ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
84+
return at::kDouble;
85+
case ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64:
86+
return at::kComplexFloat;
87+
case ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128:
88+
return at::kComplexDouble;
89+
case ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
90+
return at::kBFloat16;
91+
default:
92+
TORCH_CHECK("unexpected tensor scalar type");
93+
}
94+
return c10::optional<at::ScalarType>{};
95+
}
96+
6297
} // namespace jit
6398
} // namespace torch

torch/csrc/jit/passes/onnx/helper.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88
namespace torch {
99
namespace jit {
1010

11-
namespace onnx {
1211
static const int OPSET_VERSION_1 = 1;
1312
static const int OPSET_VERSION_9 = 9;
1413
static const int OPSET_VERSION_10 = 10;
1514
static const int OPSET_VERSION_11 = 11;
1615
static const int OPSET_VERSION_12 = 12;
17-
} // namespace onnx
1816

1917
using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;
2018

2119
using ParamMap = std::map<std::string, IValue>;
2220

21+
void buildParamsMapFromValueToParamsMap(
22+
const ValueToParamPairMap& valsToParamsMap,
23+
ParamMap& paramsDict);
2324
ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
2425
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
2526
void eraseUnusedBlockInputs(Block* b);
@@ -28,5 +29,6 @@ void buildParamsMapFromValueToParamsMap(
2829
ParamMap& paramsDict);
2930
Node* addNodeToBlock(Block* block, Value* input, Symbol kind);
3031

32+
TORCH_API c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type);
3133
} // namespace jit
3234
} // namespace torch

torch/csrc/jit/passes/onnx/peephole.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ static void eraseListConstruct(Block* block, int opset_version) {
576576
i, std::vector<Value*>({concat_node->output()}));
577577

578578
} else {
579-
if (opset_version < onnx::OPSET_VERSION_11) {
579+
if (opset_version < OPSET_VERSION_11) {
580580
// Tensor lists are used mostly for inputs to cat/stack. They are
581581
// already handled in those symbolics, and should become dead
582582
// afterwards.

0 commit comments

Comments
 (0)