Skip to content

Commit 95b1bc1

Browse files
peterbell10facebook-github-bot
authored andcommitted
Migrate nonzero from TH to ATen (CPU) (pytorch#58811)
Summary: Closes pytorchgh-24745 The existing PR (pytorchgh-50655) has been stalled because `TensorIterator` doesn't guarantee iteration order in the same way that `TH_TENSOR_APPLY` does. For contiguous test cases this isn't an issue; but it breaks down for example with channels last format. I resolve this by adding a new `TensorIteratorConfig` parameter, `enforce_linear_iteration`, which disables dimension reordering. I've also added a test case for non-contiguous tensors to verify this works. This PR also significantly improves performance by adding multithreading support to the algorithm. As part of this, I wrote a custom `count_nonzero` that gives per-thread counts which is necessary to write the outputs in the right location. | Shape | Before | After (1 thread) | After (8 threads) | |:----------:|--------:|-----------------:|------------------:| | 256,128,32 | 2610 us | 2220 us | 496 us | | 128,128,32 | 1250 us | 976 us | 175 us | | 64,128,32 | 581 us | 486 us | 88 us | | 32,128,32 | 292 us | 245 us | 80 us | | 16,128,32 | 147 us | 120 us | 71 us | | 8,128,32 | 75 us | 61 us | 61 us | | 4,128,32 | 39 us | 32 us | 32 us | | 2,128,32 | 20 us | 17 us | 17 us | | 1,128,32 | 11 us | 9 us | 9 us | Pull Request resolved: pytorch#58811 Reviewed By: anjali411 Differential Revision: D28700259 Pulled By: ngimel fbshipit-source-id: 9b279ca7c36d8e348b7e5e4be0dd159e05aee159
1 parent 934f6dc commit 95b1bc1

19 files changed

+313
-327
lines changed

BUILD.bazel

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,7 @@ filegroup(
329329
"aten/src/TH/THLapack.cpp",
330330
"aten/src/TH/THStorageFunctions.cpp",
331331
"aten/src/TH/THTensor.cpp",
332-
"aten/src/TH/THTensorEvenMoreMath.cpp",
333332
"aten/src/TH/THTensorLapack.cpp",
334-
"aten/src/TH/THTensorMath.cpp",
335333
"aten/src/TH/THTensorMoreMath.cpp",
336334
],
337335
)

aten/src/ATen/LegacyTHFunctionsCPU.cpp

Lines changed: 0 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -35,159 +35,6 @@ namespace {
3535
}
3636
}
3737

