Skip to content

Commit 5ccf28d

Browse files
soulitzerfacebook-github-bot
authored andcommitted
Do not use ZeroTensor for inplace ops (pytorch#69998)
Summary: Pull Request resolved: pytorch#69998 Fixes: pytorch#69855 The check for undefined grads for forward AD was not being run because `check_undefined_grads` was only passed as True by OpInfo for backward AD. This PR updates gradcheck to interpret `check_undefined_grads` as possibly for forward or backward AD. This PR also updates codegen to 1) not use ZeroTensor for `self` when the op is inplace. 2) only create zeros (either through ZeroTensor or at::zeros) if the tensor itself is not undefined. Previously we would error in this case when we call `.options` on the undefined tensor. ~TODO: undo the skips that are due to the original issue~ Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D33235973 Pulled By: soulitzer fbshipit-source-id: 5769b6d6ca123b2bed31dc2bc6bc8e4701581891
1 parent 3116d87 commit 5ccf28d

File tree

4 files changed

+23
-65
lines changed

4 files changed

+23
-65
lines changed

test/test_ops.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def _fn(t, *args, **kwargs):
707707
return _fn
708708

709709
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
710-
check_undefined_grad=True, check_batched_grad=None, check_batched_forward_grad=False):
710+
check_batched_grad=None, check_batched_forward_grad=False):
711711
assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
712712
# NB: check_backward_ad does not affect gradgradcheck (always True)
713713
if variant is None:
@@ -756,7 +756,7 @@ def fn(*inputs):
756756
fast_mode=op.gradcheck_fast_mode,
757757
check_forward_ad=check_forward_ad,
758758
check_backward_ad=check_backward_ad,
759-
check_undefined_grad=check_undefined_grad,
759+
check_undefined_grad=True,
760760
check_batched_forward_grad=check_batched_forward_grad))
761761
elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check
762762
self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
@@ -779,10 +779,9 @@ def fn(*inputs):
779779
self.assertTrue(False, msg="Unknown check requested!")
780780

781781
def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
782-
check_undefined_grad=True, check_batched_grad=None, check_batched_forward_grad=False):
782+
check_batched_grad=None, check_batched_forward_grad=False):
783783
return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
784-
check_backward_ad=check_backward_ad, check_undefined_grad=check_undefined_grad,
785-
check_batched_grad=check_batched_grad,
784+
check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
786785
check_batched_forward_grad=check_batched_forward_grad)
787786

788787
def _skip_helper(self, op, device, dtype):
@@ -864,8 +863,7 @@ def call_grad_test_helper():
864863
check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or
865864
(op.check_inplace_batched_forward_grad and is_inplace))
866865
self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False,
867-
check_undefined_grad=False, check_batched_grad=False,
868-
check_batched_forward_grad=check_batched_forward_grad)
866+
check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad)
869867
if op.supports_forward_ad:
870868
call_grad_test_helper()
871869
else:

tools/autograd/gen_variable_type.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,8 @@
322322
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate("""\
323323
auto ${inp}_t_raw = toNonOptFwGrad(${inp});
324324
auto ${inp}_tensor = toNonOptTensor(${inp});
325-
auto ${inp}_t = ${inp}_t_raw.defined() ? ${inp}_t_raw : at::_efficientzerotensor(${inp}_tensor.sizes(), ${inp}_tensor.options());
325+
auto ${inp}_t = (${inp}_t_raw.defined() || !${inp}_tensor.defined())
326+
? ${inp}_t_raw : at::${zeros_fn}(${inp}_tensor.sizes(), ${inp}_tensor.options());
326327
""")
327328

