30
30
31
31
#include < type_traits>
32
32
#include < tuple>
33
+ #include < iostream>
34
+ #include < mutex>
33
35
34
36
#include < ATen/cuda/CUDAContext.h>
35
37
#include < ATen/core/Array.h>
36
38
#include < ATen/detail/FunctionTraits.h>
37
39
#include < ATen/native/TensorIterator.h>
40
+ #include < ATen/native/cuda/jit_utils.h>
38
41
#include < c10/macros/Macros.h>
39
42
#include < c10/core/ScalarType.h>
40
43
#include < c10/util/TypeCast.h>
@@ -120,6 +123,139 @@ static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t
120
123
}
121
124
}
122
125
126
+ template <char const *name,
127
+ typename result_type,
128
+ typename compute_type,
129
+ typename array_t ,
130
+ typename inp_calc_t ,
131
+ typename out_calc_t ,
132
+ typename loader_t ,
133
+ typename storer_t >
134
+ static inline void launch_jitted_unrolled_kernel (
135
+ DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
136
+ inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous) {
137
+
138
+ TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
139
+ const int64_t grid = (N + block_work_size () - 1 ) / block_work_size ();
140
+
141
+ static std::mutex _jiterator_mutex;
142
+ static std::vector<at::cuda::jit::NvrtcFunction> fns (c10::cuda::device_count ());
143
+
144
+ at::cuda::jit::NvrtcFunction* fn_ptr = &fns[dev_idx];
145
+ if (!fn_ptr->function ) {
146
+ const std::lock_guard<std::mutex> lock{_jiterator_mutex};
147
+ if (!fn_ptr->function ) {
148
+ constexpr int nTensors = array_t::size ();
149
+ constexpr bool dynamic_casting = !std::is_same<decltype (l),
150
+ memory::LoadWithoutCast>() || !std::is_same<decltype (s),
151
+ memory::StoreWithoutCast>();
152
+ std::string string_name{name};
153
+ std::string compute_type_str = at::cuda::jit::typeName<compute_type>();
154
+ std::string result_type_str = at::cuda::jit::typeName<result_type>();
155
+ auto code = at::cuda::jit::generate_code (nTensors, f, string_name,
156
+ compute_type_str, result_type_str,
157
+ contiguous, dynamic_casting);
158
+ *fn_ptr = at::cuda::jit::jit_pwise_function (code, name);
159
+ }
160
+ }
161
+
162
+ // packs args
163
+ std::array<void *, 6 > args = {
164
+ (void *)&N,
165
+ (void *)&data,
166
+ (void *)&ic,
167
+ (void *)&oc,
168
+ (void *)&l,
169
+ (void *)&s
170
+ };
171
+
172
+ at::cuda::jit::launch_jitted_pwise_function (*fn_ptr, args, grid, num_threads ());
173
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
174
+ }
175
+
176
+ template <
177
+ char const *name,
178
+ typename result_type,
179
+ typename compute_type,
180
+ int arity,
181
+ typename array_t >
182
+ static inline void launch_jitted_vectorized_kernel (DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data) {
183
+ TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
184
+ const int64_t grid = (N + block_work_size () - 1 ) / block_work_size ();
185
+ const int vec_size = memory::jitted_can_vectorize_up_to<result_type, compute_type, arity>(data);
186
+
187
+ // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
188
+ // fn_ptr is set to the appropriate function based on the vec size and GPU used
189
+ // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
190
+ // the same compute capability
191
+ static std::mutex _jiterator_mutex;
192
+ static std::vector<at::cuda::jit::NvrtcFunction> fns4 (c10::cuda::device_count ());
193
+ static std::vector<at::cuda::jit::NvrtcFunction> fns2 (c10::cuda::device_count ());
194
+ static std::vector<at::cuda::jit::NvrtcFunction> fns1 (c10::cuda::device_count ());
195
+
196
+
197
+ at::cuda::jit::NvrtcFunction* fn_ptr;
198
+ if (vec_size == 4 ) {
199
+ fn_ptr = &fns4[dev_idx];
200
+ } else if (vec_size == 2 ) {
201
+ fn_ptr = &fns2[dev_idx];
202
+ } else if (vec_size ==1 ) {
203
+ fn_ptr = &fns1[dev_idx];
204
+ } else {
205
+ TORCH_INTERNAL_ASSERT (false , " unexpected vec_size for jitter vectorized kernel" );
206
+ }
207
+
208
+ bool vectorized = vec_size > 1 ;
209
+
210
+ if (!fn_ptr->function ) {
211
+ const std::lock_guard<std::mutex> lock{_jiterator_mutex};
212
+ if (!fn_ptr->function ) {
213
+ constexpr int nTensors = array_t::size ();
214
+ std::string string_name{name};
215
+ std::string compute_type_str = at::cuda::jit::typeName<compute_type>();
216
+ std::string result_type_str = at::cuda::jit::typeName<result_type>();
217
+ auto code = at::cuda::jit::generate_code (nTensors, f, string_name,
218
+ compute_type_str, result_type_str,
219
+ /* contiguous=*/ true , /* dynamic_casting=*/ false ,
220
+ vectorized, vec_size);
221
+ std::string kernel_name = vectorized ? string_name + " _vectorized" + std::to_string (vec_size) : string_name;
222
+ *fn_ptr = at::cuda::jit::jit_pwise_function (code, kernel_name);
223
+ }
224
+ }
225
+
226
+ if (vectorized) {
227
+ std::array<void *, 6 > args = {
228
+ (void *)&N,
229
+ (void *)&data,
230
+ nullptr ,
231
+ nullptr ,
232
+ nullptr ,
233
+ nullptr
234
+ };
235
+
236
+ at::cuda::jit::launch_jitted_pwise_function (*fn_ptr, args, grid, num_threads ());
237
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
238
+ } else {
239
+ auto ic = TrivialOffsetCalculator<arity>();
240
+ auto oc = TrivialOffsetCalculator<1 >();
241
+ auto l = memory::LoadWithoutCast ();
242
+ auto s = memory::StoreWithoutCast ();
243
+
244
+ std::array<void *, 6 > args = {
245
+ (void *)&N,
246
+ (void *)&data,
247
+ (void *)&ic,
248
+ (void *)&oc,
249
+ (void *)&l,
250
+ (void *)&s
251
+ };
252
+
253
+ at::cuda::jit::launch_jitted_pwise_function (*fn_ptr, args, grid, num_threads ());
254
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
255
+ }
256
+
257
+ }
258
+
123
259
template <typename func_t , typename array_t , typename inp_calc_t , typename out_calc_t , typename loader_t , typename storer_t >
124
260
static inline void launch_unrolled_kernel (int64_t N, const func_t & f, array_t data,
125
261
inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s)
@@ -131,6 +267,79 @@ static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t da
131
267
C10_CUDA_KERNEL_LAUNCH_CHECK ();
132
268
}
133
269
270
+ template <char const *name, typename result_type, typename compute_type, int arity>
271
+ void jitted_gpu_kernel_impl (TensorIteratorBase& iter, const std::string& f, const bool dynamic_casting) {
272
+ TORCH_INTERNAL_ASSERT (iter.can_use_32bit_indexing ());
273
+ TORCH_INTERNAL_ASSERT (iter.ninputs () == arity);
274
+ TORCH_INTERNAL_ASSERT (iter.noutputs () == 1 );
275
+
276
+ constexpr int ntensors = arity + 1 ;
277
+ at::detail::Array<char *, ntensors> data;
278
+ for (auto i = decltype (ntensors){0 }; i < ntensors; ++i) {
279
+ data[i] = (char *)iter.data_ptr (i);
280
+ }
281
+
282
+ int64_t numel = iter.numel ();
283
+ bool contiguous = iter.is_contiguous ();
284
+
285
+ // Decides which of 4 kernel types to launch
286
+ // Variations are:
287
+ // - Case 1: no dynamic casting and contiguous
288
+ // - Case 2: no dynamic casting and noncontiguous
289
+ // - Case 3: dynamic casting and contiguous
290
+ // - Case 4: dynamic casting and noncontiguous
291
+ // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
292
+
293
+ if (!dynamic_casting) {
294
+ if (contiguous) {
295
+ // Case 1: no dynamic casting and contiguous
296
+ launch_jitted_vectorized_kernel<name, result_type, compute_type, arity>(
297
+ iter.device ().index (), numel, f, data);
298
+ return ;
299
+ }
300
+
301
+ // Case 2: no dynamic casting and noncontiguous
302
+ auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
303
+ auto output_offset_calculator = make_output_offset_calculator (iter);
304
+ auto loader = memory::LoadWithoutCast ();
305
+ auto storer = memory::StoreWithoutCast ();
306
+ launch_jitted_unrolled_kernel<name, result_type, compute_type>(
307
+ iter.device ().index (), numel, f, data, input_offset_calculator,
308
+ output_offset_calculator, loader, storer, contiguous);
309
+ return ;
310
+ }
311
+
312
+ // Cases 3 and 4 are handled below
313
+ // Both require construction of a storer (this asserts 1 output) and one or more loaders
314
+
315
+ // Creates store cast to output (the zeroth tensor in TensorIterator)
316
+ auto storer = memory::StoreWithCast (iter.dtype (0 ));
317
+
318
+ // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
319
+ at::detail::Array<ScalarType, arity> dtypes;
320
+ for (auto i = decltype (arity){0 }; i < arity; ++i) {
321
+ dtypes[i] = iter.dtype (i + 1 );
322
+ }
323
+ auto loader = memory::LoadWithCast<arity>(dtypes);
324
+
325
+ if (contiguous) {
326
+ // Case 3: dynamic casting and contiguous
327
+ auto input_offset_calculator = TrivialOffsetCalculator<arity>();
328
+ auto output_offset_calculator = TrivialOffsetCalculator<1 >();
329
+ launch_jitted_unrolled_kernel<name, result_type, compute_type>(
330
+ iter.device ().index (), numel, f, data, input_offset_calculator,
331
+ output_offset_calculator, loader, storer, contiguous);
332
+ return ;
333
+ }
334
+
335
+ // Case 4: dynamic casting and noncontiguous
336
+ auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
337
+ auto output_offset_calculator = make_output_offset_calculator (iter);
338
+ launch_jitted_unrolled_kernel<name, result_type, compute_type>(
339
+ iter.device ().index (), numel, f, data, input_offset_calculator,
340
+ output_offset_calculator, loader, storer, contiguous);
341
+ }
342
+
134
343
template <typename func_t >
135
344
void gpu_kernel_impl (TensorIteratorBase& iter, const func_t & f) {
136
345
using traits = function_traits<func_t >;
0 commit comments