Skip to content

Commit 813adf1

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
[Pytorch Delegated Backend] Save operator name and function name in (pytorch#57441)
Summary: Pull Request resolved: pytorch#57441 debug info Previous diffs did not save operator name in debug info. For delegated backends that only idenfity op for profiling with debug handle, operator name should be stores as well. Furthermore to complete debug informaton also serialize function name. Test Plan: Existing lite interpreter and backend tests Existing lite interpreter and backend tests Imported from OSS Differential Revision: D28144581 D28144581 Reviewed By: raziel Pulled By: kimishpatel fbshipit-source-id: 415210f147530a53b444b07f1d6ee699a3570d99
1 parent a7a5992 commit 813adf1

13 files changed

+95
-69
lines changed

test/cpp/jit/test_backend.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ TEST(BackendTestDebugInfo, TestCompiler) {
190190
lm._save_for_mobile(ss, ExtraFilesMap(), true);
191191
auto mlm = _load_for_mobile(ss);
192192
std::string error_pattern = R"(
193-
Module hierarchy:top(backend_with_compiler_demoLoweredModule)
193+
Module hierarchy:top(backend_with_compiler_demoLoweredModule).aten::add
194194
Traceback of TorchScript (most recent call last):
195195
File "<string>", line 5, in FunctionName_UNKNOWN
196196
typed_inputs: List[Any] = [x, h, ]
@@ -244,7 +244,7 @@ TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
244244
lm._save_for_mobile(ss, ExtraFilesMap(), true);
245245
auto mlm = _load_for_mobile(ss);
246246
std::string error_pattern = R"(
247-
Module hierarchy:top(backend_with_compiler_demoLoweredModule).A0(A)
247+
Module hierarchy:top(backend_with_compiler_demoLoweredModule).A0(A).aten::add
248248
Traceback of TorchScript (most recent call last):
249249
File "<string>", line 5, in FunctionName_UNKNOWN
250250
typed_inputs: List[Any] = [x, y, ]
@@ -337,7 +337,7 @@ TEST(
337337
*
338338
*/
339339
std::string error_pattern = R"(
340-
Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A)
340+
Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A).aten::add
341341
Traceback of TorchScript (most recent call last):
342342
File "<string>", line 5, in FunctionName_UNKNOWN
343343
typed_inputs: List[Any] = [x, y, ]
@@ -424,7 +424,7 @@ TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) {
424424
c._save_for_mobile(ss, ExtraFilesMap(), true);
425425
auto c_loaded = _load_for_mobile(ss);
426426
std::string error_pattern = R"(
427-
Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule)
427+
Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).aten::add
428428
Traceback of TorchScript (most recent call last):
429429
File "<string>", line 3, in FunctionName_UNKNOWN
430430
@@ -545,7 +545,7 @@ TEST(
545545
*
546546
* */
547547
std::string error_pattern = R"(
548-
Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA)
548+
Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA).aten::add
549549
Traceback of TorchScript (most recent call last):
550550
File "<string>", line 3, in FunctionName_UNKNOWN
551551

test/cpp/jit/test_cs_debug_info_serialization.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,23 @@ namespace jit {
2525

2626
namespace {
2727
bool validate_debug_info(
28-
const DebugInfoPair& pre_serialize,
29-
const DebugInfoPair& post_serialize) {
30-
auto sr1 = pre_serialize.first;
31-
auto sr2 = post_serialize.first;
28+
const DebugInfoTuple& pre_serialize,
29+
const DebugInfoTuple& post_serialize) {
30+
auto sr1 = std::get<kDebugInfoTupleSourceRangeIndex>(pre_serialize);
31+
auto sr2 = std::get<kDebugInfoTupleSourceRangeIndex>(post_serialize);
3232
if (sr1 != sr2) {
3333
return false;
3434
}
35-
if (!pre_serialize.second.defined()) {
36-
return !post_serialize.second.defined();
35+
auto csptr1 = std::get<kDebugInfoTupleInlinedCSIndex>(pre_serialize);
36+
auto csptr2 = std::get<kDebugInfoTupleInlinedCSIndex>(post_serialize);
37+
if (!csptr1.defined()) {
38+
return !csptr2.defined();
3739
}
38-
if (!post_serialize.second.defined()) {
40+
if (!csptr2.defined()) {
3941
return false;
4042
}
41-
auto vec1 = pre_serialize.second->vec();
42-
auto vec2 = post_serialize.second->vec();
43+
auto vec1 = csptr1->vec();
44+
auto vec2 = csptr2->vec();
4345
if (vec1.size() != vec2.size()) {
4446
return false;
4547
}

test/cpp/jit/test_lite_interpreter.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,7 @@ TEST(LiteInterpreterTest, ModuleInfoBasic) {
496496
}
497497
}
498498

499-
std::unordered_set<std::string> expected_result({"top(M)"});
500-
AT_ASSERT(module_debug_info_set == expected_result);
499+
AT_ASSERT(module_debug_info_set.count("top(M).aten::mul"));
501500
}
502501

503502
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@@ -559,8 +558,9 @@ TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
559558
}
560559
}
561560

562-
std::set<std::string> expected_result({"top(B)", "top(B).A0(A)"});
563-
AT_ASSERT(module_debug_info_set == expected_result);
561+
AT_ASSERT(module_debug_info_set.count("top(B).aten::add"));
562+
AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::add"));
563+
AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::mul"));
564564
}
565565

