Skip to content

Commit 3116d87

Browse files
soulitzerfacebook-github-bot
authored andcommitted
Add forward AD formulas for {adaptive_,fractional_,}max_pool{2,3}d_{backward,} (pytorch#69884)
Summary: Pull Request resolved: pytorch#69884 Also fixes: pytorch#69322, pytorch#69325 Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D33093039 Pulled By: soulitzer fbshipit-source-id: b9a522a00f4e9e85974888de5058de07280f8f66
1 parent 6925576 commit 3116d87

File tree

3 files changed

+66
-20
lines changed

3 files changed

+66
-20
lines changed

test/test_nn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11941,18 +11941,21 @@ def test_max_unpool(self):
1194111941
self.assertEqual(F.max_unpool1d(output, indices, 2), F.max_unpool1d(output, indices, 2, stride=2))
1194211942

1194311943
# Test list / tuple passed as argument to max_unpool1d
11944-
input = torch.randn([1, 1, 5])
11944+
input = torch.randn([1, 1, 5], requires_grad=True)
1194511945
output, indices = F.max_pool1d(input, 2, stride=2, return_indices=True)
1194611946
self.assertEqual(F.max_unpool1d(output, indices, 2, stride=2, output_size=input.shape),
1194711947
F.max_unpool1d(output, indices, 2, stride=2, output_size=input.size()))
11948+
gradcheck(F.max_unpool1d, (output, indices, 2), check_forward_ad=True)
1194811949

1194911950
# Test 2D
11950-
output, indices = F.max_pool2d(torch.randn([1, 1, 4, 4]), 2, stride=2, return_indices=True)
11951+
output, indices = F.max_pool2d(torch.randn([1, 1, 4, 4], requires_grad=True), 2, stride=2, return_indices=True)
1195111952
self.assertEqual(F.max_unpool2d(output, indices, 2), F.max_unpool2d(output, indices, 2, stride=2))
11953+
gradcheck(F.max_unpool2d, (output, indices, 2), check_forward_ad=True)
1195211954

1195311955
# Test 3D
11954-
output, indices = F.max_pool3d(torch.randn([4, 4, 4, 4, 4]), 2, stride=2, return_indices=True)
11956+
output, indices = F.max_pool3d(torch.randn([4, 4, 4, 4, 4], requires_grad=True), 2, stride=2, return_indices=True)
1195511957
self.assertEqual(F.max_unpool3d(output, indices, 2), F.max_unpool3d(output, indices, 2, stride=2))
11958+
gradcheck(F.max_unpool3d, (output, indices, 2), check_forward_ad=True)
1195611959

1195711960
def test_dirac_properties(self):
1195811961
for dims in [3, 4, 5]:

tools/autograd/derivatives.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,9 +1962,13 @@
19621962

19631963
- name: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
19641964
self: adaptive_max_pool2d_backward(grad, self, result1)
1965+
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
1966+
output_differentiability: [True, False]
19651967

19661968
- name: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
19671969
self: adaptive_max_pool3d_backward(grad, self, result1)
1970+
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
1971+
output_differentiability: [True, False]
19681972

19691973
- name: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
19701974
self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
@@ -1976,25 +1980,33 @@
19761980

19771981
- name: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)
19781982
self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1)
1983+
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
1984+
output_differentiability: [True, False]
19791985

19801986
- name: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)
19811987
self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, result1)
1988+
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
1989+
output_differentiability: [True, False]
19821990

19831991
- name: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
19841992
self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
1993+
result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
19851994
output_differentiability: [True, False]
19861995

19871996
- name: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
19881997
self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
1998+
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
19891999
output_differentiability: [True, False]
19902000

19912001
- name: max_unpool2d(Tensor self, Tensor indices, int[2] output_size) -> Tensor
19922002
self: max_unpool2d_backward(grad, self, indices, output_size)
19932003
indices: non_differentiable
2004+
result: auto_linear
19942005

