Skip to content

vulkan: optimize mul_mat_id loading row ids into shared memory #15427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2025

Conversation

jeffbolznv
Copy link
Collaborator

  • Spread the work across the whole workgroup. Using more threads seems to far outweigh the synchronization overhead.
  • Specialize the code for when the division is by a power of two.
5090 before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      4530.43 ± 31.16 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |     7549.93 ± 116.08 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |     7275.45 ± 115.27 |

5090 after

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      6169.11 ± 71.60 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |     8449.56 ± 125.48 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |     7554.83 ± 156.67 |

4070 before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |       1547.55 ± 9.91 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |       2188.88 ± 9.85 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      2757.82 ± 25.81 |

4070 after

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 0 -p 512 -r 100 --prio 1 -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           pp512 |      2230.97 ± 21.65 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           pp512 |      2722.84 ± 13.74 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      2995.57 ± 28.40 |

- Spread the work across the whole workgroup. Using more threads seems to
far outweigh the synchronization overhead.
- Specialize the code for when the division is by a power of two.
@jeffbolznv jeffbolznv requested a review from 0cc4m as a code owner August 19, 2025 15:51
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Aug 19, 2025
Copy link
Collaborator

@0cc4m 0cc4m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, great improvement. This closes the gap between CUDA and Vulkan MMID significantly. In some cases Vulkan even beats CUDA in pp512 now on my RTX 3090.

Example result, on Master:

model size params backend ngl fa test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B Vulkan 99 0 pp512 1254.06 ± 7.48
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B Vulkan 99 0 tg128 140.24 ± 0.75
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B Vulkan 99 1 pp512 1284.92 ± 6.66
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B Vulkan 99 1 tg128 143.79 ± 0.53

PR:

model size params backend ngl fa test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B Vulkan 99 0 pp512 1986.57 ± 16.09
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B Vulkan 99 0 tg128 137.40 ± 1.86
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B Vulkan 99 1 pp512 2055.43 ± 15.58
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B Vulkan 99 1 tg128 139.03 ± 0.14

CUDA:

model size params backend ngl fa test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 0 pp512 1972.23 ± 16.20
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 0 tg128 139.29 ± 0.31
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 pp512 2106.30 ± 9.63
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg128 141.42 ± 0.42

No change on (non-coopmat) AMD and Intel, of course.

@0cc4m 0cc4m merged commit 330c3d2 into ggml-org:master Aug 23, 2025
47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants