@@ -1412,18 +1412,14 @@ def _register_buffer_comm_hook(
1412
1412
Args:
1413
1413
state (Any): Optional state that is passed to the hook.
1414
1414
hook (Callable): Callable with the following signature:
1415
- ``hook(state: object, buffers: Dict[str, torch.Tensor])
1416
- -> Optional[List[torch.futures.Future[torch.Tensor]]]``
1415
+ ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
1417
1416
comm_hook_location (_BufferCommHookLocation): Enum value indicating
1418
1417
where to run the hook.
1419
1418
_BufferCommHookLocation.PRE_FORWARD means that the
1420
1419
hook will run _before_ the forward pass, and
1421
1420
_BufferCommHookLocation.POST_FORWARD means that the
1422
1421
hook will run _after_ the forward pass.
1423
1422
1424
- hook (Callable): Callable with the following signature:
1425
- ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
1426
-
1427
1423
NOTE: To maximize performance, users can return a
1428
1424
List[torch.futures.Future] from their hook, and DDP will
1429
1425
install and await these hooks appropriately at the end of
@@ -1558,26 +1554,26 @@ def _register_fused_optim(
1558
1554
self , optim : Type , * args , optim_params = None , ** kwargs
1559
1555
):
1560
1556
r"""
1561
- Registers an optimizer with DDP such that the optimization for a
1562
- parameter will run immediately when that parameter's gradient is
1563
- finished with reduction, instead of waiting for all parameters'
1564
- gradients to finish reduction. This can result in a training speedup
1565
- depending on your workload since the optimizer can run while gradient
1566
- reduction for other parameters are still ongoing. In addition, this has
1567
- the potential to reduce peak memory consumption during training, as it
1568
- only needs to load the per-parameter optimizer states of a single
1569
- parameter at a time, instead of loading all per-parameter optimizer
1570
- states at once.
1571
-
1572
- Args:
1573
- optim_cls (Type): a ``torch.optim.Optimizer`` class to be registered
1574
- as a fused optimizer.
1575
- *args (Sequence[Any]): Arguments to forward to `optim_cls `.
1576
- optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
1577
- to optimize, similar to `params` argument of traditional `torch.optim`
1578
- Optimizers. If this is omitted, all DDP model parameters will be
1579
- optimized.
1580
- **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim_cls `.
1557
+ Registers an optimizer with DDP such that the optimization for a
1558
+ parameter will run immediately when that parameter's gradient is
1559
+ finished with reduction, instead of waiting for all parameters'
1560
+ gradients to finish reduction. This can result in a training speedup
1561
+ depending on your workload since the optimizer can run while gradient
1562
+ reduction for other parameters are still ongoing. In addition, this has
1563
+ the potential to reduce peak memory consumption during training, as it
1564
+ only needs to load the per-parameter optimizer states of a single
1565
+ parameter at a time, instead of loading all per-parameter optimizer
1566
+ states at once.
1567
+
1568
+ Args:
1569
+ optim (Type): a ``torch.optim.Optimizer`` class to be registered
1570
+ as a fused optimizer.
1571
+ *args (Sequence[Any]): Arguments to forward to `optim `.
1572
+ optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
1573
+ to optimize, similar to `params` argument of traditional `torch.optim`
1574
+ Optimizers. If this is omitted, all DDP model parameters will be
1575
+ optimized.
1576
+ **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim `.
1581
1577
1582
1578
.. warning ::
1583
1579
_register_fused_optim should only be called once on a DDP instance,
0 commit comments