Skip to content

Commit b14c320

Browse files
Meghan Lelefacebook-github-bot
Meghan Lele
authored andcommitted
[JIT] Add torch._C.ScriptDict (pytorch#52659)
Summary: Pull Request resolved: pytorch#52659 **Summary** This commit adds `torch._C.ScriptDict`, a dictionary type that has reference semantics across the Python/TorchScript boundary. That is, modifications made to instances of `torch._C.ScriptDict` in TorchScript are visible in Python even when it is not returned from the function. Instances can be constructed by passing an instance of a Python dictionary to `torch.jit.script`. In the case of an empty dictionary, its type is assumed to be `Dict[str, Tensor]` to be consistent with the handling of empty dictionaries in TorchScript source code. `torch._C.ScriptDict` is implemented using a modified version of pybind's `stl_bind.h`-style bindings attached to `ScriptDict`, `ScriptDictIterator` and `ScriptDictKeyIterator`, wrapper classes around `c10::impl::GenericDict` and `c10::impl::GenericDict::iterator`. These bindings allow instances of `torch._C.ScriptDict` to be used as if it were a regular `dict` Python. Reference semantics are achieved by simply retrieving the `IValue` contained in `ScriptDict` in `toIValue` (invoked when converting Python arguments to `IValues` before calling TorchScript code). **Test Plan** This commit adds `TestScriptDict` to `test_list_dict.py`, a set of tests that check that all of the common dictionary operations are supported and that instances have reference semantics across the Python/TorchScript boundary. Differential Revision: D27211605 D27211605 Test Plan: Imported from OSS Reviewed By: gmagogsfm Pulled By: SplitInfinity fbshipit-source-id: 446d4e5328375791aa73eb9e8b04dfe3465af960
1 parent 95b1bc1 commit b14c320

File tree

11 files changed

+612
-10
lines changed

11 files changed

+612
-10
lines changed

test/jit/test_builtins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ def del_list_multiple_operands(x: List[int]) -> List[int]:
122122

123123
py_out = del_list_multiple_operands([0, 1, 2])
124124
jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2])
125-
self.assertEquals(py_out, jit_out)
125+
self.assertEqual(py_out, jit_out)
126126

127127
def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
128128
del x['hi'], x['there']
129129
return x
130130

131131
py_out = del_dict_multiple_operands({"hi": 5, "there": 6})
132132
jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6})
133-
self.assertEquals(py_out, jit_out)
133+
self.assertEqual(py_out, jit_out)
134134

135135

136136
class TestTensorBuiltins(JitTestCase):

test/jit/test_list_dict.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,3 +1959,226 @@ def forward(self):
19591959

