Skip to content

Commit b6f41bb

Browse files
Mike Ruberryfacebook-github-bot
authored andcommitted
The Jiterator (pytorch#69439)
Summary: This PR: - creates the "jiterator" pattern, allowing elementwise unary and binary kernels that don't accept scalars to be jit compiled when called - ports the gcd and i1 CUDA kernels to use the jiterator - extends elementwise binary systemic testing to be comparable to elementwise unary systemic testing - separates one test case from test_out in test_ops.py - updates more OpInfos to use expected failures instead of skips The jiterator currently does not support half, bfloat16 or complex dtypes. It also (as mentioned above) doesn't support scalar inputs. In the future we expect to add support for those datatypes and scalars. Pull Request resolved: pytorch#69439 Reviewed By: ngimel Differential Revision: D32874968 Pulled By: mruberry fbshipit-source-id: d44bb9cde4f602703e75400ec5a0b209f085e9b3
1 parent 3202028 commit b6f41bb

File tree

16 files changed

+2551
-323
lines changed

16 files changed

+2551
-323
lines changed

aten/src/ATen/core/Array.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
namespace at { namespace detail {
1010

11-
template <typename T, int size>
11+
template <typename T, int size_>
1212
struct Array {
13-
T data[size];
13+
T data[size_];
1414

1515
C10_HOST_DEVICE T operator[](int i) const {
1616
return data[i];
@@ -27,10 +27,10 @@ struct Array {
2727
Array(const Array&) = default;
2828
Array& operator=(const Array&) = default;
2929
#endif
30-
30+
static constexpr int size(){return size_;}
3131
// Fill the array with x.
3232
C10_HOST_DEVICE Array(T x) {
33-
for (int i = 0; i < size; i++) {
33+
for (int i = 0; i < size_; i++) {
3434
data[i] = x;
3535
}
3636
}

aten/src/ATen/cudnn/Utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include <ATen/ATen.h>
44
#include <ATen/cuda/Exceptions.h>
5-
#include <THC/THC.h>
65
#include <ATen/cudnn/cudnn-wrapper.h>
76
#include <ATen/cudnn/Handle.h>
87

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ void logaddexp2_kernel(TensorIteratorBase& iter) {
994994
}
995995

996996
void gcd_kernel(TensorIteratorBase& iter) {
997-
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "gcd_cpu", [&]() {
997+
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cpu", [&]() {
998998
cpu_kernel(
999999
iter,
10001000
[](scalar_t a, scalar_t b) -> scalar_t {
@@ -1004,7 +1004,7 @@ void gcd_kernel(TensorIteratorBase& iter) {
10041004
}
10051005

10061006
void lcm_kernel(TensorIteratorBase& iter) {
1007-
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lcm_cpu", [&]() {
1007+
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cpu", [&]() {
10081008
cpu_kernel(
10091009
iter,
10101010
[](scalar_t a, scalar_t b) -> scalar_t {

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@
3030

3131
#include <type_traits>
3232
#include <tuple>
33+
#include <iostream>
34+
#include <mutex>
3335

3436
#include <ATen/cuda/CUDAContext.h>
3537
#include <ATen/core/Array.h>
3638
#include <ATen/detail/FunctionTraits.h>
3739
#include <ATen/native/TensorIterator.h>
40+
#include <ATen/native/cuda/jit_utils.h>
3841
#include <c10/macros/Macros.h>
3942
#include <c10/core/ScalarType.h>
4043
#include <c10/util/TypeCast.h>
@@ -120,6 +123,139 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t
120123
}
121124
}
122125

126+
template<char const *name,
127+
typename result_type,
128+
typename compute_type,
129+
typename array_t,
130+
typename inp_calc_t,
131+
typename out_calc_t,
132+
typename loader_t,
133+
typename storer_t>
134+
static inline void launch_jitted_unrolled_kernel(
135+
DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
136+
inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous) {
137+
138+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
139+
const int64_t grid = (N + block_work_size() - 1) / block_work_size();
140+
141+
static std::mutex _jiterator_mutex;
142+
static std::vector<at::cuda::jit::NvrtcFunction> fns(c10::cuda::device_count());
143+
144+
at::cuda::jit::NvrtcFunction* fn_ptr = &fns[dev_idx];
145+
if (!fn_ptr->function) {
146+
const std::lock_guard<std::mutex> lock{_jiterator_mutex};
147+
if (!fn_ptr->function) {
148+
constexpr int nTensors = array_t::size();
149+
constexpr bool dynamic_casting = !std::is_same<decltype(l),
150+
memory::LoadWithoutCast>() || !std::is_same<decltype(s),
151+
memory::StoreWithoutCast>();
152+
std::string string_name{name};
153+
std::string compute_type_str = at::cuda::jit::typeName<compute_type>();
154+
std::string result_type_str = at::cuda::jit::typeName<result_type>();
155+
auto code = at::cuda::jit::generate_code(nTensors, f, string_name,
156+
compute_type_str, result_type_str,
157+
contiguous, dynamic_casting);
158+
*fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
159+
}
160+
}
161+
162+
// packs args
163+
std::array<void*, 6> args = {
164+
(void*)&N,
165+
(void*)&data,
166+
(void*)&ic,
167+
(void*)&oc,
168+
(void*)&l,
169+
(void*)&s
170+
};
171+
172+
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
173+
C10_CUDA_KERNEL_LAUNCH_CHECK();
174+
}
175+
176+
template<
177+
char const *name,
178+
typename result_type,
179+
typename compute_type,
180+
int arity,
181+
typename array_t>
182+
static inline void launch_jitted_vectorized_kernel(DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data) {
183+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
184+
const int64_t grid = (N + block_work_size() - 1) / block_work_size();
185+
const int vec_size = memory::jitted_can_vectorize_up_to<result_type, compute_type, arity>(data);
186+
187+
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
188+
// fn_ptr is set to the appropriate function based on the vec size and GPU used
189+
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
190+
// the same compute capability
191+
static std::mutex _jiterator_mutex;
192+
static std::vector<at::cuda::jit::NvrtcFunction> fns4(c10::cuda::device_count());
193+
static std::vector<at::cuda::jit::NvrtcFunction> fns2(c10::cuda::device_count());
194+
static std::vector<at::cuda::jit::NvrtcFunction> fns1(c10::cuda::device_count());
195+
196+
197+
at::cuda::jit::NvrtcFunction* fn_ptr;
198+
if (vec_size == 4) {
199+
fn_ptr = &fns4[dev_idx];
200+
} else if (vec_size == 2) {
201+
fn_ptr = &fns2[dev_idx];
202+
} else if (vec_size ==1) {
203+
fn_ptr = &fns1[dev_idx];
204+
} else {
205+
TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
206+
}
207+
208+
bool vectorized = vec_size > 1;
209+
210+
if (!fn_ptr->function) {
211+
const std::lock_guard<std::mutex> lock{_jiterator_mutex};
212+
if (!fn_ptr->function) {
213+
constexpr int nTensors = array_t::size();
214+
std::string string_name{name};
215+
std::string compute_type_str = at::cuda::jit::typeName<compute_type>();
216+
std::string result_type_str = at::cuda::jit::typeName<result_type>();
217+
auto code = at::cuda::jit::generate_code(nTensors, f, string_name,
218+
compute_type_str, result_type_str,
219+
/*contiguous=*/true, /*dynamic_casting=*/false,
220+
vectorized, vec_size);
221+
std::string kernel_name = vectorized ? string_name + "_vectorized" + std::to_string(vec_size) : string_name;
222+
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
223+
}
224+
}
225+
226+
if (vectorized) {
227+
std::array<void*, 6> args = {
228+
(void*)&N,
229+
(void*)&data,
230+
nullptr,
231+
nullptr,
232+
nullptr,
233+
nullptr
234+
};
235+
236+
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
237+
C10_CUDA_KERNEL_LAUNCH_CHECK();
238+
} else {
239+
auto ic = TrivialOffsetCalculator<arity>();
240+
auto oc = TrivialOffsetCalculator<1>();
241+
auto l = memory::LoadWithoutCast();
242+
auto s = memory::StoreWithoutCast();
243+
244+
std::array<void*, 6> args = {
245+
(void*)&N,
246+
(void*)&data,
247+
(void*)&ic,
248+
(void*)&oc,
249+
(void*)&l,
250+
(void*)&s
251+
};
252+
253+
at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
254+
C10_CUDA_KERNEL_LAUNCH_CHECK();
255+
}
256+
257+
}
258+
123259
template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t>
124260
static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t data,
125261
inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s)
@@ -131,6 +267,79 @@ static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t da
131267
C10_CUDA_KERNEL_LAUNCH_CHECK();
132268
}
133269

270+
template <char const *name, typename result_type, typename compute_type, int arity>
271+
void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, const bool dynamic_casting) {
272+
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
273+
TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
274+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
275+
276+
constexpr int ntensors = arity + 1;
277+
at::detail::Array<char*, ntensors> data;
278+
for (auto i = decltype(ntensors){0}; i < ntensors; ++i) {
279+
data[i] = (char*)iter.data_ptr(i);
280+
}
281+
282+
int64_t numel = iter.numel();
283+
bool contiguous = iter.is_contiguous();
284+
285+
// Decides which of 4 kernel types to launch
286+
// Variations are:
287+
// - Case 1: no dynamic casting and contiguous
288+
// - Case 2: no dynamic casting and noncontiguous
289+
// - Case 3: dynamic casting and contiguous
290+
// - Case 4: dynamic casting and noncontiguous
291+
// These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
292+
293+
if (!dynamic_casting) {
294+
if (contiguous) {
295+
// Case 1: no dynamic casting and contiguous
296+
launch_jitted_vectorized_kernel<name, result_type, compute_type, arity>(
297+
iter.device().index(), numel, f, data);
298+
return;
299+
}
300+
301+
// Case 2: no dynamic casting and noncontiguous
302+
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
303+
auto output_offset_calculator = make_output_offset_calculator(iter);
304+
auto loader = memory::LoadWithoutCast();
305+
auto storer = memory::StoreWithoutCast();
306+
launch_jitted_unrolled_kernel<name, result_type, compute_type>(
307+
iter.device().index(), numel, f, data, input_offset_calculator,
308+
output_offset_calculator, loader, storer, contiguous);
309+
return;
310+
}
311+
312+
// Cases 3 and 4 are handled below
313+
// Both require construction of a storer (this asserts 1 output) and one or more loaders
314+
315+
// Creates store cast to output (the zeroth tensor in TensorIterator)
316+
auto storer = memory::StoreWithCast(iter.dtype(0));
317+
318+
// Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
319+
at::detail::Array<ScalarType, arity> dtypes;
320+
for (auto i = decltype(arity){0}; i < arity; ++i) {
321+
dtypes[i] = iter.dtype(i + 1);
322+
}
323+
auto loader = memory::LoadWithCast<arity>(dtypes);
324+
325+
if (contiguous) {
326+
// Case 3: dynamic casting and contiguous
327+
auto input_offset_calculator = TrivialOffsetCalculator<arity>();
328+
auto output_offset_calculator = TrivialOffsetCalculator<1>();
329+
launch_jitted_unrolled_kernel<name, result_type, compute_type>(
330+
iter.device().index(), numel, f, data, input_offset_calculator,
331+
output_offset_calculator, loader, storer, contiguous);
332+
return;
333+
}
334+
335+
// Case 4: dynamic casting and noncontiguous
336+
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
337+
auto output_offset_calculator = make_output_offset_calculator(iter);
338+
launch_jitted_unrolled_kernel<name, result_type, compute_type>(
339+
iter.device().index(), numel, f, data, input_offset_calculator,
340+
output_offset_calculator, loader, storer, contiguous);
341+
}
342+
134343
template <typename func_t>
135344
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
136345
using traits = function_traits<func_t>;

aten/src/ATen/native/cuda/GcdLcmKernel.cu

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,34 @@
55
#include <ATen/native/cuda/Math.cuh>
66
#include <ATen/native/TensorIterator.h>
77
#include <ATen/native/BinaryOps.h>
8+
#include <ATen/native/cuda/jit_utils.h>
89

910
// NOTE: CUDA on Windows requires that the enclosing function
1011
// of a __device__ lambda not have internal linkage.
1112

1213
namespace at { namespace native {
1314

15+
// See note [Jiterator]
16+
const char gcd_name[] = "gcd";
1417
void gcd_kernel_cuda(TensorIteratorBase& iter) {
15-
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "gcd_cuda", [&]() {
16-
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
17-
return calc_gcd(a, b);
18+
#ifdef USE_JITERATOR
19+
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
20+
jitted_gpu_kernel</*name=*/gcd_name,
21+
/*return_dtype=*/ scalar_t,
22+
/*common_dtype=*/ scalar_t,
23+
/*arity=*/ 2>(iter, gcd_string);
1824
});
19-
});
25+
#else
26+
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
27+
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
28+
return calc_gcd(a, b);
29+
});
30+
});
31+
#endif // USE_JITERATOR
2032
}
2133

2234
void lcm_kernel_cuda(TensorIteratorBase& iter) {
23-
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lcm_cuda", [&]() {
35+
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cuda", [&]() {
2436
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
2537
scalar_t g = calc_gcd(a, b);
2638
return (g == 0) ? 0 : ::abs(a / g * b);

0 commit comments

Comments
 (0)