File tree 6 files changed +40
-2
lines changed
testing/_internal/distributed
6 files changed +40
-2
lines changed Original file line number Diff line number Diff line change @@ -48,6 +48,7 @@ class Logger:
48
48
device_ids : List [int ],
49
49
output_device : int ,
50
50
broadcast_buffers : bool ,
51
+ has_sync_bn : bool ,
51
52
): ...
52
53
...
53
54
Original file line number Diff line number Diff line change @@ -454,6 +454,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
454
454
py::arg (" device_ids" ),
455
455
py::arg (" output_device" ),
456
456
py::arg (" broadcast_buffers" ),
457
+ py::arg (" has_sync_bn" ),
457
458
py::call_guard<py::gil_scoped_release>())
458
459
.def (
459
460
" set_runtime_stats_and_log" ,
Original file line number Diff line number Diff line change @@ -155,7 +155,8 @@ void Logger::set_construction_data_and_log(
155
155
const std::string& module_name,
156
156
const std::vector<int >& device_ids,
157
157
int output_device,
158
- bool broadcast_buffers) {
158
+ bool broadcast_buffers,
159
+ bool has_sync_bn) {
159
160
// No lock is needed, as it will be called in DistributedDataParallel
160
161
// constructor.
161
162
ddp_logging_data_->strs_map [" module_name" ] = module_name;
@@ -182,6 +183,7 @@ void Logger::set_construction_data_and_log(
182
183
ddp_logging_data_->strs_map [" device_ids" ] = c10::Join (" , " , device_ids);
183
184
ddp_logging_data_->ints_map [" output_device" ] = output_device;
184
185
ddp_logging_data_->ints_map [" broadcast_buffers" ] = broadcast_buffers;
186
+ ddp_logging_data_->ints_map [" has_sync_bn" ] = has_sync_bn;
185
187
ddp_logging_data_->ints_map [" bucket_cap_bytes" ] = reducer_->bucket_bytes_cap_ ;
186
188
ddp_logging_data_->ints_map [" find_unused_parameters" ] =
187
189
reducer_->find_unused_parameters_ ;
Original file line number Diff line number Diff line change @@ -12,7 +12,8 @@ class TORCH_API Logger {
12
12
const std::string& module_name,
13
13
const std::vector<int >& device_ids,
14
14
int output_device,
15
- bool broadcast_buffers);
15
+ bool broadcast_buffers,
16
+ bool has_sync_bn);
16
17
17
18
void set_static_graph ();
18
19
Original file line number Diff line number Diff line change @@ -674,12 +674,19 @@ def _ddp_init_helper(
674
674
# logger and reducer.
675
675
self .reducer .set_logger (self .logger )
676
676
677
+ has_sync_bn = False
678
+ for submodule in self .module .modules ():
679
+ if isinstance (submodule , torch .nn .SyncBatchNorm ):
680
+ has_sync_bn = True
681
+ break
682
+
677
683
# Set logging data that can be got during construction time.
678
684
self .logger .set_construction_data_and_log (
679
685
self .module .__class__ .__name__ ,
680
686
[] if self .device_ids is None else self .device_ids ,
681
687
- 1 if self .output_device is None else self .output_device ,
682
688
self .broadcast_buffers ,
689
+ has_sync_bn
683
690
)
684
691
685
692
# passing a handle to torch.nn.SyncBatchNorm layer
Original file line number Diff line number Diff line change @@ -8263,6 +8263,32 @@ def forward(self, x):
8263
8263
for buf in bufs [1 :]:
8264
8264
self .assertEqual (rank_0_buf , buf )
8265
8265
8266
+ @skip_if_lt_x_gpu (2 )
8267
+ @sandcastle_skip_if (
8268
+ BACKEND != "nccl" and BACKEND != "gloo" ,
8269
+ "Only Nccl & Gloo backend support DistributedDataParallel" ,
8270
+ )
8271
+ def test_sync_bn_logged (self ):
8272
+ model = BN_NET
8273
+ rank = self .rank
8274
+ # single gpu training setup
8275
+ model_gpu = model .cuda (rank )
8276
+ no_sync_bn = torch .nn .parallel .DistributedDataParallel (
8277
+ copy .deepcopy (model_gpu ),
8278
+ device_ids = [self .rank ],
8279
+ )
8280
+ ddp_logging_data = no_sync_bn ._get_ddp_logging_data ()
8281
+ sync_bn_logged = ddp_logging_data .get ("has_sync_bn" , True )
8282
+ self .assertFalse (sync_bn_logged )
8283
+ model_DDP = nn .SyncBatchNorm .convert_sync_batchnorm (model_gpu )
8284
+ model_DDP = torch .nn .parallel .DistributedDataParallel (
8285
+ model_DDP ,
8286
+ device_ids = [self .rank ],
8287
+ )
8288
+ ddp_logging_data = model_DDP ._get_ddp_logging_data ()
8289
+ sync_bn_logged = ddp_logging_data .get ("has_sync_bn" , False )
8290
+ self .assertTrue (sync_bn_logged )
8291
+
8266
8292
@skip_if_lt_x_gpu (2 )
8267
8293
@sandcastle_skip_if (
8268
8294
BACKEND != "nccl" and BACKEND != "gloo" ,
You can’t perform that action at this time.
0 commit comments