Skip to content

Commit

Permalink
Engines bn fix (#1310)
Browse files Browse the repository at this point in the history
* fix

* fix

* version upadte
  • Loading branch information
Scitator authored Sep 30, 2021
1 parent bd3a0d7 commit 3f18d6b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion catalyst/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "21.09rc1"
__version__ = "21.09"
9 changes: 4 additions & 5 deletions catalyst/engines/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ class DistributedDataParallelAPEXEngine(DistributedDataParallelEngine):
address: address to use for backend.
port: port to use for backend.
sync_bn: boolean flag for batchnorm synchonization during disributed training.
if True, applies PyTorch `convert_sync_batchnorm`_ to the model for native torch
if True, applies Apex `convert_syncbn_model`_ to the model for native torch
distributed only. Default, False.
ddp_kwargs: parameters for `apex.parallel.DistributedDataParallel`.
More info here:
Expand Down Expand Up @@ -439,9 +439,8 @@ def get_engine(self):
stages:
...
.. _convert_sync_batchnorm:
https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
torch.nn.SyncBatchNorm.convert_sync_batchnorm
.. _`convert_syncbn_model`:
https://nvidia.github.io/apex/parallel.html#apex.parallel.convert_syncbn_model
"""

def __init__(
Expand Down Expand Up @@ -501,7 +500,7 @@ def init_components(
model = model_fn()
model = self.sync_device(model)
if self._sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = apex.parallel.convert_syncbn_model(model)

criterion = criterion_fn()
criterion = self.sync_device(criterion)
Expand Down
2 changes: 2 additions & 0 deletions catalyst/engines/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def __init__(
self,
address: str = None,
port: Union[str, int] = None,
sync_bn: bool = False,
ddp_kwargs: Dict[str, Any] = None,
process_group_kwargs: Dict[str, Any] = None,
scaler_kwargs: Dict[str, Any] = None,
Expand All @@ -348,6 +349,7 @@ def __init__(
super().__init__(
address=address,
port=port,
sync_bn=sync_bn,
ddp_kwargs=ddp_kwargs,
process_group_kwargs=process_group_kwargs,
)
Expand Down

0 comments on commit 3f18d6b

Please sign in to comment.