Skip to content

Commit ded82ad

Browse files
qihqipytorchmergebot
authored andcommitted
Create method to map JIT module to (source, constant) and back. (pytorch#74119)
Summary: Pull Request resolved: pytorch#74119 implemented function to generate source as ExtraFilesMap and constants wrote function to construct jit module given (ivalue, source, constant) tripple. Test Plan: unittest Reviewed By: pavithranrao Differential Revision: D34803945 fbshipit-source-id: 2edc798407fe68294cb4c3c7516f5bd143df88c3 (cherry picked from commit 35e54e1)
1 parent cc5f8ae commit ded82ad

File tree

5 files changed

+268
-0
lines changed

5 files changed

+268
-0
lines changed

test/cpp/jit/test_save_load.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#include <test/cpp/jit/test_utils.h>
44
#include <sstream>
55

6+
#include <torch/csrc/jit/mobile/module.h>
67
#include <torch/csrc/jit/serialization/export.h>
8+
#include <torch/csrc/jit/serialization/export_bytecode.h>
79
#include <torch/csrc/jit/serialization/import.h>
810
#include <torch/csrc/jit/serialization/import_source.h>
911
#include <torch/torch.h>
@@ -13,6 +15,20 @@
1315
namespace torch {
1416
namespace jit {
1517

18+
namespace {
19+
20+
Module roundtripThroughMobile(const Module& m) {
21+
ExtraFilesMap files;
22+
std::vector<IValue> constants;
23+
jitModuleToPythonCodeAndConstants(m, &files, &constants);
24+
CompilationOptions options;
25+
mobile::Module mobilem = jitModuleToMobile(m, options);
26+
return jitModuleFromSourceAndConstants(
27+
mobilem._ivalue(), files, constants, 8);
28+
}
29+
30+
} // namespace
31+
1632
TEST(SerializationTest, ExtraFilesHookPreference) {
1733
// Tests that an extra file written explicitly has precedence over
1834
// extra files written by a hook
@@ -149,5 +165,78 @@ TEST(SerializationTest, TestJitStream_CUDA) {
149165
// Check if both the output tensors are equal
150166
ASSERT_TRUE(op.equal(c));
151167
}
168+
169+
TEST(TestSourceRoundTrip, UpsampleNearest2d) {
170+
Module m("m");
171+
m.define(R"(
172+
def forward(self, input: Tensor, scale:float):
173+
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
174+
)");
175+
176+
std::vector<IValue> inputs;
177+
inputs.emplace_back(torch::rand({1, 3, 128, 128}));
178+
inputs.emplace_back(at::Scalar(2.0));
179+
auto ref = m.forward(inputs);
180+
181+
Module m2 = roundtripThroughMobile(m);
182+
auto res = m2.forward(inputs);
183+
184+
auto resd = res.toTensor();
185+
auto refd = ref.toTensor();
186+
ASSERT_TRUE(resd.equal(refd));
187+
}
188+
189+
TEST(TestSourceRoundTrip, CheckAttrAccess) {
190+
Module m("m");
191+
m.register_attribute("mobile_optimized", BoolType::get(), true);
192+
Module m2 = roundtripThroughMobile(m);
193+
bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
194+
AT_ASSERT(mobile_optimized);
195+
}
196+
197+
TEST(TestSourceRoundTrip,
198+
MethodInvocation) { // NOLINT (use =delete in gtest)
199+
const std::vector<std::string> test_programs{
200+
// test invoking a method with default parameter
201+
R"(
202+
def test_func(self, x, b : int = 4):
203+
return self.foo + x + b
204+
)",
205+
// inner method call with default parameter (gets inlined)
206+
R"(
207+
def add_with_default_arg(self, x, b : int = 4):
208+
return self.foo + x + b
209+
def test_func(self, x):
210+
return self.add_with_default_arg(x) # invoke method w/ default arg
211+
)",
212+
// simple method call
213+
R"(
214+
def test_func(self, x):
215+
b = 4
216+
return self.foo + x + b
217+
)",
218+
};
219+
for (const auto& test_program : test_programs) {
220+
Module m("m");
221+
m.register_parameter("foo", torch::ones({}), false);
222+
m.define(test_program);
223+
224+
const int fortyTwo = 42; // (keep linter happy)
225+
auto minput = fortyTwo * torch::ones({});
226+
auto ref = m.run_method("test_func", minput);
227+
228+
Module m2 = roundtripThroughMobile(m);
229+
const auto& test_func = m2.get_method("test_func");
230+
IValue res;
231+
for (int i = 0; i < 3; ++i) {
232+
res = test_func({minput});
233+
}
234+
235+
auto resd = res.toTensor().item<float>();
236+
auto refd = ref.toTensor().item<float>();
237+
AT_ASSERT(resd == refd);
238+
}
239+
}
240+
152241
} // namespace jit
153242
} // namespace torch

