Skip to content

Commit b8ba480

Browse files
qihqipytorchmergebot
authored andcommitted
Add an option to skip loading of debug traces (pytorch#91430)
Summary: Debug traces consumes lots of memory especially for small models. Test Plan: Unit test Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#91430 Approved by: https://github.com/davidberard98
1 parent 6ec3d65 commit b8ba480

File tree

6 files changed

+191
-39
lines changed

6 files changed

+191
-39
lines changed

caffe2/serialize/inline_container.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
namespace caffe2 {
2727
namespace serialize {
28+
constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
2829

2930
size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
3031
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
@@ -222,6 +223,10 @@ size_t getPadding(
222223

223224
bool PyTorchStreamReader::hasRecord(const std::string& name) {
224225
std::lock_guard<std::mutex> guard(reader_lock_);
226+
227+
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
228+
return false;
229+
}
225230
std::string ss = archive_name_plus_slash_ + name;
226231
mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
227232
const mz_zip_error err = mz_zip_get_last_error(ar_.get());
@@ -255,8 +260,11 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
255260
": ",
256261
buf);
257262
}
258-
// NOLINTNEXTLINE(modernize-use-emplace)
259-
out.push_back(buf + archive_name_plus_slash_.size());
263+
if ((load_debug_symbol_) ||
264+
(!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) {
265+
// NOLINTNEXTLINE(modernize-use-emplace)
266+
out.push_back(buf + archive_name_plus_slash_.size());
267+
}
260268
}
261269
return out;
262270
}
@@ -276,6 +284,10 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
276284
// return dataptr, size
277285
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
278286
std::lock_guard<std::mutex> guard(reader_lock_);
287+
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
288+
at::DataPtr retval;
289+
return std::make_tuple(std::move(retval), 0);
290+
}
279291
size_t key = getRecordID(name);
280292
mz_zip_archive_file_stat stat;
281293
mz_zip_reader_file_stat(ar_.get(), key, &stat);

caffe2/serialize/inline_container.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class TORCH_API PyTorchStreamReader final {
110110
return version_;
111111
}
112112

113+
void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
114+
load_debug_symbol_ = should_load_debug_symbol;
115+
}
116+
113117
private:
114118
void init();
115119
size_t read(uint64_t pos, char* buf, size_t n);
@@ -124,6 +128,7 @@ class TORCH_API PyTorchStreamReader final {
124128
std::shared_ptr<ReadAdapterInterface> in_;
125129
int64_t version_;
126130
std::mutex reader_lock_;
131+
bool load_debug_symbol_ = true;
127132
};
128133

129134
class TORCH_API PyTorchStreamWriter final {

caffe2/serialize/inline_container_test.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,53 @@ TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
120120
EXPECT_TRUE(reader.hasRecord("key1"));
121121
}
122122

123+
TEST(PytorchStreamWriterAndReader, SkipDebugRecords) {
124+
std::ostringstream oss;
125+
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
126+
oss.write(static_cast<const char*>(b), n);
127+
return oss ? n : 0;
128+
});
129+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
130+
std::array<char, 127> data1;
131+
132+
for (auto i: c10::irange(data1.size())) {
133+
data1[i] = data1.size() - i;
134+
}
135+
writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
136+
137+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
138+
std::array<char, 64> data2;
139+
for (auto i: c10::irange(data2.size())) {
140+
data2[i] = data2.size() - i;
141+
}
142+
writer.writeRecord("key2.debug_pkl", data2.data(), data2.size());
143+
144+
const std::unordered_set<std::string>& written_records =
145+
writer.getAllWrittenRecords();
146+
ASSERT_EQ(written_records.size(), 2);
147+
ASSERT_EQ(written_records.count("key1.debug_pkl"), 1);
148+
ASSERT_EQ(written_records.count("key2.debug_pkl"), 1);
149+
writer.writeEndOfFile();
150+
151+
std::string the_file = oss.str();
152+
std::ofstream foo("output2.zip");
153+
foo.write(the_file.c_str(), the_file.size());
154+
foo.close();
155+
156+
std::istringstream iss(the_file);
157+
158+
// read records through readers
159+
PyTorchStreamReader reader(&iss);
160+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
161+
162+
reader.setShouldLoadDebugSymbol(false);
163+
EXPECT_FALSE(reader.hasRecord("key1.debug_pkl"));
164+
at::DataPtr ptr;
165+
size_t size;
166+
std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
167+
EXPECT_EQ(size, 0);
168+
}
169+
123170
} // namespace
124171
} // namespace serialize
125172
} // namespace caffe2

test/cpp/jit/test_save_load.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#include <gtest/gtest.h>
22

