Skip to content

Commit fc2cf3d

Browse files
pavithranraopytorchmergebot
authored andcommitted
Back out "Revert D34805092: Extend _save_for_mobile and _load_for_mobile to support flatbuffer format; Default format is pickle + Change buck targets to support only pickle and pickle + flatbuffer for migration" (pytorch#74594)
Summary: Pull Request resolved: pytorch#74594 Extending `_save_for_mobile` and `_load_for_mobile` to support faltbuffer format with additional optional argument which is set to pick pickle by default. Adding new binary target with suffix `_pickle_and_flatbuffer` to help migration. Size test in D34909502 shows the size has regressed by ~40K but after removing pickle and comparing lite_predictors we have ~120K size measure that we will achieve when deprecating pickle and moving to flatbuffer **BEFORE:** ```lang=mermaid graph TD; torch_core-->torch_mobile_deserialize; torch_mobile_core-->torch_mobile_deserialize; jit_module_saving-->torch_core; jit_module_saving-->torch_mobile_core; torch_mobile_deserialize-->caffe2_serialize; torch_mobile_deserialize-->torch_mobile_module; caffe2_serialize-->miniz; flatbuffer_loader-->mobile_bytecode; flatbuffer_serializer-->mobile_bytecode; mobile_bytecode-->flatbuffer_2.0; flatbuffer_loader-->torch_mobile_module; flatbuffer_serializer-->torch_mobile_module; ``` **AFTER:** ```lang=mermaid graph TD; torch_core-->torch_mobile_deserialize; torch_mobile_core-->torch_mobile_deserialize; jit_module_saving-->torch_core; jit_module_saving-->torch_mobile_core; torch_mobile_deserialize-->caffe2_serialize; torch_mobile_deserialize-->torch_mobile_module; caffe2_serialize-->miniz; flatbuffer_loader-->mobile_bytecode; flatbuffer_serializer-->mobile_bytecode; mobile_bytecode-->flatbuffer_2.0; torch_mobile_deserialize_pickle_and_flatbuffer-->|new| flatbuffer_loader; torch_mobile_deserialize_pickle_and_flatbuffer-->|new| torch_mobile_deserialize; torch_mobile_core_pickle_and_flatbuffer-->|new| torch_mobile_deserialize_pickle_and_flatbuffer; torch_core_pickle_and_flatbuffer-->|new| torch_mobile_deserialize_pickle_and_flatbuffer; jit_module_saving_pickle_and_flatbuffer-->|new| torch_core_pickle_and_flatbuffer; jit_module_saving_pickle_and_flatbuffer-->|new| torch_mobile_core_pickle_and_flatbuffer; flatbuffer_serializer-->torch_mobile_module; jit_module_saving_pickle_and_flatbuffer-->|new|jit_module_saving; jit_module_saving_pickle_and_flatbuffer-->|new|flatbuffer_serializer; flatbuffer_loader-->torch_mobile_module; ``` Original commit changeset: 780dfb6fd6ba Original Phabricator Diff: D34805092 (pytorch@284b2b7) ghstack-source-id: 152044801 (Note: this ignores all push blocking failures!) Test Plan: CI ``` ~/fbsource/fbcode] cd ~/fbsource/fbcode/ && buck test -c fbcode.caffe2_enable_flatbuffer=1 //caffe2/test/cpp/jit:jit -- FlatbufferTest.ExtraFiles Parsing buck files: finished in 0.9 sec Building: finished in 5.3 sec (100%) 12992/54304 jobs, 0/54304 updated Total time: 6.2 sec More details at https://www.internalfb.com/intern/buck/build/2b387fff-f813-4cfa-b53f-eb2378630d4e BUILD SUCCEEDED Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details. Running with tpx session id: f93a84d6-e7ce-41a0-a97f-0ef3fa6d199d Trace available for this run at /tmp/tpx-20220323-134108.766518-f93a84d6-e7ce-41a0-a97f-0ef3fa6d199d/trace.log RemoteExecution session id: reSessionID-f93a84d6-e7ce-41a0-a97f-0ef3fa6d199d-tpx Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/4503599723101693 ✓ ListingSuccess: caffe2/test/cpp/jit:jit : 486 tests discovered (19.122) ✓ Pass: caffe2/test/cpp/jit:jit - FlatbufferTest.ExtraFiles (0.187) Summary Pass: 1 ListingSuccess: 1 If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users Finished test run: https://www.internalfb.com/intern/testinfra/testrun/4503599723101693 ``` Similar Build Deps Dags ``` [[email protected] /data/users/pavithran/fbsource] buck query 'allpaths(//xplat/caffe2:torch_mobile_all_ops_pickle_and_flatbuffer, //xplat/caffe2:torch_mobile_deserialize_pickle_and_flatbuffer)' --output-format dot-compact | pastry P486770901: https://www.internalfb.com/intern/paste/P486770901/ [[email protected] /data/users/pavithran/fbsource] buck query 'allpaths(//xplat/caffe2:torch_mobile_all_ops, //xplat/caffe2:torch_mobile_deserialize)' --output-format dot-compact | pastry P486771278: https://www.internalfb.com/intern/paste/P486771278/ ``` pickle_and_flatbuffer: https://www.internalfb.com/intern/dgw/graph/?build_id=P486770901 pickle: https://www.internalfb.com/intern/dgw/graph/?build_id=P486771278 Reviewed By: iseeyuan Differential Revision: D35067157 fbshipit-source-id: 9044259c17a2e0da79bd6aedb28efbdfd57e23e0 (cherry picked from commit f738069)
1 parent d64e763 commit fc2cf3d

12 files changed

+240
-69
lines changed

test/cpp/jit/test_flatbuffer.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,30 @@ TEST(FlatbufferTest, ExtraFiles) {
153153
extra_files["metadata.json"] = "abc";
154154
extra_files["mobile_info.json"] = "{\"key\": 23}";
155155

156+
std::unordered_map<std::string, std::string> loaded_extra_files;
157+
#if defined ENABLE_FLATBUFFER
158+
std::stringstream ss;
159+
module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
160+
161+
loaded_extra_files["metadata.json"] = "";
162+
auto mobile_module = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
163+
164+
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
165+
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
166+
167+
// load it twice using the same stream
168+
auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
169+
#else
156170
CompilationOptions options;
157171
mobile::Module bc = jitModuleToMobile(*module, options);
158172
auto buff = save_mobile_module_to_bytes(bc, extra_files);
159173

160-
std::unordered_map<std::string, std::string> loaded_extra_files;
161174
loaded_extra_files["metadata.json"] = "";
162175
auto* flatbuffer_module =
163176
mobile::serialization::GetMutableModule(buff.data());
164177

165178
parseExtraFiles(flatbuffer_module, loaded_extra_files);
179+
#endif
166180

167181
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
168182
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");

test/cpp/jit/test_lite_interpreter.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,6 @@ TEST(LiteInterpreterTest, ExtraFiles) {
991991
module->_save_for_mobile(oss, extra_files);
992992

993993
std::istringstream iss(oss.str());
994-
caffe2::serialize::IStreamAdapter adapter{&iss};
995994
std::unordered_map<std::string, std::string> loaded_extra_files;
996995
loaded_extra_files["metadata.json"] = "";
997996
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
@@ -1006,7 +1005,7 @@ TEST(LiteInterpreterTest, ExtraFiles) {
10061005
loaded_extra_files[file_name.substr(6)] = "";
10071006
}
10081007
}
1009-
1008+
iss.seekg(0, iss.beg);
10101009
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
10111010
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
10121011
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");

torch/csrc/jit/api/module.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,14 @@ struct TORCH_API Module : public Object {
223223
void _save_for_mobile(
224224
std::ostream& out,
225225
const ExtraFilesMap& extra_files = ExtraFilesMap(),
226-
bool save_mobile_debug_info = false) const;
226+
bool save_mobile_debug_info = false,
227+
bool use_flatbuffer = false) const;
227228

228229
void _save_for_mobile(
229230
const std::string& filename,
230231
const ExtraFilesMap& extra_files = ExtraFilesMap(),
231-
bool save_mobile_debug_info = false) const;
232+
bool save_mobile_debug_info = false,
233+
bool use_flatbuffer = false) const;
232234

233235
Module copy() const;
234236

torch/csrc/jit/api/module_save.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,29 @@ void Module::save(const std::string& filename, const ExtraFilesMap& extra_files)
1616
void Module::_save_for_mobile(
1717
std::ostream& out,
1818
const ExtraFilesMap& extra_files,
19-
bool save_mobile_debug_info) const {
19+
bool save_mobile_debug_info,
20+
bool use_flatbuffer) const {
2021
ExportModule(
2122
*this,
2223
out,
2324
extra_files,
2425
true /* bytecode_format */,
25-
save_mobile_debug_info);
26+
save_mobile_debug_info,
27+
use_flatbuffer);
2628
}
2729

