Skip to content

Commit 263b198

Browse files
ezyangfacebook-github-bot
authored andcommitted
Revert D15833924: [jit] Fix stdout capturing, remove some expect files
Differential Revision: D15833924 Original commit changeset: 152972b4c240 fbshipit-source-id: 1d5a2258bc134fdc7bd2cb557bcc05f2289443b6
1 parent 04f09d4 commit 263b198

4 files changed

+26
-17
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Hello, I'm a test
2+
format blank
3+
stuff before hi
4+
hi stuff after
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
0.5000
2+
0.9526
3+
0.9975
4+
0.9999
5+
[ Variable[CPUDoubleType]{4} ] 1 2 [1, 2] [1., 2.]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
1
2+
[ Variable[CPULongType]{} ] abcd 2 1.5

test/jit_utils.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -258,52 +258,50 @@ def run_pass(self, name, trace):
258258
return graph
259259

260260
def checkScript(self,
261-
func,
261+
script,
262262
inputs,
263263
optimize=True,
264264
outputs=None,
265265
name='func',
266266
capture_output=False,
267267
frames_up=1,
268268
check_expected=False):
269-
if isinstance(func, str):
270-
cu = torch.jit.CompilationUnit(func, optimize, _frames_up=frames_up)
271-
scripted_fn = getattr(cu, name)
269+
if isinstance(script, str):
270+
cu = torch.jit.CompilationUnit(script, optimize, _frames_up=frames_up)
271+
ge = getattr(cu, name)
272272
else:
273273
if capture_output:
274274
with self.capture_stdout() as captured:
275-
outputs = func(*inputs)
275+
outputs = script(*inputs)
276276
else:
277-
outputs = func(*inputs)
277+
outputs = script(*inputs)
278278
# Check the string frontend first
279-
source = textwrap.dedent(inspect.getsource(func))
279+
source = textwrap.dedent(inspect.getsource(script))
280280
self.checkScript(
281281
source,
282282
inputs,
283283
optimize,
284284
outputs,
285-
func.__name__,
285+
script.__name__,
286286
capture_output,
287287
frames_up=2,
288288
check_expected=check_expected)
289289
# Continue checking the Python frontend
290-
scripted_fn = torch.jit.script(func, optimize, _frames_up=1)
290+
ge = torch.jit.script(script, optimize, _frames_up=1)
291291

292292
if capture_output:
293+
with self.capture_stdout() as captured:
294+
outputs_ge = ge(*inputs)
293295
if not IS_WINDOWS:
294-
with self.capture_stdout() as script_stdout:
295-
outputs_ge = scripted_fn(*inputs)
296-
with self.capture_stdout() as python_stdout:
297-
outputs_ge = scripted_fn(*inputs)
298-
self.assertEqual(script_stdout, python_stdout)
296+
self.assertExpected(captured[0], subname='stdout')
299297
else:
300-
outputs_ge = scripted_fn(*inputs)
298+
outputs_ge = ge(*inputs)
301299
self.assertEqual(outputs, outputs_ge)
302300

303301
if check_expected:
304-
self.assertExpectedGraph(scripted_fn.graph)
302+
self.assertExpectedGraph(ge.graph)
305303

306-
return scripted_fn
304+
return ge
307305

308306
def checkTrace(self, func, reference_tensors, input_tensors=None,
309307
optimize=True, drop=None, allow_unused=False, verbose=False,

0 commit comments

Comments
 (0)