33
#include <test/cpp/jit/test_utils.h>
4+
#include <iostream>
45
#include <sstream>
56

7+
#include <caffe2/serialize/inline_container.h>
68
#include <torch/csrc/jit/mobile/module.h>
79
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
810
#include <torch/csrc/jit/serialization/export.h>
@@ -272,5 +274,49 @@ TEST(SerializationTest, CalculateNecessaryArgsTest) {
272274
EXPECT_EQ(0, necessary.second);
273275
}
274276

277+
TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest)
278+
Module m("m");
279+
m.register_parameter("foo", torch::ones({}), false);
280+
m.define(
281+
R"(
282+
def test_func(self, x):
283+
b = 4
284+
return self.foo + x + b
285+
)");
286+
m.define(
287+
R"(
288+
def exception(self):
289+
assert False, "message"
290+
)");
291+
std::stringstream ss;
292+
m.save(ss);
293+
ss.seekg(0);
294+
caffe2::serialize::PyTorchStreamReader reader(&ss);
295+
reader.setShouldLoadDebugSymbol(true);
296+
EXPECT_TRUE(reader.hasRecord("code/__torch__.py.debug_pkl"));
297+
reader.setShouldLoadDebugSymbol(false);
298+
EXPECT_FALSE(reader.hasRecord("code/__torch__.py.debug_pkl"));
299+
ss.seekg(0);
300+
Module m2 = torch::jit::load(ss);
301+
std::string error_msg = R"(
302+
def exception(self):
303+
assert False, "message"
304+
~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE)";
305+
ASSERT_THROWS_WITH_MESSAGE(m2.run_method("exception"), error_msg);
306+
307+
ss.seekg(0);
308+
// NO DEBUG trace so error message points to torchscript generated
309+
// source instead of original python source.
310+
std::string error2 = R"(
311+
def exception(self: __torch__.m) -> NoneType:
312+
_0 = uninitialized(NoneType)
313+
ops.prim.RaiseException("AssertionError: message")
314+
~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
315+
return _0
316+
)";
317+
Module m3 = torch::jit::load(ss, c10::nullopt, false);
318+
ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception"), error2);
319+
}
320+
275321
} // namespace jit
276322
} // namespace torch

