Skip to content

Commit 7980ed9

Browse files
tangleintelpytorchmergebot
authored andcommitted
Support unpacking python dictionary in torch.jit.trace() (pytorch#81623)
# Support unpacking python dictionary in **torch.jit.trace()** ## Problem statement & Motivation ### Problem 1(usability): Say, if you have a model and its forward method defined as follows: **`def forward(self, key1=value1, key2=value2, key3=value3)`** And you have a dataset and each data point in the dataset is a python dict as follows: **`data = {key1:value1, key3:value3, key2:value2}`** The problem is that if you want to trace the model using the dict data by the giving dataset, you need unpack the dictionary and reorder its value manually and make up a tuple as **`data_tuple = (value1, value2, value3)`** as the **`example_inputs`** parameter of **`torch.jit.trace()`**. This marshalling process is not user friendly. ### Problem 2 (feasibility): Say, if you have a model and its forward method defined as follows: **`def forward(self, key1=None, key2=None, key3=None)`** -> The default value is **None** And you have a dataset and each data point in the dataset is a python dict as follows: **`data = {key1:value1, key3:value3}`** -> Only **part of** the required value by forward was given, the rest use the default value. The problem is that if you want to trace the model using the dict data by the giving dataset, it's not feasible at all. Cause neither you can pass a tuple like **`T1 = (value1, value3)`** nor **`T2 = (value1, None, value3)`**. T1 will mismatch value3 with key2 and T2 include **None** type which will be blocked by tracer's type checking. (Of course you can pass **`T3 = (value1,)`** to make the trace function finish without exception, but the traced model you get probably is not what you expect cause the different input may result in different traced result.). These problems come from the HuggingFace's PT model, especially in text-classification tasks with datasets such as [MRPC,](https://paperswithcode.com/dataset/mrpc) [MNLI](https://paperswithcode.com/dataset/multinli) etc. ## Solution To address these two issues, we propose to support a new type, that is, python dict as example_inputs parameter for torch.jit.trace(). We can base on the runtime type information of the example_inputs object to determine if we fall back to the original tuple path or go into the new dictionary path. Both problem 1 and problem 2 can be solved by utilizing the "**`**`**" operator. ## Limitation & Mitigation 1. If we use dict as example_inputs to trace the model, then we have to pass a dictionary to the traced model too. (Cause probably we will change the order of debug name of the input parameter in torchscript IR, thus we can't assume the traced model's input parameters order are the same with the original model.). We need highlight this too in the document to mitigate this problem. For example: ``` # fetch a data from dataloader, and the data is a dictionary # and the example_inputs_dict is like: {key1:value1, key3:value3, key2:value2} # the forward() is like: def forward(self, key1=value1, key2=value2, key3=value3) example_inputs_dict = next(iter(dataloader)) jit_model = model.eval() # use the dictionary to trace the model jit_model = torch.jit.trace(jit_model, example_inputs_dict, strict=False) # Now the IR will be graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key3 : type3, %key2 : type2) jit_model = torch.jit.freeze(jit_model) # It's OK to use dict as the parameter for traced model jit_model(**example_inputs_dict) example_inputs_tuple = (value1, value3, value2) # It's wrong to rely on the original args order. jit_model(*example_inputs_tuple) ``` ## Note 1. This PR will make some UT introduced in [39601](pytorch#39601) fail, which I think should be classified as unpacking a tuple containing a single dictionary element in our solution. 4. I think there is ambiguity since currently we only specify passing a tuple or a single Tensor as our example_inputs parameter in **torch.jit.trace()**'s documentation, but it seems we can still passing a dictionary. Pull Request resolved: pytorch#81623 Approved by: https://github.com/davidberard98
1 parent bdefa26 commit 7980ed9

File tree

7 files changed

+321
-38
lines changed

7 files changed

+321
-38
lines changed

test/test_jit.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3027,6 +3027,46 @@ def forward(self, x):
30273027
checker.check("def forward")
30283028
checker.run(str(cm.exception))
30293029

3030+
def test_dictionary_as_example_inputs_for_jit_trace(self):
3031+
class TestModule_v1(torch.nn.Module):
3032+
def __init__(self):
3033+
super(TestModule_v1, self).__init__()
3034+
3035+
def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None):
3036+
return key1 + key2 + key3
3037+
3038+
class TestModule_v2(torch.nn.Module):
3039+
def __init__(self):
3040+
super(TestModule_v2, self).__init__()
3041+
3042+
def forward(self, x, y):
3043+
return x + y
3044+
3045+
def test_func(x, y):
3046+
return x + y
3047+
model_1 = TestModule_v1()
3048+
model_2 = TestModule_v2()
3049+
value1 = torch.ones(1)
3050+
value2 = torch.ones(1)
3051+
value3 = torch.ones(1)
3052+
example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3}
3053+
example_input_dict_func = {'x': value1, 'y': value2}
3054+
traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False)
3055+
traced_model_1_m = torch.jit.trace_module(
3056+
model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False)
3057+
traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])})
3058+
traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False)
3059+
res_1 = traced_model_1(**example_input_dict)
3060+
res_1_m = traced_model_1_m(**example_input_dict)
3061+
self.assertEqual(res_1, 3 * torch.ones(1))
3062+
self.assertEqual(res_1_m, 3 * torch.ones(1))
3063+
res_func = traced_func(**example_input_dict_func)
3064+
self.assertEqual(res_func, 2 * torch.ones(1))
3065+
with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."):
3066+
res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])})
3067+
with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."):
3068+
res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})
3069+
30303070

