Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Feb 5, 2025
1 parent 7d30e2c commit d49f818
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions ark/include/kernels/kernel_template.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,24 @@ template <size_t ProcBegin, size_t ProcEnd, size_t ProcStep, size_t ProcCurrent,
size_t NumSlots, size_t SlotNumWarps, size_t SlotSramBytes,
void (*task)(char*, int, int, @ARG_TYPES@)>
__forceinline__ __device__ void task_seq(char *_buf, @GLOBAL_ARGS@) {
if (math::geq<ProcBegin>(blockIdx.x) && math::le<ProcEnd>(blockIdx.x) &&
((blockIdx.x - ProcBegin) % ProcStep == 0)) {
constexpr size_t SlotNumThreads = SlotNumWarps * Arch::ThreadsPerWarp;
constexpr size_t NumProcs = (ProcEnd - ProcBegin + ProcStep - 1) / ProcStep;
constexpr size_t SramBytesPerWarp = SlotSramBytes / SlotNumWarps;
size_t p = ((blockIdx.x + gridDim.x - ProcCurrent) % gridDim.x) / ProcStep;
size_t k = threadIdx.x / SlotNumThreads;
if constexpr (ARK_WARPS_PER_BLOCK > SlotNumWarps) {
if (k >= NumSlots) return;
}
size_t task_id_base = TaskBegin + p * TaskStep * TaskGranularity;
for (size_t t = k; ; t += NumSlots) {
size_t task_id = task_id_base + TaskStep *
(t % TaskGranularity + t / TaskGranularity * TaskGranularity * NumProcs);
if (task_id >= TaskEnd) break;
task(_buf, task_id, SramBytesPerWarp, @FUNCTION_ARGS@);
if constexpr (TaskBegin != TaskEnd) {
if (math::geq<ProcBegin>(blockIdx.x) && math::le<ProcEnd>(blockIdx.x) &&
((blockIdx.x - ProcBegin) % ProcStep == 0)) {
constexpr size_t SlotNumThreads = SlotNumWarps * Arch::ThreadsPerWarp;
constexpr size_t NumProcs = (ProcEnd - ProcBegin + ProcStep - 1) / ProcStep;
constexpr size_t SramBytesPerWarp = SlotSramBytes / SlotNumWarps;
size_t p = ((blockIdx.x + gridDim.x - ProcCurrent) % gridDim.x) / ProcStep;
size_t k = threadIdx.x / SlotNumThreads;
if constexpr (ARK_WARPS_PER_BLOCK > SlotNumWarps) {
if (k >= NumSlots) return;
}
size_t task_id_base = TaskBegin + p * TaskStep * TaskGranularity;
for (size_t t = k; ; t += NumSlots) {
size_t task_id = task_id_base + TaskStep *
(t % TaskGranularity + t / TaskGranularity * TaskGranularity * NumProcs);
if (task_id >= TaskEnd) break;
task(_buf, task_id, SramBytesPerWarp, @FUNCTION_ARGS@);
}
}
}
}
Expand Down

0 comments on commit d49f818

Please sign in to comment.