19952006
- name: max_unpool3d(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor
19962007
self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding)
19972008
indices: non_differentiable
2009+
result: auto_linear
19982010

19992011
- name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
20002012
input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
@@ -2086,10 +2098,12 @@
20862098
- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
20872099
grad_output: max_pool_double_backward(grad, indices, 2)
20882100
self: zeros_like(self)
2101+
result: auto_linear
20892102

20902103
- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
20912104
grad_output: max_pool_double_backward(grad, indices, 3)
20922105
self: zeros_like(self)
2106+
result: auto_linear
20932107

20942108
- name: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
20952109
grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
@@ -2108,10 +2122,12 @@
21082122
- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor
21092123
grad_output: max_pool_double_backward(grad, indices, 2)
21102124
self: zeros_like(self)
2125+
result: auto_linear
21112126

21122127
- name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor
21132128
grad_output: max_pool_double_backward(grad, indices, 3)
21142129
self: zeros_like(self)
2130+
result: auto_linear
21152131

21162132
- name: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor
21172133
grad_output: glu_double_backward_grad_output(grad, self, dim)
@@ -2148,11 +2164,13 @@
21482164
grad_output: max_pool_double_backward(grad, indices, 2)
21492165
self: zeros_like(self)
21502166
indices: non_differentiable
2167+
result: auto_linear
21512168

21522169
- name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor
21532170
grad_output: max_pool_double_backward(grad, indices, 3)
21542171
self: zeros_like(self)
21552172
indices: non_differentiable
2173+
result: auto_linear
21562174

21572175
- name: max_unpool2d_backward(Tensor grad_output, Tensor self, Tensor indices, int[2] output_size) -> Tensor
21582176
grad_output: max_unpool2d(grad, indices, output_size)

