Skip to content

llama : add high-throughput mode #14363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--attn-streams", "-as"},
string_format("use multiple streams when computing the attention (default: %s)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.attn_streams ? "true" : "false"),
[](common_params & params) {
params.attn_streams = true;
}
).set_env("LLAMA_ARG_ATTN_STREAMS"));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
cparams.attn_streams = params.attn_streams;

cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ struct common_params {
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool attn_streams = false; // multi-stream attention and KV cache buffers

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads
Expand Down
3 changes: 2 additions & 1 deletion examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ int main(int argc, char ** argv) {
auto & client = clients[i];
client.id = i;
client.smpl = common_sampler_init(model, params.sampling);
//params.sampling.seed++;
}

std::vector<llama_token> tokens_system;
Expand Down Expand Up @@ -345,7 +346,7 @@ int main(int argc, char ** argv) {
client.n_decoded = 0;
client.i_batch = batch.n_tokens - 1;

LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur);
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);

g_seq_id += 1;

Expand Down
4 changes: 4 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ extern "C" {
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573

bool attn_streams; // if enabled, use multiple streams during the attention (determined by n_seq_max)
// NOTE: this requires support for the ggml_set_rows() operator
// this flag can improve the performance for parallel, multi-sequence use cases
};

// model quantization parameters
Expand Down
17 changes: 13 additions & 4 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ llama_context::llama_context(

cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);

cparams.op_offload = params.op_offload;
cparams.op_offload = params.op_offload;
cparams.attn_streams = params.attn_streams;

const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;

Expand All @@ -112,6 +113,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: attn_streams = %s\n", __func__, cparams.attn_streams ? "true" : "false");
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);

Expand Down Expand Up @@ -267,7 +269,7 @@ llama_context::llama_context(

// reserve worst-case graph
if (!hparams.vocab_only && memory) {
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
Expand Down Expand Up @@ -300,7 +302,7 @@ llama_context::llama_context(

// reserve with tg graph to get the number of splits and nodes
{
auto * gf = graph_reserve(1, 1, 1, mctx.get());
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
Expand All @@ -311,6 +313,10 @@ llama_context::llama_context(

// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
Expand Down Expand Up @@ -475,7 +481,7 @@ bool llama_context::kv_self_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context");
}

const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
Expand Down Expand Up @@ -741,6 +747,8 @@ int llama_context::encode(const llama_batch & batch_inp) {

const uint32_t n_tokens = balloc->get_n_tokens();

// [TAG_NO_CACHE_PAD]
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
const llama_ubatch ubatch = balloc->split_simple(n_tokens);

// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
Expand Down Expand Up @@ -2187,6 +2195,7 @@ llama_context_params llama_context_default_params() {
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.attn_streams =*/ false,
};

return result;
Expand Down
5 changes: 3 additions & 2 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ struct llama_cparams {
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing

float rope_freq_base;
float rope_freq_scale;
Expand All @@ -33,6 +33,7 @@ struct llama_cparams {
bool no_perf;
bool warmup;
bool op_offload;
bool attn_streams;

enum llama_pooling_type pooling_type;

Expand Down
34 changes: 23 additions & 11 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,12 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");

const auto n_kv = inp->mctx->get_attn()->get_n_kv();
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1;

inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
Expand All @@ -1032,13 +1033,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
float kq_scale) const {
const bool v_trans = v->nb[1] > v->nb[2];

// split the batch into streams if needed
const auto n_stream = k->ne[3];

q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);

q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);

const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1];
const auto n_kv = k->ne[1];

ggml_tensor * cur;

Expand Down Expand Up @@ -1080,7 +1084,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
#endif
}

cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
} else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);

Expand Down Expand Up @@ -1125,7 +1129,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(

cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);

cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
// recombine streams
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);

if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU
Expand Down Expand Up @@ -1172,6 +1177,10 @@ ggml_tensor * llm_graph_context::build_attn(

const auto & kq_mask = inp->get_kq_mask();

// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(ubatch.equal_seqs == false);

ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
Expand Down Expand Up @@ -1202,12 +1211,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");

const auto n_kv = mctx_cur->get_n_kv();
const auto n_kv = mctx_cur->get_n_kv();
const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1;

inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
Expand Down Expand Up @@ -1449,13 +1459,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif

auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);

const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1;

{
const auto n_kv = mctx_cur->get_base()->get_n_kv();

inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
Expand All @@ -1469,7 +1481,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);

inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask_swa);

inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
Expand Down
22 changes: 11 additions & 11 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }

ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

const llama_hparams & hparams;
const llama_cparams & cparams;
Expand Down Expand Up @@ -289,14 +289,14 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }

ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

const llama_hparams & hparams;
const llama_cparams & cparams;
Expand Down Expand Up @@ -343,8 +343,8 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

const llama_hparams & hparams;
const llama_cparams & cparams;
Expand Down
40 changes: 40 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv;
}

bool llama_hparams::is_n_embd_k_gqa_variable() const {
const uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_k_gqa(il)) {
return true;
}
}

return false;
}

bool llama_hparams::is_n_embd_v_gqa_variable() const {
const uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_v_gqa(il)) {
return true;
}
}

return false;
}

uint32_t llama_hparams::n_embd_k_gqa_max() const {
uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_k_gqa(il));
}

return val;
}

uint32_t llama_hparams::n_embd_v_gqa_max() const {
uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_v_gqa(il));
}

return val;
}

uint32_t llama_hparams::n_embd_r() const {
if (wkv_head_size != 0) {
// for RWKV models
Expand Down
8 changes: 8 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ struct llama_hparams {
// dimension of value embeddings across all k-v heads
uint32_t n_embd_v_gqa(uint32_t il = 0) const;

// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
bool is_n_embd_k_gqa_variable() const;
bool is_n_embd_v_gqa_variable() const;

// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
uint32_t n_embd_k_gqa_max() const;
uint32_t n_embd_v_gqa_max() const;

// dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_r() const;
Expand Down
Loading
Loading