Skip to content

Commit 1884d7f

Browse files
mrshenlipytorchmergebot
authored andcommitted
Avoid CPU Sync in SyncBatchNorm When Capturing CUDA Graphs
We recently updated `SyncBatchNorm` to support empty input batches. The new code removes stats from ranks with empty inputs. However, this change breaks CUDA graph capture as it forces CPU sync. This commit uses `is_current_stream_capturing()` to guard the new code path, and only run the new code when not capturing CUA Graphs. To support empty inputs with CUDA graph capturing, we might need to update CUDA kernels for `batch_norm_backward_elemt` and `batch_norm_gather_stats_with_counts`. See pytorch#78656. Fixes pytorch#78549 Pull Request resolved: pytorch#78666 Approved by: https://github.com/albanD
1 parent 1eab34d commit 1884d7f

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

torch/nn/modules/_functions.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,19 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum,
6767
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
6868
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
6969

70-
# remove stats from empty inputs
71-
mask = count_all.squeeze(-1) >= 1
72-
count_all = count_all[mask]
73-
mean_all = mean_all[mask]
74-
invstd_all = invstd_all[mask]
70+
if not torch.cuda.is_current_stream_capturing():
71+
# The lines below force a synchronization between CUDA and CPU, because
72+
# the shape of the result count_all depends on the values in mask tensor.
73+
# Such synchronizations break CUDA Graph capturing.
74+
# See https://github.com/pytorch/pytorch/issues/78549
75+
# FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
76+
# a better longer-term solution.
77+
78+
# remove stats from empty inputs
79+
mask = count_all.squeeze(-1) >= 1
80+
count_all = count_all[mask]
81+
mean_all = mean_all[mask]
82+
invstd_all = invstd_all[mask]
7583

7684
# calculate global mean & invstd
7785
mean, invstd = torch.batch_norm_gather_stats_with_counts(

0 commit comments

Comments
 (0)