566566
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@@ -594,7 +594,6 @@ TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
594594
std::string module_info = bc.get_forward_method_debug_info(pc);
595595
if (!module_info.empty() &&
596596
(module_info.find("debug_handle") == std::string::npos)) {
597-
std::cout << "Module info:" << module_info << std::endl;
598597
module_debug_info_set.insert(module_info);
599598
}
600599
++pc;
@@ -603,9 +602,9 @@ TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
603602
}
604603
}
605604

606-
std::set<std::string> expected_result(
607-
{"top(C)", "top(C).A0(A)", "top(C).B0(B)"});
608-
AT_ASSERT(module_debug_info_set == expected_result);
605+
AT_ASSERT(module_debug_info_set.count("top(C).aten::add"));
606+
AT_ASSERT(module_debug_info_set.count("top(C).A0(A).aten::add"));
607+
AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add"));
609608
}
610609

611610
TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
@@ -790,9 +789,9 @@ TEST(LiteInterpreterTest, SequentialModuleInfo) {
790789
// def forward(self, x):
791790
// return self.A0.forward(self.B0.forward(x))
792791

793-
std::set<std::string> expected_result(
794-
{"top(C)", "top(C).A0(A)", "top(C).B0(B)"});
795-
AT_ASSERT(module_debug_info_set == expected_result);
792+
AT_ASSERT(module_debug_info_set.count("top(C).prim::Return"));
793+
AT_ASSERT(module_debug_info_set.count("top(C).A0(A).aten::add"));
794+
AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add"));
796795
}
797796

798797
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@@ -838,9 +837,9 @@ TEST(LiteInterpreterTest, HierarchyModuleInfo) {
838837
// "top(C).forward": for the add operator in top.
839838
// "top(C).B0(B).forward": for the add operator in B0.
840839
// "top(C).B0(B).forward.A0(A).forward": for the add operator in A0.
841-
std::set<std::string> expected_result(
842-
{"top(C)", "top(C).B0(B)", "top(C).B0(B).A0(A)"});
843-
AT_ASSERT(module_debug_info_set == expected_result);
840+
AT_ASSERT(module_debug_info_set.count("top(C).aten::add"));
841+
AT_ASSERT(module_debug_info_set.count("top(C).B0(B).aten::add"));
842+
AT_ASSERT(module_debug_info_set.count("top(C).B0(B).A0(A).aten::add"));
844843
}
845844

846845
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@@ -898,9 +897,9 @@ TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) {
898897
// "top(B).A0(A).forward": for the add operator in A0.
899898
// "top(B).A1(A).forward": for the add operator in A1.
900899

901-
std::set<std::string> expected_result(
902-
{"top(B)", "top(B).A0(A)", "top(B).A1(A)"});
903-
AT_ASSERT(module_debug_info_set == expected_result);
900+
AT_ASSERT(module_debug_info_set.count("top(B).aten::add"));
901+
AT_ASSERT(module_debug_info_set.count("top(B).A0(A).aten::add"));
902+
AT_ASSERT(module_debug_info_set.count("top(B).A1(A).aten::add"));
904903
}
905904

906905
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
Binary file not shown.

test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ TEST(RunTimeTest, DelegateException) {
142142
inputs.emplace_back(torch::rand({13, 9}));
143143

144144
std::string error_pattern = R"(
145-
Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA)
145+
Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA).aten::add
146146
Traceback of TorchScript (most recent call last):
147147
File "<string>", line 3, in FunctionName_UNKNOWN
148148