30313071
class TestScript(JitTestCase):
30323072

torch/_C/__init__.pyi.in

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,15 @@ def _create_function_from_trace(
332332
force_outplace: _bool,
333333
argument_names: List[str]
334334
) -> Tuple[Graph, Stack]: ...
335+
def _create_function_from_trace_with_dict(
336+
qualname: str,
337+
func: Callable[..., Any],
338+
input_dict: Dict[str, Any],
339+
var_lookup_fn: Callable[[Tensor], str],
340+
strict: _bool,
341+
force_outplace: _bool,
342+
argument_names: List[str]
343+
) -> Tuple[Graph, Stack]: ...
335344
def _jit_is_script_object(obj: Any) -> _bool: ...
336345
def _last_executed_optimized_graph() -> Graph: ...
337346
def parse_type_comment(comment: str) -> Decl: ...

torch/csrc/jit/python/pybind_utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,17 @@ inline Stack toTraceableStack(const py::tuple& inputs) {
565565
return info.toTupleRef().elements().vec();
566566
}
567567

568+
// Serialize the python dictionary into a traceable stack.
569+
inline Stack toTraceableStack(const py::dict& inputs) {
570+
Stack res;
571+
for (auto it = inputs.begin(); it != inputs.end(); it++) {
572+
if (THPVariable_Check(it->second.ptr())) {
573+
res.push_back(toIValue(it->second, tryToInferType(it->second).type()));
574+
}
575+
}
576+
return res;
577+
}
578+
568579
inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
569580
auto elems = c10::impl::GenericList(elem_type);
570581
for (auto elem : obj) {

torch/csrc/jit/python/python_tracer.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,69 @@ SourceRange getPythonInterpreterSourceRange() {
7373
return SourceRange(source, 0, stack_trace_text.size());
7474
}
7575

76+
std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
77+
const py::function& func,
78+
const py::dict& inputs_dict,
79+
Stack trace_inputs,
80+
const py::function& var_name_lookup_fn,
81+
bool strict,
82+
bool force_outplace,
83+
Module* self,
84+
const std::vector<std::string>& argument_names) {
85+
C10_LOG_API_USAGE_ONCE("torch.tracer");
86+
87+
auto lookup_fn_adapter =
88+
[var_name_lookup_fn](const Variable& var) -> std::string {
89+
pybind11::gil_scoped_acquire ag;
90+
return py::cast<std::string>(var_name_lookup_fn(var));
91+
};
92+
93+
// The argument_names parameter is parsed in python and its order
94+
// is the same as the arguments' decalaration order in forward() method.
95+
// These name shall be added to the graph as debug name and the order
96+
// should align with the traceable stack we generated by the python dict.
97+
std::vector<std::string> compact_argument_names;
98+
Stack compact_trace_inputs;
99+
for (std::vector<std::string>::size_type i = 0; i < argument_names.size();
100+
i++) {
101+
if (inputs_dict.contains(argument_names[i])) {
102+
compact_argument_names.push_back(argument_names[i]);
103+
}
104+
}
105+
for (std::vector<std::string>::size_type i = 0;
106+
i < compact_argument_names.size();
107+
i++) {
108+
for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) {
109+
if (py::cast<std::string>(it->first) == compact_argument_names[i]) {
110+
if (THPVariable_Check(it->second.ptr())) {
111+
compact_trace_inputs.push_back(
112+
toIValue(it->second, tryToInferType(it->second).type()));
113+
}
114+
}
115+
}
116+
}
117+
118+
auto outs = tracer::trace(
119+
std::move(compact_trace_inputs),
120+
[&](Stack inputs) -> Stack {
121+
// We just leave the inputs_dict as it was and pass it to forward
122+
// method.
123+
auto out = func(**inputs_dict);
124+
if (out.ptr() == Py_None) {
125+
AT_ERROR(
126+
"The traced function didn't return any values! Side-effects are not "
127+
"captured in traces, so it would be a no-op.");
128+
}
129+
return {toTypeInferredIValue(out)};
130+
},
131+
lookup_fn_adapter,
132+
strict,
133+
force_outplace,
134+
self,
135+
compact_argument_names);
136+
return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs));
137+
}
138+
76139
std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
77140
const py::function& func,
78141
Stack trace_inputs,

