@@ -443,6 +443,11 @@ void llama_kv_cache_unified::set_full() {
443443 n = size;
444444}
445445
446+ bool llama_kv_cache_unified::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
447+ // Unified attention cache can always do a sequence removal
448+ return true ;
449+ }
450+
446451llama_sbatch llama_kv_cache_unified::sbatch_init (
447452 const llama_batch & batch,
448453 bool logits_all) {
@@ -1481,39 +1486,33 @@ void llama_kv_cache_recurrent::clear() {
14811486}
14821487
14831488bool llama_kv_cache_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1484- uint32_t new_head = size;
1489+ if (!can_seq_rm (seq_id, p0, p1)) {
1490+ // could be fatal
1491+ return false ;
1492+ }
14851493
1494+ uint32_t new_head = size;
14861495 if (p0 < 0 ) {
14871496 p0 = 0 ;
14881497 }
1489-
14901498 if (p1 < 0 ) {
14911499 p1 = std::numeric_limits<llama_pos>::max ();
14921500 }
14931501
1494- // models like Mamba or RWKV can't have a state partially erased
1495- if (seq_id >= (int64_t ) size) {
1496- // could be fatal
1497- return false ;
1498- }
14991502 if (0 <= seq_id) {
15001503 int32_t & tail_id = cells[seq_id].tail ;
15011504 if (tail_id >= 0 ) {
15021505 const kv_cell & cell = cells[tail_id];
1503- // partial intersection is invalid
1504- if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1505- return false ;
1506- }
1506+ // already validated in can_seq_rm
1507+ GGML_ASSERT (!((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )));
15071508 // invalidate tails which will be cleared
15081509 if (p0 <= cell.pos && cell.pos < p1) {
15091510 tail_id = -1 ;
15101511 }
15111512 }
15121513 } else {
1513- // seq_id is negative, then the range should include everything or nothing
1514- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1515- return false ;
1516- }
1514+ // already validated in can_seq_rm
1515+ GGML_ASSERT (!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())));
15171516 }
15181517
15191518 for (uint32_t i = 0 ; i < size; ++i) {
@@ -1714,6 +1713,35 @@ void llama_kv_cache_recurrent::set_full() {
17141713 n = size;
17151714}
17161715
1716+ bool llama_kv_cache_recurrent::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1717+ if (p0 < 0 ) {
1718+ p0 = 0 ;
1719+ }
1720+
1721+ if (p1 < 0 ) {
1722+ p1 = std::numeric_limits<llama_pos>::max ();
1723+ }
1724+ // models like Mamba or RWKV can't have a state partially erased
1725+ if (seq_id >= (int64_t ) size) {
1726+ // could be fatal
1727+ return false ;
1728+ }
1729+ if (0 <= seq_id) {
1730+ const int32_t & tail_id = cells[seq_id].tail ;
1731+ if (tail_id >= 0 ) {
1732+ const kv_cell & cell = cells[tail_id];
1733+ // partial intersection is invalid
1734+ if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1735+ return false ;
1736+ }
1737+ }
1738+ // seq_id is negative, then the range should include everything or nothing
1739+ } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1740+ return false ;
1741+ }
1742+ return true ;
1743+ }
1744+
17171745llama_sbatch llama_kv_cache_recurrent::sbatch_init (
17181746 const llama_batch & batch,
17191747 bool logits_all) {
@@ -2456,13 +2484,18 @@ void llama_kv_cache_hybrid::clear() {
24562484}
24572485
24582486bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2459- // TODO: Will it cause problems if some caches are able to remove the seq
2460- // but others aren't?
2461- bool removed = true ;
2487+ // First check if we can do this removal. This checks all children so that
2488+ // no mutation happens before we know if it's possible
2489+ if (!can_seq_rm (seq_id, p0, p1)) {
2490+ return false ;
2491+ }
2492+
2493+ // Do the removal from each child which should never fail
24622494 for (const auto & cache : m_children) {
2463- removed = cache->seq_rm (seq_id, p0, p1) && removed;
2495+ const bool failed = cache->seq_rm (seq_id, p0, p1);
2496+ GGML_ASSERT (!failed);
24642497 }
2465- return removed ;
2498+ return true ;
24662499}
24672500
24682501void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
@@ -2529,6 +2562,15 @@ void llama_kv_cache_hybrid::set_full() {
25292562 }
25302563}
25312564
2565+ bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2566+ for (const auto & cache : m_children) {
2567+ if (!cache->can_seq_rm (seq_id, p0, p1)) {
2568+ return false ;
2569+ }
2570+ }
2571+ return true ;
2572+ }
2573+
25322574llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
25332575 // If any of the caches are recurrent, require equal split
25342576 return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
@@ -2566,7 +2608,7 @@ int32_t llama_kv_cache_hybrid::get_n_tokens() const {
25662608
25672609int32_t llama_kv_cache_hybrid::get_used_cells () const {
25682610 // TODO: Is this correct?
2569- // Return the largetst number of used cells
2611+ // Return the largest number of used cells
25702612 int32_t used_cells = -1 ;
25712613 for (const auto & cache : m_children) {
25722614 used_cells = std::max (used_cells, cache->get_used_cells ());
0 commit comments