2830
void Module::_save_for_mobile(
2931
const std::string& filename,
3032
const ExtraFilesMap& extra_files,
31-
bool save_mobile_debug_info) const {
33+
bool save_mobile_debug_info,
34+
bool use_flatbuffer) const {
3235
ExportModule(
3336
*this,
3437
filename,
3538
extra_files,
3639
true /* bytecode_format */,
37-
save_mobile_debug_info);
40+
save_mobile_debug_info,
41+
use_flatbuffer);
3842
}
3943

4044
} // namespace jit

torch/csrc/jit/mobile/flatbuffer_loader.cpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <ATen/core/ivalue.h>
77
#include <ATen/core/qualified_name.h>
88
#include <c10/core/CPUAllocator.h>
9+
#include <c10/core/impl/alloc_cpu.h>
910
#include <c10/util/Exception.h>
1011
#include <c10/util/Optional.h>
1112
#include <c10/util/ScopeExit.h>
@@ -589,26 +590,34 @@ std::tuple<std::shared_ptr<char>, size_t> get_file_content(
589590
// make sure buffer size is multiple of alignment
590591
size_t buffer_size =
591592
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
592-
#if defined(__ANDROID__)
593593
std::shared_ptr<char> data(
594-
static_cast<char*>(memalign(FLATBUFFERS_MAX_ALIGNMENT, buffer_size)),
595-
free);
596-
#elif defined(_WIN32)
597-
std::shared_ptr<char> data(
598-
static_cast<char*>(
599-
_aligned_malloc(buffer_size, FLATBUFFERS_MAX_ALIGNMENT)),
600-
_aligned_free); // NOLINT
601-
#else
602-
std::shared_ptr<char> data(
603-
static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, buffer_size)),
604-
free); // NOLINT
605-
#endif
594+
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
606595
fread(data.get(), size, 1, f);
607596
fclose(f);
608597
#endif
609598
return std::make_tuple(data, size);
610599
}
611600

