Skip to content

Commit bff64e8

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
[DDP] Track models with sync bn (pytorch#66680)
Summary: Pull Request resolved: pytorch#66680 Closes pytorch#66215. Tracks models with sync BN so we can find workflows that use them and target for perf optimization. ghstack-source-id: 140875182 Test Plan: CI Reviewed By: pritamdamania87 Differential Revision: D31679477 fbshipit-source-id: 0e68cd1a7aabbc5b26227895c53d33b8e98bfb8e
1 parent e0643fa commit bff64e8

File tree

6 files changed

+40
-2
lines changed

6 files changed

+40
-2
lines changed

torch/_C/_distributed_c10d.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Logger:
4848
device_ids: List[int],
4949
output_device: int,
5050
broadcast_buffers: bool,
51+
has_sync_bn: bool,
5152
): ...
5253
...
5354

torch/csrc/distributed/c10d/init.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
454454
py::arg("device_ids"),
455455
py::arg("output_device"),
456456
py::arg("broadcast_buffers"),
457+
py::arg("has_sync_bn"),
457458
py::call_guard<py::gil_scoped_release>())
458459
.def(
459460
"set_runtime_stats_and_log",

torch/csrc/distributed/c10d/logger.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ void Logger::set_construction_data_and_log(
155155
const std::string& module_name,
156156
const std::vector<int>& device_ids,
157157
int output_device,
158-
bool broadcast_buffers) {
158+
bool broadcast_buffers,
159+
bool has_sync_bn) {
159160
// No lock is needed, as it will be called in DistributedDataParallel
160161
// constructor.
161162
ddp_logging_data_->strs_map["module_name"] = module_name;
@@ -182,6 +183,7 @@ void Logger::set_construction_data_and_log(
182183
ddp_logging_data_->strs_map["device_ids"] = c10::Join(", ", device_ids);
183184
ddp_logging_data_->ints_map["output_device"] = output_device;
184185
ddp_logging_data_->ints_map["broadcast_buffers"] = broadcast_buffers;
186+
ddp_logging_data_->ints_map["has_sync_bn"] = has_sync_bn;
185187
ddp_logging_data_->ints_map["bucket_cap_bytes"] = reducer_->bucket_bytes_cap_;
186188
ddp_logging_data_->ints_map["find_unused_parameters"] =
187189
reducer_->find_unused_parameters_;

torch/csrc/distributed/c10d/logger.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class TORCH_API Logger {
1212
const std::string& module_name,
1313
const std::vector<int>& device_ids,
1414
int output_device,
15-
bool broadcast_buffers);
15+
bool broadcast_buffers,
16+
bool has_sync_bn);
1617

1718
void set_static_graph();
1819

torch/nn/parallel/distributed.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,12 +674,19 @@ def _ddp_init_helper(
674674
# logger and reducer.
675675
self.reducer.set_logger(self.logger)
676676

677+
has_sync_bn = False
678+
for submodule in self.module.modules():
679+
if isinstance(submodule, torch.nn.SyncBatchNorm):
680+
has_sync_bn = True
681+
break
682+
677683
# Set logging data that can be got during construction time.
678684
self.logger.set_construction_data_and_log(
679685
self.module.__class__.__name__,
680686
[] if self.device_ids is None else self.device_ids,
681687
-1 if self.output_device is None else self.output_device,
682688
self.broadcast_buffers,
689+
has_sync_bn
683690
)
684691

685692
# passing a handle to torch.nn.SyncBatchNorm layer

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8263,6 +8263,32 @@ def forward(self, x):
82638263
for buf in bufs[1:]:
82648264
self.assertEqual(rank_0_buf, buf)
82658265

8266+
@skip_if_lt_x_gpu(2)
8267+
@sandcastle_skip_if(
8268+
BACKEND != "nccl" and BACKEND != "gloo",
8269+
"Only Nccl & Gloo backend support DistributedDataParallel",
8270+
)
8271+
def test_sync_bn_logged(self):
8272+
model = BN_NET
8273+
rank = self.rank
8274+
# single gpu training setup
8275+
model_gpu = model.cuda(rank)
8276+
no_sync_bn = torch.nn.parallel.DistributedDataParallel(
8277+
copy.deepcopy(model_gpu),
8278+
device_ids=[self.rank],
8279+
)
8280+
ddp_logging_data = no_sync_bn._get_ddp_logging_data()
8281+
sync_bn_logged = ddp_logging_data.get("has_sync_bn", True)
8282+
self.assertFalse(sync_bn_logged)
8283+
model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(model_gpu)
8284+
model_DDP = torch.nn.parallel.DistributedDataParallel(
8285+
model_DDP,
8286+
device_ids=[self.rank],
8287+
)
8288+
ddp_logging_data = model_DDP._get_ddp_logging_data()
8289+
sync_bn_logged = ddp_logging_data.get("has_sync_bn", False)
8290+
self.assertTrue(sync_bn_logged)
8291+
82668292
@skip_if_lt_x_gpu(2)
82678293
@sandcastle_skip_if(
82688294
BACKEND != "nccl" and BACKEND != "gloo",

0 commit comments

Comments
 (0)