torch/testing/_internal/common_methods_invocations.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10775,6 +10775,10 @@ def ref_pairwise_distance(input1, input2):
1077510775
dtypes=floating_types(),
1077610776
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
1077710777
supports_out=False,
10778+
supports_forward_ad=True,
10779+
supports_fwgrad_bwgrad=True,
10780+
# got: Batching rule not implemented for aten::flatten.using_ints
10781+
check_batched_forward_grad=False,
1077810782
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
1077910783
sample_inputs_func=sample_inputs_adaptive_max_pool1d),
1078010784
OpInfo('nn.functional.adaptive_max_pool2d',
@@ -10792,6 +10796,10 @@ def ref_pairwise_distance(input1, input2):
1079210796
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
1079310797
),
1079410798
supports_out=False,
10799+
supports_forward_ad=True,
10800+
supports_fwgrad_bwgrad=True,
10801+
# got: Batching rule not implemented for aten::flatten.using_ints
10802+
check_batched_forward_grad=False,
1079510803
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
1079610804
sample_inputs_func=sample_inputs_adaptive_max_pool2d),
1079710805
OpInfo('nn.functional.adaptive_max_pool3d',
@@ -10811,6 +10819,10 @@ def ref_pairwise_distance(input1, input2):
1081110819
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
1081210820
),
1081310821
supports_out=False,
10822+
supports_forward_ad=True,
10823+
supports_fwgrad_bwgrad=True,
10824+
# got: Batching rule not implemented for aten::flatten.using_ints
10825+
check_batched_forward_grad=False,
1081410826
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
1081510827
sample_inputs_func=sample_inputs_adaptive_max_pool3d),
1081610828
OpInfo('nn.functional.avg_pool1d',
@@ -11201,49 +11213,54 @@ def ref_pairwise_distance(input1, input2):
1120111213
OpInfo('nn.functional.fractional_max_pool2d',
1120211214
supports_autograd=True,
1120311215
supports_out=False,
11216+
supports_forward_ad=True,
11217+
supports_fwgrad_bwgrad=True,
11218+
op=lambda input, *args, **kwargs:
11219+
wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs),
11220+
# vmap does not support random operations
11221+
check_batched_forward_grad=False,
1120411222
dtypes=floating_types(),
1120511223
dtypesIfCUDA=floating_types_and(torch.float16),
1120611224
test_neg_view=False,
1120711225
sample_inputs_func=sample_inputs_fractional_max_pool2d,
11208-
decorators=[
11209-
# FIXME: both derivatives are implemented incorrectly
11210-
# https://github.com/pytorch/pytorch/issues/69322
11211-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
11212-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
11213-
# FIXME: produces incorrect output on non-contiguous inputs
11214-
# https://github.com/pytorch/pytorch/issues/69325
11215-
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
11226+
decorators=(
1121611227
# FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
1121711228
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
1121811229
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
1121911230
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
11220-
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
11221-
], ),
11231+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'))),
1122211232
OpInfo('nn.functional.fractional_max_pool3d',
1122311233
supports_autograd=True,
1122411234
supports_out=False,
11235+
supports_forward_ad=True,
11236+
supports_fwgrad_bwgrad=True,
11237+
op=lambda input, *args, **kwargs:
11238+
wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs),
11239+
# vmap does not support random operations
11240+
check_batched_forward_grad=False,
1122511241
dtypes=floating_types(),
1122611242
dtypesIfCUDA=floating_types_and(torch.float16),
1122711243
test_neg_view=False,
1122811244
sample_inputs_func=sample_inputs_fractional_max_pool3d,
11229-
decorators=[
11245+
decorators=(
1123011246
# FIXME: both derivatives are implemented incorrectly
1123111247
# https://github.com/pytorch/pytorch/issues/69322
11232-
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
11248+
# RuntimeError: cannot reshape tensor of 0 elements into shape [0, 1, -1] because the
11249+
# unspecified dimension size -1 can be any value and is ambiguous
1123311250
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
11234-
# FIXME: produces incorrect output on non-contiguous inputs
11235-
# https://github.com/pytorch/pytorch/issues/69325
11236-
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
1123711251
# FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
1123811252
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
1123911253
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
1124011254
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
11241-
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
11242-
], ),
11255+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),)),
1124311256
OpInfo('nn.functional.max_pool1d',
1124411257
aten_name='max_pool1d',
1124511258
supports_autograd=True,
1124611259
supports_out=False,
11260+
supports_forward_ad=True,
11261+
supports_fwgrad_bwgrad=True,
11262+
# got: Batching rule not implemented for aten::flatten.using_ints
11263+
check_batched_forward_grad=False,
1124711264
# TODO: add shape checks
1124811265
assert_jit_shape_analysis=False,
1124911266
dtypes=floating_types(),
@@ -11259,6 +11276,10 @@ def ref_pairwise_distance(input1, input2):
1125911276
# Vmap is not happy with non-contiguous (channels_last) inputs
1126011277
check_batched_gradgrad=False,
1126111278
supports_out=False,
11279+
supports_forward_ad=True,
11280+
supports_fwgrad_bwgrad=True,
11281+
# got: Batching rule not implemented for aten::flatten.using_ints
11282+
check_batched_forward_grad=False,
1126211283
assert_jit_shape_analysis=True,
1126311284
dtypes=floating_types(),
1126411285
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
@@ -11267,6 +11288,10 @@ def ref_pairwise_distance(input1, input2):
1126711288
aten_name='max_pool3d',
1126811289
supports_autograd=True,
1126911290
supports_out=False,
11291+
supports_forward_ad=True,
11292+
supports_fwgrad_bwgrad=True,
11293+
# got: Batching rule not implemented for aten::flatten.using_ints
11294+
check_batched_forward_grad=False,
1127011295
# TODO: add shape checks
1127111296
assert_jit_shape_analysis=False,
1127211297
dtypes=floating_types(),

0 commit comments

Comments
 (0)