torch/csrc/jit/python/python_tracer.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ Node* preRecordPythonTrace(
2424
at::ArrayRef<autograd::Variable> inputs,
2525
std::vector<THPObjectPtr> scalar_args);
2626

27+
std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
28+
const py::function& func,
29+
const py::dict& inputs_dict,
30+
Stack inputs,
31+
const py::function& var_name_lookup_fn,
32+
bool strict,
33+
bool force_outplace,
34+
Module* self = nullptr,
35+
const std::vector<std::string>& argument_names = {});
36+
2737
std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
2838
const py::function& func,
2939
Stack inputs,

torch/csrc/jit/python/script_init.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,43 @@ void initJitScriptBindings(PyObject* module) {
12181218
py::arg("strict"),
12191219
py::arg("force_outplace"),
12201220
py::arg("argument_names") = std::vector<std::string>())
1221+
.def(
1222+
"_create_method_from_trace_with_dict",
1223+
[](Module& self,
1224+
const std::string& name,
1225+
const py::function& func,
1226+
const py::dict& input_dict,
1227+
const py::function& var_name_lookup_fn,
1228+
bool strict,
1229+
bool force_outplace,
1230+
const std::vector<std::string>& argument_names) {
1231+
// prereq: Module's buffers and parameters are unique
1232+
// this was ensured in python before calling this function
1233+
auto typed_inputs = toTraceableStack(input_dict);
1234+
1235+
std::shared_ptr<Graph> graph =
1236+
std::get<0>(tracer::createGraphByTracingWithDict(
1237+
func,
1238+
input_dict,
1239+
typed_inputs,
1240+
var_name_lookup_fn,
1241+
strict,
1242+
force_outplace,
1243+
&self,
1244+
argument_names));
1245+
const auto method_name = QualifiedName(*self.type()->name(), name);
1246+
auto fn = self._ivalue()->compilation_unit()->create_function(
1247+
method_name, graph);
1248+
self.type()->addMethod(fn);
1249+
didFinishEmitModule(self);
1250+
},
1251+
py::arg("name"),
1252+
py::arg("func"),
1253+
py::arg("input_dict"),
1254+
py::arg("var_name_lookup_fn"),
1255+
py::arg("strict"),
1256+
py::arg("force_outplace"),
1257+
py::arg("argument_names") = std::vector<std::string>())
12211258
.def(
12221259
"_get_forward_hooks",
12231260
[](const Module& m) {
@@ -1668,6 +1705,43 @@ void initJitScriptBindings(PyObject* module) {
16681705
py::arg("force_outplace"),
16691706
py::arg("argument_names") = std::vector<std::string>());
16701707

1708+
m.def(
1709+
"_create_function_from_trace_with_dict",
1710+
[](const std::string& qualname,
1711+
const py::function& func,
1712+
const py::dict& input_dict,
1713+
const py::function& var_name_lookup_fn,
1714+
bool strict,
1715+
bool force_outplace,
1716+
const std::vector<std::string>& argument_names) {
1717+
auto typed_inputs = toTraceableStack(input_dict);
1718+
std::shared_ptr<Graph> graph =
1719+
std::get<0>(tracer::createGraphByTracingWithDict(
1720+
func,
1721+
input_dict,
1722+
typed_inputs,
1723+
var_name_lookup_fn,
1724+
strict,
1725+
force_outplace,
1726+
/*self=*/nullptr,
1727+
argument_names));
1728+
1729+
auto cu = get_python_cu();
1730+
auto name = c10::QualifiedName(qualname);
1731+
auto result = cu->create_function(
1732+
std::move(name), std::move(graph), /*shouldMangle=*/true);
1733+
StrongFunctionPtr ret(std::move(cu), result);
1734+
didFinishEmitFunction(ret);
1735+
return ret;
1736+
},
1737+
py::arg("name"),
1738+
py::arg("func"),
1739+
py::arg("input_dict"),
1740+
py::arg("var_name_lookup_fn"),
1741+
py::arg("strict"),
1742+
py::arg("force_outplace"),
1743+
py::arg("argument_names") = std::vector<std::string>());
1744+
16711745
m.def(
16721746
"_jit_script_class_compile",
16731747
[](const std::string& qualifiedName,

0 commit comments

Comments
 (0)