Skip to content

Commit 30699cb

Browse files
zhxchen17facebook-github-bot
authored andcommitted
Reland D33284352: [jit][edge] Do not reuse mobile type parser for all unpicklers. (pytorch#71048)
Summary: Pull Request resolved: pytorch#71048 reland D33284352 (pytorch@0a921ba) ghstack-source-id: 146735646 Test Plan: All Github CI: ciflow rerun -l ciflow/all Reviewed By: gmagogsfm Differential Revision: D33489731 fbshipit-source-id: 3e160209a1abb193ad3eed3018054aa7d331025e
1 parent fb66f56 commit 30699cb

21 files changed

+87
-54
lines changed

test/cpp/jit/test_mobile_type_parser.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22
#include <test/cpp/jit/test_utils.h>
33

44
#include <ATen/core/jit_type.h>
5-
6-
namespace c10 {
7-
TypePtr parseType(const std::string& pythonStr);
8-
std::vector<TypePtr> parseType(std::vector<std::string>& pythonStr);
9-
} // namespace c10
5+
#include <torch/csrc/jit/mobile/type_parser.h>
106

117
namespace torch {
128
namespace jit {

test/mobile/test_upgrader_bytecode_table_example.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
* cd ~/pytorch && python torch/csrc/jit/mobile/upgrader_mobile.cpp
66
*/
77

8-
#include <caffe2/serialize/versions.h>
98
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
10-
#include <ATen/core/ivalue.h>
119

12-
namespace c10 {
13-
TypePtr parseType(const std::string& pythonStr);
14-
} // namespace c10
10+
#include <ATen/core/ivalue.h>
11+
#include <caffe2/serialize/versions.h>
12+
#include <torch/csrc/jit/mobile/type_parser.h>
1513

1614
namespace torch {
1715
namespace jit {

tools/codegen/operator_versions/gen_mobile_upgraders.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,11 @@ class ByteCode(Enum):
101101
* cd ~/pytorch && python torch/csrc/jit/mobile/upgrader_mobile.cpp
102102
*/
103103
104-
#include <caffe2/serialize/versions.h>
105104
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
106-
#include <ATen/core/ivalue.h>
107105
108-
namespace c10 {
109-
TypePtr parseType(const std::string& pythonStr);
110-
} // namespace c10
106+
#include <ATen/core/ivalue.h>
107+
#include <caffe2/serialize/versions.h>
108+
#include <torch/csrc/jit/mobile/type_parser.h>
111109
112110
namespace torch {
113111
namespace jit {

torch/csrc/jit/frontend/tree.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ struct Tree;
2929
using TreeRef = c10::intrusive_ptr<Tree>;
3030
using TreeList = at::SmallVector<TreeRef, 4>;
3131

32-
static const TreeList empty_trees = {};
33-
3432
struct Tree : c10::intrusive_ptr_target {
3533
Tree(int kind_) : kind_(kind_) {}
3634
int kind() const {
@@ -46,6 +44,7 @@ struct Tree : c10::intrusive_ptr_target {
4644
throw std::runtime_error("stringValue can only be called on TK_STRING");
4745
}
4846
virtual const TreeList& trees() const {
47+
static const TreeList empty_trees = {};
4948
return empty_trees;
5049
}
5150
const TreeRef& tree(size_t i) const {
@@ -149,11 +148,11 @@ struct Compound : public Tree {
149148
return false;
150149
}
151150
TreeRef map(const std::function<TreeRef(TreeRef)>& fn) override {
152-
TreeList trees_;
151+
TreeList ret;
153152
for (auto& t : trees()) {
154-
trees_.push_back(fn(t));
153+
ret.push_back(fn(t));
155154
}
156-
return Compound::create(kind(), range(), std::move(trees_));
155+
return Compound::create(kind(), range(), std::move(ret));
157156
}
158157

159158
const SourceRange& range() const override {

torch/csrc/jit/mobile/debug_info.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/csrc/jit/frontend/source_range.h>
22
#include <torch/csrc/jit/mobile/debug_info.h>
3+
#include <torch/csrc/jit/mobile/type_parser.h>
34
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
45
#include <torch/csrc/jit/serialization/source_range_serialization.h>
56

@@ -122,10 +123,13 @@ MobileDebugTable::MobileDebugTable(
122123
size_t debug_size{0};
123124
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
124125
auto ivalues =
125-
std::move(
126-
*jit::unpickle(
127-
reinterpret_cast<const char*>(debug_data.get()), debug_size)
128-
.toTuple())
126+
std::move(*jit::unpickle(
127+
reinterpret_cast<const char*>(debug_data.get()),
128+
debug_size,
129+
nullptr,
130+
{},
131+
c10::parseType)
132+
.toTuple())
129133
.elements();
130134
SourceRangeDeserializer deserializer;
131135
for (auto& val : ivalues) {

torch/csrc/jit/mobile/import.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <torch/csrc/jit/api/compilation_unit.h>
1313
#include <torch/csrc/jit/mobile/interpreter.h>
1414
#include <torch/csrc/jit/mobile/observer.h>
15+
#include <torch/csrc/jit/mobile/type_parser.h>
1516
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
1617
#include <torch/csrc/jit/runtime/instruction.h>
1718
#include <torch/csrc/jit/serialization/import_export_constants.h>
@@ -78,11 +79,6 @@
7879
// - Argument::{known_length_,kwarg_only_}
7980
// - FunctionSchema::{overload_name_, is_vararg_, is_varret_}
8081

81-
namespace c10 {
82-
// std::string serializeType(const Type &t);
83-
TypePtr parseType(const std::string& pythonStr);
84-
} // namespace c10
85-
8682
namespace torch {
8783
namespace jit {
8884
using caffe2::serialize::IStreamAdapter;
@@ -502,7 +498,8 @@ c10::IValue BytecodeDeserializer::readArchive(
502498
type_resolver,
503499
obj_loader,
504500
device_,
505-
*reader_.get());
501+
*reader_.get(),
502+
nullptr);
506503
return ivalues;
507504
}
508505

torch/csrc/jit/mobile/import_data.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <caffe2/serialize/inline_container.h>
66
#include <torch/csrc/jit/api/compilation_unit.h>
77
#include <torch/csrc/jit/mobile/observer.h>
8+
#include <torch/csrc/jit/mobile/type_parser.h>
89
#include <torch/csrc/jit/runtime/instruction.h>
910
#include <torch/csrc/jit/serialization/unpickler.h>
1011
#include <torch/custom_class.h>
@@ -14,11 +15,6 @@
1415
#include <string>
1516
#include <vector>
1617

17-
namespace c10 {
18-
// std::string serializeType(const Type &t);
19-
TypePtr parseType(const std::string& pythonStr);
20-
} // namespace c10
21-
2218
namespace torch {
2319
namespace jit {
2420
using caffe2::serialize::IStreamAdapter;
@@ -151,7 +147,9 @@ c10::IValue BytecodeDeserializer::readArchive(
151147
std::move(obj_loader),
152148
std::move(read_record),
153149
// NOLINTNEXTLINE(performance-move-const-arg)
154-
std::move(device));
150+
std::move(device),
151+
false,
152+
nullptr);
155153
return unpickler.parse_ivalue();
156154
}
157155

torch/csrc/jit/mobile/model_compatibility.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ c10::IValue readArchive(
5353
type_resolver,
5454
obj_loader,
5555
device,
56-
stream_reader);
56+
stream_reader,
57+
nullptr);
5758
return ivalues;
5859
}
5960

torch/csrc/jit/mobile/type_parser.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <torch/csrc/jit/mobile/type_parser.h>
2+
13
#include <ATen/core/jit_type.h>
24
#include <c10/util/string_view.h>
35
#include <torch/csrc/jit/frontend/parser_constants.h>

torch/csrc/jit/mobile/type_parser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#pragma once
2+
13
#include <ATen/core/dynamic_type.h>
24
#include <ATen/core/jit_type.h>
35

torch/csrc/jit/runtime/register_ops_utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,8 @@ void listAdd(Stack& stack) {
384384
}
385385

386386
void listInplaceAdd(Stack& stack) {
387-
c10::List<IValue> b = pop(stack).to<List<IValue>>();
388-
c10::List<IValue> a = pop(stack).to<List<IValue>>();
387+
c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
388+
c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
389389
a.append(std::move(b));
390390
push(stack, std::move(a));
391391
}

torch/csrc/jit/runtime/register_prim_ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
973973
TORCH_SELECTIVE_SCHEMA(
974974
"aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
975975
[](Stack& stack) {
976-
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
976+
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
977977
auto self = pop(stack).toTensor();
978978
auto result = at::index(self, indices);
979979
push(stack, std::move(result));
@@ -986,7 +986,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
986986
auto unsafe = pop(stack).toBool();
987987
auto accumulate = pop(stack).toBool();
988988
auto values = pop(stack).toTensor();
989-
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
989+
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
990990
auto self = pop(stack).toTensor();
991991
auto result =
992992
at::_index_put_impl_(self, indices, values, accumulate, unsafe);
@@ -999,7 +999,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
999999
[](Stack& stack) {
10001000
auto accumulate = pop(stack).toBool();
10011001
auto values = pop(stack).toTensor();
1002-
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
1002+
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
10031003
auto self = pop(stack).toTensor();
10041004
auto result = at::index_put_(self, indices, values, accumulate);
10051005
push(stack, std::move(result));
@@ -1011,7 +1011,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
10111011
[](Stack& stack) {
10121012
auto accumulate = pop(stack).toBool();
10131013
auto values = pop(stack).toTensor();
1014-
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
1014+
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
10151015
auto self = pop(stack).toTensor();
10161016
auto result = at::index_put_(self, indices, values, accumulate);
10171017
push(stack, std::move(result));

torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/jit/api/compilation_unit.h>
2+
#include <torch/csrc/jit/mobile/type_parser.h>
23
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
34
#include <torch/csrc/jit/serialization/pickle.h>
45

@@ -214,7 +215,12 @@ ska::flat_hash_map<int64_t, DebugInfoTuple> CallStackDebugInfoUnpickler::
214215
size_t size,
215216
const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
216217
const std::shared_ptr<CompilationUnit>& cu) {
217-
auto ival = jit::unpickle(reinterpret_cast<const char*>(data.get()), size);
218+
auto ival = jit::unpickle(
219+
reinterpret_cast<const char*>(data.get()),
220+
size,
221+
nullptr,
222+
{},
223+
c10::parseType);
218224
ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
219225
auto ivalues = std::move(*std::move(ival).toTuple()).elements();
220226
for (auto& val : ivalues) {

torch/csrc/jit/serialization/import.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
176176
obj_loader,
177177
device_,
178178
*reader_.get(),
179+
nullptr,
179180
storage_context_);
180181
}
181182

torch/csrc/jit/serialization/import_read.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ IValue readArchiveAndTensors(
1212
c10::optional<ObjLoader> obj_loader,
1313
c10::optional<at::Device> device,
1414
caffe2::serialize::PyTorchStreamReader& stream_reader,
15+
c10::TypePtr (*type_parser)(const std::string&),
1516
std::shared_ptr<DeserializationStorageContext> storage_context) {
1617
std::string picklename = pickle_prefix + archive_name + ".pkl";
1718
at::DataPtr pickle_ptr;
@@ -47,6 +48,7 @@ IValue readArchiveAndTensors(
4748
std::move(read_record),
4849
device,
4950
false,
51+
type_parser,
5052
storage_context);
5153
unpickler.set_version(stream_reader.version());
5254
return unpickler.parse_ivalue();

torch/csrc/jit/serialization/import_read.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ TORCH_API IValue readArchiveAndTensors(
2020
c10::optional<ObjLoader> obj_loader,
2121
c10::optional<at::Device> device,
2222
caffe2::serialize::PyTorchStreamReader& stream_reader,
23+
c10::TypePtr (*type_parser)(const std::string&) =
24+
Unpickler::defaultTypeParser,
2325
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr);
2426

2527
bool check_zip_file(

torch/csrc/jit/serialization/pickle.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,19 @@ IValue pickle_load(const std::vector<char>& data) {
120120
IValue unpickle(
121121
std::function<size_t(char*, size_t)> reader,
122122
TypeResolver type_resolver,
123-
c10::ArrayRef<at::Tensor> tensor_table) {
123+
c10::ArrayRef<at::Tensor> tensor_table,
124+
c10::TypePtr (*type_parser)(const std::string&)) {
124125
Unpickler unpickler(
125-
std::move(reader), std::move(type_resolver), tensor_table);
126+
std::move(reader), std::move(type_resolver), tensor_table, type_parser);
126127
return unpickler.parse_ivalue();
127128
}
128129

129130
IValue unpickle(
130131
const char* data,
131132
size_t size,
132133
TypeResolver type_resolver,
133-
c10::ArrayRef<at::Tensor> tensor_table) {
134+
c10::ArrayRef<at::Tensor> tensor_table,
135+
c10::TypePtr (*type_parser)(const std::string&)) {
134136
size_t bytes_read = 0;
135137
return unpickle(
136138
[&](char* buffer, size_t len) -> size_t {
@@ -145,7 +147,8 @@ IValue unpickle(
145147
return len;
146148
},
147149
std::move(type_resolver),
148-
tensor_table);
150+
tensor_table,
151+
type_parser);
149152
}
150153

151154
} // namespace jit

torch/csrc/jit/serialization/pickle.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ TORCH_API IValue pickle_load(const std::vector<char>& data);
6969
TORCH_API IValue unpickle(
7070
std::function<size_t(char*, size_t)> reader,
7171
TypeResolver type_resolver,
72-
c10::ArrayRef<at::Tensor> tensor_table);
72+
c10::ArrayRef<at::Tensor> tensor_table,
73+
c10::TypePtr (*type_parser)(const std::string&) =
74+
Unpickler::defaultTypeParser);
7375

7476
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
7577
///
@@ -81,7 +83,9 @@ TORCH_API IValue unpickle(
8183
const char* data,
8284
size_t size,
8385
TypeResolver type_resolver = nullptr,
84-
c10::ArrayRef<at::Tensor> tensor_table = {});
86+
c10::ArrayRef<at::Tensor> tensor_table = {},
87+
c10::TypePtr (*type_parser)(const std::string&) =
88+
Unpickler::defaultTypeParser);
8589

8690
} // namespace jit
8791
} // namespace torch

torch/csrc/jit/serialization/source_range_serialization.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/csrc/jit/serialization/source_range_serialization.h>
22
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
33

4+
#include <torch/csrc/jit/mobile/type_parser.h>
45
#include <torch/csrc/jit/serialization/pickle.h>
56

67
namespace torch {
@@ -111,8 +112,13 @@ void ConcreteSourceRangeUnpickler::unpickle() {
111112
return;
112113
}
113114

114-
auto ivaluesTuple =
115-
jit::unpickle(reinterpret_cast<const char*>(data.get()), size).toTuple();
115+
auto ivaluesTuple = jit::unpickle(
116+
reinterpret_cast<const char*>(data.get()),
117+
size,
118+
nullptr,
119+
{},
120+
c10::parseType)
121+
.toTuple();
116122
const auto& ivalues = ivaluesTuple->elements();
117123

118124
unpickled_records = std::make_shared<SourceRangeRecords>();

torch/csrc/jit/serialization/unpickler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ void Unpickler::readGlobal(
565565
if (type_resolver_ == nullptr) {
566566
// If we haven't injected a custom way of retrieving types from
567567
// names, use a barebones type parser.
568-
type = c10::parseType(type_str);
568+
type = type_parser_(type_str);
569569
} else {
570570
type = type_resolver_(type_str).type_;
571571
}

0 commit comments

Comments
 (0)