38-
Tensor & _th_nonzero_out(const Tensor & self, Tensor & result) {
39-
// DeviceGuard omitted
40-
auto dispatch_scalar_type = infer_scalar_type(self);
41-
42-
switch (dispatch_scalar_type) {
43-
case ScalarType::Bool: {
44-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
45-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
46-
THBoolTensor_nonzero(result_, self_);
47-
break;
48-
}
49-
case ScalarType::Byte: {
50-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
51-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
52-
THByteTensor_nonzero(result_, self_);
53-
break;
54-
}
55-
case ScalarType::Char: {
56-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
57-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
58-
THCharTensor_nonzero(result_, self_);
59-
break;
60-
}
61-
case ScalarType::Double: {
62-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
63-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
64-
THDoubleTensor_nonzero(result_, self_);
65-
break;
66-
}
67-
case ScalarType::Float: {
68-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
69-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
70-
THFloatTensor_nonzero(result_, self_);
71-
break;
72-
}
73-
case ScalarType::Int: {
74-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
75-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
76-
THIntTensor_nonzero(result_, self_);
77-
break;
78-
}
79-
case ScalarType::Long: {
80-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
81-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
82-
THLongTensor_nonzero(result_, self_);
83-
break;
84-
}
85-
case ScalarType::Short: {
86-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
87-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
88-
THShortTensor_nonzero(result_, self_);
89-
break;
90-
}
91-
case ScalarType::Half: {
92-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
93-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
94-
THHalfTensor_nonzero(result_, self_);
95-
break;
96-
}
97-
case ScalarType::BFloat16: {
98-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
99-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
100-
THBFloat16Tensor_nonzero(result_, self_);
101-
break;
102-
}
103-
case ScalarType::ComplexDouble: {
104-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
105-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
106-
THComplexDoubleTensor_nonzero(result_, self_);
107-
break;
108-
}
109-
case ScalarType::ComplexFloat: {
110-
auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
111-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
112-
THComplexFloatTensor_nonzero(result_, self_);
113-
break;
114-
}
115-
default:
116-
AT_ERROR("_th_nonzero_out not supported on CPUType for ", dispatch_scalar_type);
117-
}
118-
return result;
119-
}
120-
Tensor _th_nonzero(const Tensor & self) {
121-
// DeviceGuard omitted
122-
auto dispatch_scalar_type = infer_scalar_type(self);
123-
auto result_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(ScalarType::Long)).release();
124-
auto result = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(result_));
125-
switch (dispatch_scalar_type) {
126-
case ScalarType::Bool: {
127-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
128-
THBoolTensor_nonzero(result_, self_);
129-
break;
130-
}
131-
case ScalarType::Byte: {
132-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
133-
THByteTensor_nonzero(result_, self_);
134-
break;
135-
}
136-
case ScalarType::Char: {
137-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
138-
THCharTensor_nonzero(result_, self_);
139-
break;
140-
}
141-
case ScalarType::Double: {
142-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
143-
THDoubleTensor_nonzero(result_, self_);
144-
break;
145-
}
146-
case ScalarType::Float: {
147-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
148-
THFloatTensor_nonzero(result_, self_);
149-
break;
150-
}
151-
case ScalarType::Int: {
152-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
153-
THIntTensor_nonzero(result_, self_);
154-
break;
155-
}
156-
case ScalarType::Long: {
157-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
158-
THLongTensor_nonzero(result_, self_);
159-
break;
160-
}
161-
case ScalarType::Short: {
162-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
163-
THShortTensor_nonzero(result_, self_);
164-
break;
165-
}
166-
case ScalarType::Half: {
167-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
168-
THHalfTensor_nonzero(result_, self_);
169-
break;
170-
}
171-
case ScalarType::BFloat16: {
172-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
173-
THBFloat16Tensor_nonzero(result_, self_);
174-
break;
175-
}
176-
case ScalarType::ComplexDouble: {
177-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
178-
THComplexDoubleTensor_nonzero(result_, self_);
179-
break;
180-
}
181-
case ScalarType::ComplexFloat: {
182-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
183-
THComplexFloatTensor_nonzero(result_, self_);
184-
break;
185-
}
186-
default:
187-
AT_ERROR("_th_nonzero not supported on CPUType for ", dispatch_scalar_type);
188-
}
189-
return result;
190-
}
19138
Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt) {
19239
// DeviceGuard omitted
19340
auto dispatch_scalar_type = infer_scalar_type(self);

aten/src/ATen/LegacyTHFunctionsCPU.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ namespace cpu {
2020

2121
Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source);
2222
Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
23-
Tensor& _th_nonzero_out(const Tensor& self, Tensor& result);
24-
Tensor _th_nonzero(const Tensor & self);
2523
Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt);
2624
Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result);
2725
Tensor _th_renorm(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm);

