diff --git a/src/ATen/native/xpu/sycl/Dropout.cpp b/src/ATen/native/xpu/sycl/Dropout.cpp index acb3faee01..e9b39f9d6f 100644 --- a/src/ATen/native/xpu/sycl/Dropout.cpp +++ b/src/ATen/native/xpu/sycl/Dropout.cpp @@ -165,7 +165,7 @@ struct FusedDropoutUnrollFunctor { if (li < total_elements_) { // Convert `linearIndex` into an offset of `a` const IndexType aOffset = - IndexToOffset::get(li, a_); + IndexToOffset::get(li, a_); src[ii] = a_.data[aOffset]; } } @@ -174,7 +174,7 @@ struct FusedDropoutUnrollFunctor { if (li < total_elements_) { // Convert `linearIndex` into an offset of `b` const IndexType bOffset = - IndexToOffset::get(li, b_); + IndexToOffset::get(li, b_); b_.data[bOffset] = src[ii] * (&rand.x)[ii] * scale; c_.data[bOffset] = (mask_t)(&rand.x)[ii]; } diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index 04d3e292fa..890669b9c7 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -1218,7 +1218,14 @@ void put_kernel( }); } -template +template < + typename T, + typename IndicesType, + typename IndexType, + int DstDim, + int SrcDim, + int IdxDim, + typename func_t> struct IndexFuncSmallIndexFunctor { void operator()(sycl::nd_item<1> item) const { // In order to avoid reloading the index that we are copying, load @@ -1229,8 +1236,9 @@ struct IndexFuncSmallIndexFunctor { for (IndexType srcIndex = 0; srcIndex < indices_.sizes[0]; ++srcIndex) { // Lua indices begin at 1 IndexType dstIndex = - indices_.data[IndexToOffset::get( - srcIndex, indices_)]; + indices_ + .data[IndexToOffset::get( + srcIndex, indices_)]; SYCL_KERNEL_ASSERT(dstIndex < dstAddDimSize_); // We stride over the output ignoring the indexed dimension @@ -1240,11 +1248,11 @@ struct IndexFuncSmallIndexFunctor { linearIndex < innerSize_; linearIndex += item.get_group_range(0) * item.get_local_range(0)) { IndexType dstOffset = - IndexToOffset::get(linearIndex, dst_); + IndexToOffset::get(linearIndex, dst_); dstOffset += dstIndex * dst_.strides[dstAddDim_]; IndexType srcOffset = - IndexToOffset::get(linearIndex, src_); + IndexToOffset::get(linearIndex, src_); srcOffset += srcIndex * src_.strides[srcAddDim_]; T val = src_.data[srcOffset] * alpha_; @@ -1292,6 +1300,9 @@ template < typename T, typename IndicesType, typename IndexType, + int DstDim, + int SrcDim, + int IdxDim, bool IndexIsMajor, typename func_t> struct IndexFuncLargeIndexFunctor { @@ -1314,16 +1325,17 @@ struct IndexFuncLargeIndexFunctor { // Lua indices begin at 1 IndexType dstIndex = - indices_.data[IndexToOffset::get( - srcIndex, indices_)]; + indices_ + .data[IndexToOffset::get( + srcIndex, indices_)]; SYCL_KERNEL_ASSERT(dstIndex < dstAddDimSize_); IndexType dstOffset = - IndexToOffset::get(elementInSlice, dst_); + IndexToOffset::get(elementInSlice, dst_); dstOffset += dstIndex * dst_.strides[dstAddDim_]; IndexType srcOffset = - IndexToOffset::get(elementInSlice, src_); + IndexToOffset::get(elementInSlice, src_); srcOffset += srcIndex * src_.strides[srcAddDim_]; T val = src_.data[srcOffset] * alpha_; @@ -1444,36 +1456,55 @@ void index_reduce_func_xpu_template( } bool indContig = index.is_contiguous(); -#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, FUNC_T) \ - IndexFuncSmallIndexFunctor( \ - selfInfo, \ - sourceInfo, \ - indexInfo, \ - selfReduceDim, \ - sourceReduceDim, \ - sliceSize, \ - selfReduceDimSize, \ - selfNumel, \ - reduce_func, \ +#define SMALL_INDEX( \ + TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM, FUNC_T) \ + IndexFuncSmallIndexFunctor< \ + TENSOR_TYPE, \ + INDICES_TYPE, \ + TYPE, \ + SELF_DIM, \ + SOURCE_DIM, \ + IDX_DIM, \ + FUNC_T>( \ + selfInfo, \ + sourceInfo, \ + indexInfo, \ + selfReduceDim, \ + sourceReduceDim, \ + sliceSize, \ + selfReduceDimSize, \ + selfNumel, \ + reduce_func, \ alpha_value); -#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, IDX_IS_MAJOR, FUNC_T) \ - IndexFuncLargeIndexFunctor< \ - TENSOR_TYPE, \ - INDICES_TYPE, \ - TYPE, \ - IDX_IS_MAJOR, \ - FUNC_T>( \ - selfInfo, \ - sourceInfo, \ - indexInfo, \ - selfReduceDim, \ - sourceReduceDim, \ - sourceTotalSize, \ - (IDX_IS_MAJOR) ? sliceSize : numIndex, \ - selfReduceDimSize, \ - selfNumel, \ - reduce_func, \ +#define LARGE_INDEX( \ + TENSOR_TYPE, \ + INDICES_TYPE, \ + TYPE, \ + SELF_DIM, \ + SOURCE_DIM, \ + IDX_DIM, \ + IDX_IS_MAJOR, \ + FUNC_T) \ + IndexFuncLargeIndexFunctor< \ + TENSOR_TYPE, \ + INDICES_TYPE, \ + TYPE, \ + SELF_DIM, \ + SOURCE_DIM, \ + IDX_DIM, \ + IDX_IS_MAJOR, \ + FUNC_T>( \ + selfInfo, \ + sourceInfo, \ + indexInfo, \ + selfReduceDim, \ + sourceReduceDim, \ + sourceTotalSize, \ + (IDX_IS_MAJOR) ? sliceSize : numIndex, \ + selfReduceDimSize, \ + selfNumel, \ + reduce_func, \ alpha_value); int ssc = syclMaxDSSNum(); @@ -1505,21 +1536,144 @@ void index_reduce_func_xpu_template( // A reasonable choice for when to have each thread iterate // over index to choose if (numIndex <= 16) { - auto caller = - SMALL_INDEX(scalar_t, index_t, unsigned int, func_t); - size_t num_wg = std::min( - ceil_div(sliceSize, (uint64_t)128), (uint64_t)(ssc * 8)); - size_t wg_size = std::min(sliceSize, (uint64_t)128); - sycl_kernel_submit( - num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + auto caller = SMALL_INDEX( + scalar_t, index_t, unsigned int, 1, 1, -2, func_t); + size_t num_wg = std::min( + ceil_div(sliceSize, (uint64_t)128), + (uint64_t)(ssc * 8)); + size_t wg_size = std::min(sliceSize, (uint64_t)128); + sycl_kernel_submit( + num_wg * wg_size, + wg_size, + getCurrentSYCLQueue(), + caller); + } else if ( + selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + auto caller = SMALL_INDEX( + scalar_t, index_t, unsigned int, 2, 2, -2, func_t); + size_t num_wg = std::min( + ceil_div(sliceSize, (uint64_t)128), + (uint64_t)(ssc * 8)); + size_t wg_size = std::min(sliceSize, (uint64_t)128); + sycl_kernel_submit( + num_wg * wg_size, + wg_size, + getCurrentSYCLQueue(), + caller); + } else if ( + selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { + auto caller = SMALL_INDEX( + scalar_t, index_t, unsigned int, 3, 3, -2, func_t); + size_t num_wg = std::min( + ceil_div(sliceSize, (uint64_t)128), + (uint64_t)(ssc * 8)); + size_t wg_size = std::min(sliceSize, (uint64_t)128); + sycl_kernel_submit( + num_wg * wg_size, + wg_size, + getCurrentSYCLQueue(), + caller); + } else { + auto caller = SMALL_INDEX( + scalar_t, index_t, unsigned int, -1, -1, -1, func_t); + size_t num_wg = std::min( + ceil_div(sliceSize, (uint64_t)128), + (uint64_t)(ssc * 8)); + size_t wg_size = std::min(sliceSize, (uint64_t)128); + sycl_kernel_submit( + num_wg * wg_size, + wg_size, + getCurrentSYCLQueue(), + caller); + } } else { bool indexIsMajor = indexShouldBeMajor(selfInfo, selfReduceDim); - if (indContig) { + if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) { + auto caller = LARGE_INDEX( + scalar_t, + index_t, + unsigned int, + 1, + 1, + -2, + true, + func_t); + int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); + size_t num_wg = std::min( + ceil_div(sourceTotalSize, (uint64_t)128), + (uint64_t)(ssc * 8)); + size_t wg_size = (sourceTotalSize < defaultMaxGroupThreads) + ? sourceTotalSize + : defaultMaxGroupThreads; + sycl_kernel_submit( + num_wg * wg_size, + wg_size, + getCurrentSYCLQueue(), + caller); + } else if ( + selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) { + if (indexIsMajor) { + auto caller = LARGE_INDEX( + scalar_t, + index_t, + unsigned int, + 2, + 2, + -2, + true, + func_t); + int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); + size_t num_wg = std::min( + ceil_div(sourceTotalSize, (uint64_t)128), + (uint64_t)(ssc * 8)); + size_t wg_size = + (sourceTotalSize < defaultMaxGroupThreads) + ? sourceTotalSize + : defaultMaxGroupThreads; + sycl_kernel_submit( + num_wg * wg_size, + wg_size, + getCurrentSYCLQueue(), + caller); + } else { + auto caller = LARGE_INDEX( + scalar_t, + index_t, + unsigned int, + 2, + 2, + -2, + false, + func_t); + int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); + size_t num_wg = std::min( + ceil_div(sourceTotalSize, (uint64_t)128), + (uint64_t)(ssc * 8)); + size_t wg_size = + (sourceTotalSize < defaultMaxGroupThreads) + ? sourceTotalSize + : defaultMaxGroupThreads; + sycl_kernel_submit( + num_wg * wg_size, + wg_size, + getCurrentSYCLQueue(), + caller); + } + } else if ( + selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) { if (indexIsMajor) { auto caller = LARGE_INDEX( - scalar_t, index_t, unsigned int, true, func_t); + scalar_t, + index_t, + unsigned int, + 3, + 3, + -2, + true, + func_t); int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); size_t num_wg = std::min( ceil_div(sourceTotalSize, (uint64_t)128), @@ -1535,7 +1689,14 @@ void index_reduce_func_xpu_template( caller); } else { auto caller = LARGE_INDEX( - scalar_t, index_t, unsigned int, false, func_t); + scalar_t, + index_t, + unsigned int, + 3, + 3, + -2, + false, + func_t); int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); size_t num_wg = std::min( ceil_div(sourceTotalSize, (uint64_t)128), @@ -1552,7 +1713,14 @@ void index_reduce_func_xpu_template( } } else { auto caller = LARGE_INDEX( - scalar_t, index_t, unsigned int, true, func_t); + scalar_t, + index_t, + unsigned int, + -1, + -1, + -1, + true, + func_t); int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); size_t num_wg = std::min( ceil_div(sourceTotalSize, (uint64_t)128), @@ -1592,8 +1760,8 @@ void index_reduce_func_xpu_template( TensorInfo indexInfo = getTensorInfo(index); indexInfo.collapseDims(); - auto caller = - LARGE_INDEX(scalar_t, index_t, uint64_t, true, func_t); + auto caller = LARGE_INDEX( + scalar_t, index_t, uint64_t, -1, -1, -1, true, func_t); int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); size_t num_wg = std::min( ceil_div(sourceTotalSize, (uint64_t)128), diff --git a/src/ATen/native/xpu/sycl/Indexing.h b/src/ATen/native/xpu/sycl/Indexing.h index e696c625cb..f4d30e8de4 100644 --- a/src/ATen/native/xpu/sycl/Indexing.h +++ b/src/ATen/native/xpu/sycl/Indexing.h @@ -211,10 +211,8 @@ class IndexKernel { if constexpr (TrivialOffCal) { idx_off = idx_logical_off; } else { - idx_off = IndexToOffset::get( - idx_logical_off, - cfg_.iinfo_, - IndexToOffset::NON_STRICT_CONTIGUOUS); + idx_off = IndexToOffset::get( + idx_logical_off, cfg_.iinfo_); } glb_batch_group = id.glb_batch / cfg_.index_num_; glb_batch_group_loc_off = cfg_.iinfo_.data[idx_off]; @@ -322,26 +320,18 @@ class IndexKernel { } else { if (cfg_.indexing_dst_) { // index_copy, index_add, index_fill - dst_off = IndexToOffset::get( - glb_indexing_logical_off, - cfg_.dinfo_, - IndexToOffset::NON_STRICT_CONTIGUOUS); + dst_off = IndexToOffset::get( + glb_indexing_logical_off, cfg_.dinfo_); if (cfg_.sinfo_.data != nullptr) { - src_off = IndexToOffset::get( - glb_fixing_logical_off, - cfg_.sinfo_, - IndexToOffset::NON_STRICT_CONTIGUOUS); + src_off = IndexToOffset::get( + glb_fixing_logical_off, cfg_.sinfo_); } } else { // index_select - src_off = IndexToOffset::get( - glb_indexing_logical_off, - cfg_.sinfo_, - IndexToOffset::NON_STRICT_CONTIGUOUS); - dst_off = IndexToOffset::get( - glb_fixing_logical_off, - cfg_.dinfo_, - IndexToOffset::NON_STRICT_CONTIGUOUS); + src_off = IndexToOffset::get( + glb_indexing_logical_off, cfg_.sinfo_); + dst_off = IndexToOffset::get( + glb_fixing_logical_off, cfg_.dinfo_); } } cfg_.func_( diff --git a/src/ATen/native/xpu/sycl/RNNKernels.cpp b/src/ATen/native/xpu/sycl/RNNKernels.cpp index bad6bdf69d..fb2d6c471b 100644 --- a/src/ATen/native/xpu/sycl/RNNKernels.cpp +++ b/src/ATen/native/xpu/sycl/RNNKernels.cpp @@ -77,12 +77,13 @@ void collapseDims(TensorInfo& info, Args&... infos) { collapseDims(infos...); } -#define DEVICE_LINEAR_GET(D_TENSOR, INDEX) \ - D_TENSOR.data[IndexToOffset::get(INDEX, D_TENSOR)] +#define DEVICE_LINEAR_GET(D_TENSOR, INDEX) \ + D_TENSOR.data[IndexToOffset::get( \ + INDEX, D_TENSOR)] // Biases are always 1D #define DEVICE_BIAS_GET(D_TENSOR, INDEX) \ - D_TENSOR.data[IndexToOffset::get(INDEX, D_TENSOR)] + D_TENSOR.data[IndexToOffset::get(INDEX, D_TENSOR)] #define H2F(input) static_cast(input) #define F2H(input) static_cast(input) @@ -93,7 +94,11 @@ inline T sigmoid(T in) { return one / (one + std::exp(-in)); } -template +template < + typename scalar_t, + typename accscalar_t, + typename index_type, + int indexing_kind> struct LstmCellForwardFunctor { void operator()(sycl::nd_item<1> item) const { bool has_bias = bias1_.data != nullptr; @@ -205,7 +210,11 @@ struct LstmCellForwardFunctor { index_type totalElements_; }; -template +template < + typename scalar_t, + typename accscalar_t, + typename index_type, + int indexing_kind> struct LstmCellBackwardFunctor { void operator()(sycl::nd_item<1> item) const { bool has_gradoutput = gradoutput_.data != nullptr; @@ -296,7 +305,11 @@ struct LstmCellBackwardFunctor { index_type totalElements_; }; -template +template < + typename scalar_t, + typename accscalar_t, + typename index_type, + int indexing_kind> struct GruCellForwardFunctor { void operator()(sycl::nd_item<1> item) const { bool has_bias = Bias1_.data != nullptr; @@ -387,7 +400,11 @@ struct GruCellForwardFunctor { const index_type totalElements_; }; -template +template < + typename scalar_t, + typename accscalar_t, + typename index_type, + int indexing_kind> struct GruCellBackwardFunctor { void operator()(sycl::nd_item<1> item) const { for (index_type linearIndex = item.get_global_id(0); @@ -469,12 +486,6 @@ void lstm_forward_impl( if (numel == 0) return; - using KernelT = LstmCellForwardFunctor; - auto max_wg_size = syclMaxWorkGroupSize(); - auto config = rnn_get_launch_config(max_wg_size, numel); - auto nwg = std::get<0>(config); - auto local_range = std::get<1>(config); - auto input_gatesI = getTensorInfo(input_gates); auto hidden_gatesI = getTensorInfo(hidden_gates); auto input_biasI = tryGetTensorInfo(input_bias); @@ -503,6 +514,12 @@ void lstm_forward_impl( hyI, cyI, workspaceI); + using KernelT = + LstmCellForwardFunctor; + auto max_wg_size = syclMaxWorkGroupSize(); + auto config = rnn_get_launch_config(max_wg_size, numel); + auto nwg = std::get<0>(config); + auto local_range = std::get<1>(config); KernelT kfn( input_gatesI, hidden_gatesI, @@ -517,6 +534,12 @@ void lstm_forward_impl( sycl_kernel_submit( nwg * local_range, local_range, getCurrentSYCLQueue(), kfn); } else { + using KernelT = + LstmCellForwardFunctor; + auto max_wg_size = syclMaxWorkGroupSize(); + auto config = rnn_get_launch_config(max_wg_size, numel); + auto nwg = std::get<0>(config); + auto local_range = std::get<1>(config); KernelT kfn( input_gatesI, hidden_gatesI, @@ -548,12 +571,6 @@ void lstm_backward_impl( if (numel == 0) return; - using KernelT = LstmCellBackwardFunctor; - auto max_wg_size = syclMaxWorkGroupSize(); - auto config = rnn_get_launch_config(max_wg_size, numel); - auto nwg = std::get<0>(config); - auto local_range = std::get<1>(config); - auto grad_hyI = tryGetTensorInfo(grad_hy); auto grad_cyI = tryGetTensorInfo(grad_cy); auto cxI = getTensorInfo(cx); @@ -567,6 +584,12 @@ void lstm_backward_impl( {grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx})) { collapseDims( grad_hyI, grad_cyI, cxI, cyI, workspaceI, grad_gatesI, grad_cxI); + using KernelT = + LstmCellBackwardFunctor; + auto max_wg_size = syclMaxWorkGroupSize(); + auto config = rnn_get_launch_config(max_wg_size, numel); + auto nwg = std::get<0>(config); + auto local_range = std::get<1>(config); KernelT kfn( workspaceI, grad_gatesI, @@ -580,6 +603,12 @@ void lstm_backward_impl( sycl_kernel_submit( nwg * local_range, local_range, getCurrentSYCLQueue(), kfn); } else { + using KernelT = + LstmCellBackwardFunctor; + auto max_wg_size = syclMaxWorkGroupSize(); + auto config = rnn_get_launch_config(max_wg_size, numel); + auto nwg = std::get<0>(config); + auto local_range = std::get<1>(config); KernelT kfn( workspaceI, grad_gatesI, @@ -610,12 +639,6 @@ void gru_forward_impl( if (numel == 0) return; - using KernelT = GruCellForwardFunctor; - auto max_wg_size = syclMaxWorkGroupSize(); - auto config = rnn_get_launch_config(max_wg_size, numel); - auto nwg = std::get<0>(config); - auto local_range = std::get<1>(config); - auto input_gatesI = getTensorInfo(input_gates); auto hidden_gatesI = getTensorInfo(hidden_gates); auto input_biasI = tryGetTensorInfo(input_bias); @@ -641,6 +664,11 @@ void gru_forward_impl( hxI, hyI, workspaceI); + using KernelT = GruCellForwardFunctor; + auto max_wg_size = syclMaxWorkGroupSize(); + auto config = rnn_get_launch_config(max_wg_size, numel); + auto nwg = std::get<0>(config); + auto local_range = std::get<1>(config); KernelT kfn( input_gatesI, hidden_gatesI, @@ -654,6 +682,11 @@ void gru_forward_impl( sycl_kernel_submit( nwg * local_range, local_range, getCurrentSYCLQueue(), kfn); } else { + using KernelT = GruCellForwardFunctor; + auto max_wg_size = syclMaxWorkGroupSize(); + auto config = rnn_get_launch_config(max_wg_size, numel); + auto nwg = std::get<0>(config); + auto local_range = std::get<1>(config); KernelT kfn( input_gatesI, hidden_gatesI, @@ -682,12 +715,6 @@ void gru_backward_impl( if (numel == 0) return; - using KernelT = GruCellBackwardFunctor; - auto max_wg_size = syclMaxWorkGroupSize(); - auto config = rnn_get_launch_config(max_wg_size, numel); - auto nwg = std::get<0>(config); - auto local_range = std::get<1>(config); - auto grad_hyI = getTensorInfo(grad_hy); auto workspaceI = getTensorInfo(workspace); auto grad_input_gatesI = @@ -701,6 +728,12 @@ void gru_backward_impl( {grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx})) { collapseDims( grad_hyI, workspaceI, grad_input_gatesI, grad_hidden_gatesI, grad_hxI); + using KernelT = + GruCellBackwardFunctor; + auto max_wg_size = syclMaxWorkGroupSize(); + auto config = rnn_get_launch_config(max_wg_size, numel); + auto nwg = std::get<0>(config); + auto local_range = std::get<1>(config); KernelT kfn( grad_input_gatesI, grad_hidden_gatesI, @@ -712,6 +745,12 @@ void gru_backward_impl( sycl_kernel_submit( nwg * local_range, local_range, getCurrentSYCLQueue(), kfn); } else { + using KernelT = + GruCellBackwardFunctor; + auto max_wg_size = syclMaxWorkGroupSize(); + auto config = rnn_get_launch_config(max_wg_size, numel); + auto nwg = std::get<0>(config); + auto local_range = std::get<1>(config); KernelT kfn( grad_input_gatesI, grad_hidden_gatesI, diff --git a/src/ATen/native/xpu/sycl/ScanUtils.h b/src/ATen/native/xpu/sycl/ScanUtils.h index df8926d9e9..aae7da1181 100644 --- a/src/ATen/native/xpu/sycl/ScanUtils.h +++ b/src/ATen/native/xpu/sycl/ScanUtils.h @@ -71,32 +71,24 @@ T inline group_x_scan_by_uds_for_loop_scan( glb_str_off_1 = glb1; } else { glb_ldr_logical_off_0 = glb0; - glb_ldr_off_0 = IndexToOffset::get( - glb_ldr_logical_off_0, - cfg.input_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_ldr_off_0 = + IndexToOffset::get( + glb_ldr_logical_off_0, cfg.input_); glb_ldr_logical_off_1 = glb1; - glb_ldr_off_1 = IndexToOffset::get( - glb_ldr_logical_off_1, - cfg.input_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_ldr_off_1 = + IndexToOffset::get( + glb_ldr_logical_off_1, cfg.input_); glb_str_logical_off_0 = glb0; - glb_str_off_0 = IndexToOffset::get( - glb_str_logical_off_0, - cfg.output_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_str_off_0 = + IndexToOffset::get( + glb_str_logical_off_0, cfg.output_); glb_str_logical_off_1 = glb1; - glb_str_off_1 = IndexToOffset::get( - glb_str_logical_off_1, - cfg.output_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_str_off_1 = + IndexToOffset::get( + glb_str_logical_off_1, cfg.output_); } // TODO: opti for bank conflict elemination // Read data from global memory to shared local memory @@ -204,44 +196,32 @@ void inline group_x_scan_by_uds_for_loop_scan_with_indices( glb_idx_off_1 = glb1; } else { glb_ldr_logical_off_0 = glb0; - glb_ldr_off_0 = IndexToOffset::get( - glb_ldr_logical_off_0, - cfg.input_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_ldr_off_0 = + IndexToOffset::get( + glb_ldr_logical_off_0, cfg.input_); glb_ldr_logical_off_1 = glb1; - glb_ldr_off_1 = IndexToOffset::get( - glb_ldr_logical_off_1, - cfg.input_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_ldr_off_1 = + IndexToOffset::get( + glb_ldr_logical_off_1, cfg.input_); glb_str_logical_off_0 = glb0; - glb_str_off_0 = IndexToOffset::get( - glb_str_logical_off_0, - cfg.output_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_str_off_0 = + IndexToOffset::get( + glb_str_logical_off_0, cfg.output_); glb_str_logical_off_1 = glb1; - glb_str_off_1 = IndexToOffset::get( - glb_str_logical_off_1, - cfg.output_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - glb_idx_off_0 = IndexToOffset::get( - glb0, - cfg.indices_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - glb_idx_off_1 = IndexToOffset::get( - glb1, - cfg.indices_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_str_off_1 = + IndexToOffset::get( + glb_str_logical_off_1, cfg.output_); + + glb_idx_off_0 = + IndexToOffset::get( + glb0, cfg.indices_); + + glb_idx_off_1 = + IndexToOffset::get( + glb1, cfg.indices_); } // TODO: opti for bank conflict elemination // Read data from global memory to shared local memory @@ -828,22 +808,15 @@ class SegmentScanKernel : public __SYCL_KER_CONFIG_CONVENTION__ { glb_str_off = glb_str_logical_off; glb_str_off_0 = glb_ldr_logical_off; } else { - glb_ldr_off = IndexToOffset::get( - glb_ldr_logical_off, - cfg_.iinfo_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); - glb_str_off = IndexToOffset::get( - glb_str_logical_off, - cfg_.oinfo_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_ldr_off = + IndexToOffset::get( + glb_ldr_logical_off, cfg_.iinfo_); + glb_str_off = + IndexToOffset::get( + glb_str_logical_off, cfg_.oinfo_); glb_str_off_0 = - IndexToOffset::get( - glb_ldr_logical_off, - cfg_.oinfo_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + IndexToOffset::get( + glb_ldr_logical_off, cfg_.oinfo_); } T value = cfg_.init_; if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { @@ -941,33 +914,21 @@ class SegmentScanWithIndicesKernel : public __SYCL_KER_CONFIG_CONVENTION__ { glb_idx_off = glb_idx_logical_off; glb_idx_off_0 = glb_ldr_logical_off; } else { - glb_ldr_off = IndexToOffset::get( - glb_ldr_logical_off, - cfg_.iinfo_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); - glb_str_off = IndexToOffset::get( - glb_str_logical_off, - cfg_.oinfo_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + glb_ldr_off = + IndexToOffset::get( + glb_ldr_logical_off, cfg_.iinfo_); + glb_str_off = + IndexToOffset::get( + glb_str_logical_off, cfg_.oinfo_); glb_str_off_0 = - IndexToOffset::get( - glb_ldr_logical_off, - cfg_.oinfo_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); - glb_idx_off = IndexToOffset::get( - glb_idx_logical_off, - cfg_.idxinfo_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + IndexToOffset::get( + glb_ldr_logical_off, cfg_.oinfo_); + glb_idx_off = + IndexToOffset::get( + glb_idx_logical_off, cfg_.idxinfo_); glb_idx_off_0 = - IndexToOffset::get( - glb_ldr_logical_off, - cfg_.oinfo_, - IndexToOffset:: - NON_STRICT_CONTIGUOUS); + IndexToOffset::get( + glb_ldr_logical_off, cfg_.oinfo_); } T value = cfg_.init_; IndicesT idx = pi; diff --git a/src/ATen/native/xpu/sycl/Sorting.cpp b/src/ATen/native/xpu/sycl/Sorting.cpp index 67ce6bfd3f..f79572ea9f 100644 --- a/src/ATen/native/xpu/sycl/Sorting.cpp +++ b/src/ATen/native/xpu/sycl/Sorting.cpp @@ -168,11 +168,11 @@ struct GatherMedianKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { // Finds the start offset for our slice index_t valuesSliceStartIndex = - IndexToOffset::get(slice, values_); + IndexToOffset::get(slice, values_); index_t indicesSliceStartIndex = - IndexToOffset::get(slice, indices_); + IndexToOffset::get(slice, indices_); index_t inputSliceStartIndex = - IndexToOffset::get(slice, input_); + IndexToOffset::get(slice, input_); scalar_t* valuesSliceStart = values_data_ + valuesSliceStartIndex; int64_t* indicesSliceStart = indices_data_ + indicesSliceStartIndex; @@ -286,11 +286,11 @@ struct GatherKthValueKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { // Finds the start offset for our slice index_t valuesSliceStartIndex = - IndexToOffset::get(slice, values_); + IndexToOffset::get(slice, values_); index_t indicesSliceStartIndex = - IndexToOffset::get(slice, indices_); + IndexToOffset::get(slice, indices_); index_t inputSliceStartIndex = - IndexToOffset::get(slice, input_); + IndexToOffset::get(slice, input_); scalar_t* valuesSliceStart = values_data_ + valuesSliceStartIndex; int64_t* indicesSliceStart = indices_data_ + indicesSliceStartIndex; diff --git a/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp b/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp index ac30e93f7f..de0c3eaf57 100644 --- a/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp @@ -41,6 +41,7 @@ template < typename input_t, typename IndexType, int ADims, + int BDims, bool has_weight, typename Op> struct Histogram1DKernelFunctor { @@ -52,14 +53,14 @@ struct Histogram1DKernelFunctor { auto linear_index = item_id.get_id(0); // Convert `linear_index` into an offset of `b` const IndexType b_offset = - IndexToOffset::get(linear_index, b_); + IndexToOffset::get(linear_index, b_); const auto b_val = in_ptr[b_offset]; if (b_val >= min_value_ && b_val <= max_value_) { // Use value at `b` as an offset of `a` const IndexType bin = get_bin(b_val, min_value_, max_value_, nbins_); const IndexType a_offset = - IndexToOffset::get(bin, a_); + IndexToOffset::get(bin, a_); atomicAdd( (sycl_global_ptr)&out_ptr[a_offset], get_op_(weight_ptr, linear_index)); @@ -102,6 +103,7 @@ template < typename input_t, typename IndexType, int ADims, + int BDims, bool has_weight, typename Op> void histogram_1d_kernel( @@ -115,28 +117,35 @@ void histogram_1d_kernel( Op get_op) { auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); - Histogram1DKernelFunctor + Histogram1DKernelFunctor< + output_t, + input_t, + IndexType, + ADims, + BDims, + has_weight, + Op> kfn(a, b, c, nbins, min_value, max_value, total_elements, get_op); sycl_kernel_submit(::sycl::range<1>(total_elements), sycl_queue, kfn); } -#define HANDLE_CASE(WEIGHTS_OP, WITH_WEIGHT) \ - histogram_1d_kernel( \ - a_info, \ - b_info, \ - c_info, \ - nbins, \ - min_value, \ - max_value, \ - total_elements, \ +#define HANDLE_CASE(WEIGHTS_OP, WITH_WEIGHT) \ + histogram_1d_kernel( \ + a_info, \ + b_info, \ + c_info, \ + nbins, \ + min_value, \ + max_value, \ + total_elements, \ WEIGHTS_OP); template struct IndexingFunctor { auto operator()(output_t* c_ptr, index_type c_index) const { const index_type c_offset = - IndexToOffset::get(c_index, c_info); + IndexToOffset::get(c_index, c_info); return c_ptr[c_offset]; } diff --git a/src/ATen/native/xpu/sycl/TensorApplyUtils.h b/src/ATen/native/xpu/sycl/TensorApplyUtils.h index 20a7dca339..2e06ce72af 100644 --- a/src/ATen/native/xpu/sycl/TensorApplyUtils.h +++ b/src/ATen/native/xpu/sycl/TensorApplyUtils.h @@ -136,12 +136,12 @@ struct ApplyOp2 { Offsets... bOffsets) { // Convert `linearIndex` into an offset of `a` const IndexType aOffset = static_cast(sizeof...(Offsets)) < n - ? IndexToOffset::get(linearIndex, a) + ? IndexToOffset::get(linearIndex, a) : 0; // Convert `linearIndex` into an offset of `b` const IndexType bOffset = static_cast(sizeof...(Offsets)) < n - ? IndexToOffset::get(linearIndex, b) + ? IndexToOffset::get(linearIndex, b) : 0; ApplyOp2< diff --git a/src/ATen/native/xpu/sycl/TensorModeKernel.cpp b/src/ATen/native/xpu/sycl/TensorModeKernel.cpp index dcecf27ab0..4bd61d4c5c 100644 --- a/src/ATen/native/xpu/sycl/TensorModeKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorModeKernel.cpp @@ -428,7 +428,7 @@ struct ComputeModeKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { // thread to place this in the appropriate output position if (tidx == 0) { unsigned int outputOffset = - at::xpu::detail::IndexToOffset::get( + at::xpu::detail::IndexToOffset::get( groupId, values_); values_.data[outputOffset] = mode_[0]; indices_.data[outputOffset] = index; diff --git a/src/ATen/native/xpu/sycl/WeightNormKernels.cpp b/src/ATen/native/xpu/sycl/WeightNormKernels.cpp index ac67d5d346..490ad63651 100644 --- a/src/ATen/native/xpu/sycl/WeightNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/WeightNormKernels.cpp @@ -33,18 +33,14 @@ struct WeightNormReduceKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int64_t str_pi = id.chunk; int64_t ldr_lid = si + ldr_pi * cfg_.stride_ + bi * cfg_.problem_ * cfg_.stride_; - int64_t ldr_off = at::xpu::detail::IndexToOffset::get( - ldr_lid, - iinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t ldr_off = + at::xpu::detail::IndexToOffset::get( + ldr_lid, iinfo_); int64_t str_lid = si + str_pi * cfg_.stride_ + bi * id.chunk_num * cfg_.stride_; - int64_t str_off = at::xpu::detail::IndexToOffset::get( - str_lid, - oinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t str_off = + at::xpu::detail::IndexToOffset::get( + str_lid, oinfo_); accscalar_t value = 0; if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { @@ -170,29 +166,15 @@ struct SegmentWeightNormKernelFunctor { int64_t w_lid = si + pi * cfg_.stride_ + bi * cfg_.problem_ * cfg_.stride_; int64_t n_lid = id.glb_batch; - int64_t v_off = at::xpu::detail::IndexToOffset::get( - w_lid, - vinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t w_off = at::xpu::detail::IndexToOffset::get( - w_lid, - winfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t g_off = at::xpu::detail::IndexToOffset::get( - n_lid, - ginfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t n_off = at::xpu::detail::IndexToOffset::get( - n_lid, - ninfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t v_off = at::xpu::detail::IndexToOffset::get( + w_lid, vinfo_); + int64_t w_off = at::xpu::detail::IndexToOffset::get( + w_lid, winfo_); + int64_t g_off = at::xpu::detail::IndexToOffset::get( + n_lid, ginfo_); + int64_t n_off = + at::xpu::detail::IndexToOffset::get( + n_lid, ninfo_); if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { winfo_.data[w_off] = @@ -257,17 +239,12 @@ struct WeightNormKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto id = cfg_.get_item_desc(item); int64_t n_lid = id.glb_batch; - int64_t g_off = at::xpu::detail::IndexToOffset::get( - n_lid, - ginfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t g_off = at::xpu::detail::IndexToOffset::get( + n_lid, ginfo_); - int64_t n_off = at::xpu::detail::IndexToOffset::get( - n_lid, - ninfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t n_off = + at::xpu::detail::IndexToOffset::get( + n_lid, ninfo_); int64_t si = id.glb_batch % cfg_.stride_; int64_t bi = id.glb_batch / cfg_.stride_; @@ -278,11 +255,9 @@ struct WeightNormKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { if (id.glb_batch < cfg_.problem_batch_) { for (int pi_ = pi; pi_ < cfg_.problem_; pi_ += cfg_.problem_wg_range_) { int64_t v_lid = bi + pi_ * cfg_.stride_; - int64_t v_off = at::xpu::detail::IndexToOffset::get( - v_lid, - vinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t v_off = + at::xpu::detail::IndexToOffset::get( + v_lid, vinfo_); accscalar_t v = (accscalar_t)vinfo_.data[v_off]; value += v * v; @@ -310,16 +285,12 @@ struct WeightNormKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { if (id.glb_batch < cfg_.problem_batch_) { for (int pi_ = pi; pi_ < cfg_.problem_; pi_ += cfg_.problem_wg_range_) { int64_t v_lid = bi + pi_ * cfg_.stride_; - int64_t v_off = at::xpu::detail::IndexToOffset::get( - v_lid, - vinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - int64_t w_off = at::xpu::detail::IndexToOffset::get( - v_lid, - winfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t v_off = + at::xpu::detail::IndexToOffset::get( + v_lid, vinfo_); + int64_t w_off = + at::xpu::detail::IndexToOffset::get( + v_lid, winfo_); winfo_.data[w_off] = (1.f / shared_[n_slid]) * vinfo_.data[v_off] * ginfo_.data[g_off]; @@ -468,26 +439,19 @@ struct WeightNormBackwardReduceKernelFunctor int64_t i_lid = si + i_pi * cfg_.stride_ + bi * cfg_.problem_ * cfg_.stride_; - int64_t i1_off = at::xpu::detail::IndexToOffset::get( - i_lid, - i1info_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t i1_off = + at::xpu::detail::IndexToOffset::get( + i_lid, i1info_); int64_t i2_off; if (is_first) { - i2_off = at::xpu::detail::IndexToOffset::get( - i_lid, - i2info_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + i2_off = at::xpu::detail::IndexToOffset::get( + i_lid, i2info_); } int64_t o_lid = si + o_pi * cfg_.stride_ + bi * id.chunk_num * cfg_.stride_; - int64_t o_off = at::xpu::detail::IndexToOffset::get( - o_lid, - oinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t o_off = + at::xpu::detail::IndexToOffset::get( + o_lid, oinfo_); accscalar_t value = 0; if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { @@ -637,47 +601,28 @@ struct SegmentWeightNormBackwardKernelFunctor { int64_t gv_lid = si + pi * cfg_.stride_ + bi * cfg_.problem_ * cfg_.stride_; int64_t gg_lid = id.glb_batch; - int64_t v_off = at::xpu::detail::IndexToOffset::get( - gv_lid, - vinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t gw_off = at::xpu::detail::IndexToOffset::get( - gv_lid, - gwinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t gv_off = at::xpu::detail::IndexToOffset::get( - gv_lid, - gvinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t g_off = at::xpu::detail::IndexToOffset::get( - gg_lid, - ginfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t n_off = at::xpu::detail::IndexToOffset::get( - gg_lid, - ninfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t r_off = at::xpu::detail::IndexToOffset::get( - gg_lid, - rinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - int64_t gg_off = at::xpu::detail::IndexToOffset::get( - gg_lid, - gginfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t v_off = at::xpu::detail::IndexToOffset::get( + gv_lid, vinfo_); + + int64_t gw_off = at::xpu::detail::IndexToOffset::get( + gv_lid, gwinfo_); + + int64_t gv_off = at::xpu::detail::IndexToOffset::get( + gv_lid, gvinfo_); + + int64_t g_off = at::xpu::detail::IndexToOffset::get( + gg_lid, ginfo_); + + int64_t n_off = + at::xpu::detail::IndexToOffset::get( + gg_lid, ninfo_); + + int64_t r_off = + at::xpu::detail::IndexToOffset::get( + gg_lid, rinfo_); + + int64_t gg_off = at::xpu::detail::IndexToOffset::get( + gg_lid, gginfo_); if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { accscalar_t g = ginfo_.data[g_off]; @@ -769,21 +714,13 @@ struct WeightNormBackwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { void operator()(sycl::nd_item<2> item) const { auto id = cfg_.get_item_desc(item); int64_t n_lid = id.glb_batch; - int64_t g_off = at::xpu::detail::IndexToOffset::get( - n_lid, - ginfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - int64_t gg_off = at::xpu::detail::IndexToOffset::get( - n_lid, - gginfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - int64_t n_off = at::xpu::detail::IndexToOffset::get( - n_lid, - ninfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + int64_t g_off = at::xpu::detail::IndexToOffset::get( + n_lid, ginfo_); + int64_t gg_off = at::xpu::detail::IndexToOffset::get( + n_lid, gginfo_); + int64_t n_off = + at::xpu::detail::IndexToOffset::get( + n_lid, ninfo_); int64_t si = id.glb_batch % cfg_.stride_; int64_t bi = id.glb_batch / cfg_.stride_; int64_t pi = id.chunk_off; @@ -795,17 +732,11 @@ struct WeightNormBackwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int64_t v_lid, v_off, gw_off; v_lid = bi + pi_ * cfg_.stride_; - v_off = at::xpu::detail::IndexToOffset::get( - v_lid, - vinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + v_off = at::xpu::detail::IndexToOffset::get( + v_lid, vinfo_); - gw_off = at::xpu::detail::IndexToOffset::get( - v_lid, - gwinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + gw_off = at::xpu::detail::IndexToOffset::get( + v_lid, gwinfo_); accscalar_t v = (accscalar_t)vinfo_.data[v_off]; accscalar_t gw = (accscalar_t)gwinfo_.data[gw_off]; @@ -832,23 +763,14 @@ struct WeightNormBackwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int64_t v_lid, v_off, gw_off, gv_off; v_lid = bi + pi_ * cfg_.stride_; - v_off = at::xpu::detail::IndexToOffset::get( - v_lid, - vinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - gw_off = at::xpu::detail::IndexToOffset::get( - v_lid, - gwinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); - - gv_off = at::xpu::detail::IndexToOffset::get( - v_lid, - gvinfo_, - at::xpu::detail::IndexToOffset:: - NON_STRICT_CONTIGUOUS); + v_off = at::xpu::detail::IndexToOffset::get( + v_lid, vinfo_); + + gw_off = at::xpu::detail::IndexToOffset::get( + v_lid, gwinfo_); + + gv_off = at::xpu::detail::IndexToOffset::get( + v_lid, gvinfo_); accscalar_t g = ginfo_.data[g_off]; accscalar_t gw = gwinfo_.data[gw_off]; diff --git a/src/comm/TensorInfo.h b/src/comm/TensorInfo.h index 67b5c5aa91..703807b239 100644 --- a/src/comm/TensorInfo.h +++ b/src/comm/TensorInfo.h @@ -151,42 +151,41 @@ IndexType TensorInfo::outerSize(const int exclusive) { } // Translate a linear index for the apply to a T* offset; -template +template struct IndexToOffset { - static constexpr bool STRICT_CONTIGUOUS = true; - static constexpr bool NON_STRICT_CONTIGUOUS = false; - static inline IndexType get( + static IndexType get( IndexType linearId, - const TensorInfo& info, - bool strict_contiguous = true) { + const TensorInfo& info) { IndexType offset = 0; - if (info.isContiguousCheckStrict(strict_contiguous)) { - return linearId; - } - - for (int dim = info.dims - 1; dim > 0; --dim) { - IndexType curDimIndex = linearId % info.sizes[dim]; - IndexType curDimOffset = curDimIndex * info.strides[dim]; + // Uses static dims + for (int i = Dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; offset += curDimOffset; - linearId /= info.sizes[dim]; + linearId /= info.sizes[i]; } + return offset + linearId * info.strides[0]; } }; -// To isolate unnecessary code, even the code is not involved in -// contiguouse case. Additional unnecessary code impacts efficiency of -// generated code. +// Uses dynamic (runtime) instead of static (compiletime) dims template -struct IndexToOffset { - static constexpr bool STRICT_CONTIGUOUS = true; - static constexpr bool NON_STRICT_CONTIGUOUS = false; +struct IndexToOffset { static inline IndexType get( IndexType linearId, - const TensorInfo& info, - bool strict_contiguous = true) { - return linearId; + const TensorInfo& info) { + IndexType offset = 0; + + for (int i = info.dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; } };