Skip to content

Commit e8d226c

Browse files
peterbell10pytorchmergebot
authored andcommitted
Remove some unnecessary python functional wrappers (pytorch#61608)
Summary: Pull Request resolved: pytorch#61608 See pytorch#61544 for an example of issues created by functional wrappers. In this case, these are directly wrapping the native function with no added functionality. One exception was `bilinear` which was just missing the default argument in C++, but was otherwise the same. I've kept the symbol `torch.functional.istft` because it looks like public API, but it could just as easily be moved to `_torch_docs.py`. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D31401361 Pulled By: albanD fbshipit-source-id: 162b74d0b2d4f2e5c4834687a94541960cefdd52 (cherry picked from commit 700cd73)
1 parent 7ea96a7 commit e8d226c

12 files changed

+103
-108
lines changed

aten/src/ATen/native/native_functions.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@
881881
device_check: NoCheck # TensorIterator
882882
variants: function, method
883883

884-
- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor
884+
- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor
885885

886886
- func: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
887887
device_check: NoCheck # TensorIterator

test/expect/TestTensorBoard.test_nested_nn_squential.expect

+12-12
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ node {
5050
}
5151
}
5252
node {
53-
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/bias/bias.9"
53+
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/bias/bias.1"
5454
op: "prim::GetAttr"
5555
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/weight/_0.1"
5656
attr {
@@ -61,7 +61,7 @@ node {
6161
}
6262
}
6363
node {
64-
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/weight/weight.9"
64+
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/weight/weight.1"
6565
op: "prim::GetAttr"
6666
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/weight/_0.1"
6767
attr {
@@ -75,8 +75,8 @@ node {
7575
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/input.1"
7676
op: "aten::linear"
7777
input: "input/x"
78-
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/weight/weight.9"
79-
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/bias/bias.9"
78+
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/weight/weight.1"
79+
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/bias/bias.1"
8080
attr {
8181
key: "_output_shapes"
8282
value {
@@ -100,7 +100,7 @@ node {
100100
}
101101
}
102102
node {
103-
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/bias/bias.11"
103+
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/bias/bias.3"
104104
op: "prim::GetAttr"
105105
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/weight/_1.1"
106106
attr {
@@ -111,7 +111,7 @@ node {
111111
}
112112
}
113113
node {
114-
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/weight/weight.11"
114+
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/weight/weight.3"
115115
op: "prim::GetAttr"
116116
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/weight/_1.1"
117117
attr {
@@ -125,8 +125,8 @@ node {
125125
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/input.3"
126126
op: "aten::linear"
127127
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[0]/input.1"
128-
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/weight/weight.11"
129-
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/bias/bias.11"
128+
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/weight/weight.3"
129+
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/bias/bias.3"
130130
attr {
131131
key: "_output_shapes"
132132
value {
@@ -150,7 +150,7 @@ node {
150150
}
151151
}
152152
node {
153-
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/bias/bias.13"
153+
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/bias/bias.5"
154154
op: "prim::GetAttr"
155155
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/weight/_0"
156156
attr {
@@ -161,7 +161,7 @@ node {
161161
}
162162
}
163163
node {
164-
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/weight/weight.13"
164+
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/weight/weight.5"
165165
op: "prim::GetAttr"
166166
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/weight/_0"
167167
attr {
@@ -175,8 +175,8 @@ node {
175175
name: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/input"
176176
op: "aten::linear"
177177
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[0]/Sequential[inner_nn_squential]/Linear[1]/input.3"
178-
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/weight/weight.13"
179-
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/bias/bias.13"
178+
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/weight/weight.5"
179+
input: "OuterNNSquential/Sequential[outer_nn_squential]/InnerNNSquential[1]/Sequential[inner_nn_squential]/Linear[0]/bias/bias.5"
180180
attr {
181181
key: "_output_shapes"
182182
value {

test/jit/test_optimize_for_mobile_preserve_debug_info.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def forward(self, x):
127127
),
128128
),
129129
replacements={
130-
"prepacked::linear_clamp_prepack": "prim::CallFunction",
131-
"prepacked::linear_clamp_run": "prim::CallFunction",
130+
"prepacked::linear_clamp_prepack": "aten::linear",
131+
"prepacked::linear_clamp_run": "aten::linear",
132132
"prepacked::conv2d_clamp_prepack": "aten::conv2d",
133133
"prepacked::conv2d_clamp_run": "aten::conv2d",
134134
"prepacked::conv2d_transpose_clamp_prepack":

test/onnx/expect/TestOperators.test_gelu.expect

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ graph {
1616
}
1717
}
1818
node {
19-
input: "x"
19+
input: "onnx::Div_0"
2020
input: "onnx::Div_1"
2121
output: "onnx::Erf_2"
2222
name: "Div_1"
@@ -49,7 +49,7 @@ graph {
4949
op_type: "Add"
5050
}
5151
node {
52-
input: "x"
52+
input: "onnx::Div_0"
5353
input: "onnx::Mul_5"
5454
output: "onnx::Mul_6"
5555
name: "Mul_5"
@@ -77,7 +77,7 @@ graph {
7777
}
7878
name: "torch-jit-export"
7979
input {
80-
name: "x"
80+
name: "onnx::Div_0"
8181
type {
8282
tensor_type {
8383
elem_type: 1

test/onnx/expect/TestOperators.test_linear.expect

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ producer_name: "pytorch"
33
producer_version: "CURRENT_VERSION"
44
graph {
55
node {
6-
input: "input"
6+
input: "onnx::Gemm_0"
77
input: "weight"
88
input: "bias"
99
output: "3"
@@ -40,7 +40,7 @@ graph {
4040
raw_data: "\324BO\276@\245T>\350\377\245\275\374u\336\276&\212\304>"
4141
}
4242
input {
43-
name: "input"
43+
name: "onnx::Gemm_0"
4444
type {
4545
tensor_type {
4646
elem_type: 1

test/onnx/expect/TestOperators.test_prelu.expect

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ producer_name: "pytorch"
33
producer_version: "CURRENT_VERSION"
44
graph {
55
node {
6-
input: "input"
6+
input: "onnx::PRelu_0"
77
input: "onnx::PRelu_4"
88
output: "3"
99
name: "PRelu_0"
@@ -19,7 +19,7 @@ graph {
1919
raw_data: "\000\000\200>\000\000\200>"
2020
}
2121
input {
22-
name: "input"
22+
name: "onnx::PRelu_0"
2323
type {
2424
tensor_type {
2525
elem_type: 1

test/onnx/expect/TestOperators.test_retain_param_name_disabled.expect

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ producer_name: "pytorch"
33
producer_version: "CURRENT_VERSION"
44
graph {
55
node {
6-
input: "input.1"
6+
input: "0"
77
input: "7"
88
output: "4"
99
name: "MatMul_0"
@@ -32,7 +32,7 @@ graph {
3232
raw_data: "\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@\000\000@@"
3333
}
3434
input {
35-
name: "input.1"
35+
name: "0"
3636
type {
3737
tensor_type {
3838
elem_type: 1

test/quantization/jit/test_quantize_jit.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -1218,14 +1218,19 @@ def forward(self, x):
12181218
self.assertEqual(res, ref_res)
12191219

12201220
def test_swap_functional_linear(self):
1221+
# TODO: This pass replaces any function called "linear" with "aten::linear"
1222+
# No longer necessary, and also quite surprising
1223+
def linear(input, weight, bias):
1224+
return torch.nn.functional.linear(input, weight, bias)
1225+
12211226
class M(torch.nn.Module):
12221227
def __init__(self):
12231228
super(M, self).__init__()
12241229

12251230
def forward(self, x, weight, bias):
12261231
x = torch.dequantize(x)
12271232
weight = torch.dequantize(weight)
1228-
x = F.linear(x, weight, bias)
1233+
x = linear(x, weight, bias)
12291234
x = torch.quantize_per_tensor(
12301235
x, scale=1.0, zero_point=0, dtype=torch.quint8
12311236
)
@@ -3314,14 +3319,11 @@ def forward(self, x):
33143319
model = quantize_dynamic_jit(model, qconfig_dict, debug=True)
33153320
graph_qparams = []
33163321
for x, obs in model._modules._c.items():
3317-
if x == 'fc' and tracing:
3318-
graph_qparams.append(
3319-
(obs.getattr("weight.6_scale_0"), obs.getattr("weight.6_zero_point_0"))
3320-
)
3321-
else:
3322-
graph_qparams.append(
3323-
(obs.getattr("weight.1_scale_0"), obs.getattr("weight.1_zero_point_0"))
3324-
)
3322+
n = 2 if x == 'fc' and tracing else 1
3323+
graph_qparams.append(
3324+
(obs.getattr(f"weight.{n}_scale_0"),
3325+
obs.getattr(f"weight.{n}_zero_point_0"))
3326+
)
33253327
self.assertEqual(ref_qparams, graph_qparams)
33263328

33273329
def test_convert_dynamic_fp16(self):

test/test_fx.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3466,6 +3466,7 @@ def tearDown(self):
34663466
"avg_pool1d": BUILT_IN_FUNC,
34673467
"avg_pool2d": BUILT_IN_FUNC,
34683468
"avg_pool3d": BUILT_IN_FUNC,
3469+
"bilinear": BUILT_IN_FUNC,
34693470
"celu_": BUILT_IN_FUNC,
34703471
"channel_shuffle": BUILT_IN_FUNC,
34713472
"conv1d": BUILT_IN_FUNC,
@@ -3477,13 +3478,18 @@ def tearDown(self):
34773478
"conv_transpose3d": BUILT_IN_FUNC,
34783479
"cosine_similarity": BUILT_IN_FUNC,
34793480
"elu_": BUILT_IN_FUNC,
3481+
"gelu": BUILT_IN_FUNC,
3482+
"hardshrink": BUILT_IN_FUNC,
34803483
"hardtanh_": BUILT_IN_FUNC,
34813484
"leaky_relu_": BUILT_IN_FUNC,
3485+
"linear": BUILT_IN_FUNC,
34823486
"logsigmoid": BUILT_IN_FUNC,
34833487
"one_hot": BUILT_IN_FUNC,
3488+
"pairwise_distance": BUILT_IN_FUNC,
34843489
"pdist": BUILT_IN_FUNC,
34853490
"pixel_shuffle": BUILT_IN_FUNC,
34863491
"pixel_unshuffle": BUILT_IN_FUNC,
3492+
"prelu": BUILT_IN_FUNC,
34873493
"relu_": BUILT_IN_FUNC,
34883494
"rrelu_": BUILT_IN_FUNC,
34893495
"selu_": BUILT_IN_FUNC,
@@ -3516,10 +3522,8 @@ def tearDown(self):
35163522
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
35173523
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
35183524
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
3519-
"hardshrink": ARG_TYPE_MISMATCH,
35203525
"layer_norm": ARG_TYPE_MISMATCH,
35213526
"lp_pool1d": ARG_TYPE_MISMATCH,
3522-
"pairwise_distance": ARG_TYPE_MISMATCH,
35233527

35243528
"affine_grid": CONTROL_FLOW,
35253529
"alpha_dropout": CONTROL_FLOW,

test/test_fx_experimental.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import unittest
88
from typing import Callable, Dict, Union, List, Optional
9+
from types import BuiltinFunctionType
910

1011
import torch
1112
import torch.fx.experimental.optimization as optimization
@@ -1085,14 +1086,13 @@ def is_leaf_module(
10851086
check = (node.op, node.target)
10861087
excluded_nodes = {
10871088
("placeholder", "x"),
1088-
("call_function", torch.conv2d),
10891089
# Return type differs based on boolean dispatch :(
10901090
("call_function", torch.nn.functional.max_pool2d),
1091-
("call_function", operator.add),
1092-
("call_function", torch.flatten),
10931091
("output", "output"),
10941092
}
1095-
self.assertIn(check, excluded_nodes)
1093+
# AnnotateTypesWithSchema doesn't work with bound C++ functions
1094+
if not isinstance(node.target, BuiltinFunctionType):
1095+
self.assertIn(check, excluded_nodes)
10961096

10971097
# Smoke test torchscript compilation since now we're emitting type annotations
10981098
torch.jit.script(traced_functionals_annotated)

tools/pyi/gen_pyi.py

-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
9595
'norm',
9696
'chain_matmul',
9797
'stft',
98-
'istft',
9998
'tensordot',
10099
'split',
101100
'unique_consecutive',

0 commit comments

Comments
 (0)