Skip to content

Commit 98c7edd

Browse files
llama: consistent ctx <-> buf order for KV cache
1 parent 945501f commit 98c7edd

File tree

5 files changed

+41
-33
lines changed

5 files changed

+41
-33
lines changed

src/llama-kv-cache.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <algorithm>
99
#include <cassert>
1010
#include <cmath>
11+
#include <cstring>
1112
#include <limits>
1213
#include <map>
1314
#include <stdexcept>
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
3738

3839
const uint32_t n_layer_kv = hparams.n_layer_kv();
3940

41+
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
42+
struct ggml_backend_buft_comparator {
43+
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
44+
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
45+
}
46+
};
47+
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
48+
4049
// create a context for each buffer type
41-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
4250
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4351
auto it = ctx_map.find(buft);
4452
if (it == ctx_map.end()) {
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
5361
return nullptr;
5462
}
5563

56-
ctx_map[buft] = ctx;
57-
ctxs.emplace_back(ctx);
64+
ctx_map.emplace(buft, ctx);
5865

5966
return ctx;
6067
}
6168

62-
return it->second;
69+
return it->second.get();
6370
};
6471

6572
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
@@ -167,19 +174,16 @@ llama_kv_cache::llama_kv_cache(
167174
}
168175

169176
// allocate tensors and initialize the buffers to avoid NaNs in the padding
170-
for (auto it : ctx_map) {
171-
auto * buft = it.first;
172-
auto * ctx = it.second;
173-
174-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
177+
for (auto & [buft, ctx] : ctx_map) {
178+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
175179
if (!buf) {
176180
throw std::runtime_error("failed to allocate buffer for kv cache");
177181
}
178182

179183
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
180184

181185
ggml_backend_buffer_clear(buf, 0);
182-
bufs.emplace_back(buf);
186+
ctxs_bufs.emplace_back(std::move(ctx), buf);
183187
}
184188

185189
{
@@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
203207
}
204208

205209
if (data) {
206-
for (auto & buf : bufs) {
210+
for (auto & [_, buf] : ctxs_bufs) {
207211
ggml_backend_buffer_clear(buf.get(), 0);
208212
}
209213
}
@@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472476

473477
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
474478
std::map<ggml_backend_buffer_type_t, size_t> ret;
475-
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
476-
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
479+
for (const auto & [_, buf] : ctxs_bufs) {
480+
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
477481
}
478482
return ret;
479483
}
@@ -1298,7 +1302,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
12981302
size_t llama_kv_cache::total_size() const {
12991303
size_t size = 0;
13001304

1301-
for (const auto & buf : bufs) {
1305+
for (const auto & [_, buf] : ctxs_bufs) {
13021306
size += ggml_backend_buffer_get_size(buf.get());
13031307
}
13041308

src/llama-kv-cache.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ class llama_kv_cache : public llama_memory_i {
217217
// this is the SWA type of the cache - not to be confused with the model SWA type
218218
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
219219

220-
std::vector<ggml_context_ptr> ctxs;
221-
std::vector<ggml_backend_buffer_ptr> bufs;
220+
// ggml contexts for the KV cache along with the allocated backend buffers:
221+
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
222222

223223
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
224224
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method

src/llama-memory-recurrent.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <algorithm>
99
#include <cassert>
10+
#include <cstring>
1011
#include <limits>
1112
#include <map>
1213
#include <stdexcept>
@@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
3233
cells.clear();
3334
cells.resize(mem_size);
3435

36+
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
37+
struct ggml_backend_buft_comparator {
38+
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
39+
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
40+
}
41+
};
42+
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
43+
3544
// create a context for each buffer type
36-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
3745
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
3846
auto it = ctx_map.find(buft);
3947
if (it == ctx_map.end()) {
@@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
4856
return nullptr;
4957
}
5058

51-
ctx_map[buft] = ctx;
52-
ctxs.emplace_back(ctx);
59+
ctx_map.emplace(buft, ctx);
5360

5461
return ctx;
5562
}
5663

57-
return it->second;
64+
return it->second.get();
5865
};
5966

6067
r_l.resize(n_layer);
@@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
93100
}
94101

95102
// allocate tensors and initialize the buffers to avoid NaNs in the padding
96-
for (auto it : ctx_map) {
97-
auto * buft = it.first;
98-
auto * ctx = it.second;
99-
100-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
103+
for (auto & [buft, ctx] : ctx_map) {
104+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
101105
if (!buf) {
102106
throw std::runtime_error("failed to allocate buffer for rs cache");
103107
}
104108
ggml_backend_buffer_clear(buf, 0);
105109
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
106-
bufs.emplace_back(buf);
110+
ctxs_bufs.emplace_back(std::move(ctx), buf);
107111
}
108112

109113
{
@@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
129133
used = 0;
130134

131135
if (data) {
132-
for (auto & buf : bufs) {
136+
for (auto & [_, buf] : ctxs_bufs) {
133137
ggml_backend_buffer_clear(buf.get(), 0);
134138
}
135139
}
@@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
364368

365369
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
366370
std::map<ggml_backend_buffer_type_t, size_t> ret;
367-
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
368-
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
371+
for (const auto & [_, buf] : ctxs_bufs) {
372+
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
369373
}
370374
return ret;
371375
}
@@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {
662666

663667
size_t llama_memory_recurrent::total_size() const {
664668
size_t size = 0;
665-
for (const auto & buf : bufs) {
669+
for (const auto & [_, buf] : ctxs_bufs) {
666670
size += ggml_backend_buffer_get_size(buf.get());
667671
}
668672

src/llama-memory-recurrent.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class llama_memory_recurrent : public llama_memory_i {
109109

110110
const uint32_t n_seq_max = 1;
111111

112-
std::vector<ggml_context_ptr> ctxs;
113-
std::vector<ggml_backend_buffer_ptr> bufs;
112+
// ggml contexts for the KV cache along with the allocated backend buffers:
113+
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
114114

115115
size_t total_size() const;
116116

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2231,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22312231
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
22322232
struct ggml_backend_buft_comparator {
22332233
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
2234-
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
2234+
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
22352235
}
22362236
};
22372237
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;

0 commit comments

Comments
 (0)