328329
FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate("""\
@@ -908,12 +909,13 @@ def emit_fw_derivatives() -> List[str]:
908909

909910
unpacked_arguments = ""
910911
for inp in differentiable_inputs:
912+
zeros_fn = "zeros" if inplace and inp.name == "self" else "_efficientzerotensor"
911913
if inp.name in derivative.required_inputs_fw_grad:
912-
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp=inp.name)
914+
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp=inp.name, zeros_fn=zeros_fn)
913915
if inp.name in (derivative.required_inputs_primal or []):
914916
unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp=inp.name)
915917
if derivative.required_original_self_value:
916-
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp="original_self")
918+
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp="original_self", zeros_fn=zeros_fn)
917919
unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp="original_self")
918920
elif inplace and derivative.is_reusing_outplace_formula:
919921
# The gradient wasn't already cloned, do it if grad mode is enabled

torch/autograd/gradcheck.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ def _test_undefined_forward_mode(func, outputs, inputs):
882882
with fwAD.dual_level():
883883
fw_grads = []
884884
dual_inputs = []
885+
tensor_indices = set()
885886
for i, inp in enumerate(inputs):
886887
if is_tensor_like(inp) and inp.requires_grad:
887888
if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
@@ -891,12 +892,15 @@ def _test_undefined_forward_mode(func, outputs, inputs):
891892
# If inp is a differentiable view, the dual might not be the tangent given to
892893
# make_dual, so read it explicitly from the dual tensor
893894
fw_grads.append(fwAD.unpack_dual(inp)[1])
895+
tensor_indices.add(i)
894896
dual_inputs.append(inp)
895897

896898
for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
897899
fw_grad.copy_(u.view_as(fw_grad))
898900

899-
for idx, inp in enumerate(tensor_inputs):
901+
for idx, inp in enumerate(inputs):
902+
if idx not in tensor_indices:
903+
continue
900904
dual_inp_obj = dual_inputs[idx]
901905

902906
# case 1 (Materialized Zero Tensor Tangent)
@@ -1381,8 +1385,6 @@ def gradcheck(
13811385
"""
13821386
assert check_forward_ad or check_backward_ad, \
13831387
"Expected at least one of check_forward_ad or check_backward_ad to be True"
1384-
assert not (check_undefined_grad and not check_backward_ad), \
1385-
"Setting check_undefined_grad=True requires check_backward_ad to be True"
13861388
assert not (check_batched_grad and not check_backward_ad), (
13871389
"Setting check_batched_grad=True requires check_backward_ad to be True")
13881390
assert not (check_batched_forward_grad and not check_forward_ad), (
@@ -1427,7 +1429,7 @@ def _gradcheck_helper(func, inputs, eps, atol, rtol, check_sparse_nnz, nondet_to
14271429

14281430
_test_backward_mul_by_grad_output(outputs, tupled_inputs, check_sparse_nnz)
14291431

1430-
if check_undefined_grad:
1432+
if check_undefined_grad and check_backward_ad:
14311433
_test_undefined_backward_mode(func, outputs, tupled_inputs)
14321434
return True
14331435

torch/testing/_internal/common_methods_invocations.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10203,11 +10203,7 @@ def ref_pairwise_distance(input1, input2):
1020310203
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
1020410204
supports_forward_ad=True,
1020510205
supports_fwgrad_bwgrad=True,
10206-
sample_inputs_func=sample_inputs_masked_select,
10207-
skips=(
10208-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
10209-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
10210-
)),
10206+
sample_inputs_func=sample_inputs_masked_select),
1021110207
OpInfo('matrix_exp',
1021210208
dtypes=floating_and_complex_types_and(torch.bfloat16),
1021310209
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
@@ -10257,44 +10253,23 @@ def ref_pairwise_distance(input1, input2):
1025710253
supports_out=False,
1025810254
supports_forward_ad=True,
1025910255
supports_fwgrad_bwgrad=True,
10260-
sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
10261-
skips=(
10262-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
10263-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
10264-
# (ROCm) unexpected success
10265-
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
10266-
device_type='cuda', dtypes=[torch.float64], active_if=TEST_WITH_ROCM),
10267-
)),
10256+
sample_inputs_func=sample_inputs_max_min_reduction_no_dim),
1026810257
OpInfo('median',
1026910258
dtypes=all_types_and(torch.bfloat16),
1027010259
dtypesIfCUDA=all_types_and(torch.float16),
1027110260
# TODO: some signatures of median do support out
1027210261
supports_out=False,
1027310262
supports_forward_ad=True,
1027410263
supports_fwgrad_bwgrad=True,
10275-
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False),
10276-
skips=(
10277-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
10278-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
10279-
# (ROCm) unexpected success
10280-
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
10281-
device_type='cuda', dtypes=[torch.float64], active_if=TEST_WITH_ROCM),
10282-
)),
10264+
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
1028310265
OpInfo('nanmedian',
1028410266
dtypes=all_types_and(torch.bfloat16),
1028510267
dtypesIfCUDA=all_types_and(torch.float16),
1028610268
# TODO: some signatures of nanmedian do support out
1028710269
supports_out=False,
1028810270
supports_forward_ad=True,
1028910271
supports_fwgrad_bwgrad=True,
10290-
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False),
10291-
skips=(
10292-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
10293-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
10294-
# (ROCm) unexpected success
10295-
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
10296-
device_type='cuda', dtypes=[torch.float64], active_if=TEST_WITH_ROCM),
10297-
)),
10272+
sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
1029810273
OpInfo('var_mean',
1029910274
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
1030010275
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
@@ -10394,14 +10369,7 @@ def ref_pairwise_distance(input1, input2):
1039410369
supports_out=False,
1039510370
supports_forward_ad=True,
1039610371
supports_fwgrad_bwgrad=True,
10397-
sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
10398-
skips=(
10399-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
10400-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
10401-
# (ROCm) unexpected success
10402-
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
10403-
device_type='cuda', dtypes=[torch.float64], active_if=TEST_WITH_ROCM),
10404-
)),
10372+
sample_inputs_func=sample_inputs_max_min_reduction_no_dim),
1040510373
OpInfo('quantile',
1040610374
dtypes=floating_types(),
1040710375
sample_inputs_func=sample_inputs_reduction_quantile,
@@ -10709,8 +10677,6 @@ def ref_pairwise_distance(input1, input2):
1070910677
check_inplace_batched_forward_grad=False,
1071010678
sample_inputs_func=sample_inputs_as_strided,
1071110679
skips=(
10712-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
10713-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
1071410680
# AssertionError: False is not true : Tensors failed to compare as equal!
1071510681
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
1071610682
# AssertionError: False is not true : Scalars failed to compare as equal!
@@ -13080,11 +13046,7 @@ def ref_pairwise_distance(input1, input2):
1308013046
supports_forward_ad=True,
1308113047
supports_fwgrad_bwgrad=True,
1308213048
assert_jit_shape_analysis=True,
13083-
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
13084-
skips=(
13085-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
13086-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
13087-
)),
13049+
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
1308813050
OpInfo('index_add',
1308913051
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
1309013052
# An `out=` variant exists but is not exposed to the Python API
@@ -13990,11 +13952,7 @@ def ref_pairwise_distance(input1, input2):
1399013952
supports_out=False,
1399113953
supports_forward_ad=True,
1399213954
supports_fwgrad_bwgrad=True,
13993-
sample_inputs_func=sample_inputs_trace,
13994-
skips=(
13995-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
13996-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
13997-
)),
13955+
sample_inputs_func=sample_inputs_trace),
1399813956
OpInfo('transpose',
1399913957
aliases=('swapdims', 'swapaxes'),
1400013958
assert_jit_shape_analysis=True,
@@ -15244,8 +15202,6 @@ def ref_pairwise_distance(input1, input2):
1524415202
supports_forward_ad=True,
1524515203
supports_fwgrad_bwgrad=True,
1524615204
skips=(
15247-
# 69855: RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor (...)
15248-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),
1524915205
DecorateInfo(
1525015206
unittest.skip("Skipped!"),
1525115207
"TestJit",

0 commit comments

Comments
 (0)