Skip to content

Commit c8c5e11

Browse files
houseroadfacebook-github-bot
authored andcommitted
Support variadic returns in Schema's operator<< (pytorch#23204)
Summary: old: prim::PythonOp(...) -> new: prim::PythonOp(...) -> ... Pull Request resolved: pytorch#23204 ghstack-source-id: 87208343 Reviewed By: zrphercule Differential Revision: D16433592 fbshipit-source-id: 36cbb329188f112e09c3b1708a8090781b830dfe
1 parent 34f5356 commit c8c5e11

13 files changed

+90
-84
lines changed

aten/src/ATen/core/function_schema_inl.h

+14-8
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,22 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema)
2929
}
3030

3131
out << ") -> ";
32-
if (schema.returns().size() == 1) {
33-
out << schema.returns().at(0).type()->str();
34-
} else if (schema.returns().size() > 1) {
35-
out << "(";
36-
for (size_t i = 0; i < schema.returns().size(); ++i) {
37-
if (i > 0) out << ", ";
38-
out << schema.returns()[i].type()->str();
32+
33+
const auto& returns = schema.returns();
34+
out << "(";
35+
for(size_t i = 0; i < returns.size(); ++i) {
36+
if (i > 0) {
37+
out << ", ";
3938
}
40-
out << ")";
39+
out << returns.at(i);
40+
}
41+
if (schema.is_varret()) {
42+
if (returns.size() != 0) {
43+
out << ", ";
44+
}
45+
out << "...";
4146
}
47+
out << ")";
4248
return out;
4349
}
4450

Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
foo(Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
2-
foo(Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
3-
foo(str x, (Tensor, Tensor) y) -> (str, str)
4-
foo(int x, (Tensor, Tensor) y) -> (int, int)
5-
foo(bool x, (Tensor, Tensor) y) -> (bool, bool)
6-
foo(float[3] x, (Tensor, Tensor) y) -> (float[], float[])
7-
foo(int[2] x, (Tensor, Tensor) y) -> (int[], int[])
8-
foo(int[] x, (Tensor, Tensor) y) -> (int[], int[])
9-
foo(int? x, (Tensor, Tensor) y) -> (int?, int?)
1+
foo(Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
2+
foo(Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
3+
foo(str x, (Tensor, Tensor) y) -> ((str, str))
4+
foo(int x, (Tensor, Tensor) y) -> ((int, int))
5+
foo(bool x, (Tensor, Tensor) y) -> ((bool, bool))
6+
foo(float[3] x, (Tensor, Tensor) y) -> ((float[], float[]))
7+
foo(int[2] x, (Tensor, Tensor) y) -> ((int[], int[]))
8+
foo(int[] x, (Tensor, Tensor) y) -> ((int[], int[]))
9+
foo(int? x, (Tensor, Tensor) y) -> ((int?, int?))
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
foo(ClassType<FooModule> self, Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
2-
foo(ClassType<FooModule> self, Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
3-
foo(ClassType<FooModule> self, str x, (Tensor, Tensor) y) -> (str, str)
4-
foo(ClassType<FooModule> self, int x, (Tensor, Tensor) y) -> (int, int)
5-
foo(ClassType<FooModule> self, bool x, (Tensor, Tensor) y) -> (bool, bool)
6-
foo(ClassType<FooModule> self, float[3] x, (Tensor, Tensor) y) -> (float[], float[])
7-
foo(ClassType<FooModule> self, int[2] x, (Tensor, Tensor) y) -> (int[], int[])
8-
foo(ClassType<FooModule> self, int[] x, (Tensor, Tensor) y) -> (int[], int[])
9-
foo(ClassType<FooModule> self, int? x, (Tensor, Tensor) y) -> (int?, int?)
1+
foo(ClassType<FooModule> self, Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
2+
foo(ClassType<FooModule> self, Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
3+
foo(ClassType<FooModule> self, str x, (Tensor, Tensor) y) -> ((str, str))
4+
foo(ClassType<FooModule> self, int x, (Tensor, Tensor) y) -> ((int, int))
5+
foo(ClassType<FooModule> self, bool x, (Tensor, Tensor) y) -> ((bool, bool))
6+
foo(ClassType<FooModule> self, float[3] x, (Tensor, Tensor) y) -> ((float[], float[]))
7+
foo(ClassType<FooModule> self, int[2] x, (Tensor, Tensor) y) -> ((int[], int[]))
8+
foo(ClassType<FooModule> self, int[] x, (Tensor, Tensor) y) -> ((int[], int[]))
9+
foo(ClassType<FooModule> self, int? x, (Tensor, Tensor) y) -> ((int?, int?))
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
foo(Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
2-
foo(Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
3-
foo(str x, (Tensor, Tensor) y) -> (str, str)
4-
foo(int x, (Tensor, Tensor) y) -> (int, int)
5-
foo(bool x, (Tensor, Tensor) y) -> (bool, bool)
6-
foo(float[3] x, (Tensor, Tensor) y) -> (float[], float[])
7-
foo(int[2] x, (Tensor, Tensor) y) -> (int[], int[])
8-
foo(int[] x, (Tensor, Tensor) y) -> (int[], int[])
9-
foo(int? x, (Tensor, Tensor) y) -> (int?, int?)
1+
foo(Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
2+
foo(Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
3+
foo(str x, (Tensor, Tensor) y) -> ((str, str))
4+
foo(int x, (Tensor, Tensor) y) -> ((int, int))
5+
foo(bool x, (Tensor, Tensor) y) -> ((bool, bool))
6+
foo(float[3] x, (Tensor, Tensor) y) -> ((float[], float[]))
7+
foo(int[2] x, (Tensor, Tensor) y) -> ((int[], int[]))
8+
foo(int[] x, (Tensor, Tensor) y) -> ((int[], int[]))
9+
foo(int? x, (Tensor, Tensor) y) -> ((int?, int?))
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
foo(ClassType<FooModule> self, Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
2-
foo(ClassType<FooModule> self, Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
3-
foo(ClassType<FooModule> self, str x, (Tensor, Tensor) y) -> (str, str)
4-
foo(ClassType<FooModule> self, int x, (Tensor, Tensor) y) -> (int, int)
5-
foo(ClassType<FooModule> self, bool x, (Tensor, Tensor) y) -> (bool, bool)
6-
foo(ClassType<FooModule> self, float[3] x, (Tensor, Tensor) y) -> (float[], float[])
7-
foo(ClassType<FooModule> self, int[2] x, (Tensor, Tensor) y) -> (int[], int[])
8-
foo(ClassType<FooModule> self, int[] x, (Tensor, Tensor) y) -> (int[], int[])
9-
foo(ClassType<FooModule> self, int? x, (Tensor, Tensor) y) -> (int?, int?)
1+
foo(ClassType<FooModule> self, Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
2+
foo(ClassType<FooModule> self, Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
3+
foo(ClassType<FooModule> self, str x, (Tensor, Tensor) y) -> ((str, str))
4+
foo(ClassType<FooModule> self, int x, (Tensor, Tensor) y) -> ((int, int))
5+
foo(ClassType<FooModule> self, bool x, (Tensor, Tensor) y) -> ((bool, bool))
6+
foo(ClassType<FooModule> self, float[3] x, (Tensor, Tensor) y) -> ((float[], float[]))
7+
foo(ClassType<FooModule> self, int[2] x, (Tensor, Tensor) y) -> ((int[], int[]))
8+
foo(ClassType<FooModule> self, int[] x, (Tensor, Tensor) y) -> ((int[], int[]))
9+
foo(ClassType<FooModule> self, int? x, (Tensor, Tensor) y) -> ((int?, int?))
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
foo(Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
2-
foo(Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
3-
foo(str x, (Tensor, Tensor) y) -> (str, str)
4-
foo(int x, (Tensor, Tensor) y) -> (int, int)
5-
foo(bool x, (Tensor, Tensor) y) -> (bool, bool)
6-
foo(float[3] x, (Tensor, Tensor) y) -> (float[], float[])
7-
foo(int[2] x, (Tensor, Tensor) y) -> (int[], int[])
8-
foo(int[] x, (Tensor, Tensor) y) -> (int[], int[])
9-
foo(int? x, (Tensor, Tensor) y) -> (int?, int?)
1+
foo(Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
2+
foo(Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
3+
foo(str x, (Tensor, Tensor) y) -> ((str, str))
4+
foo(int x, (Tensor, Tensor) y) -> ((int, int))
5+
foo(bool x, (Tensor, Tensor) y) -> ((bool, bool))
6+
foo(float[3] x, (Tensor, Tensor) y) -> ((float[], float[]))
7+
foo(int[2] x, (Tensor, Tensor) y) -> ((int[], int[]))
8+
foo(int[] x, (Tensor, Tensor) y) -> ((int[], int[]))
9+
foo(int? x, (Tensor, Tensor) y) -> ((int?, int?))
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
foo(ClassType<TestModule> self, Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
2-
foo(ClassType<TestModule> self, Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
3-
foo(ClassType<TestModule> self, str x, (Tensor, Tensor) y) -> (str, str)
4-
foo(ClassType<TestModule> self, int x, (Tensor, Tensor) y) -> (int, int)
5-
foo(ClassType<TestModule> self, bool x, (Tensor, Tensor) y) -> (bool, bool)
6-
foo(ClassType<TestModule> self, float[3] x, (Tensor, Tensor) y) -> (float[], float[])
7-
foo(ClassType<TestModule> self, int[2] x, (Tensor, Tensor) y) -> (int[], int[])
8-
foo(ClassType<TestModule> self, int[] x, (Tensor, Tensor) y) -> (int[], int[])
9-
foo(ClassType<TestModule> self, int? x, (Tensor, Tensor) y) -> (int?, int?)
1+
foo(ClassType<TestModule> self, Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
2+
foo(ClassType<TestModule> self, Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
3+
foo(ClassType<TestModule> self, str x, (Tensor, Tensor) y) -> ((str, str))
4+
foo(ClassType<TestModule> self, int x, (Tensor, Tensor) y) -> ((int, int))
5+
foo(ClassType<TestModule> self, bool x, (Tensor, Tensor) y) -> ((bool, bool))
6+
foo(ClassType<TestModule> self, float[3] x, (Tensor, Tensor) y) -> ((float[], float[]))
7+
foo(ClassType<TestModule> self, int[2] x, (Tensor, Tensor) y) -> ((int[], int[]))
8+
foo(ClassType<TestModule> self, int[] x, (Tensor, Tensor) y) -> ((int[], int[]))
9+
foo(ClassType<TestModule> self, int? x, (Tensor, Tensor) y) -> ((int?, int?))
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
foo(Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
2-
foo(Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
3-
foo(str x, (Tensor, Tensor) y) -> (str, str)
4-
foo(int x, (Tensor, Tensor) y) -> (int, int)
5-
foo(bool x, (Tensor, Tensor) y) -> (bool, bool)
6-
foo(float[3] x, (Tensor, Tensor) y) -> (float[], float[])
7-
foo(int[2] x, (Tensor, Tensor) y) -> (int[], int[])
8-
foo(int[] x, (Tensor, Tensor) y) -> (int[], int[])
9-
foo(int? x, (Tensor, Tensor) y) -> (int?, int?)
1+
foo(Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
2+
foo(Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
3+
foo(str x, (Tensor, Tensor) y) -> ((str, str))
4+
foo(int x, (Tensor, Tensor) y) -> ((int, int))
5+
foo(bool x, (Tensor, Tensor) y) -> ((bool, bool))
6+
foo(float[3] x, (Tensor, Tensor) y) -> ((float[], float[]))
7+
foo(int[2] x, (Tensor, Tensor) y) -> ((int[], int[]))
8+
foo(int[] x, (Tensor, Tensor) y) -> ((int[], int[]))
9+
foo(int? x, (Tensor, Tensor) y) -> ((int?, int?))
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
foo(ClassType<TestModule> self, Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
2-
foo(ClassType<TestModule> self, Tensor x, (Tensor, Tensor) y) -> (Tensor, Tensor)
3-
foo(ClassType<TestModule> self, str x, (Tensor, Tensor) y) -> (str, str)
4-
foo(ClassType<TestModule> self, int x, (Tensor, Tensor) y) -> (int, int)
5-
foo(ClassType<TestModule> self, bool x, (Tensor, Tensor) y) -> (bool, bool)
6-
foo(ClassType<TestModule> self, float[3] x, (Tensor, Tensor) y) -> (float[], float[])
7-
foo(ClassType<TestModule> self, int[2] x, (Tensor, Tensor) y) -> (int[], int[])
8-
foo(ClassType<TestModule> self, int[] x, (Tensor, Tensor) y) -> (int[], int[])
9-
foo(ClassType<TestModule> self, int? x, (Tensor, Tensor) y) -> (int?, int?)
1+
foo(ClassType<TestModule> self, Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
2+
foo(ClassType<TestModule> self, Tensor x, (Tensor, Tensor) y) -> ((Tensor, Tensor))
3+
foo(ClassType<TestModule> self, str x, (Tensor, Tensor) y) -> ((str, str))
4+
foo(ClassType<TestModule> self, int x, (Tensor, Tensor) y) -> ((int, int))
5+
foo(ClassType<TestModule> self, bool x, (Tensor, Tensor) y) -> ((bool, bool))
6+
foo(ClassType<TestModule> self, float[3] x, (Tensor, Tensor) y) -> ((float[], float[]))
7+
foo(ClassType<TestModule> self, int[2] x, (Tensor, Tensor) y) -> ((int[], int[]))
8+
foo(ClassType<TestModule> self, int[] x, (Tensor, Tensor) y) -> ((int[], int[]))
9+
foo(ClassType<TestModule> self, int? x, (Tensor, Tensor) y) -> ((int?, int?))
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
foo(Tensor x, (Tensor, Tensor, Tensor) y, (Tensor, (Tensor, Tensor)) z) -> Tensor
1+
foo(Tensor x, (Tensor, Tensor, Tensor) y, (Tensor, (Tensor, Tensor)) z) -> (Tensor)
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
forward(ClassType<SM> self, (Tensor, Tensor) x, Tensor y) -> (Tensor, Tensor, Tensor)
1+
forward(ClassType<SM> self, (Tensor, Tensor) x, Tensor y) -> ((Tensor, Tensor, Tensor))
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
foo(Tensor x, ((Tensor, Tensor), Tensor) y) -> (Tensor, Tensor)
1+
foo(Tensor x, ((Tensor, Tensor), Tensor) y) -> ((Tensor, Tensor))
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
foo(Tensor x, ((Tensor, Tensor), Tensor) y) -> (Tensor, Tensor)
1+
foo(Tensor x, ((Tensor, Tensor), Tensor) y) -> ((Tensor, Tensor))

0 commit comments

Comments
 (0)