-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
SplitK with Atomic Reduce Counting for Skinny GEMMs #29807
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new GEMM kernel, wvSplitKrc, optimized for skinny matrices on ROCm GPUs using a Split-K algorithm with atomic reduction. While the initiative is good, the implementation has several critical issues that must be addressed. There are correctness bugs related to GPU architecture detection and memory access patterns that will lead to incorrect results or prevent the kernel from being used at all. Additionally, there is a significant performance issue due to an inefficient runtime loop inside the kernel. I have provided detailed comments and suggestions to fix these problems.
| return out_c; | ||
| } | ||
|
|
||
| #if defined(__gfx950__) // TODO: Add NAVI support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The preprocessor directive #if defined(__gfx950__) is likely incorrect as gfx950 is not a standard ROCm architecture name. For MI300 series GPUs, the architecture is gfx942. This will prevent the kernel from being compiled for the intended hardware, making this new feature dead code. You should use a correct macro, for example #if defined(__gfx942__) or the existing __HIP__MI3XX__ if it's meant for all MI300 series.
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
| if (((K + kfitsPerRdc * kFit - 1) / (kfitsPerRdc * kFit)) * numCuWithFullK <= | ||
| CuCount) | ||
| while (true) { | ||
| while (kFit > TUC_) { | ||
| uint32_t kFit_ = kFit - TUC_; | ||
| if (((K + (kfitsPerRdc * kFit_ - 1)) / (kfitsPerRdc * kFit_)) * | ||
| numCuWithFullK > | ||
| CuCount) | ||
| break; | ||
| kFit = kFit_; | ||
| } | ||
| if (((K + ((kfitsPerRdc - 1) * kFit - 1)) / ((kfitsPerRdc - 1) * kFit)) * | ||
| numCuWithFullK <= | ||
| CuCount) | ||
| kfitsPerRdc--; | ||
| else | ||
| break; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This while(true) loop calculates the optimal K-split configuration (kFit and kfitsPerRdc) at runtime inside the kernel. This is highly inefficient because it is executed by every thread, leading to significant redundant computation and thread divergence. This logic should be performed once on the host, and the results passed as arguments to the kernel to avoid severe performance degradation.
| for (uint32_t k2 = 0; k2 < UNRL; k2++) { | ||
| uint32_t k = k_str + k2 * THRDS * A_CHUNK; | ||
| uint32_t k_ = k + threadIdx.x * A_CHUNK; | ||
| const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of the base address for loading from matrix B is incorrect. min__(k_, K - A_CHUNK) causes threads that should be processing data near the end of the matrix to instead load data from an earlier, incorrect position. This will lead to incorrect matrix multiplication results. The boundary handling should rely on the out-of-bounds checks and zero-padding already present later in the code.
const scalar_t* B_ = &B[k_];
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intended. It will do junk calc, but it will ultimately get discarded. This way the loop can be fully unrolled without compiler getting confused by oob-handling if-conditions in the loop.
| use_skinny_reduce_counting = ( | ||
| envs.VLLM_ROCM_USE_SKINNY_GEMM | ||
| and on_gfx950() | ||
| and x.dtype in [torch.float16, torch.bfloat16] | ||
| and (n == 32 and k == 2880 and (m == 640 or m == 128)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function on_gfx950 checks for the "gfx950" architecture, which is incorrect for MI300 series GPUs (which are gfx942). This will cause on_gfx950() to return False, and the new wvSplitKrc kernel will never be executed. This makes the new feature dead code. You should use a correct check for the target hardware. A similar issue exists in csrc/rocm/skinny_gemms.cu where the kernel is guarded by #if defined(__gfx950__).
csrc/rocm/skinny_gemms.cu
Outdated
| __device__ inline int min__(int a, int b) { | ||
| int tmp; | ||
| asm("v_min_i32_e32 %0, %2, %3 " : "=v"(tmp) : "0"(tmp), "v"(a), "v"(b)); | ||
| return tmp; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The min__ function is implemented with inline assembly that has an incorrect constraint "0"(tmp). This can lead to miscompilation. It's safer and more maintainable to use the standard min() function, which the compiler will optimize to the v_min_i32_e32 instruction. Given the critical bug in its usage at line 1485, it's best to refactor this helper function.
__device__ inline int min__(int a, int b) {
return min(a, b);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Johnny Yang <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
…oject#28878) Signed-off-by: HDCharles <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
…a building with DCP > 1 (vllm-project#29449) Signed-off-by: Matthew Bonanni <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
…m-project#28619) Signed-off-by: Jinzhen Lin <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Wentao Ye <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Tsai, Louie <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
…#29576) Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: tjtanaa <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Fadi Arafeh <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]> Signed-off-by: Hashem Hashemi <[email protected]>
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.