Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstring>
#include <limits>
#include <map>
#include <stdexcept>
Expand Down Expand Up @@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(

const uint32_t n_layer_kv = hparams.n_layer_kv();

// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;

// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
Expand All @@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
return nullptr;
}

ctx_map[buft] = ctx;
ctxs.emplace_back(ctx);
ctx_map.emplace(buft, ctx);

return ctx;
}

return it->second;
return it->second.get();
};

GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
Expand Down Expand Up @@ -167,19 +174,16 @@ llama_kv_cache::llama_kv_cache(
}

// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;
auto * ctx = it.second;

ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
for (auto & [buft, ctx] : ctx_map) {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
if (!buf) {
throw std::runtime_error("failed to allocate buffer for kv cache");
}

LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);

ggml_backend_buffer_clear(buf, 0);
bufs.emplace_back(buf);
ctxs_bufs.emplace_back(std::move(ctx), buf);
}

{
Expand All @@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
}

if (data) {
for (auto & buf : bufs) {
for (auto & [_, buf] : ctxs_bufs) {
ggml_backend_buffer_clear(buf.get(), 0);
}
}
Expand Down Expand Up @@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {

std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
for (const auto & [_, buf] : ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
}
return ret;
}
Expand Down Expand Up @@ -1298,7 +1302,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
size_t llama_kv_cache::total_size() const {
size_t size = 0;

for (const auto & buf : bufs) {
for (const auto & [_, buf] : ctxs_bufs) {
size += ggml_backend_buffer_get_size(buf.get());
}

Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ class llama_kv_cache : public llama_memory_i {
// this is the SWA type of the cache - not to be confused with the model SWA type
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;

std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;

// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
Expand Down
32 changes: 18 additions & 14 deletions src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <algorithm>
#include <cassert>
#include <cstring>
#include <limits>
#include <map>
#include <stdexcept>
Expand All @@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
cells.clear();
cells.resize(mem_size);

// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;

// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
Expand All @@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
return nullptr;
}

ctx_map[buft] = ctx;
ctxs.emplace_back(ctx);
ctx_map.emplace(buft, ctx);

return ctx;
}

return it->second;
return it->second.get();
};

r_l.resize(n_layer);
Expand Down Expand Up @@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
}

// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;
auto * ctx = it.second;

ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
for (auto & [buft, ctx] : ctx_map) {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
if (!buf) {
throw std::runtime_error("failed to allocate buffer for rs cache");
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
bufs.emplace_back(buf);
ctxs_bufs.emplace_back(std::move(ctx), buf);
}

{
Expand All @@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
used = 0;

if (data) {
for (auto & buf : bufs) {
for (auto & [_, buf] : ctxs_bufs) {
ggml_backend_buffer_clear(buf.get(), 0);
}
}
Expand Down Expand Up @@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {

std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
for (const auto & [_, buf] : ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
}
return ret;
}
Expand Down Expand Up @@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {

size_t llama_memory_recurrent::total_size() const {
size_t size = 0;
for (const auto & buf : bufs) {
for (const auto & [_, buf] : ctxs_bufs) {
size += ggml_backend_buffer_get_size(buf.get());
}

Expand Down
4 changes: 2 additions & 2 deletions src/llama-memory-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ class llama_memory_recurrent : public llama_memory_i {

const uint32_t n_seq_max = 1;

std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;

size_t total_size() const;

Expand Down
2 changes: 1 addition & 1 deletion src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
Expand Down
Loading