@@ -13190,6 +13190,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319013190 llama_memory_i * res;
1319113191
1319213192 switch (arch) {
13193+ // Models that need specific instantiation should be handled in the
13194+ // switch statement
1319313195 case LLM_ARCH_BERT:
1319413196 case LLM_ARCH_JINA_BERT_V2:
1319513197 case LLM_ARCH_NOMIC_BERT:
@@ -13198,58 +13200,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319813200 {
1319913201 res = nullptr;
1320013202 } break;
13201- case LLM_ARCH_MAMBA:
13202- case LLM_ARCH_RWKV6:
13203- case LLM_ARCH_RWKV6QWEN2:
13204- case LLM_ARCH_RWKV7:
13205- case LLM_ARCH_ARWKV7:
13206- {
13207- res = new llama_kv_cache_recurrent(
13208- *this,
13209- nullptr,
13210- GGML_TYPE_F32,
13211- GGML_TYPE_F32,
13212- cparams.offload_kqv,
13213- std::max((uint32_t) 1, cparams.n_seq_max),
13214- cparams.n_seq_max);
13215- } break;
13203+ // Models that need standard caching should rely on recurrent/hybrid
13204+ // checks
1321613205 default:
1321713206 {
13218- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13219-
13220- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13221-
13222- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13223-
13224- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13225- GGML_ASSERT(hparams.is_swa_any());
13226-
13227- res = new llama_kv_cache_unified_iswa(
13228- *this,
13229- params.type_k,
13230- params.type_v,
13231- !cparams.flash_attn,
13232- cparams.offload_kqv,
13233- params.swa_full,
13234- cparams.n_ctx,
13235- cparams.n_seq_max,
13236- cparams.n_batch,
13237- padding);
13238- } else {
13239- GGML_ASSERT(!hparams.is_swa_any());
13240-
13241- res = new llama_kv_cache_unified(
13207+ if (llm_arch_is_recurrent(arch)) {
13208+ res = new llama_kv_cache_recurrent(
1324213209 *this,
1324313210 nullptr,
13244- params.type_k,
13245- params.type_v,
13246- !cparams.flash_attn,
13211+ GGML_TYPE_F32,
13212+ GGML_TYPE_F32,
1324713213 cparams.offload_kqv,
13248- cparams.n_ctx,
13249- cparams.n_seq_max,
13250- padding,
13251- hparams.n_swa,
13252- hparams.swa_type);
13214+ std::max((uint32_t) 1, cparams.n_seq_max),
13215+ cparams.n_seq_max);
13216+ } else if (llm_arch_is_hybrid_recurrent(arch)) {
13217+ res = new llama_kv_cache_hybrid_recurrent(
13218+ /* model */ *this,
13219+ /* attn_type_k */ params.type_k,
13220+ /* attn_type_v */ params.type_v,
13221+ /* attn_v_trans */ !cparams.flash_attn,
13222+ /* attn_kv_size */ cparams.n_ctx,
13223+ /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13224+ /* attn_n_swa */ hparams.n_swa,
13225+ /* attn_swa_type */ hparams.swa_type,
13226+ /* recurrent_type_k */ GGML_TYPE_F32,
13227+ /* recurrent_type_v */ GGML_TYPE_F32,
13228+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13229+ /* n_seq_max */ cparams.n_seq_max,
13230+ /* offload */ cparams.offload_kqv);
13231+ } else {
13232+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13233+
13234+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13235+
13236+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13237+
13238+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13239+ GGML_ASSERT(hparams.is_swa_any());
13240+
13241+ res = new llama_kv_cache_unified_iswa(
13242+ *this,
13243+ params.type_k,
13244+ params.type_v,
13245+ !cparams.flash_attn,
13246+ cparams.offload_kqv,
13247+ params.swa_full,
13248+ cparams.n_ctx,
13249+ cparams.n_seq_max,
13250+ cparams.n_batch,
13251+ padding);
13252+ } else {
13253+ GGML_ASSERT(!hparams.is_swa_any());
13254+
13255+ res = new llama_kv_cache_unified(
13256+ *this,
13257+ nullptr,
13258+ params.type_k,
13259+ params.type_v,
13260+ !cparams.flash_attn,
13261+ cparams.offload_kqv,
13262+ cparams.n_ctx,
13263+ cparams.n_seq_max,
13264+ padding,
13265+ hparams.n_swa,
13266+ hparams.swa_type);
13267+ }
1325313268 }
1325413269 }
1325513270 }
0 commit comments