Skip to content

llama : add high-throughput mode #14363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Jun 24, 2025

target #14285

Overview

Improve multi-sequence decoding performance by avoiding the cross-sequence attention compute of the unified KV cache.

Still WIP, but initial results are promising. The functionality is temporarily gated with env LLAMA_HT and also requires the LLAMA_SET_ROWS from #14285.

Detailed description will be added when I make some more progress and am more convinced that the approach is viable.

Testing

# master
make -j && ./bin/llama-batched-bench -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -c 100000 -b 2048 -ub 512 -npp 0,0,512,1024,2048 -ntg 32 -npl 32 -fa

main: n_kv_max = 100096, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.387 |   738.54 |    1.387 |   738.51 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.396 |   733.74 |    1.396 |   733.69 |
|   512 |     32 |   32 |  17408 |    6.054 |  2706.51 |    2.021 |   506.62 |    8.075 |  2155.85 |
|  1024 |     32 |   32 |  33792 |   13.002 |  2520.20 |    2.668 |   383.80 |   15.670 |  2156.45 |
|  2048 |     32 |   32 |  66560 |   29.922 |  2190.20 |    3.957 |   258.81 |   33.879 |  1964.64 |
# PR
make -j && LLAMA_HT=1 LLAMA_SET_ROWS=1 ./bin/llama-batched-bench -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -c 10000 -b 2048 -ub 512 -npp 0,0,512,1024,2048 -ntg 32 -npl 32 -fa

main: n_kv_max = 10240, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16

|    PP |     TG |    B |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |      T s |    S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.322 |   774.85 |    1.322 |   774.71 |
|     0 |     32 |   32 |   1024 |    0.000 |     0.00 |    1.315 |   778.66 |    1.315 |   778.49 |
|   512 |     32 |   32 |  17408 |    5.630 |  2910.32 |    1.474 |   694.51 |    7.104 |  2450.44 |
|  1024 |     32 |   32 |  33792 |   11.577 |  2830.46 |    1.549 |   661.13 |   13.126 |  2574.48 |
|  2048 |     32 |   32 |  66560 |   24.376 |  2688.52 |    1.729 |   592.27 |   26.105 |  2549.69 |

Using a more real-world example with llama-parallel:

# master
make -j && ./bin/llama-parallel -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -np 8 -ns 128 -s 1 -c 16384 -fa

# PR
make -j && LLAMA_HT=1 LLAMA_SET_ROWS=1 ./bin/llama-parallel -m ../models/qwen2.5-3b-coder/ggml-model-q8_0.gguf -np 32 -ns 128 -s 1 -c 4096 -fa

TODO

  • FA path
  • Non-FA path
  • Metal FA
  • Metal non-FA
  • CPU FA
  • CPU non-FA
  • ggml_soft_max_ext() support for virtual sequences
  • llama_memory_seq_cp support for virtual sequences
  • iSWA
  • split_equal support sequential ids
  • CUDA
  • Vulkan
  • etc.
  • more consistent sequence/virtual sequence naming
  • better term than "virtual sequence"?
  • env LLAMA_HT become regular compute parameter
  • Fix n_ctx meaning (total vs per-sequence)
  • Check input batch for no coupled sequences when HT is on
  • Require n_embd_v_gqa(il) == const when FA is off (no longer needed)

@github-actions github-actions bot added examples ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Jun 24, 2025
@JohannesGaessler
Copy link
Collaborator

Right now I am comparatively less busy with my PhD so it would be a good time for me to write CUDA code that is still missing, if there is any.

@ggerganov
Copy link
Member Author

ggerganov commented Jun 24, 2025

For now, these are the necessary CUDA changes:

  • Add ggml_set_rows() support (need PR towards ggml : add ggml_set_rows #14274, can already start implementing this)
  • Extend ggml_flash_attn_ext() to support n_seq dim if it does not yet:
// old
    // q:    [n_embd_k, n_batch,     n_head,    1]
    // k:    [n_embd_k, n_kv,        n_head_kv, 1]
    // v:    [n_embd_v, n_kv,        n_head_kv, 1] !! not transposed !!
    // mask: [n_kv,     n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
    // res:  [n_embd_v, n_head,      n_batch,   1] !! permuted !!
    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
            ...);

// new - supports `n_seq` dimension:
    // q:    [n_embd_k, n_batch,     n_head,    n_seq]
    // k:    [n_embd_k, n_kv,        n_head_kv, n_seq]
    // v:    [n_embd_v, n_kv,        n_head_kv, n_seq] !! not transposed !!
    // mask: [n_kv,     n_batch_pad, n_seq,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
    // res:  [n_embd_v, n_head,      n_batch,   n_seq] !! permuted !!
    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
            ...);

CPU might also need to be extended (not sure yet)

  • Extend ggml_soft_max_ext to support n_seq dim if it does not yet in a similar way. Also not sure about the CPU state.

Edit: the CPU versions of ggml_soft_max_ext() and ggml_flash_attn_ext() are now correct and can be used as a reference.