19601960
for name in ['a', 'b', 'c']:
19611961
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
1962+
1963+
class TestScriptDict(JitTestCase):
1964+
"""
1965+
This class contains a suite of tests for torch.jit.script, a
1966+
function that returns a dictionary-like object that has reference
1967+
semantics across the Python/TorchScript boundary. That is,
1968+
it can be passed to a TorchScript function that mutates it
1969+
and those modifications are visible in the scope of the Python
1970+
caller of said TorchScript function.
1971+
1972+
The vast majority of tests are for making sure that objects returned
1973+
by torch.jit.script behave like dictionaries do so that they are fungible
1974+
in almost all cirumstances with regular dictionaries.
1975+
"""
1976+
def _script_dict_add(self, d: torch._C.ScriptDict, k: int, v: int):
1977+
"""
1978+
This is a helper function that inserts the pair (k, v) into the
1979+
dictionary d in TorchScript. It is used for testing reference
1980+
semantics.
1981+
"""
1982+
@torch.jit.script
1983+
def dict_add(d: Dict[int, int], k: int, v: int):
1984+
d[k] = v
1985+
1986+
dict_add(d, k, v)
1987+
1988+
def _compare_eager_and_script(self, fn, input_dict, script_input_dict=None):
1989+
"""
1990+
This is a helper function that facilitates comparing behaviour between
1991+
Python dictionaries and "scripted" dictionaries.
1992+
1993+
Args:
1994+
fn: The function to test and compare the behaviour of.
1995+
input_dict: The input dictionary to use for the test (passed to fn).
1996+
script_input_dict: The scripted input dictionary to use for the tests.
1997+
If None, input_dict is scripted with torch.jit.script
1998+
and used instead.
1999+
"""
2000+
# Create ScriptDict version of input_dict if needed.
2001+
script_input_dict = script_input_dict or torch.jit.script(input_dict)
2002+
2003+
# Run fn with both input_dict and scripted_dict.
2004+
eager_raised, script_raised = False, False
2005+
2006+
try:
2007+
eager_out = fn(input_dict)
2008+
except Exception as e:
2009+
eager_exception = e
2010+
eager_raised = True
2011+
2012+
try:
2013+
script_out = fn(script_input_dict)
2014+
except Exception as e:
2015+
script_exception = e
2016+
script_raised = True
2017+
2018+
# Check that both calls raised or none of them raised.
2019+
self.assertEqual(eager_raised, script_raised)
2020+
2021+
if eager_raised:
2022+
# If fn raised an exception, it should be the same between
2023+
# regular and scripted dictionaries.
2024+
self.assertEqual(type(eager_exception), type(script_exception))
2025+
else:
2026+
# Otherwise, make sure the outputs match and the dictionaries
2027+
# match (the latter may not be the same as the output).
2028+
self.assertEqual(eager_out, script_out)
2029+
self.assertEqual(input_dict, script_input_dict)
2030+
2031+
def test_repr(self):
2032+
"""
2033+
Test the __repr__ method.
2034+
"""
2035+
self._compare_eager_and_script(lambda d: repr(d), {1: 2})
2036+
2037+
def test_bool(self):
2038+
"""
2039+
Test the __bool__ method. This should return True
2040+
if the dictionary is non-empty and False otherwise.
2041+
"""
2042+
self._compare_eager_and_script(lambda d: bool(d), {1: 2})
2043+
self._compare_eager_and_script(lambda d: bool(d), {})
2044+
2045+
def test_iter(self):
2046+
"""
2047+
Test iteration over a dictionary's keys.
2048+
"""
2049+
def sum_keys(input_dict):
2050+
s = 0
2051+
for k in input_dict:
2052+
s += k
2053+
2054+
return s
2055+
2056+
self._compare_eager_and_script(sum_keys, {1: 2, 3: 4})
2057+
2058+
def test_items(self):
2059+
"""
2060+
Test .items().
2061+
"""
2062+
def sum_pair_product(input_dict):
2063+
s = 0
2064+
for k, v in input_dict.items():
2065+
s += k * v
2066+
2067+
return s
2068+
2069+
self._compare_eager_and_script(sum_pair_product, {1: 2, 3: 4})
2070+
2071+
def test_getitem(self):
2072+
"""
2073+
Test accessing dictionary values using the [] operator.
2074+
"""
2075+
data = {1: 2, 3: 4}
2076+
self._compare_eager_and_script(lambda d: d[1], data)
2077+
self._compare_eager_and_script(lambda d: d[4], data)
2078+
self._compare_eager_and_script(lambda d: d[2], data)
2079+
self._compare_eager_and_script(lambda d: d["key"], data)
2080+
2081+
def test_setitem(self):
2082+
"""
2083+
Test setting dictionary values using the [] operator.
2084+
"""
2085+
data = {1: 2, 3: 4}
2086+
2087+
def fn(input_dict):
2088+
input_dict[1] = 10
2089+
input_dict[3] = 11
2090+
2091+
self._compare_eager_and_script(fn, data)
2092+
2093+
# Check that using improperly typed keys and values
2094+
# throws TypeError.
2095+
# _compare_eager_and_script cannot be used here since
2096+
# the following uses of __setitem__ are valid in
2097+
# Python.
2098+
script_data = torch.jit.script(data)
2099+
2100+
with self.assertRaises(TypeError):
2101+
script_data["str"] = 3
2102+
2103+
with self.assertRaises(TypeError):
2104+
script_data[3] = "str"
2105+
2106+
def test_contains(self):
2107+
"""
2108+
Test membership checks (x in y, x not in y).
2109+
"""
2110+
data = {1: 2, 3: 4}
2111+
2112+
def fn(input_dict):
2113+
return 1 in input_dict, 2 not in input_dict, 3 in input_dict, 4 not in input_dict
2114+
2115+
self._compare_eager_and_script(fn, data)
2116+
2117+
# Check that using an improperly typed key
2118+
# throws KeyError.
2119+
script_data = torch.jit.script(data)
2120+
2121+
with self.assertRaises(KeyError):
2122+
a = "str" in script_data
2123+
2124+
def test_delitem(self):
2125+
"""
2126+
Test deletion.
2127+
"""
2128+
data = {1: 2, 3: 4}
2129+
2130+
def del_fn(input_dict):
2131+
del input_dict[1]
2132+
2133+
def del_fn_raises(input_dict):
2134+
del input_dict[10]
2135+
2136+
self._compare_eager_and_script(del_fn, data)
2137+
self._compare_eager_and_script(del_fn_raises, data)
2138+
2139+
# Check that using an improperly typed key
2140+
# throws TypeError.
2141+
script_data = torch.jit.script(data)
2142+
2143+
with self.assertRaises(TypeError):
2144+
del script_data["str"]
2145+
2146+
def test_len(self):
2147+
"""
2148+
Test len() builtin function.
2149+
"""
2150+
self._compare_eager_and_script(lambda d: len(d), {1: 2})
2151+
self._compare_eager_and_script(lambda d: len(d), {})
2152+
2153+
@unittest.skip("Cannot pass until all dicts returned from TorchScript are ScriptDicts")
2154+
def test_nested(self):
2155+
"""
2156+
Test that reference semantics are honoured when the ScriptDict that is
2157+
mutated using TorchScript is inside another.
2158+
"""
2159+
nested = torch.jit.script({1: {1: 2}, 2: {3: 4}}, type_hint=Dict[int, Dict[int, int]])
2160+
2161+
one = nested[1]
2162+
two = nested[2]
2163+
2164+
self._script_dict_add(one, 9, 10)
2165+
self._script_dict_add(two, 11, 12)
2166+
2167+
# The mutation should be visible in the original dictionary, nested.
2168+
self.assertEqual(len(one), 2)
2169+
self.assertEqual(len(two), 2)
2170+
self.assertEqual(len(nested[1]), 2)
2171+
self.assertEqual(len(nested[2]), 2)
2172+
2173+
def test_reference_semantics(self):
2174+
"""
2175+
Test that reference semantics are honoured; that modifications made
2176+
to a ScriptDict in TorchScript are visible in Python.
2177+
"""
2178+
data = torch.jit.script({1: 2})
2179+
self._script_dict_add(data, 3, 4)
2180+
2181+
# The mutation should be visible in the original dictionary.
2182+
self.assertEqual(len(data), 2)
2183+
self.assertTrue(3 in data)
2184+
self.assertEqual(data[3], 4)

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from jit.test_type_sharing import TestTypeSharing # noqa: F401
1010
from jit.test_logging import TestLogging # noqa: F401
1111
from jit.test_backends import TestBackends # noqa: F401
12-
from jit.test_list_dict import TestList, TestDict, TestNamedTuple # noqa: F401
12+
from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict # noqa: F401
1313
from jit.test_async import TestAsync # noqa: F401
1414
from jit.test_data_parallel import TestDataParallel # noqa: F401
1515
from jit.test_models import TestModels # noqa: F401