torch/csrc/jit/serialization/import.cpp

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,11 @@ Module ScriptModuleDeserializer::deserialize(
290290
Module import_ir_module(
291291
std::shared_ptr<CompilationUnit> cu,
292292
std::istream& in,
293-
c10::optional<at::Device> device) {
293+
c10::optional<at::Device> device,
294+
bool load_debug_files) {
294295
ExtraFilesMap extra_files;
295-
return import_ir_module(std::move(cu), in, device, extra_files);
296+
return import_ir_module(
297+
std::move(cu), in, device, extra_files, load_debug_files);
296298
}
297299

298300
static Module _load_jit_module_from_bytes(
@@ -344,12 +346,14 @@ Module import_ir_module(
344346
std::shared_ptr<CompilationUnit> cu,
345347
std::istream& in,
346348
c10::optional<at::Device> device,
347-
ExtraFilesMap& extra_files) {
349+
ExtraFilesMap& extra_files,
350+
bool load_debug_files) {
348351
in.seekg(0, in.beg);
349352
// NOTE: Zipformat can be large files. So using stream version directly
350353
// instead of reading the file all at once.
351354
if (getFileFormat(in) != FileFormat::FlatbufferFileFormat) {
352355
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
356+
reader->setShouldLoadDebugSymbol(load_debug_files);
353357
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
354358
return deserializer.deserialize(device, extra_files);
355359
}
@@ -379,20 +383,24 @@ Module import_ir_module(
379383
Module import_ir_module(
380384
std::shared_ptr<CompilationUnit> cu,
381385
const std::string& filename,
382-
c10::optional<at::Device> device) {
386+
c10::optional<at::Device> device,
387+
bool load_debug_files) {
383388
ExtraFilesMap extra_files;
384-
return import_ir_module(std::move(cu), filename, device, extra_files);
389+
return import_ir_module(
390+
std::move(cu), filename, device, extra_files, load_debug_files);
385391
}
386392

387393
Module import_ir_module(
388394
std::shared_ptr<CompilationUnit> cu,
389395
const std::string& filename,
390396
c10::optional<at::Device> device,
391-
ExtraFilesMap& extra_files) {
397+
ExtraFilesMap& extra_files,
398+
bool load_debug_files) {
392399
// NOTE: Zipformat can be large files. So using stream version directly
393400
// instead of reading the file all at once.
394401
if (getFileFormat(filename) != FileFormat::FlatbufferFileFormat) {
395402
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
403+
reader->setShouldLoadDebugSymbol(load_debug_files);
396404
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
397405
return deserializer.deserialize(device, extra_files);
398406
}
@@ -405,70 +413,90 @@ Module import_ir_module(
405413
Module import_ir_module(
406414
std::shared_ptr<CompilationUnit> cu,
407415
std::unique_ptr<ReadAdapterInterface> rai,
408-
c10::optional<at::Device> device) {
416+
c10::optional<at::Device> device,
417+
bool load_debug_files) {
409418
ExtraFilesMap extra_files;
410-
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
419+
return import_ir_module(
420+
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
411421
}
412422

413423
Module import_ir_module(
414424
std::shared_ptr<CompilationUnit> cu,
415425
std::unique_ptr<ReadAdapterInterface> rai,
416426
c10::optional<at::Device> device,
417-
ExtraFilesMap& extra_files) {
427+
ExtraFilesMap& extra_files,
428+
bool load_debug_files) {
418429
std::shared_ptr<ReadAdapterInterface> rai_shared = std::move(rai);
419-
return import_ir_module(cu, rai_shared, device, extra_files);
430+
return import_ir_module(
431+
cu, rai_shared, device, extra_files, load_debug_files);
420432
}
421433

422434
Module import_ir_module(
423435
std::shared_ptr<CompilationUnit> cu,
424436
std::shared_ptr<ReadAdapterInterface> rai,
425437
c10::optional<at::Device> device,
426-
ExtraFilesMap& extra_files) {
438+
ExtraFilesMap& extra_files,
439+
bool load_debug_files) {
427440
auto reader = std::make_shared<PyTorchStreamReader>(std::move(rai));
441+
reader->setShouldLoadDebugSymbol(load_debug_files);
428442
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
429443
return deserializer.deserialize(device, extra_files);
430444
}
431445

432-
Module load(std::istream& in, c10::optional<at::Device> device) {
446+
Module load(
447+
std::istream& in,
448+
c10::optional<at::Device> device,
449+
bool load_debug_files) {
433450
auto cu = std::make_shared<CompilationUnit>();
434-
return import_ir_module(std::move(cu), in, device);
451+
return import_ir_module(std::move(cu), in, device, load_debug_files);
435452
}
436453

437454
Module load(
438455
std::istream& in,
439456
c10::optional<at::Device> device,
440-
ExtraFilesMap& extra_files) {
457+
ExtraFilesMap& extra_files,
458+
bool load_debug_files) {
441459
auto cu = std::make_shared<CompilationUnit>();
442-
return import_ir_module(std::move(cu), in, device, extra_files);
460+
return import_ir_module(
461+
std::move(cu), in, device, extra_files, load_debug_files);
443462
}
444463

445-
Module load(const std::string& filename, c10::optional<at::Device> device) {
464+
Module load(
465+
const std::string& filename,
466+
c10::optional<at::Device> device,
467+
bool load_debug_files) {
446468
auto cu = std::make_shared<CompilationUnit>();
447-
return import_ir_module(std::move(cu), filename, device);
469+
return import_ir_module(std::move(cu), filename, device, load_debug_files);
448470
}
449471

450472
Module load(
451473
const std::string& filename,
452474
c10::optional<at::Device> device,
453-
ExtraFilesMap& extra_files) {
475+
ExtraFilesMap& extra_files,
476+
bool load_debug_files) {
454477
auto cu = std::make_shared<CompilationUnit>();
455-
return import_ir_module(std::move(cu), filename, device, extra_files);
478+
return import_ir_module(
479+
std::move(cu), filename, device, extra_files, load_debug_files);
456480
}
457481

458482
Module load(
459483
std::shared_ptr<ReadAdapterInterface> rai,
460-
c10::optional<c10::Device> device) {
484+
c10::optional<c10::Device> device,
485+
bool load_debug_files) {
461486
auto cu = std::make_shared<CompilationUnit>();
462487
ExtraFilesMap extra_files;
463-
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
488+
return import_ir_module(
489+
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
464490
}
465491

466492
Module load(
467493
std::shared_ptr<ReadAdapterInterface> rai,
468494
c10::optional<c10::Device> device,
469-
ExtraFilesMap& extra_files) {
495+
ExtraFilesMap& extra_files,
496+
bool load_debug_files) {
470497
auto cu = std::make_shared<CompilationUnit>();
471-
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
498+
return import_ir_module(
499+
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
472500
}
473501

474502
Module _load_jit_module_from_bytes(

0 commit comments

Comments
 (0)