Skip to content

Commit f3393b7

Browse files
larryliu0820pytorchmergebot
authored andcommitted
[torchgen] Introduce Executorch types and signatures (pytorch#90781)
Retry of pytorch#90591, which is a retry of pytorch#89595. Reverted due to dependency PR breaking internal fbcode. ## Forked BaseCppType Created a module for Executorch: `torchgen.executorch`. ## In `torchgen.executorch.api.types.types`: * Define `BaseCppType` with `torch::executor` namespace. ## In `torchgen.executorch.api.et_cpp`: * Help generate `NamedCType` for `ExecutorchCppSignature` arguments. ## In `torchgen.executorch.api.types.signatures`: * Define the signature using these types. (`ExecutorchCppSignature`) ## In `torchgen.executorch.api.types.__init__`: * Suppress flake8 error for `import *`. Pull Request resolved: pytorch#90781 Approved by: https://github.com/ezyang
1 parent 4adffe6 commit f3393b7

File tree

8 files changed

+608
-1
lines changed

8 files changed

+608
-1
lines changed

.flake8

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ ignore =
1212
B007,B008,
1313
# these ignores are from flake8-comprehensions; please fix!
1414
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
15-
per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 torchgen/api/types/__init__.py: F401,F403
15+
per-file-ignores =
16+
__init__.py: F401
17+
torch/utils/cpp_extension.py: B950
18+
torchgen/api/types/__init__.py: F401,F403
19+
torchgen/executorch/api/types/__init__.py: F401,F403
1620
optional-ascii-coding = True
1721
exclude =
1822
./.git,

tools/test/test_executorch_types.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import unittest
2+
3+
from torchgen import local
4+
from torchgen.api.types import (
5+
BaseCType,
6+
ConstRefCType,
7+
CType,
8+
longT,
9+
MutRefCType,
10+
NamedCType,
11+
OptionalCType,
12+
TupleCType,
13+
VectorCType,
14+
voidT,
15+
)
16+
from torchgen.executorch.api.et_cpp import argument_type, return_type, returns_type
17+
from torchgen.executorch.api.types import ArrayRefCType, scalarT, tensorListT, tensorT
18+
from torchgen.model import Argument, FunctionSchema, Return
19+
20+
21+
class ExecutorchCppTest(unittest.TestCase):
22+
"""
23+
Test torchgen.executorch.api.cpp
24+
"""
25+
26+
def _test_argumenttype_type(self, arg_str: str, expected: NamedCType) -> None:
27+
arg = Argument.parse(arg_str)
28+
self.assertEqual(str(argument_type(arg, binds=arg.name)), str(expected))
29+
30+
@local.parametrize(
31+
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
32+
)
33+
def test_argumenttype_type(self) -> None:
34+
data = [
35+
("Tensor self", NamedCType("self", ConstRefCType(BaseCType(tensorT)))),
36+
("Tensor(a!) out", NamedCType("out", MutRefCType(BaseCType(tensorT)))),
37+
(
38+
"Tensor? opt",
39+
NamedCType("opt", ConstRefCType(OptionalCType(BaseCType(tensorT)))),
40+
),
41+
("Scalar scalar", NamedCType("scalar", ConstRefCType(BaseCType(scalarT)))),
42+
(
43+
"Scalar? scalar",
44+
NamedCType("scalar", ConstRefCType(OptionalCType(BaseCType(scalarT)))),
45+
),
46+
("int[] size", NamedCType("size", ArrayRefCType(BaseCType(longT)))),
47+
("int? dim", NamedCType("dim", OptionalCType(BaseCType(longT)))),
48+
("Tensor[] weight", NamedCType("weight", BaseCType(tensorListT))),
49+
(
50+
"Scalar[] spacing",
51+
NamedCType("spacing", ArrayRefCType(ConstRefCType(BaseCType(scalarT)))),
52+
),
53+
(
54+
"Tensor?[] weight",
55+
NamedCType("weight", ArrayRefCType(OptionalCType(BaseCType(tensorT)))),
56+
),
57+
(
58+
"SymInt[]? output_size",
59+
NamedCType(
60+
"output_size", OptionalCType(ArrayRefCType(BaseCType(longT)))
61+
),
62+
),
63+
(
64+
"int[]? dims",
65+
NamedCType("dims", OptionalCType(ArrayRefCType(BaseCType(longT)))),
66+
),
67+
]
68+
for d in data:
69+
self._test_argumenttype_type(*d)
70+
71+
def _test_returntype_type(self, ret_str: str, expected: CType) -> None:
72+
ret = Return.parse(ret_str)
73+
self.assertEqual(str(return_type(ret)), str(expected))
74+
75+
@local.parametrize(
76+
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
77+
)
78+
def test_returntype_type(self) -> None:
79+
data = [
80+
("Tensor", BaseCType(tensorT)),
81+
("Tensor(a!)", MutRefCType(BaseCType(tensorT))),
82+
("Tensor[]", VectorCType(BaseCType(tensorT))),
83+
]
84+
for d in data:
85+
self._test_returntype_type(*d)
86+
87+
@local.parametrize(
88+
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
89+
)
90+
def test_returns_type(self) -> None:
91+
func = FunctionSchema.parse(
92+
"min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"
93+
)
94+
expected = TupleCType([BaseCType(tensorT), BaseCType(tensorT)])
95+
self.assertEqual(str(returns_type(func.returns)), str(expected))
96+
97+
@local.parametrize(
98+
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
99+
)
100+
def test_void_return_type(self) -> None:
101+
func = FunctionSchema.parse(
102+
"_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()"
103+
)
104+
expected = BaseCType(voidT)
105+
self.assertEqual(str(returns_type(func.returns)), str(expected))
106+
107+
108+
if __name__ == "__main__":
109+
unittest.main()

torchgen/executorch/__init__.py

Whitespace-only changes.

torchgen/executorch/api/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)