@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from ab2a2bb to 1b74b9d Compare June 24, 2025 17:24
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch 3 times, most recently from c246784 to 06bb08a Compare June 27, 2025 14:35
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch 3 times, most recently from 82277da to 4534123 Compare June 30, 2025 14:08
@ggerganov ggerganov mentioned this pull request Jul 1, 2025
5 tasks
@ggerganov ggerganov force-pushed the gg/kv-cache-use-set-rows branch from 2f577c5 to 30b4d4e Compare July 2, 2025 12:49
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from 6179578 to dfceb01 Compare July 2, 2025 18:20
Base automatically changed from gg/kv-cache-use-set-rows to master July 3, 2025 07:53
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch 2 times, most recently from eb5856c to ee0f729 Compare July 3, 2025 08:12
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from ee0f729 to deae7cd Compare July 3, 2025 08:53
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch 2 times, most recently from 988d0cd to dbcfcaa Compare July 3, 2025 12:11
v_cells[s].resize(kv_size);
}

// by default, all sequence ids are mapped to the 0th virtual sequence
Copy link
Collaborator

@compilade compilade Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to understand the purpose of virtual sequences.

  • Is it to make the unified cache not unified?
    • Should it be a separate cache type instead?
  • why is n_seq_virt a number and not a bool of whether or not the cache is unified?
    • Is it to eventually allow n_seq_max % n_seq_virt == 0 for a partially-unified cache?
  • Are virtual sequences intended to be used with other types of caches eventually (e.g. recurrent)?
    • The concept here seems specific to the self-attention KV cache (unless I'm misunderstanding).

Copy link
Member Author

@ggerganov ggerganov Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today I found a better term instead of "virtual sequences": "streams". So I'll use "streams" here and will update the code later today or tomorrow.

Is it to make the unified cache not unified?

Roughly yes. The user will be able to select between unified (i.e. single stream) or non-unified (multiple streams). Each mode has advantages in different scenarios. Single stream is good when the sequences share large common prefixes. Multiple streams are good when the sequences are mostly or completely independent from each other.

The first iteration will support 1 stream (i.e. same as master, vanilla unified KV cache) and n_seq_max streams. The latter means that each sequence id is assigned to a separate stream.

In theory, we could assign multiple sequence ids to the same stream to get a partially-unified KV cache, but this would need extra work and it might not have any useful applications. So out of scope for now.

Should it be a separate cache type instead?

There is too much similar logic. Still thinking about it, but most likely it will end up in the same cache type.

The concept here seems specific to the self-attention KV cache (unless I'm misunderstanding)

Yes.

Comment on lines 73 to 74
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are sequential seq_ids required when virtual sequences are used?

Is it because a contiguous (along the virtual sequence dimension) slice of the KV cache is used?

I wonder if there could be a way to avoid this requirement with ggml_get_rows and/or ggml_mul_mat_id. Might not be worth the extra indirection, though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are sequential seq_ids required when virtual sequences are used?

Is it because a contiguous (along the virtual sequence dimension) slice of the KV cache is used?

Yes, we make a view of the KV cache across the streams here:

ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
const uint64_t kv_size = get_size();
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*kv_size)*sinfo.s0);
}

The ns var is the number of streams that participate in the current ubatch. Their stream indices range from [s0, s1].

I wonder if there could be a way to avoid this requirement with ggml_get_rows and/or ggml_mul_mat_id. Might not be worth the extra indirection, though.

It should be possible. But I'm not sure if it would be worth - both in performance and in complexity. We can explore though.

@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*(1 + n_seq_virt)*n_layer_cache*ggml_tensor_overhead()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 1 + intended? Why was it added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the per-stream views of the KV cache:

std::vector<ggml_tensor *> k_seq;
std::vector<ggml_tensor *> v_seq;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
k_seq.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
v_seq.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
}

These are used to implement the llama_memory_seq_cp(). This operation is no longer just assigning ids - it performs actual copy of the buffers in memory when we use multiple streams. Using these helper views, the operation is quite simple to implement:

bool is_full = true;
if (p0 > 0 && p0 + 1 < (int) get_size()) {
is_full = false;
}
if (p1 > 0 && p1 + 1 < (int) get_size()) {
is_full = false;
}
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (virt = %d) to %d (virt = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1);
for (uint32_t il = 0; il < layers.size(); ++il) {
const auto & layer = layers[il];
ggml_backend_tensor_copy(layer.k_seq[s0], layer.k_seq[s1]);
ggml_backend_tensor_copy(layer.v_seq[s0], layer.v_seq[s1]);
// TODO: do we need synchronization here?
}
// TODO: support this:
GGML_ASSERT(v_cells[s0].get_has_shift() == false && "cannot copy a KV buffer that has a pending shift");
v_cells[s1].reset();
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
if (v_cells[s0].seq_has(i, seq_id_src)) {
v_cells[s1].pos_set(i, v_cells[s0].pos_get(i));
v_cells[s1].seq_add(i, seq_id_dst);
}
}
v_heads[s1] = v_heads[s0];
//for (uint32_t s = 0; s < n_seq_virt; ++s) {
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
//}
}

Though we cannot copy partial sequences when using multiple streams.

Comment on lines 498 to 501
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about decreasing sequence ids? Is the requirement that they are increasing, or that the included seq_ids should be in a contiguous range?

(decreasing sequence ids might not really happen often in practice though)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decreasing would also work - we just need continuous range. We can either add this, if there is an elegant way to search for this. Or we add some batch pre-processing step to move the complexity at a higher level. Or just delegate it to the user by warning when the batch is not arranged optimally.

@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from dbcfcaa to 33dcc3c Compare July 4, 2025 07:04
@ggerganov ggerganov force-pushed the gg/llama-high-throughput branch from 33dcc3c to 5363817 Compare July 4, 2025 07:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) examples ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants