Skip to content

Commit 577f87b

Browse files
qihqipytorchmergebot
authored andcommitted
Make flatbuffer loads faster if loading as mobile module. (pytorch#78998)
BCFC check: verified that flatbuffer file created in this commit can be loaded in HEAD and file created in HEAD can be loaded in this commit Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#78998 Approved by: https://github.com/zhxchen17
1 parent 81cd276 commit 577f87b

File tree

5 files changed

+67
-13
lines changed

5 files changed

+67
-13
lines changed

torch/csrc/jit/mobile/flatbuffer_loader.cpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,18 @@ void parseExtraFiles(
177177
parseExtraFilesFromVector(extra_files_offsets, &extra_files);
178178
}
179179

180+
void FlatbufferLoader::parseAndPopulate(
181+
uint32_t i,
182+
const mobile::serialization::IValue* ivalue) {
183+
if (const auto* func = ivalue->val_as_Function()) {
184+
auto func_ptr = parseFunction(func);
185+
all_functions_[i] = func_ptr.get();
186+
mcu_->register_function(std::move(func_ptr));
187+
} else {
188+
all_ivalues_[i] = parseIValue(ivalue);
189+
}
190+
}
191+
180192
mobile::Module FlatbufferLoader::parseModule(
181193
mobile::serialization::Module* module) {
182194
module_ = module;
@@ -192,15 +204,14 @@ mobile::Module FlatbufferLoader::parseModule(
192204
storages_.resize(module->storage_data_size());
193205
storage_loaded_.resize(module->storage_data_size(), false);
194206

195-
for (uint32_t i = 0; i < ivalues->size(); i++) {
207+
mobile_ivalue_size_ = module_->mobile_ivalue_size();
208+
if (mobile_ivalue_size_ == 0) {
209+
mobile_ivalue_size_ = ivalues->size();
210+
}
211+
212+
for (uint32_t i = 0; i < mobile_ivalue_size_; i++) {
196213
const auto* ival = ivalues->Get(i);
197-
if (const auto* func = ival->val_as_Function()) {
198-
auto func_ptr = parseFunction(func);
199-
all_functions_[i] = func_ptr.get();
200-
mcu_->register_function(std::move(func_ptr));
201-
} else {
202-
all_ivalues_[i] = parseIValue(ival);
203-
}
214+
parseAndPopulate(i, ival);
204215
}
205216
IValue& module_ivalue = getIValue(module->state_obj());
206217

@@ -660,6 +671,21 @@ void FlatbufferLoader::extractJitSourceAndConstants(
660671
AT_ASSERT(
661672
module_parsed_,
662673
"Need to first parse a flatbuffer file before extracing jit_sources");
674+
675+
const auto* ivalues = module_->ivalues();
676+
for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) {
677+
const auto* ival = ivalues->Get(i);
678+
parseAndPopulate(i, ival);
679+
}
680+
// register functions
681+
for (const auto& f : all_functions_) {
682+
if (f.first >= mobile_ivalue_size_) {
683+
uint32_t class_index =
684+
ivalues->Get(f.first)->val_as_Function()->class_type();
685+
ClassTypePtr class_type = all_types_[class_index];
686+
class_type->addMethod(f.second);
687+
}
688+
}
663689
const auto* jit_constants = module_->jit_constants();
664690
for (auto i = 0; i < jit_constants->size(); ++i) {
665691
constants->emplace_back(getIValue(jit_constants->Get(i)));

torch/csrc/jit/mobile/flatbuffer_loader.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ class TORCH_API FlatbufferLoader {
141141
IValue parseIValue(const mobile::serialization::IValue* ivalue);
142142
std::unique_ptr<mobile::Function> parseFunction(
143143
const mobile::serialization::Function* method);
144+
void parseAndPopulate(
145+
uint32_t i,
146+
const mobile::serialization::IValue* ivalue);
144147

145148
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
146149
std::vector<ClassTypePtr> all_types_;
@@ -158,6 +161,8 @@ class TORCH_API FlatbufferLoader {
158161
bool module_parsed_ = false;
159162
bool should_copy_tensor_memory_ = false;
160163
bool should_load_operators_ = true;
164+
// 0 -> mobile_ivalue_size_ elements are from the mobile module.
165+
uint32_t mobile_ivalue_size_ = 0;
161166
};
162167

163168
} // namespace jit

torch/csrc/jit/serialization/flatbuffer_serializer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
358358
auto jit_source_offset = storeExtraFilesAndGetOffset(fbb, jit_sources);
359359
std::vector<uint32_t> jit_constants_indexes;
360360
jit_constants_indexes.reserve(jit_constants.size());
361+
const uint32_t mobile_ivalue_size = ivalue_offsets_.size();
361362
for (const auto& ival : jit_constants) {
362363
jit_constants_indexes.emplace_back(storeIValueAndGetIndex(fbb, ival));
363364
}
@@ -408,7 +409,8 @@ flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
408409
fbb.CreateVector(obj_types_offset_),
409410
jit_source_offset,
410411
fbb.CreateVector(jit_constants_indexes),
411-
operator_version);
412+
operator_version,
413+
mobile_ivalue_size);
412414
FinishModuleBuffer(fbb, mod);
413415
return fbb.Release();
414416
}

torch/csrc/jit/serialization/mobile_bytecode.fbs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ table Module {
211211
// To read more:
212212
// https://github.com/pytorch/rfcs/blob/master/RFC-0017-PyTorch-Operator-Versioning.md
213213
operator_version:uint;
214+
215+
// Size of ivalue that comes from the mobile module.
216+
// Because the ivalues array above can also have ivalues that cames from
217+
// the jit::Module that got it's source attached to flatbuffer file.
218+
// this should be smaller than ivalues.size()
219+
mobile_ivalue_size:uint;
214220
}
215221

216222
root_type Module;

torch/csrc/jit/serialization/mobile_bytecode_generated.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2228,7 +2228,8 @@ struct Module FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
22282228
VT_OBJECT_TYPES = 18,
22292229
VT_JIT_SOURCES = 20,
22302230
VT_JIT_CONSTANTS = 22,
2231-
VT_OPERATOR_VERSION = 24
2231+
VT_OPERATOR_VERSION = 24,
2232+
VT_MOBILE_IVALUE_SIZE = 26
22322233
};
22332234
uint32_t bytecode_version() const {
22342235
return GetField<uint32_t>(VT_BYTECODE_VERSION, 0);
@@ -2296,6 +2297,12 @@ struct Module FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
22962297
bool mutate_operator_version(uint32_t _operator_version = 0) {
22972298
return SetField<uint32_t>(VT_OPERATOR_VERSION, _operator_version, 0);
22982299
}
2300+
uint32_t mobile_ivalue_size() const {
2301+
return GetField<uint32_t>(VT_MOBILE_IVALUE_SIZE, 0);
2302+
}
2303+
bool mutate_mobile_ivalue_size(uint32_t _mobile_ivalue_size = 0) {
2304+
return SetField<uint32_t>(VT_MOBILE_IVALUE_SIZE, _mobile_ivalue_size, 0);
2305+
}
22992306
bool Verify(flatbuffers::Verifier &verifier) const {
23002307
return VerifyTableStart(verifier) &&
23012308
VerifyField<uint32_t>(verifier, VT_BYTECODE_VERSION) &&
@@ -2321,6 +2328,7 @@ struct Module FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
23212328
VerifyOffset(verifier, VT_JIT_CONSTANTS) &&
23222329
verifier.VerifyVector(jit_constants()) &&
23232330
VerifyField<uint32_t>(verifier, VT_OPERATOR_VERSION) &&
2331+
VerifyField<uint32_t>(verifier, VT_MOBILE_IVALUE_SIZE) &&
23242332
verifier.EndTable();
23252333
}
23262334
};
@@ -2362,6 +2370,9 @@ struct ModuleBuilder {
23622370
void add_operator_version(uint32_t operator_version) {
23632371
fbb_.AddElement<uint32_t>(Module::VT_OPERATOR_VERSION, operator_version, 0);
23642372
}
2373+
void add_mobile_ivalue_size(uint32_t mobile_ivalue_size) {
2374+
fbb_.AddElement<uint32_t>(Module::VT_MOBILE_IVALUE_SIZE, mobile_ivalue_size, 0);
2375+
}
23652376
explicit ModuleBuilder(flatbuffers::FlatBufferBuilder &_fbb)
23662377
: fbb_(_fbb) {
23672378
start_ = fbb_.StartTable();
@@ -2385,8 +2396,10 @@ inline flatbuffers::Offset<Module> CreateModule(
23852396
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<torch::jit::mobile::serialization::ObjectType>>> object_types = 0,
23862397
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<torch::jit::mobile::serialization::ExtraFile>>> jit_sources = 0,
23872398
flatbuffers::Offset<flatbuffers::Vector<uint32_t>> jit_constants = 0,
2388-
uint32_t operator_version = 0) {
2399+
uint32_t operator_version = 0,
2400+
uint32_t mobile_ivalue_size = 0) {
23892401
ModuleBuilder builder_(_fbb);
2402+
builder_.add_mobile_ivalue_size(mobile_ivalue_size);
23902403
builder_.add_operator_version(operator_version);
23912404
builder_.add_jit_constants(jit_constants);
23922405
builder_.add_jit_sources(jit_sources);
@@ -2413,7 +2426,8 @@ inline flatbuffers::Offset<Module> CreateModuleDirect(
24132426
const std::vector<flatbuffers::Offset<torch::jit::mobile::serialization::ObjectType>> *object_types = nullptr,
24142427
const std::vector<flatbuffers::Offset<torch::jit::mobile::serialization::ExtraFile>> *jit_sources = nullptr,
24152428
const std::vector<uint32_t> *jit_constants = nullptr,
2416-
uint32_t operator_version = 0) {
2429+
uint32_t operator_version = 0,
2430+
uint32_t mobile_ivalue_size = 0) {
24172431
auto extra_files__ = extra_files ? _fbb.CreateVector<flatbuffers::Offset<torch::jit::mobile::serialization::ExtraFile>>(*extra_files) : 0;
24182432
auto methods__ = methods ? _fbb.CreateVector<uint32_t>(*methods) : 0;
24192433
auto ivalues__ = ivalues ? _fbb.CreateVector<flatbuffers::Offset<torch::jit::mobile::serialization::IValue>>(*ivalues) : 0;
@@ -2433,7 +2447,8 @@ inline flatbuffers::Offset<Module> CreateModuleDirect(
24332447
object_types__,
24342448
jit_sources__,
24352449
jit_constants__,
2436-
operator_version);
2450+
operator_version,
2451+
mobile_ivalue_size);
24372452
}
24382453

24392454
inline bool VerifyIValueUnion(flatbuffers::Verifier &verifier, const void *obj, IValueUnion type) {

0 commit comments

Comments
 (0)