Skip to content

Commit 0b700f2

Browse files
authored
[flang][cuda] Add entry point to launch global function with cluster_dims (#113958)
1 parent 12a8f50 commit 0b700f2

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

flang/include/flang/Runtime/CUDA/kernel.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@
1515

1616
extern "C" {
1717

18-
// This function uses intptr_t instead of CUDA's unsigned int to match
18+
// These functions use intptr_t instead of CUDA's unsigned int to match
1919
// the type of MLIR's index type. This avoids the need for casts in the
2020
// generated MLIR code.
21+
2122
void RTDEF(CUFLaunchKernel)(const void *kernelName, intptr_t gridX,
2223
intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY,
2324
intptr_t blockZ, int32_t smem, void **params, void **extra);
2425

26+
void RTDEF(CUFLaunchClusterKernel)(const void *kernelName, intptr_t clusterX,
27+
intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY,
28+
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
29+
int32_t smem, void **params, void **extra);
30+
2531
} // extern "C"
2632

2733
#endif // FORTRAN_RUNTIME_CUDA_KERNEL_H_

flang/runtime/CUDA/kernel.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,32 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
2525
blockDim.x = blockX;
2626
blockDim.y = blockY;
2727
blockDim.z = blockZ;
28-
cudaStream_t stream = 0;
28+
cudaStream_t stream = 0; // TODO stream managment
2929
CUDA_REPORT_IF_ERROR(
3030
cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, stream));
3131
}
3232

33+
void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
34+
intptr_t clusterY, intptr_t clusterZ, intptr_t gridX, intptr_t gridY,
35+
intptr_t gridZ, intptr_t blockX, intptr_t blockY, intptr_t blockZ,
36+
int32_t smem, void **params, void **extra) {
37+
cudaLaunchConfig_t config;
38+
config.gridDim.x = gridX;
39+
config.gridDim.y = gridY;
40+
config.gridDim.z = gridZ;
41+
config.blockDim.x = blockX;
42+
config.blockDim.y = blockY;
43+
config.blockDim.z = blockZ;
44+
config.dynamicSmemBytes = smem;
45+
config.stream = 0; // TODO stream managment
46+
cudaLaunchAttribute launchAttr[1];
47+
launchAttr[0].id = cudaLaunchAttributeClusterDimension;
48+
launchAttr[0].val.clusterDim.x = clusterX;
49+
launchAttr[0].val.clusterDim.y = clusterY;
50+
launchAttr[0].val.clusterDim.z = clusterZ;
51+
config.numAttrs = 1;
52+
config.attrs = launchAttr;
53+
CUDA_REPORT_IF_ERROR(cudaLaunchKernelExC(&config, kernel, params));
54+
}
55+
3356
} // extern "C"

0 commit comments

Comments
 (0)