aten/src/ATen/ParallelOpenMP.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,30 @@ inline void parallel_for(
1717
const int64_t end,
1818
const int64_t grain_size,
1919
const F& f) {
20-
TORCH_CHECK(grain_size >= 0);
21-
at::internal::lazy_init_num_threads();
20+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
2221
if (begin >= end) {
2322
return;
2423
}
25-
if (end - begin == 1) {
24+
25+
#ifdef _OPENMP
26+
at::internal::lazy_init_num_threads();
27+
const auto numiter = end - begin;
28+
const bool use_parallel = (
29+
numiter > grain_size && numiter > 1 &&
30+
omp_get_max_threads() > 1 && !omp_in_parallel());
31+
if (!use_parallel) {
2632
f(begin, end);
2733
return;
2834
}
29-
#ifdef _OPENMP
35+
3036
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
3137
std::exception_ptr eptr;
3238
// Work around memory leak when using 1 thread in nested "omp parallel"
3339
// caused by some buggy OpenMP versions and the fact that omp_in_parallel()
3440
// returns false when omp_get_max_threads() == 1 inside nested "omp parallel"
3541
// See issue gh-32284
3642

37-
#pragma omp parallel if (omp_get_max_threads() > 1 && !omp_in_parallel() && ((end - begin) > grain_size))
43+
#pragma omp parallel
3844
{
3945
// choose number of tasks based on grain size and number of threads
4046
// can't use num_threads clause due to bugs in GOMP's thread pool (See #32008)
@@ -76,15 +82,16 @@ inline scalar_t parallel_reduce(
7682
at::internal::lazy_init_num_threads();
7783
if (begin >= end) {
7884
return ident;
79-
} else if (in_parallel_region() || get_num_threads() == 1) {
85+
} else if ((end - begin) <= grain_size || in_parallel_region() ||
86+
get_num_threads() == 1) {
8087
return f(begin, end, ident);
8188
} else {
8289
const int64_t num_results = divup((end - begin), grain_size);
8390
std::vector<scalar_t> results(num_results);
8491
scalar_t* results_data = results.data();
8592
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
8693
std::exception_ptr eptr;
87-
#pragma omp parallel for if ((end - begin) >= grain_size)
94+
#pragma omp parallel for
8895
for (int64_t id = 0; id < num_results; id++) {
8996
int64_t i = begin + id * grain_size;
9097
try {

aten/src/ATen/TensorIterator.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ void TensorIteratorBase::reorder_dimensions() {
129129
// initialize perm with n-1, n-2, ..., 1, 0
130130
std::iota(perm_.rbegin(), perm_.rend(), 0);
131131

132+
// Reordering dimensions changes iteraton order
133+
if (enforce_linear_iteration_) {
134+
permute_dimensions(perm_);
135+
return;
136+
}
137+
132138
// returns 1 if the dim0 should come after dim1, -1 if dim0 should come
133139
// before dim1, and 0 if the comparison is ambiguous.
134140
auto should_swap = [&](size_t dim0, size_t dim1) {
@@ -1213,6 +1219,20 @@ FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorCo
12131219
return FastSetupType::NONE;
12141220
}
12151221

1222+
// For linear iteration, only contiguous tensors can be coalesced
1223+
// Fast setup of any other format requires changing iteration order
1224+
if (enforce_linear_iteration_) {
1225+
for (const auto& op : operands_) {
1226+
if (op.tensor->defined() && !op.will_resize) {
1227+
auto is_contiguous = op.tensor->is_contiguous(at::MemoryFormat::Contiguous);
1228+
if (!is_contiguous) {
1229+
return FastSetupType::NONE;
1230+
}
1231+
}
1232+
}
1233+
return FastSetupType::CONTIGUOUS;
1234+
}
1235+
12161236
bool is_contiguous = true;
12171237
bool is_channels_last = true;
12181238
bool is_non_overlapping_and_dense = true;
@@ -1265,6 +1285,7 @@ TensorIteratorBase::TensorIteratorBase() = default;
12651285
void TensorIteratorBase::build(TensorIteratorConfig& config) {
12661286
// populate some persistent configuration fields
12671287
is_reduction_ = config.is_reduction_;
1288+
enforce_linear_iteration_ = config.enforce_linear_iteration_;
12681289

12691290
// fill in operands_ based on configuration
12701291
populate_operands(config);

aten/src/ATen/TensorIterator.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,10 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase {
426426
/// been called? This is SOLELY used to check validity of perm_.
427427
bool has_coalesced_dimensions_ = false;
428428

429+
/// Whether iteration must be fixed. This disables dimension permuting and also
430+
/// changes how for_each divides work among threads.
431+
bool enforce_linear_iteration_ = false;
432+
429433
/// The index offsets into the original tensors for each dimension.
430434
/// This is only non-zero when you narrow() a TensorIterator (e.g.,
431435
/// when you make sub-TensorIterators).
@@ -583,6 +587,17 @@ class TORCH_API TensorIteratorConfig final {
583587
return *this;
584588
}
585589

590+
// Sets the enforce_linear_iteration_ flag, which is false by default.
591+
// If true, iteration goes in the same order as a C-contiguous tensor
592+
// is layed out in memory. i.e. last dimension iterates fastest.
593+
//
594+
// This iteration order can be less efficient and may even prevent vectorization.
595+
// So only use if the correctness of your kernel depends on it.
596+
TensorIteratorConfig& enforce_linear_iteration(const bool _enforce_linear_iteration = true) {
597+
enforce_linear_iteration_ = _enforce_linear_iteration;
598+
return *this;
599+
}
600+
586601
// Sets the promote_inputs_to_common_dtype_ flag, which is false by default
587602
// If true, the iterator's "common dtype" is always computed (see the
588603
// [Common Dtype Computation] note) and, on the CPU, temporary copies of
@@ -664,6 +679,7 @@ class TORCH_API TensorIteratorConfig final {
664679
bool check_all_same_dtype_ = true;
665680
bool check_all_same_device_ = true;
666681
bool enforce_safe_casting_to_output_ = false;
682+
bool enforce_linear_iteration_ = false;
667683
bool promote_inputs_to_common_dtype_ = false;
668684
bool promote_integer_inputs_to_float_ = false;
669685
bool cast_common_dtype_to_outputs_ = false;

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,19 +1746,6 @@ Tensor dist(const Tensor &self, const Tensor& other, const Scalar& p){
17461746
return at::norm(self - other, p);
17471747
}
17481748

1749-
Tensor count_nonzero(const Tensor& self, IntArrayRef dims){
1750-
auto mask = (self != 0);
1751-
return mask.sum(dims);
1752-
}
1753-
1754-
Tensor count_nonzero(const Tensor& self, c10::optional<int64_t> dim){
1755-
if (dim){
1756-
auto wrap_dim = maybe_wrap_dim(dim.value(), self.dim());
1757-
return at::count_nonzero(self, IntArrayRef{wrap_dim});
1758-
}
1759-
return at::count_nonzero(self, IntArrayRef{});
1760-
}
1761-
17621749
bool cpu_equal(const Tensor& self, const Tensor& other) {
17631750
if (!at::namedinference::are_names_equal(
17641751
self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) {

0 commit comments

Comments
 (0)