diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 91b1d6078a252..c90dff8b82675 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -166,6 +166,8 @@ bool llama_batch_allocr::init( // note: tracking the other way around is not necessary for now //seq_cpl[s0][s1] = true; + + has_cpl = true; } } } @@ -459,9 +461,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { return ubatch_add(idxs, idxs.size(), false); } -llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { +llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { + if (sequential && has_cpl) { + LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); + + return {}; + } + std::vector cur_seq_set; + llama_seq_id last_seq_id = -1; + // determine the non-overlapping sequence sets participating in this ubatch for (int32_t i = 0; i < batch.n_tokens; ++i) { if (used[i]) { @@ -478,9 +488,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { } } + // accept only increasing sequence ids + if (sequential) { + add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); + } + if (add) { cur_seq_set.push_back(seq_set[i]); + last_seq_id = batch.seq_id[i][0]; + if (cur_seq_set.size() > n_ubatch) { break; } diff --git a/src/llama-batch.h b/src/llama-batch.h index d2c5376188a0b..459a00be70a04 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -69,7 +69,8 @@ class llama_batch_allocr { llama_ubatch split_simple(uint32_t n_ubatch); // make ubatches of equal-length sequences sets - llama_ubatch split_equal(uint32_t n_ubatch); + // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids + llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); // sequence-set-wise split - each ubatch contains a single sequence-set llama_ubatch split_seq(uint32_t n_ubatch); @@ -112,6 +113,9 @@ class llama_batch_allocr { using pos_set_t = std::set; using seq_cpl_t = std::vector; + // helper flag to quickly determine if there are any coupled sequences in the batch + bool has_cpl; + std::vector seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index ee202cc710bd6..def5d3708d9de 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -135,7 +135,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch); + auto ubatch = balloc.split_equal(n_ubatch, false); if (ubatch.n_tokens == 0) { break; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 03d974d852039..345e88bc175f4 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } if (ubatch.n_tokens == 0) { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 6ed84057ccfe2..ab393ca5ac2f7 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } if (ubatch.n_tokens == 0) {