Skip to content

Commit 026cfe8

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Fix InlinedCallStack annotation to account for module calling its own (pytorch#61791)
Summary: Pull Request resolved: pytorch#61791 methods from forward During inlining we attached InlinedCallstack to nodes being inlined. In the process we attach moodule information as well, such that if CallMethod is being inlined we know which class instance and class type the method belongs to. However, CallMethod can be calling a method of the same object to which the graph belongs. e.g.: ``` def forward(self, input): x = input + 10 return forward_impl_(x, input) ``` Here forward_impl is method defined on the same class in which forward is defined. Existing module hierarchy annotation will mislabel this as unknown instance since the method is not associated with output of GetAttr node (it would be we had called self.conv.forward_impl_ for example). Change in this PR reconciles this by creating a placeholder name "SELF" for module instance indicating that you can traverse InlinedCallStack backwards to find first node with name != SELF, which would be the name of the object. e.g.: TOP(ResNet)::forward.SELF(ResNet)::_forward_impl.layer1(Sequential)::forward.0(BasicBlock)::forward.conv1(Conv2d)::forward.SELF(Conv2d)::_conv_forward Test Plan: Add test Imported from OSS Reviewed By: larryliu0820 Differential Revision: D29745443 fbshipit-source-id: 1525e41df53913341c4c36a56772454782a0ba93
1 parent f16102f commit 026cfe8

File tree

7 files changed

+83
-32
lines changed

7 files changed

+83
-32
lines changed

test/cpp/jit/test_misc.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,6 +2102,49 @@ TEST(InlinedCallStackTest, BlockAnnotation) {
21022102
ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
21032103
}
21042104

2105+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
2106+
TEST(InlinedCallStackTest, SelfCallMethods) {
2107+
Module a("A");
2108+
a.define(R"(
2109+
def my_new_method(self, x):
2110+
return x * 3
2111+
def forward_impl_(self, x, y):
2112+
return self.my_new_method(x) + y
2113+
def forward(self, x, y):
2114+
y = y + 2
2115+
return self.forward_impl_(x, y)
2116+
)");
2117+
Module b("B");
2118+
b.define(R"(
2119+
def forward(self, x):
2120+
return x + 2
2121+
)");
2122+
Module c("C");
2123+
c.register_module("A0", a);
2124+
c.register_module("B0", b);
2125+
c.define(R"(
2126+
def call_b(self, x):
2127+
return self.B0.forward(x)
2128+
def forward(self, x, y):
2129+
return self.A0.forward(x, y) + self.call_b(x)
2130+
)");
2131+
2132+
auto graph = c.get_method("forward").function().optimized_graph();
2133+
std::unordered_map<std::string, size_t> module_hierarchies;
2134+
for (Node* n : graph->nodes()) {
2135+
auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
2136+
if (module_hierarchies.count(hierarchy) == 0) {
2137+
module_hierarchies[hierarchy] = 0;
2138+
}
2139+
module_hierarchies[hierarchy] += 1;
2140+
}
2141+
ASSERT_EQ(module_hierarchies["A0(A)"], 2);
2142+
ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)"], 2);
2143+
ASSERT_EQ(module_hierarchies["A0(A).SELF(A)"], 1);
2144+
ASSERT_EQ(module_hierarchies["SELF(C)"], 1);
2145+
ASSERT_EQ(module_hierarchies["SELF(C).B0(B)"], 1);
2146+
}
2147+
21052148
TEST(AutogradSymbolsTest, Basic) {
21062149
Symbol sym = Symbol::fromQualString("aten::test_symbol");
21072150
Graph graph;

torch/csrc/jit/ir/ir.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,29 @@
2424
namespace torch {
2525
namespace jit {
2626

27+
namespace utils {
28+
std::string getNodesModuleHierarchy(const Node& n) {
29+
if (!n.callstack().has_value()) {
30+
return std::string();
31+
}
32+
InlinedCallStackPtr callstack_ptr = n.callstack().value();
33+
std::string module_hierarchy;
34+
for (auto& entry : callstack_ptr->vec()) {
35+
const auto& opt_module_info = std::get<kModuleInstanceInfo>(entry);
36+
if (opt_module_info.has_value()) {
37+
const auto& module_instance_info = opt_module_info.value();
38+
if (!module_hierarchy.empty()) {
39+
module_hierarchy.append(".");
40+
}
41+
module_hierarchy.append(utils::get_module_info(module_instance_info));
42+
} else {
43+
module_hierarchy += ".UNKNOWN_INSTANCE(UNKNOWN_TYPE)";
44+
}
45+
}
46+
return module_hierarchy;
47+
}
48+
} // namespace utils
49+
2750
namespace {
2851

2952
// Constants relating to maintaining the topological index of nodes.
@@ -2059,10 +2082,18 @@ std::vector<Value*> inlineCallTo(
20592082
if (to_replace->input(0)->node()->kind() == prim::GetAttr) {
20602083
module_instance_info = c10::make_optional(ModuleInstanceInfo(
20612084
class_type_ptr, to_replace->input(0)->node()->s(attr::name)));
2085+
} else if (
2086+
to_replace->owningGraph()->inputs().size() > 0 &&
2087+
to_replace->input(0) == to_replace->owningGraph()->inputs()[0]) {
2088+
// This CallMethod must correspond to method of the same object
2089+
// to which this graph belongs.
2090+
module_instance_info =
2091+
c10::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF"));
20622092
} else {
2063-
std::string instance_name_unknown("INSTANCE_NAME_UNKNOWN");
2093+
// Not sure if it is possible to come here ever.
2094+
// TODO: Remove this else. Or add assert
20642095
module_instance_info = c10::make_optional(
2065-
ModuleInstanceInfo(class_type_ptr, instance_name_unknown));
2096+
ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN"));
20662097
}
20672098
}
20682099

torch/csrc/jit/ir/ir.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ using pyobj_list = std::vector<THPObjectPtr>;
3232

3333
namespace torch {
3434
namespace jit {
35+
namespace utils {
36+
TORCH_API std::string getNodesModuleHierarchy(const Node& n);
37+
} // namespace utils
3538
class AliasDb;
3639

3740
using ::c10::Argument;

torch/csrc/jit/ir/scope.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
namespace torch {
1111
namespace jit {
1212
struct ModuleInstanceInfo;
13+
constexpr size_t kModuleInstanceInfo = 2;
14+
1315
namespace utils {
14-
TORCH_API std::string get_module_info(
15-
const ModuleInstanceInfo& module_instance_info);
16+
std::string get_module_info(const ModuleInstanceInfo& module_instance_info);
1617
} // namespace utils
1718

1819
// Scope is a node of a trie that represents the tree of nested scopes.

torch/csrc/jit/mobile/debug_info.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ namespace {
1616
std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
1717
const DebugInfoTuple& source_callstack,
1818
const std::string& caller_name) {
19-
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
20-
constexpr size_t kSourceRange = 1;
21-
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
22-
constexpr size_t kModuleInstanceInfo = 2;
2319
std::vector<StackEntry> entries;
2420

2521
const SourceRange& range =

torch/csrc/jit/python/python_ir.cpp

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -569,28 +569,7 @@ void initPythonIRBindings(PyObject* module_) {
569569
.def("output", [](Node& n) { return n.output(); })
570570
.def(
571571
"getModuleHierarchy",
572-
[](Node& n) {
573-
if (!n.callstack().has_value()) {
574-
return std::string();
575-
}
576-
InlinedCallStackPtr callstack_ptr = n.callstack().value();
577-
std::string module_info;
578-
for (auto& entry : callstack_ptr->vec()) {
579-
const auto& opt_module_info =
580-
std::get<kModuleInstanceInfo>(entry);
581-
if (opt_module_info.has_value()) {
582-
const auto& module_instance_info = opt_module_info.value();
583-
if (!module_info.empty()) {
584-
module_info.append(".");
585-
}
586-
module_info.append(
587-
utils::get_module_info(module_instance_info));
588-
} else {
589-
module_info += ".UNKNOWN_INSTANCE(UNKNOWN_TYPE)";
590-
}
591-
}
592-
return module_info;
593-
})
572+
[](Node& n) { return torch::jit::utils::getNodesModuleHierarchy(n); })
594573
.NS(addInput)
595574
.NS(replaceInput)
596575
.NS(replaceInputWith)

torch/csrc/jit/python/python_ir.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
namespace torch {
77
namespace jit {
88

9-
constexpr size_t kModuleInstanceInfo = 2;
10-
119
void initPythonIRBindings(PyObject* module);
1210

1311
// execute a Python function, used for Ops we can't optimize but that we want to

0 commit comments

Comments
 (0)