@@ -258,52 +258,50 @@ def run_pass(self, name, trace):
258
258
return graph
259
259
260
260
def checkScript (self ,
261
- func ,
261
+ script ,
262
262
inputs ,
263
263
optimize = True ,
264
264
outputs = None ,
265
265
name = 'func' ,
266
266
capture_output = False ,
267
267
frames_up = 1 ,
268
268
check_expected = False ):
269
- if isinstance (func , str ):
270
- cu = torch .jit .CompilationUnit (func , optimize , _frames_up = frames_up )
271
- scripted_fn = getattr (cu , name )
269
+ if isinstance (script , str ):
270
+ cu = torch .jit .CompilationUnit (script , optimize , _frames_up = frames_up )
271
+ ge = getattr (cu , name )
272
272
else :
273
273
if capture_output :
274
274
with self .capture_stdout () as captured :
275
- outputs = func (* inputs )
275
+ outputs = script (* inputs )
276
276
else :
277
- outputs = func (* inputs )
277
+ outputs = script (* inputs )
278
278
# Check the string frontend first
279
- source = textwrap .dedent (inspect .getsource (func ))
279
+ source = textwrap .dedent (inspect .getsource (script ))
280
280
self .checkScript (
281
281
source ,
282
282
inputs ,
283
283
optimize ,
284
284
outputs ,
285
- func .__name__ ,
285
+ script .__name__ ,
286
286
capture_output ,
287
287
frames_up = 2 ,
288
288
check_expected = check_expected )
289
289
# Continue checking the Python frontend
290
- scripted_fn = torch .jit .script (func , optimize , _frames_up = 1 )
290
+ ge = torch .jit .script (script , optimize , _frames_up = 1 )
291
291
292
292
if capture_output :
293
+ with self .capture_stdout () as captured :
294
+ outputs_ge = ge (* inputs )
293
295
if not IS_WINDOWS :
294
- with self .capture_stdout () as script_stdout :
295
- outputs_ge = scripted_fn (* inputs )
296
- with self .capture_stdout () as python_stdout :
297
- outputs_ge = scripted_fn (* inputs )
298
- self .assertEqual (script_stdout , python_stdout )
296
+ self .assertExpected (captured [0 ], subname = 'stdout' )
299
297
else :
300
- outputs_ge = scripted_fn (* inputs )
298
+ outputs_ge = ge (* inputs )
301
299
self .assertEqual (outputs , outputs_ge )
302
300
303
301
if check_expected :
304
- self .assertExpectedGraph (scripted_fn .graph )
302
+ self .assertExpectedGraph (ge .graph )
305
303
306
- return scripted_fn
304
+ return ge
307
305
308
306
def checkTrace (self , func , reference_tensors , input_tensors = None ,
309
307
optimize = True , drop = None , allow_unused = False , verbose = False ,
0 commit comments