torch/csrc/jit/backends/backend_debug_handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ int64_t BackendDebugInfoRecorder::getNextDebugHandle(const Node* node) {
2020
DebugHandleType debug_handle = unique_debug_handle_;
2121
const SourceRange& range = node->sourceRange();
2222
handles_to_inlined_callstack_ptrs_[debug_handle] =
23-
std::make_pair(range, cs_ptr);
23+
std::make_tuple(range, node->kind().toQualString(), cs_ptr);
2424
// This increment is with seq memory order.
2525
// Not trying to perf optimizing this for now.
2626
unique_debug_handle_++;

torch/csrc/jit/backends/backend_debug_handler.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@ namespace jit {
1313
* BackendDebugHandleManager is responsible for issuing debug handles to
1414
* backends. Debug handles are associated with nodes of a graph.
1515
* BackendDebugHandleManager also maintains a map
16-
* [debug-handle, DebugInfoPair = {source range, inlined callstack ptr]} that
16+
* [debug-handle, DebugInfoTuple = {source range, inlined callstack ptr]} that
1717
* will help generate a callstack for exception raised using debug handles.
1818
* Effectively debug handles are something that is given to backend and later
1919
* when an exception occurs in the backend, backend can tell, using debug
2020
* handle, that an exception occurred here. Then the runtime can generate
2121
* callstack correspoding to the exception.
2222
* There are two parts to BackendDebugHandleManager:
2323
* 1. static std::atomic debug_handle
24-
* 2. Map of [debug-handle, DebugInfoPair]
24+
* 2. Map of [debug-handle, DebugInfoTuple]
2525
*
2626
* About 1:
2727
* Why do they have to be unique. The reason is that by ensuring
2828
* uniqueness of debug handles, we remove the burden of another layer of
2929
* mapping where we need to say this set of debug handles were generated for
3030
* this lowered module or this bytecode function. This simplifies the API for
31-
* serialization since debug handles can uniquely identify DebugInfoPair.
31+
* serialization since debug handles can uniquely identify DebugInfoTuple.
3232
* Thus simplifies the runtime API for throwing exception. Exception throwing
3333
* only needs to know debug_handle and not which module or method threw it.
3434
* There are 2 issues to keep in mind, though,for static std::atomic
@@ -40,8 +40,8 @@ namespace jit {
4040
* done.
4141
*
4242
* Now about 2:
43-
* There are two usecases for [debug-handle, DebugInfoPair]
44-
* A. During bytecode generation the DebugInfoPair corresponding to the nodes
43+
* There are two usecases for [debug-handle, DebugInfoTuple]
44+
* A. During bytecode generation the DebugInfoTuple corresponding to the nodes
4545
* of the inlined graph being serialized, are stored in this object and a
4646
* unique debug handle is returned. This unique debug handle is stored in
4747
* mobile_debug info for pytorch lite models. It will be used for raising
@@ -52,29 +52,29 @@ namespace jit {
5252
* the debug handles provide a way to map nodes of the graph to the model level
5353
* debug info.
5454
*
55-
* During byte-code model serialization, [debug-handle, DebugInfoPair] is
55+
* During byte-code model serialization, [debug-handle, DebugInfoTuple] is
5656
* serialized. Now we know a. debug handles and b. how to map debug handles to
5757
* model source code. Thus we can either do eager symbolication by converting
5858
* debug handles to corresponding source code at runtime, or do lazy
5959
* symbolicattion offline.
6060
*
61-
* Note that it is not necessary to serialize [debug-handle, DebugInfoPair]
61+
* Note that it is not necessary to serialize [debug-handle, DebugInfoTuple]
6262
* corresponding to lowered backend if the lowering process, that is
6363
* preprocess/compile, and execution happens in the same session, then eager
6464
* symbolication can be employed.
6565
*
6666
* Now how does BackendDebugHandleManager capture all of the above?
6767
* By providing two API.
6868
* 1. getNextDebugHandle which given a Node* returns a unique debug handle,
69-
* that will uniquely identify DebugInfoPair.
69+
* that will uniquely identify DebugInfoTuple.
7070
* and
7171
* 2. getCallStackPtrMap which returns the map
72-
* [debug-handle, DebugInfoPair]
72+
* [debug-handle, DebugInfoTuple]
7373
*
7474
* 1 provides debug handles to backends and 2 provides runtime a way to map
7575
* debug handles to source level debug info.
7676
*
77-
* So why does debug handle map to DebugInfoPair = {source range and inlined
77+
* So why does debug handle map to DebugInfoTuple = {source range and inlined
7878
* cs}? {debug_handle, source_range_tag, serialized_callstack} Take this
7979
* example: class L(nn.Module): def __init__(self):
8080
* ...
@@ -112,7 +112,7 @@ namespace jit {
112112
using DebugHandleType = int64_t;
113113

114114
using BackendDebugInfoMapType =
115-
std::unordered_map<DebugHandleType, DebugInfoPair>;
115+
std::unordered_map<DebugHandleType, DebugInfoTuple>;
116116

117117
/*
118118
* This class is used to generate debug info map.

torch/csrc/jit/ir/scope.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target {
175175
}
176176
};
177177

178-
using DebugInfoPair = std::pair<SourceRange, InlinedCallStackPtr>;
178+
// {source range, node name, InlinedCallStack}
179+
// We store node name because same debug infor will be used for
180+
// profiling as well, so we need to know op names as well.
181+
using DebugInfoTuple =
182+
std::tuple<SourceRange, std::string, InlinedCallStackPtr>;
183+
constexpr size_t kDebugInfoTupleSourceRangeIndex{0};
184+
constexpr size_t kDebugInfoTupleNodeNameIndex{1};
185+
constexpr size_t kDebugInfoTupleInlinedCSIndex{2};
179186
} // namespace jit
180187
} // namespace torch

torch/csrc/jit/mobile/debug_info.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ namespace jit {
1414
namespace {
1515

1616
std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
17-
const DebugInfoPair& source_callstack) {
17+
const DebugInfoTuple& source_callstack) {
1818
constexpr size_t kSourceRange = 1;
1919
constexpr size_t kModuleInstanceInfo = 2;
2020
std::vector<StackEntry> entries;
2121

22-
const SourceRange& range = source_callstack.first;
23-
InlinedCallStackPtr callstack_ptr = source_callstack.second;
22+
const SourceRange& range =
23+
std::get<kDebugInfoTupleSourceRangeIndex>(source_callstack);
24+
InlinedCallStackPtr callstack_ptr =
25+
std::get<kDebugInfoTupleInlinedCSIndex>(source_callstack);
2426
std::string module_info;
2527
if (!callstack_ptr) {
2628
// If not cs then top level node
@@ -70,7 +72,7 @@ std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy
7072
// will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv)
7173
// Source level stack information will be from model source code.
7274
std::pair<std::string, std::string> getStackTraceWithModuleHierarchy(
73-
const std::vector<DebugInfoPair>& source_callstacks,
75+
const std::vector<DebugInfoTuple>& source_callstacks,
7476
const std::string& root_scope_string,
7577
const std::string& top_module_type_name) {
7678
std::vector<StackEntry> stack_entries;
@@ -82,6 +84,12 @@ std::pair<std::string, std::string> getStackTraceWithModuleHierarchy(
8284
stack_entries.insert(stack_entries.end(), entries.begin(), entries.end());
8385
module_info += debug_info_pair.second;
8486
}
87+
// Only last entry in the callstack will have a node name of interest.
88+
// Rest are likely CallMethod/CallFunction nodes
89+
auto last_entry = source_callstacks.back();
90+
const std::string& node_name =
91+
std::get<kDebugInfoTupleNodeNameIndex>(last_entry);
92+
module_info += "." + node_name;
8593
std::ostringstream ss;
8694
ss << "Module hierarchy:" << module_info << "\n";
8795
format_stack_trace(ss, stack_entries);
@@ -177,7 +185,7 @@ std::pair<std::string, std::string> MobileDebugTable::
177185
getSourceDebugModuleHierarchyInfo(
178186
const std::vector<int64_t>& debug_handles,
179187
const std::string& top_module_type_name) const {
180-
std::vector<DebugInfoPair> debug_infos;
188+
std::vector<DebugInfoTuple> debug_infos;
181189
bool debug_handle_not_found{false};
182190
for (auto it = debug_handles.rbegin(); it != debug_handles.rend(); ++it) {
183191
auto debug_handle = *it;

torch/csrc/jit/mobile/debug_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class MobileDebugTable {
4040
std::pair<std::string, std::string> getSourceDebugModuleHierarchyInfo(
4141
const std::vector<int64_t>& debug_handles,
4242
const std::string& top_module_type_name = "ModuleTypeUnknown") const;
43-
ska::flat_hash_map<int64_t, DebugInfoPair> callstack_ptr_map_;
43+
ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptr_map_;
4444
};
4545

4646
} // namespace jit

0 commit comments

Comments
 (0)