@@ -1733,8 +1733,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
17331733 uint32_t n_seq_max,
17341734 uint32_t n_batch,
17351735 uint32_t n_pad) : hparams(model.hparams) {
1736- llama_kv_cache_unified ::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1737- llama_kv_cache_unified ::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1736+ llama_kv_cache ::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1737+ llama_kv_cache ::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
17381738
17391739 const uint32_t size_base = kv_size;
17401740
@@ -3082,3 +3082,239 @@ int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
30823082float llama_kv_cache_recurrent_state::s_mask (int i) const {
30833083 return kv->s_mask (i);
30843084}
3085+
3086+ //
3087+ // llama_kv_cache_hybrid_recurrent
3088+ //
3089+
3090+ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent (
3091+ const llama_model & model,
3092+ /* attn */
3093+ ggml_type attn_type_k,
3094+ ggml_type attn_type_v,
3095+ bool attn_v_trans,
3096+ uint32_t attn_kv_size,
3097+ uint32_t attn_n_pad,
3098+ uint32_t attn_n_swa,
3099+ llama_swa_type attn_swa_type,
3100+ /* recurrent */
3101+ ggml_type recurrent_type_k,
3102+ ggml_type recurrent_type_v,
3103+ uint32_t recurrent_kv_size,
3104+ /* common */
3105+ uint32_t n_seq_max,
3106+ bool offload) :
3107+ hparams(model.hparams),
3108+ kv_attn(new llama_kv_cache_unified(
3109+ model,
3110+ [&](int32_t il) { return !model.hparams .recurrent_layer (il); },
3111+ attn_type_k,
3112+ attn_type_v,
3113+ attn_v_trans,
3114+ offload,
3115+ attn_kv_size,
3116+ n_seq_max,
3117+ attn_n_pad,
3118+ attn_n_swa,
3119+ attn_swa_type
3120+ )),
3121+ kv_recurrent (new llama_kv_cache_recurrent(
3122+ model,
3123+ [&](int32_t il) { return model.hparams .recurrent_layer (il); },
3124+ recurrent_type_k,
3125+ recurrent_type_v,
3126+ offload,
3127+ recurrent_kv_size,
3128+ n_seq_max
3129+ )) {}
3130+
3131+ void llama_kv_cache_hybrid_recurrent::clear () {
3132+ kv_attn ->clear ();
3133+ kv_recurrent->clear ();
3134+ }
3135+
3136+ bool llama_kv_cache_hybrid_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
3137+ // Try removing from the recurrent cache first since it may fail. If it does
3138+ // fail, the cache will not have been mutated.
3139+ if (!kv_recurrent->seq_rm (seq_id, p0, p1)) {
3140+ return false ;
3141+ }
3142+ return kv_attn->seq_rm (seq_id, p0, p1);
3143+ }
3144+
3145+ void llama_kv_cache_hybrid_recurrent::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3146+ kv_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3147+ kv_recurrent->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3148+ }
3149+
3150+ void llama_kv_cache_hybrid_recurrent::seq_keep (llama_seq_id seq_id) {
3151+ kv_attn ->seq_keep (seq_id);
3152+ kv_recurrent->seq_keep (seq_id);
3153+ }
3154+
3155+ void llama_kv_cache_hybrid_recurrent::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
3156+ kv_attn->seq_add (seq_id, p0, p1, shift);
3157+ kv_recurrent->seq_add (seq_id, p0, p1, shift);
3158+ }
3159+
3160+ void llama_kv_cache_hybrid_recurrent::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3161+ kv_attn ->seq_div (seq_id, p0, p1, d);
3162+ kv_recurrent->seq_div (seq_id, p0, p1, d);
3163+ }
3164+
3165+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min (llama_seq_id seq_id) const {
3166+ // the min of the total cache is the max of the two caches' min values
3167+ return std::max (kv_attn->seq_pos_min (seq_id), kv_recurrent->seq_pos_min (seq_id));
3168+ }
3169+
3170+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max (llama_seq_id seq_id) const {
3171+ // the max of the total cache is the min of the two caches' max values
3172+ return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
3173+ }
3174+
3175+ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
3176+
3177+ // since this includes a recurrent cache, we cannot use split_simple
3178+ auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
3179+
3180+ // follow the recurrent pattern for creating the ubatch splits
3181+ std::vector<llama_ubatch> ubatches;
3182+ while (sbatch.n_tokens > 0 ) {
3183+ llama_ubatch ubatch;
3184+
3185+ if (embd_pooled) {
3186+ // Pooled embeddings cannot be split across ubatches (yet)
3187+ ubatch = sbatch.split_seq (n_ubatch);
3188+ } else {
3189+ ubatch = sbatch.split_equal (n_ubatch);
3190+ }
3191+
3192+ ubatches.push_back (ubatch);
3193+ }
3194+
3195+ // prepare the recurrent batches first
3196+ if (!kv_recurrent->prepare (ubatches)) {
3197+ // TODO: will the recurrent cache be in an undefined state at this point?
3198+ LLAMA_LOG_ERROR (" %s: failed to prepare recurrent ubatches\n " , __func__);
3199+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3200+ }
3201+
3202+ // prepare the attention cache
3203+ auto heads_attn = kv_attn->prepare (ubatches);
3204+ if (heads_attn.empty ()) {
3205+ LLAMA_LOG_ERROR (" %s: failed to prepare attention ubatches\n " , __func__);
3206+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3207+ }
3208+
3209+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
3210+ this , std::move (sbatch), std::move (heads_attn), std::move (ubatches));
3211+ }
3212+
3213+ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full () {
3214+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this );
3215+ }
3216+
3217+ bool llama_kv_cache_hybrid_recurrent::update (llama_context & lctx) {
3218+ bool res = false ;
3219+
3220+ res = res | kv_attn ->update (lctx);
3221+ res = res | kv_recurrent->update (lctx);
3222+
3223+ return res;
3224+ }
3225+
3226+ void llama_kv_cache_hybrid_recurrent::defrag_sched (float thold) {
3227+ kv_attn ->defrag_sched (thold);
3228+ kv_recurrent->defrag_sched (thold);
3229+ }
3230+
3231+ bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
3232+ // TODO: Should this return true if the attention cache can shift?
3233+ return false ;
3234+ }
3235+
3236+ void llama_kv_cache_hybrid_recurrent::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
3237+ kv_attn ->state_write (io, seq_id);
3238+ kv_recurrent->state_write (io, seq_id);
3239+ }
3240+
3241+ void llama_kv_cache_hybrid_recurrent::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
3242+ kv_attn ->state_read (io, seq_id);
3243+ kv_recurrent->state_read (io, seq_id);
3244+ }
3245+
3246+ llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn () const {
3247+ return kv_attn.get ();
3248+ }
3249+
3250+ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent () const {
3251+ return kv_recurrent.get ();
3252+ }
3253+
3254+ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_memory_status status)
3255+ : status(status), state_attn(status), state_recurrent(status) {}
3256+
3257+ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_kv_cache_hybrid_recurrent * kv)
3258+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
3259+ kv(kv),
3260+ state_attn(status, kv->get_kv_attn ()),
3261+ state_recurrent(status, kv->get_kv_recurrent ()) {}
3262+
3263+ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (
3264+ llama_kv_cache_hybrid_recurrent * kv,
3265+ llama_sbatch sbatch,
3266+ std::vector<uint32_t > heads_attn,
3267+ std::vector<llama_ubatch> ubatches)
3268+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
3269+ kv(kv),
3270+ sbatch(std::move(sbatch)),
3271+ heads_attn(std::move(heads_attn)),
3272+ ubatches(std::move(ubatches)),
3273+ // NOTE: these child states are only used as wrapper APIs for the
3274+ // const methods, so we use the "init full" signature since the
3275+ // actual state is not used.
3276+ state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn ()),
3277+ state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent ()) {}
3278+
3279+
3280+ bool llama_kv_cache_hybrid_recurrent_state::next () {
3281+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3282+
3283+ if (++i_next >= ubatches.size ()) {
3284+ return false ;
3285+ }
3286+
3287+ return true ;
3288+ }
3289+
3290+ bool llama_kv_cache_hybrid_recurrent_state::apply () {
3291+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3292+
3293+ kv->get_kv_attn () ->apply_ubatch (heads_attn[i_next], ubatches[i_next]);
3294+ kv->get_kv_recurrent ()->find_slot (ubatches[i_next]);
3295+
3296+ return true ;
3297+ }
3298+
3299+ std::vector<int64_t > & llama_kv_cache_hybrid_recurrent_state::out_ids () {
3300+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3301+
3302+ return sbatch.out_ids ;
3303+ }
3304+
3305+ llama_memory_status llama_kv_cache_hybrid_recurrent_state::get_status () const {
3306+ return status;
3307+ }
3308+
3309+ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch () const {
3310+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3311+ return ubatches[i_next];
3312+ }
3313+
3314+ const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const {
3315+ return &state_attn;
3316+ }
3317+
3318+ const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent () const {
3319+ return &state_recurrent;
3320+ }
0 commit comments