torch/csrc/jit/serialization/import.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,5 +403,75 @@ Module load(
403403
return deserializer.deserialize(device, extra_files);
404404
}
405405

406+
// Replace object with a newly created but equivalent object.
407+
// The goal is to replace object's methods. However, since object's
408+
// methods are attached to type; we need to replace it's type.
409+
// Non-objects are unchanged; however, nested structures such as list, dict
410+
// are also reconstructed because they might contain an object.
411+
static IValue recreateObject(IValue ivalue, TypeResolver resolver) {
412+
if (ivalue.isObject()) {
413+
auto obj = ivalue.toObject();
414+
auto classtype_old = obj->type();
415+
auto newtype = resolver(*classtype_old->name());
416+
size_t n = classtype_old->numAttributes();
417+
auto newobj = c10::ivalue::Object::create(newtype, n);
418+
for (const auto i : c10::irange(n)) {
419+
newobj->setSlot(i, recreateObject(obj->getSlot(i), resolver));
420+
}
421+
return newobj;
422+
} else if (ivalue.isList()) {
423+
auto res = c10::impl::GenericList(ivalue.type()->containedType(0));
424+
for (const auto& ival : ivalue.toList()) {
425+
res.emplace_back(recreateObject(ival, resolver));
426+
}
427+
return res;
428+
} else if (ivalue.isGenericDict()) {
429+
auto result = c10::impl::GenericDict(
430+
ivalue.type()->containedType(0), ivalue.type()->containedType(1));
431+
for (const auto& kv : ivalue.toGenericDict()) {
432+
result.insert_or_assign(
433+
recreateObject(kv.key(), resolver),
434+
recreateObject(kv.value(), resolver));
435+
}
436+
return result;
437+
} else if (ivalue.isTuple()) {
438+
std::vector<IValue> res;
439+
for (const auto& ival : ivalue.toTuple()->elements()) {
440+
res.push_back(recreateObject(ival, resolver));
441+
}
442+
return c10::ivalue::Tuple::create(res);
443+
}
444+
// Leaf types are returned verbatim.
445+
return ivalue;
446+
}
447+
448+
Module jitModuleFromSourceAndConstants(
449+
const IValue& ivalue,
450+
const ExtraFilesMap& source,
451+
const std::vector<IValue>& constants,
452+
int32_t version) {
453+
auto compilation_unit = std::make_shared<CompilationUnit>();
454+
SourceImporter importer(
455+
compilation_unit,
456+
&constants,
457+
[&source](const std::string& qualifier) -> std::shared_ptr<SourceView> {
458+
auto source_iter = source.find(qualifier);
459+
if (source_iter == source.end()) {
460+
return nullptr;
461+
}
462+
return std::make_shared<Source>(
463+
source_iter->second, qualifier, 1, nullptr);
464+
},
465+
version);
466+
auto type_resolver = [&](const c10::QualifiedName& qn) {
467+
auto cls = importer.loadType(qn);
468+
return c10::StrongTypePtr(compilation_unit, std::move(cls));
469+
};
470+
auto newIvalue = recreateObject(ivalue, type_resolver).toObject();
471+
Module m(newIvalue);
472+
rewriteQuantizedConvForBC(m);
473+
return m;
474+
}
475+
406476
} // namespace jit
407477
} // namespace torch

torch/csrc/jit/serialization/import.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,11 @@ TORCH_API Module load(
9898
c10::optional<c10::Device> device,
9999
ExtraFilesMap& extra_files);
100100

101+
TORCH_API Module jitModuleFromSourceAndConstants(
102+
const IValue& ivalue,
103+
const ExtraFilesMap& source,
104+
const std::vector<IValue>& constants,
105+
int32_t version);
106+
101107
} // namespace jit
102108
} // namespace torch

torch/csrc/jit/serialization/python_print.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <algorithm>
44

5+
#include <ATen/core/ivalue.h>
56
#include <ATen/core/qualified_name.h>
67
#include <c10/util/Exception.h>
78
#include <c10/util/StringUtil.h>
@@ -17,6 +18,7 @@
1718
#include <torch/csrc/jit/operator_upgraders/version_map.h>
1819
#include <torch/csrc/jit/resource_guard.h>
1920
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
21+
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
2022

2123
using c10::QualifiedName;
2224