601+
std::tuple<std::shared_ptr<char>, size_t> get_stream_content(std::istream& in) {
602+
// get size of the stream and reset to orig
603+
std::streampos orig_pos = in.tellg();
604+
in.seekg(orig_pos, std::ios::end);
605+
const long size = in.tellg();
606+
in.seekg(orig_pos, in.beg);
607+
608+
// read stream
609+
// NOLINT make sure buffer size is multiple of alignment
610+
size_t buffer_size =
611+
(size / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT;
612+
std::shared_ptr<char> data(
613+
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
614+
in.read(data.get(), size);
615+
616+
// reset stream to original position
617+
in.seekg(orig_pos, in.beg);
618+
return std::make_tuple(data, size);
619+
}
620+
612621
void FlatbufferLoader::extractJitSourceAndConstants(
613622
ExtraFilesMap* jit_sources,
614623
std::vector<IValue>* constants) {
@@ -626,6 +635,9 @@ mobile::Module parse_and_initialize_mobile_module(
626635
std::shared_ptr<char> data,
627636
size_t,
628637
c10::optional<at::Device>) {
638+
TORCH_CHECK(
639+
mobile::serialization::ModuleBufferHasIdentifier(data.get()),
640+
"Format error");
629641
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
630642
mobile::Module m = FlatbufferLoader().parseModule(flatbuffer_module);
631643
m.set_delete_memory(std::move(data));

torch/csrc/jit/mobile/flatbuffer_loader.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ TORCH_API void parseExtraFiles(
5959
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_file_content(
6060
const char* filename);
6161

62+
TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
63+
std::istream& in);
64+
6265
class TORCH_API FlatbufferLoader {
6366
public:
6467
FlatbufferLoader();

torch/csrc/jit/mobile/import.cpp

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#include <caffe2/serialize/inline_container.h>
1111
#include <caffe2/serialize/versions.h>
1212
#include <torch/csrc/jit/api/compilation_unit.h>
13+
#include <torch/csrc/jit/mobile/file_format.h>
14+
#if defined(ENABLE_FLATBUFFER)
15+
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
16+
#endif
1317
#include <torch/csrc/jit/mobile/interpreter.h>
1418
#include <torch/csrc/jit/mobile/observer.h>
1519
#include <torch/csrc/jit/mobile/type_parser.h>
@@ -536,29 +540,85 @@ mobile::Module _load_for_mobile(
536540
std::istream& in,
537541
c10::optional<at::Device> device,
538542
ExtraFilesMap& extra_files) {
539-
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
540-
auto module = _load_for_mobile(std::move(rai), device, extra_files);
541-
return module;
543+
auto format = getFileFormat(in);
544+
switch (format) {
545+
case FileFormat::ZipFileFormat: {
546+
std::unique_ptr<IStreamAdapter> rai =
547+
std::make_unique<IStreamAdapter>(&in);
548+
auto module = _load_for_mobile(std::move(rai), device, extra_files);
549+
return module;
550+
}
551+
#if defined(ENABLE_FLATBUFFER)
552+
case FileFormat::FlatbufferFileFormat: {
553+
std::shared_ptr<char> data;
554+
size_t size = 0;
555+
std::tie(data, size) = get_stream_content(in);
556+
auto* flatbuffer_module =
557+
mobile::serialization::GetMutableModule(data.get());
558+
mobile::Module m = initialize_mobile_module(flatbuffer_module);
559+
parseExtraFiles(flatbuffer_module, extra_files);
560+
return m;
561+
}
562+
#else
563+
case FileFormat::FlatbufferFileFormat: {
564+
TORCH_CHECK(
565+
false,
566+
"Flatbuffer input file but the build hasn't enabled flatbuffer");
567+
}
568+
#endif
569+
default: {
570+
TORCH_CHECK(false, "Format error");
571+
}
572+
}
542573
}
543574

544575
mobile::Module _load_for_mobile(
545576
const std::string& filename,
546577
c10::optional<at::Device> device,
547578
ExtraFilesMap& extra_files) {
548-
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
549-
auto module = _load_for_mobile(std::move(rai), device, extra_files);
550-
return module;
579+
return _load_for_mobile(
580+
filename,
581+
device,
582+
extra_files,
583+
/*module_load_options=*/_default_mobile_module_load_options);
551584
}
552585

553586
mobile::Module _load_for_mobile(
554587
const std::string& filename,
555588
c10::optional<at::Device> device,
556589
ExtraFilesMap& extra_files,
557590
uint64_t module_load_options) {
558-
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
559-
auto module = _load_for_mobile_impl(
560-
std::move(rai), device, extra_files, module_load_options);
561-
return module;
591+
auto format = getFileFormat(filename);
592+
switch (format) {
593+
case FileFormat::ZipFileFormat: {
594+
std::unique_ptr<FileAdapter> rai =
595+
std::make_unique<FileAdapter>(filename);
596+
auto module = _load_for_mobile_impl(
597+
std::move(rai), device, extra_files, module_load_options);
598+
return module;
599+
}
600+
#if defined(ENABLE_FLATBUFFER)
601+
case FileFormat::FlatbufferFileFormat: {
602+
std::shared_ptr<char> data;
603+
size_t size = 0;
604+
std::tie(data, size) = get_file_content(filename.c_str());
605+
auto* flatbuffer_module =
606+
mobile::serialization::GetMutableModule(data.get());
607+
mobile::Module m = initialize_mobile_module(flatbuffer_module);
608+
parseExtraFiles(flatbuffer_module, extra_files);
609+
return m;
610+
}
611+
#else
612+
case FileFormat::FlatbufferFileFormat: {
613+
TORCH_CHECK(
614+
false,
615+
"Flatbuffer input file but the build hasn't enabled flatbuffer");
616+
}
617+
#endif
618+
default: {
619+
TORCH_CHECK(false, "Format error");
620+
}
621+
}
562622
}
563623

564624
mobile::Module _load_for_mobile(

torch/csrc/jit/mobile/import_data.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <torch/csrc/jit/runtime/instruction.h>
1010
#include <torch/csrc/jit/serialization/unpickler.h>
1111
#include <torch/custom_class.h>
12-
1312
#include <exception>
1413
#include <fstream>
1514
#include <string>

torch/csrc/jit/python/script_init.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,23 +1096,32 @@ void initJitScriptBindings(PyObject* module) {
10961096
[](Module& m,
10971097
const std::string& filename,
10981098
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1099-
bool _save_mobile_debug_info = false) {
1100-
m._save_for_mobile(filename, _extra_files, _save_mobile_debug_info);
1099+
bool _save_mobile_debug_info = false,
1100+
bool _use_flatbuffer = false) {
1101+
m._save_for_mobile(
1102+
filename,
1103+
_extra_files,
1104+
_save_mobile_debug_info,
1105+
_use_flatbuffer);
11011106
},
11021107
py::arg("filename"),
11031108
py::arg("_extra_files") = ExtraFilesMap(),
1104-
py::arg("_save_mobile_debug_info") = false)
1109+
py::arg("_save_mobile_debug_info") = false,
1110+
py::arg("_use_flatbuffer") = false)
11051111
.def(
11061112
"_save_to_buffer_for_mobile",
11071113
[](Module& m,
11081114
const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1109-
bool _save_mobile_debug_info = false) {
1115+
bool _save_mobile_debug_info = false,
1116+
bool _use_flatbuffer = false) {
11101117
std::ostringstream buf;
1111-
m._save_for_mobile(buf, _extra_files, _save_mobile_debug_info);
1118+
m._save_for_mobile(
1119+
buf, _extra_files, _save_mobile_debug_info, _use_flatbuffer);
11121120
return py::bytes(buf.str());
11131121
},
11141122
py::arg("_extra_files") = ExtraFilesMap(),
1115-
py::arg("_save_mobile_debug_info") = false)
1123+
py::arg("_save_mobile_debug_info") = false,
1124+
py::arg("_use_flatbuffer") = false)
11161125
.def("_set_optimized", &Module::set_optimized)
11171126
.def(
11181127
"dump",
@@ -1891,6 +1900,10 @@ void initJitScriptBindings(PyObject* module) {
18911900
std::istringstream in(buffer);
18921901
return _get_mobile_model_contained_types(in);
18931902
});
1903+
m.def("_nn_module_to_mobile", [](const Module& module) {
1904+
CompilationOptions options;
1905+
return jitModuleToMobile(module, options);
1906+
});
18941907
py::class_<OperatorInfo>(m, "OperatorInfo")
18951908
.def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
18961909
m.def("_get_model_ops_and_info", [](const std::string& filename) {

torch/csrc/jit/serialization/export.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,21 +158,24 @@ TORCH_API void ExportModule(
158158
std::ostream& out,
159159
const ExtraFilesMap& metadata = ExtraFilesMap(),
160160
bool bytecode_format = false,
161-
bool save_mobile_debug_info = false);
161+
bool save_mobile_debug_info = false,
162+
bool use_flatbuffer = false);
162163

163164
TORCH_API void ExportModule(
164165
const Module& module,
165166
const std::string& filename,
166167
const ExtraFilesMap& metadata = ExtraFilesMap(),
167168
bool bytecode_format = false,
168-
bool save_mobile_debug_info = false);
169+
bool save_mobile_debug_info = false,
170+
bool use_flatbuffer = false);
169171

170172
TORCH_API void ExportModule(
171173
const Module& module,
172174
const std::function<size_t(const void*, size_t)>& writer_func,
173175
const ExtraFilesMap& metadata = ExtraFilesMap(),
174176
bool bytecode_format = false,
175-
bool save_mobile_debug_info = false);
177+
bool save_mobile_debug_info = false,
178+
bool use_flatbuffer = false);
176179

177180
// Write the bytes of a pickle archive and the tensors referenced inside that
178181
// archive

0 commit comments

Comments
 (0)