Skip to content

Commit 9ef1d55

Browse files
kit1980pytorchmergebot
authored andcommitted
Fix non-existing parameters in docstrings in torch/nn (pytorch#90596)
This is a continuation of pytorch#90505 Pull Request resolved: pytorch#90596 Approved by: https://github.com/lezcano
1 parent 45109ec commit 9ef1d55

File tree

4 files changed

+27
-33
lines changed

4 files changed

+27
-33
lines changed

torch/nn/parallel/distributed.py

+21-25
Original file line numberDiff line numberDiff line change
@@ -1412,18 +1412,14 @@ def _register_buffer_comm_hook(
14121412
Args:
14131413
state (Any): Optional state that is passed to the hook.
14141414
hook (Callable): Callable with the following signature:
1415-
``hook(state: object, buffers: Dict[str, torch.Tensor])
1416-
-> Optional[List[torch.futures.Future[torch.Tensor]]]``
1415+
``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
14171416
comm_hook_location (_BufferCommHookLocation): Enum value indicating
14181417
where to run the hook.
14191418
_BufferCommHookLocation.PRE_FORWARD means that the
14201419
hook will run _before_ the forward pass, and
14211420
_BufferCommHookLocation.POST_FORWARD means that the
14221421
hook will run _after_ the forward pass.
14231422
1424-
hook (Callable): Callable with the following signature:
1425-
``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
1426-
14271423
NOTE: To maximize performance, users can return a
14281424
List[torch.futures.Future] from their hook, and DDP will
14291425
install and await these hooks appropriately at the end of
@@ -1558,26 +1554,26 @@ def _register_fused_optim(
15581554
self, optim: Type, *args, optim_params=None, **kwargs
15591555
):
15601556
r"""
1561-
Registers an optimizer with DDP such that the optimization for a
1562-
parameter will run immediately when that parameter's gradient is
1563-
finished with reduction, instead of waiting for all parameters'
1564-
gradients to finish reduction. This can result in a training speedup
1565-
depending on your workload since the optimizer can run while gradient
1566-
reduction for other parameters are still ongoing. In addition, this has
1567-
the potential to reduce peak memory consumption during training, as it
1568-
only needs to load the per-parameter optimizer states of a single
1569-
parameter at a time, instead of loading all per-parameter optimizer
1570-
states at once.
1571-
1572-
Args:
1573-
optim_cls (Type): a ``torch.optim.Optimizer`` class to be registered
1574-
as a fused optimizer.
1575-
*args (Sequence[Any]): Arguments to forward to `optim_cls`.
1576-
optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
1577-
to optimize, similar to `params` argument of traditional `torch.optim`
1578-
Optimizers. If this is omitted, all DDP model parameters will be
1579-
optimized.
1580-
**kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim_cls`.
1557+
Registers an optimizer with DDP such that the optimization for a
1558+
parameter will run immediately when that parameter's gradient is
1559+
finished with reduction, instead of waiting for all parameters'
1560+
gradients to finish reduction. This can result in a training speedup
1561+
depending on your workload since the optimizer can run while gradient
1562+
reduction for other parameters are still ongoing. In addition, this has
1563+
the potential to reduce peak memory consumption during training, as it
1564+
only needs to load the per-parameter optimizer states of a single
1565+
parameter at a time, instead of loading all per-parameter optimizer
1566+
states at once.
1567+
1568+
Args:
1569+
optim (Type): a ``torch.optim.Optimizer`` class to be registered
1570+
as a fused optimizer.
1571+
*args (Sequence[Any]): Arguments to forward to `optim`.
1572+
optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
1573+
to optimize, similar to `params` argument of traditional `torch.optim`
1574+
Optimizers. If this is omitted, all DDP model parameters will be
1575+
optimized.
1576+
**kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`.
15811577
15821578
.. warning ::
15831579
_register_fused_optim should only be called once on a DDP instance,

torch/nn/utils/_expanded_weights/expanded_weights_utils.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,10 @@ def forward_helper(func, expanded_args, expanded_kwargs):
2424
2525
Args:
2626
func: The function to be called
27-
ctx: The context from the autograd.Function object. Will be used to save
28-
computed state from the forward pass
2927
expanded_args: Arguments to be passed to :attr:`func`. Will include arguments
3028
that need to be unpacked because they are ExpandedWeights
31-
num_true_outs: The number of outputs seen by the user since some functions
32-
return auxillary data that is only used in the backward pass
29+
expanded_kwargs: Keyword arguments to be passed to :attr:`func`.
30+
Similar to :attr:`expanded_args`.
3331
'''
3432
unexpanded_args, unexpanded_kwargs = _check_and_unexpand_args(func, expanded_args, expanded_kwargs)
3533
return func(*unexpanded_args, **unexpanded_kwargs)

torch/nn/utils/memory_format.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def convert_conv2d_weight_memory_format(module, memory_format):
4141
immediately before a convolution.
4242
4343
Args:
44-
module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container
44+
module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container
4545
``nn.Module``
46-
format: user specified ``memory_format``,
46+
memory_format: user specified ``memory_format``,
4747
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
4848
4949
Returns:

torch/nn/utils/parametrize.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(
120120
# assert X.dtype == Z.dtype and X.shape == Z.shape
121121
# # If it has one input, this allows to be able to use set_ to be able to
122122
# # move data to/from the original tensor without changing its id (which is what the
123-
# # optimiser uses to track parameters)
123+
# # optimizer uses to track parameters)
124124
# if isinstance(Y, Tensor)
125125
# assert X.dtype == Y.dtype
126126
# Below we use original = X, new = Y
@@ -591,7 +591,7 @@ def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool:
591591
592592
Args:
593593
module (nn.Module): module to query
594-
name (str, optional): attribute in the module to query
594+
tensor_name (str, optional): attribute in the module to query
595595
Default: ``None``
596596
"""
597597
parametrizations = getattr(module, "parametrizations", None)

0 commit comments

Comments
 (0)