Skip to content

Commit 9af18d8

Browse files
colesburyfacebook-github-bot
authored andcommitted
Fix accesses to uninitialized memory when running sum() within an OMP… (pytorch#13274)
Summary: ``` … parallel region. The two_pass_reduction code allocates a buffer of size at::max_threads(). When called within a parallel region, at::parallel_for only uses 1 thread so some of this buffer is not written. This makes two changes: 1) two_pass_reduction is not called when already in a parallel region 2) two_pass_reduction fills unwritten buffer elements with the identity (the value in dst) ``` cc The controller you requested could not be found. SsnL: I think this should fix the NaNs in BatchNorm when calling sum() within a parallel region. Pull Request resolved: pytorch#13274 Differential Revision: D12840034 Pulled By: colesbury fbshipit-source-id: d32e80909a98a0f1bb1c80689fe5089b7019ef59
1 parent f04a705 commit 9af18d8

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

aten/src/ATen/Parallel.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ inline int get_thread_num() {
3636
#endif
3737
}
3838

39+
inline bool in_parallel_region() {
40+
#ifdef _OPENMP
41+
return omp_in_parallel();
42+
#else
43+
return false;
44+
#endif
45+
}
46+
3947
template <class F>
4048
inline void parallel_for(
4149
const int64_t begin,

aten/src/ATen/native/TensorIteratorReduce.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <ATen/native/TensorIterator.h>
22
#include <ATen/Parallel.h>
3+
#include <algorithm>
4+
#include <memory>
35

46
/// Contains the implementation of parallel reductions in TensorIterator.
57

@@ -14,7 +16,7 @@ static void parallel_dim_reduction(TensorIterator& iter, const loop2d_t& loop);
1416
void TensorIterator::parallel_reduce(const loop2d_t& loop) {
1517
AT_CHECK(ntensors() == 2, "parallel_reduce only supports one input and one output");
1618
int64_t numel = this->numel();
17-
if (numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1) {
19+
if (numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region()) {
1820
serial_for_each(loop, {0, numel});
1921
} else if (use_two_pass_reduction(*this)) {
2022
two_pass_reduction(*this, loop);
@@ -28,21 +30,33 @@ static bool use_two_pass_reduction(TensorIterator& iter) {
2830
}
2931

3032
static void two_pass_reduction(TensorIterator& iter, const loop2d_t& loop) {
31-
int num_threads = at::get_max_threads();
33+
int max_threads = at::get_max_threads();
3234

3335
auto& dst = iter.tensor(0);
3436
auto buffer_shape = DimVector(dst.sizes());
35-
buffer_shape.insert(buffer_shape.begin(), num_threads);
37+
buffer_shape.insert(buffer_shape.begin(), max_threads);
3638
auto buffer = at::empty(buffer_shape, dst.type());
3739

40+
std::unique_ptr<bool[]> written(new bool[max_threads]);
41+
std::fill(written.get(), written.get() + max_threads, false);
42+
3843
at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
39-
auto slice = buffer[at::get_thread_num()];
44+
int thread_num = at::get_thread_num();
45+
written[thread_num] = true;
46+
auto slice = buffer[thread_num];
4047
slice.copy_(dst);
4148

4249
auto sub_iter = TensorIterator::reduce_op(slice, iter.tensor(1));
4350
sub_iter->serial_for_each(loop, {begin, end});
4451
});
4552

53+
// fill any unwritten slices of the buffer with the identity
54+
for (int thread_num = 0; thread_num < max_threads; thread_num++) {
55+
if (!written[thread_num]) {
56+
buffer[thread_num].copy_(dst);
57+
}
58+
}
59+
4660
auto unsqueezed = dst.unsqueeze(0);
4761
auto final_reduce = TensorIterator::reduce_op(unsqueezed, buffer);
4862
final_reduce->for_each(loop);

aten/src/ATen/test/test_parallel.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#include "gtest/gtest.h"
22

3-
#include "ATen/ATen.h"
4-
#include "ATen/DLConvertor.h"
3+
#include <ATen/ATen.h>
4+
#include <ATen/DLConvertor.h>
5+
#include <ATen/Parallel.h>
56

67
#include <iostream>
78
#include <string.h>
@@ -24,3 +25,14 @@ TEST(TestParallel, TestParallel) {
2425
as[2] = 0;
2526
ASSERT_TRUE(a.sum(0).equal(as));
2627
}
28+
29+
TEST(TestParallel, NestedParallel) {
30+
Tensor a = ones({1024, 1024});
31+
auto expected = a.sum();
32+
// check that calling sum() from within a parallel block computes the same result
33+
at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
34+
if (begin == 0) {
35+
ASSERT_TRUE(a.sum().equal(expected));
36+
}
37+
});
38+
}

0 commit comments

Comments
 (0)