Skip to content

Commit 52d1ffb

Browse files
zou3519facebook-github-bot
authored andcommitted
Teach pytrees about namedtuple (pytorch#62292)
Summary: Pull Request resolved: pytorch#62292 This PR adds pytree support for namedtuples. The challenge about namedtuple is that each namedtuple class is actually different. This PR does the following: - it adds a namedtuple flatten/unflatten. The flatten function returns a context that is the actual type of the namedtuple subclass. The unflatten function uses that type to reconstruct the namedtuple - Special cases all pytree logic to consider all namedtuples the same. This is done by creating a `_get_node_type(pytree)` helper function that returns `namedtuple` if `pytree` is any namedtuple subclass. The effect of this is that all namedtuple subclasses will go through the namedtuple flatten/unflatten functions - Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function flattens the namedtuple based on the spec and is equivalent to the `_tuple_flatten_spec`. Test Plan - new tests in test/test_pytree.py and test/test_fx.py Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D29947302 Pulled By: zou3519 fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
1 parent c06b6e4 commit 52d1ffb

File tree

4 files changed

+75
-6
lines changed

4 files changed

+75
-6
lines changed

test/test_fx.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torch.fx.experimental.rewriter import RewritingTracer
2929
from torch.fx.operator_schemas import get_signature_for_torch_op
3030
from copy import deepcopy
31+
from collections import namedtuple
3132

3233
from torch.fx.proxy import TraceError
3334

@@ -65,6 +66,11 @@ def forward(self, x):
6566
def a_non_torch_leaf(a, b):
6667
return a + b
6768

69+
# used in test_pytree. It's all the way out here because pickling a GraphModule
70+
# that uses Point errors out if Point is local to the function
71+
Point = namedtuple('Point', ['x', 'y'])
72+
73+
6874
# Test wrap() passing both a function name as well as a function
6975
# directly
7076
def a_lifted_leaf(a, b):
@@ -2610,6 +2616,8 @@ def f_dict_list_map(x):
26102616
def f_dict_add(x):
26112617
return x['a'] + sum(x['z'])
26122618

2619+
def f_namedtuple_add(x):
2620+
return x.x + x.y
26132621

26142622
pytree._register_pytree_node(
26152623
Foo,
@@ -2639,6 +2647,7 @@ def f_return_custom(x):
26392647
(f_custom, Foo(PH, 3)),
26402648
(f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
26412649
# (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
2650+
(f_namedtuple_add, Point(PH, PH)),
26422651
]
26432652

26442653
def verify_pytree(f, inp):

test/test_pytree.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch.testing._internal.common_utils import TestCase, run_tests
33
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten, TreeSpec, LeafSpec
44
from torch.utils._pytree import _broadcast_to_and_flatten
5+
from collections import namedtuple
56

67
class TestPytree(TestCase):
78
def test_treespec_equality(self):
@@ -59,6 +60,33 @@ def run_test(tup):
5960
run_test((1., 2))
6061
run_test((torch.tensor([1., 2]), 2, 10, 9, 11))
6162

63+
def test_flatten_unflatten_namedtuple(self):
64+
Point = namedtuple('Point', ['x', 'y'])
65+
66+
def run_test(tup):
67+
expected_spec = TreeSpec(namedtuple, Point, [LeafSpec() for _ in tup])
68+
values, treespec = tree_flatten(tup)
69+
self.assertTrue(isinstance(values, list))
70+
self.assertEqual(values, list(tup))
71+
self.assertEqual(treespec, expected_spec)
72+
73+
unflattened = tree_unflatten(values, treespec)
74+
self.assertEqual(unflattened, tup)
75+
self.assertTrue(isinstance(unflattened, Point))
76+
77+
run_test(Point(1., 2))
78+
run_test(Point(torch.tensor(1.), 2))
79+
80+
def test_flatten_unflatten_torch_namedtuple_return_type(self):
81+
x = torch.randn(3, 3)
82+
expected = torch.max(x, dim=0)
83+
84+
values, spec = tree_flatten(expected)
85+
result = tree_unflatten(values, spec)
86+
87+
self.assertEqual(type(result), type(expected))
88+
self.assertEqual(result, expected)
89+
6290
def test_flatten_unflatten_dict(self):
6391
def run_test(tup):
6492
expected_spec = TreeSpec(dict, list(tup.keys()),

torch/fx/_pytree.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Callable, Any, Tuple, List, Dict, Type
1+
from typing import Callable, Any, Tuple, List, Dict, Type, NamedTuple
22
from torch.utils._pytree import PyTree, TreeSpec, LeafSpec
3+
from collections import namedtuple
34

45
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
56

@@ -32,6 +33,10 @@ def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
3233
def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
3334
return [d[i] for i in range(len(spec.children_specs))]
3435

36+
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
37+
return [d[i] for i in range(len(spec.children_specs))]
38+
3539
register_pytree_flatten_spec(dict, _dict_flatten_spec)
3640
register_pytree_flatten_spec(list, _list_flatten_spec)
3741
register_pytree_flatten_spec(tuple, _tuple_flatten_spec)
42+
register_pytree_flatten_spec(namedtuple, _tuple_flatten_spec)

torch/utils/_pytree.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional
2+
from collections import namedtuple
23

34
"""
45
Contains utility functions for working with nested python data structures.
@@ -56,14 +57,38 @@ def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
5657
def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
5758
return tuple(values)
5859

60+
def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
61+
return list(d), type(d)
62+
63+
def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple:
64+
return cast(NamedTuple, context(*values))
65+
5966
_register_pytree_node(dict, _dict_flatten, _dict_unflatten)
6067
_register_pytree_node(list, _list_flatten, _list_unflatten)
6168
_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
69+
_register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten)
70+
6271

72+
# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
73+
def _is_namedtuple_instance(pytree: Any) -> bool:
74+
typ = type(pytree)
75+
bases = typ.__bases__
76+
if len(bases) != 1 or bases[0] != tuple:
77+
return False
78+
fields = getattr(typ, '_fields', None)
79+
if not isinstance(fields, tuple):
80+
return False
81+
return all(type(entry) == str for entry in fields)
82+
83+
def _get_node_type(pytree: Any) -> Any:
84+
if _is_namedtuple_instance(pytree):
85+
return namedtuple
86+
return type(pytree)
6387

6488
# A leaf is defined as anything that is not a Node.
6589
def _is_leaf(pytree: PyTree) -> bool:
66-
return type(pytree) not in SUPPORTED_NODES.keys()
90+
return _get_node_type(pytree) not in SUPPORTED_NODES.keys()
91+
6792

6893
# A TreeSpec represents the structure of a pytree. It holds:
6994
# "type": the type of root Node of the pytree
@@ -105,7 +130,8 @@ def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
105130
if _is_leaf(pytree):
106131
return [pytree], LeafSpec()
107132

108-
flatten_fn = SUPPORTED_NODES[type(pytree)].flatten_fn
133+
node_type = _get_node_type(pytree)
134+
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
109135
child_pytrees, context = flatten_fn(pytree)
110136

111137
# Recursively flatten the children
@@ -116,7 +142,7 @@ def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
116142
result += flat
117143
children_specs.append(child_spec)
118144

119-
return result, TreeSpec(type(pytree), context, children_specs)
145+
return result, TreeSpec(node_type, context, children_specs)
120146

121147

122148
def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
@@ -167,10 +193,11 @@ def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[A
167193
return [pytree] * spec.num_leaves
168194
if isinstance(spec, LeafSpec):
169195
return None
170-
if type(pytree) != spec.type:
196+
node_type = _get_node_type(pytree)
197+
if node_type != spec.type:
171198
return None
172199

173-
flatten_fn = SUPPORTED_NODES[type(pytree)].flatten_fn
200+
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
174201
child_pytrees, ctx = flatten_fn(pytree)
175202

176203
# Check if the Node is different from the spec

0 commit comments

Comments
 (0)