Skip to content

Commit 829b49b

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Output UnionType str rep with () instead of [] (pytorch#69502)
Summary: Pull Request resolved: pytorch#69502 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D32902781 Pulled By: tugsbayasgalan fbshipit-source-id: 67a73b209575437477cdbd3eb8f685019709e99c
1 parent a8232ee commit 829b49b

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

aten/src/ATen/core/jit_type.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ struct TORCH_API UnionType : public Type {
147147
protected:
148148
explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
149149
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
150-
std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const;
150+
std::string unionStr(
151+
TypePrinter printer = nullptr,
152+
bool is_annotation_str = false) const;
151153
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
152154
bool has_free_variables_;
153155
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)

aten/src/ATen/core/type.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -1099,8 +1099,8 @@ bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
10991099
});
11001100
}
11011101

1102-
1103-
std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) const {
1102+
std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str)
1103+
const {
11041104
std::stringstream ss;
11051105

11061106
bool can_hold_numbertype = this->canHoldType(*NumberType::get());
@@ -1116,7 +1116,10 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) con
11161116
return false;
11171117
};
11181118

1119-
ss << "Union[";
1119+
std::string open_delimeter = is_annotation_str ? "[" : "(";
1120+
std::string close_delimeter = is_annotation_str ? "]" : ")";
1121+
1122+
ss << "Union" + open_delimeter;
11201123
bool printed = false;
11211124
for (size_t i = 0; i < types_.size(); ++i) {
11221125
if (!can_hold_numbertype || !is_numbertype(types_[i])) {
@@ -1141,7 +1144,7 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) con
11411144
ss << NumberType::get()->str();
11421145
}
11431146
}
1144-
ss << "]";
1147+
ss << close_delimeter;
11451148
return ss.str();
11461149
}
11471150

test/jit/test_union.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ class TestUnion(JitTestCase):
3333
equivalent functions to emulate `checkScript`.
3434
"""
3535

36+
def test_check_union_annotation(self):
37+
def test_func(a: Union[int, float], b: Optional[int]):
38+
return 0
39+
40+
scripted_func = torch.jit.script(test_func)
41+
graph_rep = str(scripted_func.graph)
42+
code_rep = str(scripted_func.code)
43+
# TS graph IR for Union should be annotated as Union()
44+
FileCheck().check("Union(").check("int?").run(graph_rep)
45+
# Serialized code for Union should be annotated as Union[]
46+
FileCheck().check("Union[").check("Optional[int]").run(code_rep)
47+
self.checkScript(test_func, (5, 6))
48+
# this shouldn't error out
49+
torch._C.parse_ir(str(scripted_func.graph))
50+
3651
def test_union_with_scalar_values(self):
3752
def fn(x: Union[int, float]) -> str:
3853
return "foo"
@@ -210,7 +225,7 @@ def fn(x: Union[Union[int, str], float]) -> str:
210225

211226
s = fn.graph
212227

213-
FileCheck().check("x : Union[float, int, str]") \
228+
FileCheck().check("x : Union(float, int, str)") \
214229
.run(s)
215230

216231
def test_unions_of_a_single_argument_vanish(self):
@@ -230,7 +245,7 @@ def fn(x: Union[int, str, int]) -> str:
230245

231246
s = fn.graph
232247

233-
FileCheck().check("x : Union[int, str]") \
248+
FileCheck().check("x : Union(int, str)") \
234249
.run(s)
235250

236251
def test_union_redundant_arguments_are_skipped_optional(self):
@@ -240,7 +255,7 @@ def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
240255

241256
s = fn.graph
242257

243-
FileCheck().check("x : Union[float, int, NoneType]") \
258+
FileCheck().check("x : Union(float, int, NoneType)") \
244259
.run(s)
245260

246261
def test_union_redundant_arguments_are_skipped_subtyping(self):
@@ -250,7 +265,7 @@ def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
250265

251266
s = fn.graph
252267

253-
FileCheck().check("x : Union[(int?, int), str]") \
268+
FileCheck().check("x : Union((int?, int), str)") \
254269
.run(s)
255270

256271
def test_union_redundant_arguments_are_skipped_container(self):
@@ -260,7 +275,7 @@ def fn(x: Union[List[str], List[float], List[str]]) -> str:
260275

261276
s = fn.graph
262277

263-
FileCheck().check("x : Union[float[], str[]]") \
278+
FileCheck().check("x : Union(float[], str[])") \
264279
.run(s)
265280

266281
def test_union_argument_order_is_ignored(self):
@@ -273,7 +288,7 @@ def fn2(x: Union[str, int]) -> str:
273288
return "foo"
274289

275290
for s in (fn1.graph, fn2.graph):
276-
FileCheck().check("x : Union[int, str]") \
291+
FileCheck().check("x : Union(int, str)") \
277292
.run(s)
278293

279294
def test_union_argument_order_is_ignored_container(self):
@@ -286,7 +301,7 @@ def fn2(x: Union[List[int], List[str]]) -> str:
286301
return "foo"
287302

288303
for s in (fn1.graph, fn2.graph):
289-
FileCheck().check("x : Union[int[], str[]]") \
304+
FileCheck().check("x : Union(int[], str[])") \
290305
.run(s)
291306

292307
def test_union_T_None_is_equivalent_to_optional_T(self):

0 commit comments

Comments
 (0)