tools/build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ libtorch_python_core_sources = [
621621
"torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp",
622622
"torch/csrc/jit/python/python_arg_flatten.cpp",
623623
"torch/csrc/jit/python/python_custom_class.cpp",
624+
"torch/csrc/jit/python/python_dict.cpp",
624625
"torch/csrc/jit/python/python_interpreter.cpp",
625626
"torch/csrc/jit/python/python_ir.cpp",
626627
"torch/csrc/jit/python/python_tracer.cpp",

torch/csrc/jit/python/pybind_utils.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <torch/csrc/jit/python/pybind_utils.h>
2-
2+
#include <torch/csrc/jit/python/python_dict.h>
33
#include <torch/csrc/jit/python/python_ivalue.h>
44

55
namespace torch {
@@ -136,6 +136,23 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
136136
}
137137
case TypeKind::DictType: {
138138
const auto& dict_type = type->expect<DictType>();
139+
140+
// If the object is a ScriptDict, retrieve the c10::Dict
141+
// instance inside it.
142+
try {
143+
auto script_dict = py::cast<ScriptDict>(obj);
144+
return script_dict.dict_;
145+
} catch (py::cast_error& e) {
146+
}
147+
148+
// If not (i.e. it is a regular Python dictionary), make a new
149+
// c10::Dict.
150+
151+
TORCH_WARN(
152+
"Script your dictionary using torch.jit.script in order to get reference semantics and reduced copy overhead between Python and TorchScript");
153+
154+
// If not (i.e. it is a regular Python dictionary), make a new
155+
// c10::Dict.
139156
return createGenericDict(
140157
py::cast<py::dict>(obj),
141158
dict_type->getKeyType(),

torch/csrc/jit/python/pybind_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <torch/csrc/jit/frontend/tracer.h>
1919
#include <torch/csrc/jit/python/module_python.h>
2020
#include <torch/csrc/jit/python/python_custom_class.h>
21+
#include <torch/csrc/jit/python/python_dict.h>
2122
#include <torch/csrc/jit/python/python_tracer.h>
2223
#include <torch/csrc/jit/resource_guard.h>
2324
#include <torch/csrc/jit/runtime/operator.h>

0 commit comments

Comments
 (0)