@@ -33,6 +33,21 @@ class TestUnion(JitTestCase):
33
33
equivalent functions to emulate `checkScript`.
34
34
"""
35
35
36
+ def test_check_union_annotation (self ):
37
+ def test_func (a : Union [int , float ], b : Optional [int ]):
38
+ return 0
39
+
40
+ scripted_func = torch .jit .script (test_func )
41
+ graph_rep = str (scripted_func .graph )
42
+ code_rep = str (scripted_func .code )
43
+ # TS graph IR for Union should be annotated as Union()
44
+ FileCheck ().check ("Union(" ).check ("int?" ).run (graph_rep )
45
+ # Serialized code for Union should be annotated as Union[]
46
+ FileCheck ().check ("Union[" ).check ("Optional[int]" ).run (code_rep )
47
+ self .checkScript (test_func , (5 , 6 ))
48
+ # this shouldn't error out
49
+ torch ._C .parse_ir (str (scripted_func .graph ))
50
+
36
51
def test_union_with_scalar_values (self ):
37
52
def fn (x : Union [int , float ]) -> str :
38
53
return "foo"
@@ -210,7 +225,7 @@ def fn(x: Union[Union[int, str], float]) -> str:
210
225
211
226
s = fn .graph
212
227
213
- FileCheck ().check ("x : Union[ float, int, str] " ) \
228
+ FileCheck ().check ("x : Union( float, int, str) " ) \
214
229
.run (s )
215
230
216
231
def test_unions_of_a_single_argument_vanish (self ):
@@ -230,7 +245,7 @@ def fn(x: Union[int, str, int]) -> str:
230
245
231
246
s = fn .graph
232
247
233
- FileCheck ().check ("x : Union[ int, str] " ) \
248
+ FileCheck ().check ("x : Union( int, str) " ) \
234
249
.run (s )
235
250
236
251
def test_union_redundant_arguments_are_skipped_optional (self ):
@@ -240,7 +255,7 @@ def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
240
255
241
256
s = fn .graph
242
257
243
- FileCheck ().check ("x : Union[ float, int, NoneType] " ) \
258
+ FileCheck ().check ("x : Union( float, int, NoneType) " ) \
244
259
.run (s )
245
260
246
261
def test_union_redundant_arguments_are_skipped_subtyping (self ):
@@ -250,7 +265,7 @@ def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
250
265
251
266
s = fn .graph
252
267
253
- FileCheck ().check ("x : Union[( int?, int), str] " ) \
268
+ FileCheck ().check ("x : Union(( int?, int), str) " ) \
254
269
.run (s )
255
270
256
271
def test_union_redundant_arguments_are_skipped_container (self ):
@@ -260,7 +275,7 @@ def fn(x: Union[List[str], List[float], List[str]]) -> str:
260
275
261
276
s = fn .graph
262
277
263
- FileCheck ().check ("x : Union[ float[], str[]] " ) \
278
+ FileCheck ().check ("x : Union( float[], str[]) " ) \
264
279
.run (s )
265
280
266
281
def test_union_argument_order_is_ignored (self ):
@@ -273,7 +288,7 @@ def fn2(x: Union[str, int]) -> str:
273
288
return "foo"
274
289
275
290
for s in (fn1 .graph , fn2 .graph ):
276
- FileCheck ().check ("x : Union[ int, str] " ) \
291
+ FileCheck ().check ("x : Union( int, str) " ) \
277
292
.run (s )
278
293
279
294
def test_union_argument_order_is_ignored_container (self ):
@@ -286,7 +301,7 @@ def fn2(x: Union[List[int], List[str]]) -> str:
286
301
return "foo"
287
302
288
303
for s in (fn1 .graph , fn2 .graph ):
289
- FileCheck ().check ("x : Union[ int[], str[]] " ) \
304
+ FileCheck ().check ("x : Union( int[], str[]) " ) \
290
305
.run (s )
291
306
292
307
def test_union_T_None_is_equivalent_to_optional_T (self ):
0 commit comments