@@ -1662,5 +1664,98 @@ uint64_t PythonPrint::minVersion() const {
16621664

16631665
PythonPrint::~PythonPrint() = default;
16641666

1667+
std::vector<IValue> traverseIValueAndGetObjects(IValue ivalue) {
1668+
std::vector<IValue> result;
1669+
std::vector<IValue> stack;
1670+
stack.emplace_back(ivalue);
1671+
while (!stack.empty()) {
1672+
IValue head = stack.back();
1673+
stack.pop_back();
1674+
if (head.isObject()) {
1675+
result.push_back(head);
1676+
auto obj = head.toObject();
1677+
ClassTypePtr type = obj->type();
1678+
if (type->hasMethod("__getstate__")) {
1679+
Function& getstate = type->getMethod("__getstate__");
1680+
stack.emplace_back(getstate({obj}));
1681+
} else {
1682+
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
1683+
stack.emplace_back(obj->getSlot(i));
1684+
}
1685+
}
1686+
} else if (ivalue.isGenericDict()) {
1687+
for (const auto& kv : ivalue.toGenericDict()) {
1688+
// skip key because key cannot be an object
1689+
stack.emplace_back(kv.value());
1690+
}
1691+
} else if (ivalue.isList()) {
1692+
for (const auto& v : ivalue.toList()) {
1693+
stack.emplace_back(v);
1694+
}
1695+
} else if (ivalue.isTuple()) {
1696+
for (const auto& v : ivalue.toTuple()->elements()) {
1697+
stack.emplace_back(v);
1698+
}
1699+
}
1700+
}
1701+
return result;
1702+
}
1703+
1704+
c10::optional<std::string> printType(
1705+
const c10::Type& type,
1706+
torch::jit::TypeNameUniquer& type_name_uniquer) {
1707+
if (auto dyn = type.castRaw<c10::DynamicType>()) {
1708+
return dyn->fallback()->annotation_str(
1709+
[&](auto&& t) { return printType(t, type_name_uniquer); });
1710+
}
1711+
auto namedType = type.cast<c10::NamedType>();
1712+
if (namedType && namedType->name()) {
1713+
return type_name_uniquer.getUniqueName(namedType).qualifiedName();
1714+
}
1715+
return c10::nullopt;
1716+
}
1717+
1718+
void jitModuleToPythonCodeAndConstants(
1719+
const Module& module,
1720+
ExtraFilesMap* jit_sources, // output
1721+
std::vector<IValue>* constants // output
1722+
) {
1723+
std::vector<IValue> objects = traverseIValueAndGetObjects(module._ivalue());
1724+
std::unordered_set<c10::QualifiedName> visited;
1725+
PrintDepsTable class_deps;
1726+
TypeNameUniquer uniquer;
1727+
auto type_printer = [&](const c10::Type& t) { return printType(t, uniquer); };
1728+
1729+
// Group by prefix; because every prefix is a file.
1730+
std::unordered_map<std::string, PythonPrint> grouped_by_prefix;
1731+
for (const IValue& obj : objects) {
1732+
ObjectPtr obj_ptr = obj.toObject();
1733+
ClassTypePtr class_type = obj_ptr->type();
1734+
class_deps.add(class_type);
1735+
}
1736+
1737+
for (int i = 0; i < class_deps.size(); ++i) {
1738+
auto type = class_deps[i];
1739+
auto qualname = uniquer.getUniqueName(type);
1740+
std::string qualifier = qualname.prefix();
1741+
auto pp_iter = grouped_by_prefix.find(qualifier);
1742+
if (pp_iter == grouped_by_prefix.end()) {
1743+
pp_iter = grouped_by_prefix
1744+
.emplace(
1745+
qualifier,
1746+
PythonPrint(
1747+
*constants,
1748+
class_deps,
1749+
type_printer,
1750+
/*enforce_importable=*/true))
1751+
.first;
1752+
}
1753+
pp_iter->second.printNamedType(type);
1754+
}
1755+
for (const auto& kv : grouped_by_prefix) {
1756+
(*jit_sources)[kv.first] = kv.second.str();
1757+
}
1758+
}
1759+
16651760
} // namespace jit
16661761
} // namespace torch

torch/csrc/jit/serialization/python_print.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <torch/csrc/Export.h>
3+
#include <torch/csrc/jit/api/module.h>
34
#include <torch/csrc/jit/ir/ir.h>
45
#include <iostream>
56
#include <vector>
@@ -49,5 +50,12 @@ struct TORCH_API PythonPrint {
4950
};
5051

5152
TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
53+
54+
TORCH_API void jitModuleToPythonCodeAndConstants(
55+
const Module& module,
56+
ExtraFilesMap* jit_sources, // output
57+
std::vector<IValue>* constants // output
58+
);
59+
5260
} // namespace jit
5361
} // namespace torch

0 commit comments

Comments
 (0)