Skip to content

Commit 740fb22

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[BE][Easy][4/19] enforce style for empty lines in import segments in functorch/ (pytorch#129755)
See pytorch#129751 (comment). Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: pytorch#129755 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#129752
1 parent a085acd commit 740fb22

36 files changed

+48
-27
lines changed

functorch/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7-
87
from torch._functorch.deprecated import (
98
combine_state_for_ensemble,
109
functionalize,
@@ -26,13 +25,15 @@
2625
FunctionalModuleWithBuffers,
2726
)
2827

28+
# Was never documented
29+
from torch._functorch.python_key import make_fx
30+
31+
2932
# Top-level APIs. Please think carefully before adding something to the
3033
# top-level namespace:
3134
# - private helper functions should go into torch._functorch
3235
# - very experimental things should go into functorch.experimental
3336
# - compilation related things should go into functorch.compile
3437

35-
# Was never documented
36-
from torch._functorch.python_key import make_fx
3738

3839
__version__ = torch.__version__

functorch/benchmarks/chrome_trace_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#!/usr/bin/env python3
22
import argparse
33
import logging
4-
54
import os
65

76
import pandas as pd
87

98
from torch._functorch.benchmark_utils import compute_utilization
109

10+
1111
# process the chrome traces output by the pytorch profiler
1212
# require the json input file's name to be in format {model_name}_chrome_trace_*.json
1313
# the runtimes file should have format (model_name, runtime)

functorch/benchmarks/cse.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import torch
22
import torch.fx as fx
3-
43
from functorch import make_fx
5-
64
from torch._functorch.compile_utils import fx_graph_cse
75
from torch.profiler import profile, ProfilerActivity
86

functorch/benchmarks/operator_authoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import pandas as pd
66

77
import torch
8-
98
from functorch.compile import pointwise_operator
109

10+
1111
WRITE_CSV = False
1212
CUDA = False
1313
SIZES = [1, 512, 8192]

functorch/benchmarks/per_sample_grads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import torch
88
import torch.nn as nn
9-
109
from functorch import grad, make_functional, vmap
1110

11+
1212
device = "cuda"
1313
batch_size = 128
1414
torch.manual_seed(0)

functorch/benchmarks/pointwise_scorecard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import time
55

66
import torch
7-
87
from functorch import pointwise_operator
98

9+
1010
torch.set_num_threads(1)
1111
torch._C._debug_set_fusion_group_inlining(False)
1212

functorch/benchmarks/process_scorecard.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import matplotlib.pyplot as plt
22
import pandas
33

4+
45
df = pandas.read_csv("perf.csv")
56

67
ops = pandas.unique(df["operator"])

functorch/dim/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from typing import Sequence, Union
44

55
import functorch._C
6-
76
import torch
87
from functorch._C import dim as _C
8+
99
from .tree_map import tree_flatten, tree_map
1010
from .wrap_type import wrap_type
1111

12+
1213
_C._patch_tensor_class()
1314
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
1415

@@ -23,6 +24,7 @@ class DimensionBindError(Exception):
2324

2425
from . import op_properties
2526

27+
2628
# use dict to avoid writing C++ bindings for set
2729
pointwise = dict.fromkeys(op_properties.pointwise, True)
2830

functorch/dim/batch_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
99

10+
1011
_enabled = False
1112

1213

functorch/dim/dim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66
import dis
77
import inspect
8-
98
from dataclasses import dataclass
109
from typing import Union
1110

1211
from . import DimList
1312

13+
1414
_vmap_levels = []
1515

1616

functorch/dim/op_properties.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77

8+
89
# pointwise operators can go through a faster pathway
910

1011
tensor_magic_methods = ["add", ""]

functorch/dim/reference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
# reference python implementations for C ops
88
import torch
9-
109
from functorch._C import dim as _C
10+
1111
from . import op_properties
1212
from .batch_tensor import _enable_layers
1313
from .tree_map import tree_flatten, tree_map
1414

15+
1516
DimList = _C.DimList
1617
import operator
1718
from functools import reduce
@@ -407,7 +408,6 @@ def t__getitem__(self, input):
407408
# (keep track of whether we have to call super)
408409
# * call super if needed
409410
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
410-
411411
# this handles bool indexing handling, as well as some other simple cases.
412412

413413
is_simple = (

functorch/dim/tree_map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from functorch._C import dim
88

9+
910
tree_flatten = dim.tree_flatten
1011

1112

functorch/dim/wrap_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from functorch._C import dim as _C
1616

17+
1718
_wrap_method = _C._wrap_method
1819

1920
FUNC_TYPES = (

functorch/docs/source/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import functorch
1818

19+
1920
# import sys
2021

2122
# source code directory, relative to this file, for sphinx-autobuild
@@ -27,6 +28,7 @@
2728

2829
import pytorch_sphinx_theme
2930

31+
3032
# -- General configuration ------------------------------------------------
3133

3234
# Required version of sphinx is set from docs/requirements.txt
@@ -274,11 +276,11 @@ def setup(app):
274276

275277
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
276278
# See http://stackoverflow.com/a/41184353/3343043
277-
278279
from docutils import nodes
279280
from sphinx import addnodes
280281
from sphinx.util.docfields import TypedField
281282

283+
282284
# Without this, doctest adds any example with a `>>>` as a test
283285
doctest_test_doctest_blocks = ""
284286
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS

functorch/einops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .rearrange import rearrange
22

3+
34
__all__ = ["rearrange"]

functorch/einops/_parsing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import warnings
2929
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
3030

31+
3132
_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
3233

3334

functorch/einops/rearrange.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from typing import Callable, Dict, List, Sequence, Tuple, Union
55

66
import torch
7-
87
from functorch._C import dim as _C
8+
99
from ._parsing import (
1010
_ellipsis,
1111
AnonymousAxis,
@@ -14,6 +14,7 @@
1414
validate_rearrange_expressions,
1515
)
1616

17+
1718
__all__ = ["rearrange"]
1819

1920
dims = _C.dims

functorch/examples/compilation/eager_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import torch
44
import torch.utils
5-
65
from functorch.compile import aot_function, tvm_compile
76

7+
88
a = torch.randn(2000, 1, 4, requires_grad=True)
99
b = torch.randn(1, 2000, 4)
1010

functorch/examples/compilation/fuse_module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
import torch.nn as nn
5-
65
from functorch.compile import compiled_module, tvm_compile
76

87

functorch/examples/compilation/linear_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import torch
1010
import torch.nn as nn
11-
1211
from functorch import make_functional
1312
from functorch.compile import nnc_jit
1413

14+
1515
torch._C._jit_override_can_fuse_on_cpu(True)
1616

1717

functorch/examples/compilation/simple_function.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import time
88

99
import torch
10-
1110
from functorch import grad, make_fx
1211
from functorch.compile import nnc_jit
1312

functorch/examples/dp_cifar10/cifar10_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import torch.nn as nn
2222
import torch.optim as optim
2323
import torch.utils.data
24-
2524
from torch.func import functional_call, grad_and_value, vmap
2625

26+
2727
logging.basicConfig(
2828
format="%(asctime)s:%(levelname)s:%(message)s",
2929
datefmt="%m/%d/%Y %H:%M:%S",

functorch/examples/ensembling/parallel_train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn.functional as F
77
from torch.func import functional_call, grad_and_value, stack_module_state, vmap
88

9+
910
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
1011
# tutorial on Model Ensembling with JAX by Will Whitney.
1112
#

functorch/examples/lennard_jones/lennard_jones.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.func import jacrev, vmap
88
from torch.nn.functional import mse_loss
99

10+
1011
sigma = 0.5
1112
epsilon = 4.0
1213

functorch/examples/maml_omniglot/maml-omniglot-higher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import matplotlib as mpl
3535
import matplotlib.pyplot as plt
3636
import numpy as np
37-
3837
import pandas as pd
3938
from support.omniglot_loaders import OmniglotNShot
4039

@@ -43,6 +42,7 @@
4342
import torch.optim as optim
4443
from torch import nn
4544

45+
4646
mpl.use("Agg")
4747
plt.style.use("bmh")
4848

functorch/examples/maml_omniglot/maml-omniglot-ptonly.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,16 @@
3333
import matplotlib as mpl
3434
import matplotlib.pyplot as plt
3535
import numpy as np
36-
3736
import pandas as pd
3837
from support.omniglot_loaders import OmniglotNShot
3938

4039
import torch
4140
import torch.nn.functional as F
4241
import torch.optim as optim
43-
4442
from functorch import make_functional_with_buffers
4543
from torch import nn
4644

45+
4746
mpl.use("Agg")
4847
plt.style.use("bmh")
4948

functorch/examples/maml_omniglot/maml-omniglot-transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import matplotlib as mpl
3535
import matplotlib.pyplot as plt
3636
import numpy as np
37-
3837
import pandas as pd
3938
from support.omniglot_loaders import OmniglotNShot
4039

@@ -44,6 +43,7 @@
4443
from torch import nn
4544
from torch.func import functional_call, grad, vmap
4645

46+
4747
mpl.use("Agg")
4848
plt.style.use("bmh")
4949

functorch/examples/maml_regression/evjang.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from torch.nn import functional as F
1313

14+
1415
mpl.use("Agg")
1516

1617

functorch/examples/maml_regression/evjang_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.func import grad, vmap
1313
from torch.nn import functional as F
1414

15+
1516
mpl.use("Agg")
1617

1718

functorch/examples/maml_regression/evjang_transforms_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import numpy as np
1010

1111
import torch
12-
1312
from functorch import grad, make_functional, vmap
1413
from torch import nn
1514
from torch.nn import functional as F
1615

16+
1717
mpl.use("Agg")
1818

1919

functorch/experimental/control_flow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from torch import cond # noqa: F401
22
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
3-
43
from torch._higher_order_ops.map import ( # noqa: F401
54
_stack_pytree,
65
_unstack_pytree,

0 commit comments

Comments
 (0)