@@ -37,8 +37,15 @@ llama_kv_cache::llama_kv_cache(
3737
3838    const  uint32_t  n_layer_kv = hparams.n_layer_kv ();
3939
40+     //  define a comparator for the buft -> ctx map to ensure that the order is well-defined:
41+     struct  ggml_backend_buft_comparator  {
42+         bool  operator ()(const  ggml_backend_buffer_type_t  & lhs, const  ggml_backend_buffer_type_t  & rhs) const  {
43+             return  ggml_backend_buft_name (lhs) < ggml_backend_buft_name (rhs);
44+         }
45+     };
46+     std::map<ggml_backend_buffer_type_t , ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
47+ 
4048    //  create a context for each buffer type
41-     std::map<ggml_backend_buffer_type_t , ggml_context *> ctx_map;
4249    auto  ctx_for_buft = [&](ggml_backend_buffer_type_t  buft) -> ggml_context * {
4350        auto  it = ctx_map.find (buft);
4451        if  (it == ctx_map.end ()) {
@@ -53,13 +60,12 @@ llama_kv_cache::llama_kv_cache(
5360                return  nullptr ;
5461            }
5562
56-             ctx_map[buft] = ctx;
57-             ctxs.emplace_back (ctx);
63+             ctx_map.emplace (buft, ctx);
5864
5965            return  ctx;
6066        }
6167
62-         return  it->second ;
68+         return  it->second . get () ;
6369    };
6470
6571    GGML_ASSERT (n_stream == 1  || n_stream == n_seq_max);
@@ -167,19 +173,16 @@ llama_kv_cache::llama_kv_cache(
167173    }
168174
169175    //  allocate tensors and initialize the buffers to avoid NaNs in the padding
170-     for  (auto  it : ctx_map) {
171-         auto  * buft = it.first ;
172-         auto  * ctx  = it.second ;
173- 
174-         ggml_backend_buffer_t  buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx, buft);
176+     for  (auto  & [buft, ctx] : ctx_map) {
177+         ggml_backend_buffer_t  buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx.get (), buft);
175178        if  (!buf) {
176179            throw  std::runtime_error (" failed to allocate buffer for kv cache"  );
177180        }
178181
179182        LLAMA_LOG_INFO (" %s: %10s KV buffer size = %8.2f MiB\n "  , __func__, ggml_backend_buffer_name (buf), ggml_backend_buffer_get_size (buf)/1024.0 /1024.0 );
180183
181184        ggml_backend_buffer_clear (buf, 0 );
182-         bufs .emplace_back (buf);
185+         ctxs_bufs .emplace_back (std::move (ctx),  buf);
183186    }
184187
185188    {
@@ -203,7 +206,7 @@ void llama_kv_cache::clear(bool data) {
203206    }
204207
205208    if  (data) {
206-         for  (auto  & buf : bufs ) {
209+         for  (auto  & [_,  buf]  : ctxs_bufs ) {
207210            ggml_backend_buffer_clear (buf.get (), 0 );
208211        }
209212    }
@@ -472,8 +475,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472475
473476std::map<ggml_backend_buffer_type_t , size_t > llama_kv_cache::memory_breakdown () const  {
474477    std::map<ggml_backend_buffer_type_t , size_t > ret;
475-     for  (const  ggml_backend_buffer_ptr  & buf_ptr : bufs ) {
476-         ret[ggml_backend_buffer_get_type (buf_ptr .get ())] += ggml_backend_buffer_get_size (buf_ptr .get ());
478+     for  (const  auto  & [_, buf] : ctxs_bufs ) {
479+         ret[ggml_backend_buffer_get_type (buf .get ())] += ggml_backend_buffer_get_size (buf .get ());
477480    }
478481    return  ret;
479482}
@@ -1298,7 +1301,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
12981301size_t  llama_kv_cache::total_size () const  {
12991302    size_t  size = 0 ;
13001303
1301-     for  (const  auto  & buf : bufs ) {
1304+     for  (const  auto  & [_,  buf]  : ctxs_bufs ) {
13021305        size += ggml_backend_buffer_get_size (buf.get ());
13031306    }
13041307
0 commit comments