@@ -13040,46 +13040,94 @@ struct llm_build_bailingmoe : public llm_graph_context {
1304013040 }
1304113041};
1304213042
13043- llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
13043+ llama_memory_i * llama_model::create_memory(
13044+ const llama_memory_params & params,
13045+ llama_cparams & cparams,
13046+ const llama_hparams & hparams) const {
1304413047 llama_memory_i * res;
1304513048
1304613049 switch (arch) {
13050+ // Models that need specific instantiation should be handled in the
13051+ // switch statement
1304713052 case LLM_ARCH_BERT:
1304813053 case LLM_ARCH_JINA_BERT_V2:
1304913054 case LLM_ARCH_NOMIC_BERT:
1305013055 case LLM_ARCH_NOMIC_BERT_MOE:
1305113056 {
1305213057 res = nullptr;
1305313058 } break;
13054- case LLM_ARCH_MAMBA:
13055- case LLM_ARCH_RWKV6:
13056- case LLM_ARCH_RWKV6QWEN2:
13057- case LLM_ARCH_RWKV7:
13058- case LLM_ARCH_ARWKV7:
13059- {
13060- res = new llama_kv_cache_recurrent(
13061- *this,
13062- GGML_TYPE_F32,
13063- GGML_TYPE_F32,
13064- cparams.offload_kqv,
13065- std::max((uint32_t) 1, cparams.n_seq_max));
13066- } break;
13059+ // Models that need standard caching should rely on recurrent/hybrid
13060+ // checks
1306713061 default:
1306813062 {
13069- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13063+ if (llm_arch_is_hybrid(arch)) {
13064+ // make vectors of recurrent and non-recurrent layer indices
13065+ std::vector<size_t> recurrent_layers;
13066+ std::vector<size_t> unified_layers;
13067+ for (auto il = 0u; il < hparams.n_layer; ++il) {
13068+ if (hparams.recurrent_layer(il)) {
13069+ recurrent_layers.push_back(il);
13070+ } else {
13071+ unified_layers.push_back(il);
13072+ }
13073+ }
13074+
13075+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13076+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13077+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13078+
13079+ // initialize the children
13080+ std::vector<llama_kv_cache_hybrid::child_cache> children;
13081+ children.emplace_back(
13082+ std::unique_ptr<llama_kv_cache>(
13083+ new llama_kv_cache_recurrent(
13084+ *this,
13085+ GGML_TYPE_F32,
13086+ GGML_TYPE_F32,
13087+ cparams.offload_kqv,
13088+ std::max((uint32_t) 1, cparams.n_seq_max))
13089+ ),
13090+ std::move(recurrent_layers)
13091+ );
13092+ children.emplace_back(
13093+ std::unique_ptr<llama_kv_cache>(
13094+ new llama_kv_cache_unified(
13095+ *this,
13096+ params.type_k,
13097+ params.type_v,
13098+ !cparams.flash_attn,
13099+ cparams.offload_kqv,
13100+ cparams.n_ctx,
13101+ padding)
13102+ ),
13103+ std::move(unified_layers)
13104+ );
13105+
13106+ // initialize the hybrid cache with both children
13107+ res = new llama_kv_cache_hybrid(hparams, std::move(children));
13108+ } else if (llm_arch_is_recurrent(arch)) {
13109+ res = new llama_kv_cache_recurrent(
13110+ *this,
13111+ GGML_TYPE_F32,
13112+ GGML_TYPE_F32,
13113+ cparams.offload_kqv,
13114+ std::max((uint32_t) 1, cparams.n_seq_max));
13115+ } else {
13116+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
1307013117
13071- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13118+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
1307213119
13073- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13120+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
1307413121
13075- res = new llama_kv_cache_unified(
13076- *this,
13077- params.type_k,
13078- params.type_v,
13079- !cparams.flash_attn,
13080- cparams.offload_kqv,
13081- cparams.n_ctx,
13082- padding);
13122+ res = new llama_kv_cache_unified(
13123+ *this,
13124+ params.type_k,
13125+ params.type_v,
13126+ !cparams.flash_attn,
13127+ cparams.offload_kqv,
13128+ cparams.n_ctx,
13129+ padding);
13130+ }
1308313131 }
1308413132 }
1308513133
0 commit comments