Skip to content

Commit 89aad5a

Browse files
committed
llama : expose API to retrieve devices used by model.
It's useful from the library to be able to do things like list the features being used by the loaded model.
1 parent 745aa53 commit 89aad5a

File tree

4 files changed

+37
-0
lines changed

4 files changed

+37
-0
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,8 @@ extern "C" {
499499
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
500500
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
501501

502+
LLAMA_API size_t llama_n_backends(const struct llama_context * ctx);
503+
LLAMA_API size_t llama_get_backends(const struct llama_context * ctx, ggml_backend_t * out_buf, size_t out_len);
502504
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
503505

504506
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
@@ -510,6 +512,7 @@ extern "C" {
510512
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
511513
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
512514
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
515+
LLAMA_API const ggml_backend_dev_t * llama_model_get_devices (const struct llama_model * model, size_t * out_len);
513516

514517
// Get the model's RoPE frequency scaling factor
515518
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

src/llama-context.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "llama-mmap.h"
77
#include "llama-model.h"
88

9+
#include <algorithm>
910
#include <cinttypes>
1011
#include <cstring>
1112
#include <limits>
@@ -418,6 +419,23 @@ uint32_t llama_context::n_threads_batch() const {
418419
return cparams.n_threads_batch;
419420
}
420421

422+
size_t llama_context::n_backends() const {
423+
return backends.size();
424+
}
425+
426+
size_t llama_context::copy_backends_list(ggml_backend_t* out, size_t out_len) const {
427+
size_t copy_len;
428+
if (out_len > backends.size()) {
429+
copy_len = backends.size();
430+
} else {
431+
copy_len = out_len;
432+
}
433+
std::transform(backends.begin(), backends.begin() + copy_len, out, [](const ggml_backend_ptr& ptr) {
434+
return ptr.get();
435+
});
436+
return copy_len;
437+
}
438+
421439
llama_memory_t llama_context::get_memory() const {
422440
return memory.get();
423441
}
@@ -2428,6 +2446,14 @@ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
24282446
return ctx->get_memory();
24292447
}
24302448

2449+
size_t llama_n_backends(const struct llama_context * ctx) {
2450+
return ctx->n_backends();
2451+
}
2452+
2453+
size_t llama_get_backends(const struct llama_context * ctx, ggml_backend_t * out, size_t out_len) {
2454+
return ctx->copy_backends_list(out, out_len);
2455+
}
2456+
24312457
void llama_memory_clear(llama_memory_t mem, bool data) {
24322458
if (!mem) {
24332459
return;

src/llama-context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ struct llama_context {
4646
uint32_t n_threads() const;
4747
uint32_t n_threads_batch() const;
4848

49+
size_t n_backends() const;
50+
size_t copy_backends_list(ggml_backend_t* out, size_t out_len) const;
51+
4952
llama_memory_t get_memory() const;
5053

5154
// return true of the KV cache was updated

src/llama-model.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13628,6 +13628,11 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i)
1362813628
return nullptr;
1362913629
}
1363013630

13631+
const ggml_backend_dev_t * llama_model_get_devices(const struct llama_model * model, size_t * out_len) {
13632+
*out_len = model->devices.size();
13633+
return model->devices.data();
13634+
}
13635+
1363113636
// deprecated
1363213637
int32_t llama_n_ctx_train(const llama_model * model) {
1363313638
return llama_model_n_ctx_train(model);

0 commit comments

Comments
 (0)